OpenSearch updates (#906)

* updates

* add uuid filter functionality

* update

* updates

* bump-version

* update

* fix typo

* use async function

* update unit tests

* update delete

* update deletion

* async update

* update

* update

* update

* update
This commit is contained in:
Preston Rasmussen 2025-09-14 01:43:37 -04:00 committed by GitHub
parent 4dab259217
commit 3efe085a92
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 479 additions and 191 deletions

View file

@ -17,16 +17,19 @@ limitations under the License.
import asyncio
import copy
import logging
import os
from abc import ABC, abstractmethod
from collections.abc import Coroutine
from datetime import datetime
from enum import Enum
from typing import Any
from dotenv import load_dotenv
from graphiti_core.embedder.client import EMBEDDING_DIM
try:
from opensearchpy import OpenSearch, helpers
from opensearchpy import AsyncOpenSearch, helpers
_HAS_OPENSEARCH = True
except ImportError:
@ -38,6 +41,13 @@ logger = logging.getLogger(__name__)
DEFAULT_SIZE = 10
load_dotenv()
ENTITY_INDEX_NAME = os.environ.get('ENTITY_INDEX_NAME', 'entities')
EPISODE_INDEX_NAME = os.environ.get('EPISODE_INDEX_NAME', 'episodes')
COMMUNITY_INDEX_NAME = os.environ.get('COMMUNITY_INDEX_NAME', 'communities')
ENTITY_EDGE_INDEX_NAME = os.environ.get('ENTITY_EDGE_INDEX_NAME', 'entity_edges')
class GraphProvider(Enum):
NEO4J = 'neo4j'
@ -48,20 +58,19 @@ class GraphProvider(Enum):
aoss_indices = [
{
'index_name': 'entities',
'index_name': ENTITY_INDEX_NAME,
'body': {
'settings': {'index': {'knn': True}},
'mappings': {
'properties': {
'uuid': {'type': 'keyword'},
'name': {'type': 'text'},
'summary': {'type': 'text'},
'group_id': {'type': 'text'},
'created_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"},
'group_id': {'type': 'keyword'},
'created_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
'name_embedding': {
'type': 'knn_vector',
'dims': EMBEDDING_DIM,
'index': True,
'similarity': 'cosine',
'dimension': EMBEDDING_DIM,
'method': {
'engine': 'faiss',
'space_type': 'cosinesimil',
@ -70,23 +79,23 @@ aoss_indices = [
},
},
}
}
},
},
},
{
'index_name': 'communities',
'index_name': COMMUNITY_INDEX_NAME,
'body': {
'mappings': {
'properties': {
'uuid': {'type': 'keyword'},
'name': {'type': 'text'},
'group_id': {'type': 'text'},
'group_id': {'type': 'keyword'},
}
}
},
},
{
'index_name': 'episodes',
'index_name': EPISODE_INDEX_NAME,
'body': {
'mappings': {
'properties': {
@ -94,31 +103,30 @@ aoss_indices = [
'content': {'type': 'text'},
'source': {'type': 'text'},
'source_description': {'type': 'text'},
'group_id': {'type': 'text'},
'created_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"},
'valid_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"},
'group_id': {'type': 'keyword'},
'created_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
'valid_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
}
}
},
},
{
'index_name': 'entity_edges',
'index_name': ENTITY_EDGE_INDEX_NAME,
'body': {
'settings': {'index': {'knn': True}},
'mappings': {
'properties': {
'uuid': {'type': 'keyword'},
'name': {'type': 'text'},
'fact': {'type': 'text'},
'group_id': {'type': 'text'},
'created_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"},
'valid_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"},
'expired_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"},
'invalid_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"},
'group_id': {'type': 'keyword'},
'created_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
'valid_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
'expired_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
'invalid_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
'fact_embedding': {
'type': 'knn_vector',
'dims': EMBEDDING_DIM,
'index': True,
'similarity': 'cosine',
'dimension': EMBEDDING_DIM,
'method': {
'engine': 'faiss',
'space_type': 'cosinesimil',
@ -127,7 +135,7 @@ aoss_indices = [
},
},
}
}
},
},
},
]
@ -163,7 +171,7 @@ class GraphDriver(ABC):
'' # Neo4j (default) syntax does not require a prefix for fulltext queries
)
_database: str
aoss_client: OpenSearch | None # type: ignore
aoss_client: AsyncOpenSearch | None # type: ignore
@abstractmethod
def execute_query(self, cypher_query_: str, **kwargs: Any) -> Coroutine:
@ -205,7 +213,7 @@ class GraphDriver(ABC):
alias_name = index['index_name']
# If alias already exists, skip (idempotent behavior)
if client.indices.exists_alias(name=alias_name):
if await client.indices.exists_alias(name=alias_name):
continue
# Build a physical index name with timestamp
@ -213,27 +221,67 @@ class GraphDriver(ABC):
physical_index_name = f'{alias_name}_{ts_suffix}'
# Create the index
client.indices.create(index=physical_index_name, body=index['body'])
await client.indices.create(index=physical_index_name, body=index['body'])
# Point alias to it
client.indices.put_alias(index=physical_index_name, name=alias_name)
await client.indices.put_alias(index=physical_index_name, name=alias_name)
# Allow some time for index creation
await asyncio.sleep(60)
await asyncio.sleep(1)
async def delete_aoss_indices(self):
client = self.aoss_client
if not client:
logger.warning('No OpenSearch client found')
return
for entry in aoss_indices:
alias_name = entry['index_name']
try:
# Resolve alias → indices
alias_info = await client.indices.get_alias(name=alias_name)
indices = list(alias_info.keys())
if not indices:
logger.info(f"No indices found for alias '{alias_name}'")
continue
for index in indices:
if await client.indices.exists(index=index):
await client.indices.delete(index=index)
logger.info(f"Deleted index '{index}' (alias: {alias_name})")
else:
logger.warning(f"Index '{index}' not found for alias '{alias_name}'")
except Exception as e:
logger.error(f"Error deleting indices for alias '{alias_name}': {e}")
async def clear_aoss_indices(self):
client = self.aoss_client
if not client:
logger.warning('No OpenSearch client found')
return
for index in aoss_indices:
index_name = index['index_name']
client = self.aoss_client
if not client:
logger.warning('No OpenSearch client found')
return
if await client.indices.exists(index=index_name):
try:
# Delete all documents but keep the index
response = await client.delete_by_query(
index=index_name,
body={'query': {'match_all': {}}},
)
logger.info(f"Cleared index '{index_name}': {response}")
except Exception as e:
logger.error(f"Error clearing index '{index_name}': {e}")
else:
logger.warning(f"Index '{index_name}' does not exist")
if client.indices.exists(index=index_name):
client.indices.delete(index=index_name)
def save_to_aoss(self, name: str, data: list[dict]) -> int:
async def save_to_aoss(self, name: str, data: list[dict]) -> int:
client = self.aoss_client
if not client or not helpers:
logger.warning('No OpenSearch client found')
@ -243,16 +291,22 @@ class GraphDriver(ABC):
if name.lower() == index['index_name']:
to_index = []
for d in data:
item = {
'_index': name,
'_routing': d.get('group_id'), # shard routing
}
doc = {}
for p in index['body']['mappings']['properties']:
if p in d: # protect against missing fields
item[p] = d[p]
doc[p] = d[p]
item = {
'_index': name,
'_id': d['uuid'],
'_routing': d.get('group_id'),
'_source': doc,
}
to_index.append(item)
success, failed = helpers.bulk(client, to_index, stats_only=True)
success, failed = await helpers.async_bulk(
client, to_index, stats_only=True, request_timeout=60
)
return success if failed == 0 else success

View file

@ -28,7 +28,13 @@ logger = logging.getLogger(__name__)
try:
import boto3
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection
from opensearchpy import (
AIOHttpConnection,
AsyncOpenSearch,
AWSV4SignerAuth,
Urllib3AWSV4SignerAuth,
Urllib3HttpConnection,
)
_HAS_OPENSEARCH = True
except ImportError:
@ -50,6 +56,9 @@ class Neo4jDriver(GraphDriver):
database: str = 'neo4j',
aoss_host: str | None = None,
aoss_port: int | None = None,
aws_profile_name: str | None = None,
aws_region: str | None = None,
aws_service: str | None = None,
):
super().__init__()
self.client = AsyncGraphDatabase.driver(
@ -61,15 +70,17 @@ class Neo4jDriver(GraphDriver):
self.aoss_client = None
if aoss_host and aoss_port and boto3 is not None:
try:
session = boto3.Session()
self.aoss_client = OpenSearch( # type: ignore
region = aws_region
service = aws_service
credentials = boto3.Session(profile_name=aws_profile_name).get_credentials()
auth = AWSV4SignerAuth(credentials, region or '', service or '')
self.aoss_client = AsyncOpenSearch(
hosts=[{'host': aoss_host, 'port': aoss_port}],
http_auth=Urllib3AWSV4SignerAuth( # type: ignore
session.get_credentials(), session.region_name, 'aoss'
),
auth=auth,
use_ssl=True,
verify_certs=True,
connection_class=Urllib3HttpConnection,
connection_class=AIOHttpConnection,
pool_maxsize=20,
) # type: ignore
except Exception as e:

View file

@ -237,12 +237,12 @@ class NeptuneDriver(GraphDriver):
'You must provide an AOSS endpoint to create an OpenSearch driver.'
)
if not client.indices.exists(index=index_name):
client.indices.create(index=index_name, body=index['body'])
await client.indices.create(index=index_name, body=index['body'])
alias_name = index.get('alias_name', index_name)
if not client.indices.exists_alias(name=alias_name, index=index_name):
client.indices.put_alias(index=index_name, name=alias_name)
await client.indices.put_alias(index=index_name, name=alias_name)
# Sleep for 1 minute to let the index creation complete
await asyncio.sleep(60)

View file

@ -25,7 +25,7 @@ from uuid import uuid4
from pydantic import BaseModel, Field
from typing_extensions import LiteralString
from graphiti_core.driver.driver import GraphDriver, GraphProvider
from graphiti_core.driver.driver import ENTITY_EDGE_INDEX_NAME, GraphDriver, GraphProvider
from graphiti_core.embedder import EmbedderClient
from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
from graphiti_core.helpers import parse_db_date
@ -77,6 +77,13 @@ class Edge(BaseModel, ABC):
uuid=self.uuid,
)
if driver.aoss_client:
await driver.aoss_client.delete(
index=ENTITY_EDGE_INDEX_NAME,
id=self.uuid,
params={'routing': self.group_id},
)
logger.debug(f'Deleted Edge: {self.uuid}')
@classmethod
@ -108,6 +115,12 @@ class Edge(BaseModel, ABC):
uuids=uuids,
)
if driver.aoss_client:
await driver.aoss_client.delete_by_query(
index=ENTITY_EDGE_INDEX_NAME,
body={'query': {'terms': {'uuid': uuids}}},
)
logger.debug(f'Deleted Edges: {uuids}')
def __hash__(self):
@ -256,13 +269,13 @@ class EntityEdge(Edge):
RETURN [x IN split(e.fact_embedding, ",") | toFloat(x)] as fact_embedding
"""
elif driver.aoss_client:
resp = driver.aoss_client.search(
resp = await driver.aoss_client.search(
body={
'query': {'multi_match': {'query': self.uuid, 'fields': ['uuid']}},
'size': 1,
},
index='entity_edges',
routing=self.group_id,
index=ENTITY_EDGE_INDEX_NAME,
params={'routing': self.group_id},
)
if resp['hits']['hits']:
@ -314,7 +327,7 @@ class EntityEdge(Edge):
edge_data.update(self.attributes or {})
if driver.aoss_client:
driver.save_to_aoss('entity_edges', [edge_data]) # pyright: ignore reportAttributeAccessIssue
await driver.save_to_aoss(ENTITY_EDGE_INDEX_NAME, [edge_data]) # pyright: ignore reportAttributeAccessIssue
result = await driver.execute_query(
get_entity_edge_save_query(driver.provider),
@ -351,6 +364,35 @@ class EntityEdge(Edge):
raise EdgeNotFoundError(uuid)
return edges[0]
@classmethod
async def get_between_nodes(
cls, driver: GraphDriver, source_node_uuid: str, target_node_uuid: str
):
match_query = """
MATCH (n:Entity {uuid: $source_node_uuid})-[e:RELATES_TO]->(m:Entity {uuid: $target_node_uuid})
"""
if driver.provider == GraphProvider.KUZU:
match_query = """
MATCH (n:Entity {uuid: $source_node_uuid})
-[:RELATES_TO]->(e:RelatesToNode_)
-[:RELATES_TO]->(m:Entity {uuid: $target_node_uuid})
"""
records, _, _ = await driver.execute_query(
match_query
+ """
RETURN
"""
+ get_entity_edge_return_query(driver.provider),
source_node_uuid=source_node_uuid,
target_node_uuid=target_node_uuid,
routing_='r',
)
edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
return edges
@classmethod
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
if len(uuids) == 0:

View file

@ -60,9 +60,7 @@ from graphiti_core.search.search_config_recipes import (
from graphiti_core.search.search_filters import SearchFilters
from graphiti_core.search.search_utils import (
RELEVANT_SCHEMA_LIMIT,
get_edge_invalidation_candidates,
get_mentioned_nodes,
get_relevant_edges,
)
from graphiti_core.telemetry import capture_event
from graphiti_core.utils.bulk_utils import (
@ -1037,10 +1035,28 @@ class Graphiti:
updated_edge = resolve_edge_pointers([edge], uuid_map)[0]
related_edges = (await get_relevant_edges(self.driver, [updated_edge], SearchFilters()))[0]
valid_edges = await EntityEdge.get_between_nodes(
self.driver, edge.source_node_uuid, edge.target_node_uuid
)
related_edges = (
await search(
self.clients,
updated_edge.fact,
group_ids=[updated_edge.group_id],
config=EDGE_HYBRID_SEARCH_RRF,
search_filter=SearchFilters(edge_uuids=[edge.uuid for edge in valid_edges]),
)
).edges
existing_edges = (
await get_edge_invalidation_candidates(self.driver, [updated_edge], SearchFilters())
)[0]
await search(
self.clients,
updated_edge.fact,
group_ids=[updated_edge.group_id],
config=EDGE_HYBRID_SEARCH_RRF,
search_filter=SearchFilters(),
)
).edges
resolved_edge, invalidated_edges, _ = await resolve_extracted_edge(
self.llm_client,

View file

@ -26,7 +26,14 @@ from uuid import uuid4
from pydantic import BaseModel, Field
from typing_extensions import LiteralString
from graphiti_core.driver.driver import GraphDriver, GraphProvider
from graphiti_core.driver.driver import (
COMMUNITY_INDEX_NAME,
ENTITY_EDGE_INDEX_NAME,
ENTITY_INDEX_NAME,
EPISODE_INDEX_NAME,
GraphDriver,
GraphProvider,
)
from graphiti_core.embedder import EmbedderClient
from graphiti_core.errors import NodeNotFoundError
from graphiti_core.helpers import parse_db_date
@ -94,13 +101,39 @@ class Node(BaseModel, ABC):
async def delete(self, driver: GraphDriver):
match driver.provider:
case GraphProvider.NEO4J:
await driver.execute_query(
records, _, _ = await driver.execute_query(
"""
MATCH (n:Entity|Episodic|Community {uuid: $uuid})
MATCH (n {uuid: $uuid})
WHERE n:Entity OR n:Episodic OR n:Community
OPTIONAL MATCH (n)-[r]-()
WITH collect(r.uuid) AS edge_uuids, n
DETACH DELETE n
RETURN edge_uuids
""",
uuid=self.uuid,
)
edge_uuids: list[str] = records[0].get('edge_uuids', []) if records else []
if driver.aoss_client:
# Delete the node from OpenSearch indices
for index in (EPISODE_INDEX_NAME, ENTITY_INDEX_NAME, COMMUNITY_INDEX_NAME):
await driver.aoss_client.delete(
index=index,
id=self.uuid,
params={'routing': self.group_id},
)
# Bulk delete the detached edges
if edge_uuids:
actions = []
for eid in edge_uuids:
actions.append(
{'delete': {'_index': ENTITY_EDGE_INDEX_NAME, '_id': eid}}
)
await driver.aoss_client.bulk(body=actions)
case GraphProvider.KUZU:
for label in ['Episodic', 'Community']:
await driver.execute_query(
@ -162,6 +195,32 @@ class Node(BaseModel, ABC):
group_id=group_id,
batch_size=batch_size,
)
if driver.aoss_client:
await driver.aoss_client.delete_by_query(
index=EPISODE_INDEX_NAME,
body={'query': {'term': {'group_id': group_id}}},
params={'routing': group_id},
)
await driver.aoss_client.delete_by_query(
index=ENTITY_INDEX_NAME,
body={'query': {'term': {'group_id': group_id}}},
params={'routing': group_id},
)
await driver.aoss_client.delete_by_query(
index=COMMUNITY_INDEX_NAME,
body={'query': {'term': {'group_id': group_id}}},
params={'routing': group_id},
)
await driver.aoss_client.delete_by_query(
index=ENTITY_EDGE_INDEX_NAME,
body={'query': {'term': {'group_id': group_id}}},
params={'routing': group_id},
)
case GraphProvider.KUZU:
for label in ['Episodic', 'Community']:
await driver.execute_query(
@ -240,6 +299,23 @@ class Node(BaseModel, ABC):
)
case _: # Neo4J, Neptune
async with driver.session() as session:
# Collect all edge UUIDs before deleting nodes
result = await session.run(
"""
MATCH (n:Entity|Episodic|Community)
WHERE n.uuid IN $uuids
MATCH (n)-[r]-()
RETURN collect(r.uuid) AS edge_uuids
""",
uuids=uuids,
)
record = await result.single()
edge_uuids: list[str] = (
record['edge_uuids'] if record and record['edge_uuids'] else []
)
# Now delete the nodes in batches
await session.run(
"""
MATCH (n:Entity|Episodic|Community)
@ -253,6 +329,20 @@ class Node(BaseModel, ABC):
batch_size=batch_size,
)
if driver.aoss_client:
for index in (EPISODE_INDEX_NAME, ENTITY_INDEX_NAME, COMMUNITY_INDEX_NAME):
await driver.aoss_client.delete_by_query(
index=index,
body={'query': {'terms': {'uuid': uuids}}},
)
if edge_uuids:
actions = [
{'delete': {'_index': ENTITY_EDGE_INDEX_NAME, '_id': eid}}
for eid in edge_uuids
]
await driver.aoss_client.bulk(body=actions)
@classmethod
async def get_by_uuid(cls, driver: GraphDriver, uuid: str): ...
@ -286,7 +376,7 @@ class EpisodicNode(Node):
}
if driver.aoss_client:
driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
await driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
'episodes',
[episode_args],
)
@ -426,13 +516,13 @@ class EntityNode(Node):
RETURN [x IN split(n.name_embedding, ",") | toFloat(x)] as name_embedding
"""
elif driver.aoss_client:
resp = driver.aoss_client.search(
resp = await driver.aoss_client.search(
body={
'query': {'multi_match': {'query': self.uuid, 'fields': ['uuid']}},
'size': 1,
},
index='entities',
routing=self.group_id,
index=ENTITY_INDEX_NAME,
params={'routing': self.group_id},
)
if resp['hits']['hits']:
@ -479,7 +569,7 @@ class EntityNode(Node):
labels = ':'.join(self.labels + ['Entity'])
if driver.aoss_client:
driver.save_to_aoss('entities', [entity_data]) # pyright: ignore reportAttributeAccessIssue
await driver.save_to_aoss(ENTITY_INDEX_NAME, [entity_data]) # pyright: ignore reportAttributeAccessIssue
result = await driver.execute_query(
get_entity_node_save_query(driver.provider, labels, bool(driver.aoss_client)),
@ -577,7 +667,7 @@ class CommunityNode(Node):
async def save(self, driver: GraphDriver):
if driver.provider == GraphProvider.NEPTUNE:
driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
await driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
'communities',
[{'name': self.name, 'uuid': self.uuid, 'group_id': self.group_id}],
)

View file

@ -52,6 +52,7 @@ class SearchFilters(BaseModel):
invalid_at: list[list[DateFilter]] | None = Field(default=None)
created_at: list[list[DateFilter]] | None = Field(default=None)
expired_at: list[list[DateFilter]] | None = Field(default=None)
edge_uuids: list[str] | None = Field(default=None)
def cypher_to_opensearch_operator(op: ComparisonOperator) -> str:
@ -108,6 +109,10 @@ def edge_search_filter_query_constructor(
filter_queries.append('e.name in $edge_types')
filter_params['edge_types'] = edge_types
if filters.edge_uuids is not None:
filter_queries.append('e.uuid in $edge_uuids')
filter_params['edge_uuids'] = filters.edge_uuids
if filters.node_labels is not None:
if provider == GraphProvider.KUZU:
node_label_filter = (
@ -261,6 +266,9 @@ def build_aoss_edge_filters(group_ids: list[str], search_filters: SearchFilters)
if search_filters.edge_types:
filters.append({'terms': {'edge_types': search_filters.edge_types}})
if search_filters.edge_uuids:
filters.append({'terms': {'uuid': search_filters.edge_uuids}})
for field in ['valid_at', 'invalid_at', 'created_at', 'expired_at']:
ranges = getattr(search_filters, field)
if ranges:

View file

@ -23,7 +23,13 @@ import numpy as np
from numpy._typing import NDArray
from typing_extensions import LiteralString
from graphiti_core.driver.driver import GraphDriver, GraphProvider
from graphiti_core.driver.driver import (
ENTITY_EDGE_INDEX_NAME,
ENTITY_INDEX_NAME,
EPISODE_INDEX_NAME,
GraphDriver,
GraphProvider,
)
from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
from graphiti_core.graph_queries import (
get_nodes_query,
@ -209,11 +215,11 @@ async def edge_fulltext_search(
# Match the edge ids and return the values
query = (
"""
UNWIND $ids as id
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
WHERE e.group_id IN $group_ids
AND id(e)=id
"""
UNWIND $ids as id
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
WHERE e.group_id IN $group_ids
AND id(e)=id
"""
+ filter_query
+ """
AND id(e)=id
@ -248,17 +254,21 @@ async def edge_fulltext_search(
elif driver.aoss_client:
route = group_ids[0] if group_ids else None
filters = build_aoss_edge_filters(group_ids or [], search_filter)
res = driver.aoss_client.search(
index='entity_edges',
routing=route,
_source=['uuid'],
query={
'bool': {
'filter': filters,
'must': [{'match': {'fact': {'query': query, 'operator': 'or'}}}],
}
res = await driver.aoss_client.search(
index=ENTITY_EDGE_INDEX_NAME,
params={'routing': route},
body={
'size': limit,
'_source': ['uuid'],
'query': {
'bool': {
'filter': filters,
'must': [{'match': {'fact': {'query': query, 'operator': 'or'}}}],
}
},
},
)
if res['hits']['total']['value'] > 0:
input_uuids = {}
for r in res['hits']['hits']:
@ -344,8 +354,8 @@ async def edge_similarity_search(
if driver.provider == GraphProvider.NEPTUNE:
query = (
"""
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
"""
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
"""
+ filter_query
+ """
RETURN DISTINCT id(e) as id, e.fact_embedding as embedding
@ -406,17 +416,22 @@ async def edge_similarity_search(
elif driver.aoss_client:
route = group_ids[0] if group_ids else None
filters = build_aoss_edge_filters(group_ids or [], search_filter)
res = driver.aoss_client.search(
index='entity_edges',
routing=route,
_source=['uuid'],
knn={
'field': 'fact_embedding',
'query_vector': search_vector,
'k': limit,
'num_candidates': 1000,
res = await driver.aoss_client.search(
index=ENTITY_EDGE_INDEX_NAME,
params={'routing': route},
body={
'size': limit,
'_source': ['uuid'],
'query': {
'knn': {
'fact_embedding': {
'vector': list(map(float, search_vector)),
'k': limit,
'filter': {'bool': {'filter': filters}},
}
}
},
},
query={'bool': {'filter': filters}},
)
if res['hits']['total']['value'] > 0:
@ -428,6 +443,7 @@ async def edge_similarity_search(
entity_edges = await EntityEdge.get_by_uuids(driver, list(input_uuids.keys()))
entity_edges.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True)
return entity_edges
return []
else:
query = (
@ -622,11 +638,11 @@ async def node_fulltext_search(
# Match the edge ides and return the values
query = (
"""
UNWIND $ids as i
MATCH (n:Entity)
WHERE n.uuid=i.id
RETURN
"""
UNWIND $ids as i
MATCH (n:Entity)
WHERE n.uuid=i.id
RETURN
"""
+ get_entity_node_return_query(driver.provider)
+ """
ORDER BY i.score DESC
@ -646,25 +662,27 @@ async def node_fulltext_search(
elif driver.aoss_client:
route = group_ids[0] if group_ids else None
filters = build_aoss_node_filters(group_ids or [], search_filter)
res = driver.aoss_client.search(
'entities',
routing=route,
_source=['uuid'],
query={
'bool': {
'filter': filters,
'must': [
{
'multi_match': {
'query': query,
'field': ['name', 'summary'],
'operator': 'or',
res = await driver.aoss_client.search(
index=ENTITY_INDEX_NAME,
params={'routing': route},
body={
'_source': ['uuid'],
'size': limit,
'query': {
'bool': {
'filter': filters,
'must': [
{
'multi_match': {
'query': query,
'fields': ['name', 'summary'],
'operator': 'or',
}
}
}
],
}
],
}
},
},
limit=limit,
)
if res['hits']['total']['value'] > 0:
@ -734,8 +752,8 @@ async def node_similarity_search(
if driver.provider == GraphProvider.NEPTUNE:
query = (
"""
MATCH (n:Entity)
"""
MATCH (n:Entity)
"""
+ filter_query
+ """
RETURN DISTINCT id(n) as id, n.name_embedding as embedding
@ -764,11 +782,11 @@ async def node_similarity_search(
# Match the edge ides and return the values
query = (
"""
UNWIND $ids as i
MATCH (n:Entity)
WHERE id(n)=i.id
RETURN
"""
UNWIND $ids as i
MATCH (n:Entity)
WHERE id(n)=i.id
RETURN
"""
+ get_entity_node_return_query(driver.provider)
+ """
ORDER BY i.score DESC
@ -789,17 +807,22 @@ async def node_similarity_search(
elif driver.aoss_client:
route = group_ids[0] if group_ids else None
filters = build_aoss_node_filters(group_ids or [], search_filter)
res = driver.aoss_client.search(
index='entities',
routing=route,
_source=['uuid'],
knn={
'field': 'fact_embedding',
'query_vector': search_vector,
'k': limit,
'num_candidates': 1000,
res = await driver.aoss_client.search(
index=ENTITY_INDEX_NAME,
params={'routing': route},
body={
'size': limit,
'_source': ['uuid'],
'query': {
'knn': {
'name_embedding': {
'vector': list(map(float, search_vector)),
'k': limit,
'filter': {'bool': {'filter': filters}},
}
}
},
},
query={'bool': {'filter': filters}},
)
if res['hits']['total']['value'] > 0:
@ -811,11 +834,12 @@ async def node_similarity_search(
entity_nodes = await EntityNode.get_by_uuids(driver, list(input_uuids.keys()))
entity_nodes.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True)
return entity_nodes
return []
else:
query = (
"""
MATCH (n:Entity)
"""
MATCH (n:Entity)
"""
+ filter_query
+ """
WITH n, """
@ -988,11 +1012,12 @@ async def episode_fulltext_search(
return []
elif driver.aoss_client:
route = group_ids[0] if group_ids else None
res = driver.aoss_client.search(
'episodes',
routing=route,
_source=['uuid'],
query={
res = await driver.aoss_client.search(
index=EPISODE_INDEX_NAME,
params={'routing': route},
body={
'size': limit,
'_source': ['uuid'],
'bool': {
'filter': {'terms': group_ids},
'must': [
@ -1004,9 +1029,8 @@ async def episode_fulltext_search(
}
}
],
}
},
},
limit=limit,
)
if res['hits']['total']['value'] > 0:
@ -1147,8 +1171,8 @@ async def community_similarity_search(
if driver.provider == GraphProvider.NEPTUNE:
query = (
"""
MATCH (n:Community)
"""
MATCH (n:Community)
"""
+ group_filter_query
+ """
RETURN DISTINCT id(n) as id, n.name_embedding as embedding
@ -1207,8 +1231,8 @@ async def community_similarity_search(
query = (
"""
MATCH (c:Community)
"""
MATCH (c:Community)
"""
+ group_filter_query
+ """
WITH c,
@ -1350,9 +1374,9 @@ async def get_relevant_nodes(
# FIXME: Kuzu currently does not support using variables such as `node.fulltext_query` as an input to FTS, which means `get_relevant_nodes()` won't work with Kuzu as the graph driver.
query = (
"""
UNWIND $nodes AS node
MATCH (n:Entity {group_id: $group_id})
"""
UNWIND $nodes AS node
MATCH (n:Entity {group_id: $group_id})
"""
+ filter_query
+ """
WITH node, n, """
@ -1397,9 +1421,9 @@ async def get_relevant_nodes(
else:
query = (
"""
UNWIND $nodes AS node
MATCH (n:Entity {group_id: $group_id})
"""
UNWIND $nodes AS node
MATCH (n:Entity {group_id: $group_id})
"""
+ filter_query
+ """
WITH node, n, """
@ -1488,9 +1512,9 @@ async def get_relevant_edges(
if driver.provider == GraphProvider.NEPTUNE:
query = (
"""
UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
"""
UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
"""
+ filter_query
+ """
WITH e, edge
@ -1560,9 +1584,9 @@ async def get_relevant_edges(
query = (
"""
UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid})
"""
UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid})
"""
+ filter_query
+ """
WITH e, edge, n, m, """
@ -1598,9 +1622,9 @@ async def get_relevant_edges(
else:
query = (
"""
UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
"""
UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
"""
+ filter_query
+ """
WITH e, edge, """
@ -1673,10 +1697,10 @@ async def get_edge_invalidation_candidates(
if driver.provider == GraphProvider.NEPTUNE:
query = (
"""
UNWIND $edges AS edge
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
"""
UNWIND $edges AS edge
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
"""
+ filter_query
+ """
WITH e, edge
@ -1746,10 +1770,10 @@ async def get_edge_invalidation_candidates(
query = (
"""
UNWIND $edges AS edge
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity)
WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid])
"""
UNWIND $edges AS edge
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity)
WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid])
"""
+ filter_query
+ """
WITH edge, e, n, m, """
@ -1785,10 +1809,10 @@ async def get_edge_invalidation_candidates(
else:
query = (
"""
UNWIND $edges AS edge
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
"""
UNWIND $edges AS edge
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
"""
+ filter_query
+ """
WITH edge, e, """

View file

@ -23,7 +23,14 @@ import numpy as np
from pydantic import BaseModel, Field
from typing_extensions import Any
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
from graphiti_core.driver.driver import (
ENTITY_EDGE_INDEX_NAME,
ENTITY_INDEX_NAME,
EPISODE_INDEX_NAME,
GraphDriver,
GraphDriverSession,
GraphProvider,
)
from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge, create_entity_edge_embeddings
from graphiti_core.embedder import EmbedderClient
from graphiti_core.graphiti_types import GraphitiClients
@ -203,9 +210,9 @@ async def add_nodes_and_edges_bulk_tx(
)
if driver.aoss_client:
driver.save_to_aoss('episodes', episodes)
driver.save_to_aoss('entities', nodes)
driver.save_to_aoss('entity_edges', edges)
await driver.save_to_aoss(EPISODE_INDEX_NAME, episodes)
await driver.save_to_aoss(ENTITY_INDEX_NAME, nodes)
await driver.save_to_aoss(ENTITY_EDGE_INDEX_NAME, edges)
async def extract_nodes_and_edges_bulk(

View file

@ -36,8 +36,10 @@ from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
from graphiti_core.prompts import prompt_library
from graphiti_core.prompts.dedupe_edges import EdgeDuplicate
from graphiti_core.prompts.extract_edges import ExtractedEdges, MissingFacts
from graphiti_core.search.search import search
from graphiti_core.search.search_config import SearchResults
from graphiti_core.search.search_config_recipes import EDGE_HYBRID_SEARCH_RRF
from graphiti_core.search.search_filters import SearchFilters
from graphiti_core.search.search_utils import get_edge_invalidation_candidates, get_relevant_edges
from graphiti_core.utils.datetime_utils import ensure_utc, utc_now
logger = logging.getLogger(__name__)
@ -258,12 +260,44 @@ async def resolve_extracted_edges(
embedder = clients.embedder
await create_entity_edge_embeddings(embedder, extracted_edges)
search_results = await semaphore_gather(
get_relevant_edges(driver, extracted_edges, SearchFilters()),
get_edge_invalidation_candidates(driver, extracted_edges, SearchFilters(), 0.2),
valid_edges_list: list[list[EntityEdge]] = await semaphore_gather(
*[
EntityEdge.get_between_nodes(driver, edge.source_node_uuid, edge.target_node_uuid)
for edge in extracted_edges
]
)
related_edges_lists, edge_invalidation_candidates = search_results
related_edges_results: list[SearchResults] = await semaphore_gather(
*[
search(
clients,
extracted_edge.fact,
group_ids=[extracted_edge.group_id],
config=EDGE_HYBRID_SEARCH_RRF,
search_filter=SearchFilters(edge_uuids=[edge.uuid for edge in valid_edges]),
)
for extracted_edge, valid_edges in zip(extracted_edges, valid_edges_list, strict=True)
]
)
related_edges_lists: list[list[EntityEdge]] = [result.edges for result in related_edges_results]
edge_invalidation_candidate_results: list[SearchResults] = await semaphore_gather(
*[
search(
clients,
extracted_edge.fact,
group_ids=[extracted_edge.group_id],
config=EDGE_HYBRID_SEARCH_RRF,
search_filter=SearchFilters(),
)
for extracted_edge in extracted_edges
]
)
edge_invalidation_candidates: list[list[EntityEdge]] = [
result.edges for result in edge_invalidation_candidate_results
]
logger.debug(
f'Related edges lists: {[(e.name, e.uuid) for edges_lst in related_edges_lists for e in edges_lst]}'

View file

@ -95,6 +95,8 @@ async def clear_data(driver: GraphDriver, group_ids: list[str] | None = None):
async def delete_all(tx):
await tx.run('MATCH (n) DETACH DELETE n')
if driver.aoss_client:
await driver.clear_aoss_indices()
async def delete_group_ids(tx):
labels = ['Entity', 'Episodic', 'Community']
@ -151,9 +153,9 @@ async def retrieve_episodes(
query: LiteralString = (
"""
MATCH (e:Episodic)
WHERE e.valid_at <= $reference_time
"""
MATCH (e:Episodic)
WHERE e.valid_at <= $reference_time
"""
+ query_filter
+ """
RETURN

View file

@ -1,7 +1,7 @@
[project]
name = "graphiti-core"
description = "A temporal graph building library"
version = "0.21.0pre1"
version = "0.21.0pre2"
authors = [
{ name = "Paul Paliychuk", email = "paul@getzep.com" },
{ name = "Preston Rasmussen", email = "preston@getzep.com" },

2
uv.lock generated
View file

@ -783,7 +783,7 @@ wheels = [
[[package]]
name = "graphiti-core"
version = "0.21.0rc1"
version = "0.21.0rc2"
source = { editable = "." }
dependencies = [
{ name = "diskcache" },