From 6b579798691e58a1489453b61b61c305bb453014 Mon Sep 17 00:00:00 2001 From: prestonrasmussen Date: Mon, 8 Sep 2025 10:14:25 -0400 Subject: [PATCH] updates --- graphiti_core/driver/driver.py | 5 +- graphiti_core/driver/falkordb_driver.py | 2 +- graphiti_core/driver/kuzu_driver.py | 2 +- graphiti_core/driver/neo4j_driver.py | 10 +- graphiti_core/driver/neptune_driver.py | 4 + graphiti_core/search/search_utils.py | 127 ++++++++++++------------ 6 files changed, 84 insertions(+), 66 deletions(-) diff --git a/graphiti_core/driver/driver.py b/graphiti_core/driver/driver.py index d5291e0d..7999bc25 100644 --- a/graphiti_core/driver/driver.py +++ b/graphiti_core/driver/driver.py @@ -21,7 +21,7 @@ from abc import ABC, abstractmethod from collections.abc import Coroutine from datetime import datetime from enum import Enum -from typing import Any +from typing import TYPE_CHECKING, Any try: from opensearchpy import OpenSearch, helpers @@ -32,6 +32,9 @@ except ImportError: helpers = None _HAS_OPENSEARCH = False +if TYPE_CHECKING: + from opensearchpy import OpenSearch, helpers + logger = logging.getLogger(__name__) DEFAULT_SIZE = 10 diff --git a/graphiti_core/driver/falkordb_driver.py b/graphiti_core/driver/falkordb_driver.py index 984b070b..95605cd2 100644 --- a/graphiti_core/driver/falkordb_driver.py +++ b/graphiti_core/driver/falkordb_driver.py @@ -39,7 +39,7 @@ logger = logging.getLogger(__name__) class FalkorDriverSession(GraphDriverSession): provider = GraphProvider.FALKORDB - aoss_client: None + aoss_client: None = None def __init__(self, graph: FalkorGraph): self.graph = graph diff --git a/graphiti_core/driver/kuzu_driver.py b/graphiti_core/driver/kuzu_driver.py index 2e0edbe3..8a04a4ac 100644 --- a/graphiti_core/driver/kuzu_driver.py +++ b/graphiti_core/driver/kuzu_driver.py @@ -92,7 +92,7 @@ SCHEMA_QUERIES = """ class KuzuDriver(GraphDriver): provider: GraphProvider = GraphProvider.KUZU - aoss_client: None + aoss_client: None = None def __init__( self, diff --git a/graphiti_core/driver/neo4j_driver.py b/graphiti_core/driver/neo4j_driver.py index 00a7371c..7bf5b4a7 100644 --- a/graphiti_core/driver/neo4j_driver.py +++ b/graphiti_core/driver/neo4j_driver.py @@ -22,6 +22,7 @@ from neo4j import AsyncGraphDatabase, EagerResult from typing_extensions import LiteralString from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider +from graphiti_core.helpers import semaphore_gather logger = logging.getLogger(__name__) @@ -98,9 +99,14 @@ class Neo4jDriver(GraphDriver): async def close(self) -> None: return await self.client.close() - def delete_all_indexes(self) -> Coroutine[Any, Any, EagerResult]: + def delete_all_indexes(self) -> Coroutine: if self.aoss_client: - return self.delete_aoss_indices() + return semaphore_gather( + self.client.execute_query( + 'CALL db.indexes() YIELD name DROP INDEX name', + ), + self.delete_aoss_indices(), + ) return self.client.execute_query( 'CALL db.indexes() YIELD name DROP INDEX name', ) diff --git a/graphiti_core/driver/neptune_driver.py b/graphiti_core/driver/neptune_driver.py index 6ecadccd..cb343163 100644 --- a/graphiti_core/driver/neptune_driver.py +++ b/graphiti_core/driver/neptune_driver.py @@ -232,6 +232,10 @@ class NeptuneDriver(GraphDriver): for index in neptune_aoss_indices: index_name = index['index_name'] client = self.aoss_client + if not client: + raise ValueError( + 'You must provide an AOSS endpoint to create an OpenSearch driver.' + ) if not client.indices.exists(index=index_name): client.indices.create(index=index_name, body=index['body']) diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index c9da7512..40eeaa7e 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -209,11 +209,11 @@ async def edge_fulltext_search( # Match the edge ids and return the values query = ( """ - UNWIND $ids as id - MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) - WHERE e.group_id IN $group_ids - AND id(e)=id - """ + UNWIND $ids as id + MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) + WHERE e.group_id IN $group_ids + AND id(e)=id + """ + filter_query + """ AND id(e)=id @@ -246,10 +246,11 @@ async def edge_fulltext_search( else: return [] elif driver.aoss_client: - filters = build_aoss_edge_filters(group_ids, search_filter) + route = group_ids[0] if group_ids else None + filters = build_aoss_edge_filters(group_ids or [], search_filter) res = driver.aoss_client.search( index='entity_edges', - routing=group_ids[0], + routing=route, _source=['uuid'], query={ 'bool': { @@ -343,8 +344,8 @@ async def edge_similarity_search( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) - """ + MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) + """ + filter_query + """ RETURN DISTINCT id(e) as id, e.fact_embedding as embedding @@ -403,10 +404,11 @@ async def edge_similarity_search( else: return [] elif driver.aoss_client: - filters = build_aoss_edge_filters(group_ids, search_filter) + route = group_ids[0] if group_ids else None + filters = build_aoss_edge_filters(group_ids or [], search_filter) res = driver.aoss_client.search( index='entity_edges', - routing=group_ids[0], + routing=route, _source=['uuid'], knn={ 'field': 'fact_embedding', @@ -620,11 +622,11 @@ async def node_fulltext_search( # Match the edge ides and return the values query = ( """ - UNWIND $ids as i - MATCH (n:Entity) - WHERE n.uuid=i.id - RETURN - """ + UNWIND $ids as i + MATCH (n:Entity) + WHERE n.uuid=i.id + RETURN + """ + get_entity_node_return_query(driver.provider) + """ ORDER BY i.score DESC @@ -642,10 +644,11 @@ async def node_fulltext_search( else: return [] elif driver.aoss_client: - filters = build_aoss_node_filters(group_ids, search_filter) + route = group_ids[0] if group_ids else None + filters = build_aoss_node_filters(group_ids or [], search_filter) res = driver.aoss_client.search( 'entities', - routing=group_ids[0], + routing=route, _source=['uuid'], query={ 'bool': { @@ -731,8 +734,8 @@ async def node_similarity_search( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - MATCH (n:Entity) - """ + MATCH (n:Entity) + """ + filter_query + """ RETURN DISTINCT id(n) as id, n.name_embedding as embedding @@ -761,11 +764,11 @@ async def node_similarity_search( # Match the edge ides and return the values query = ( """ - UNWIND $ids as i - MATCH (n:Entity) - WHERE id(n)=i.id - RETURN - """ + UNWIND $ids as i + MATCH (n:Entity) + WHERE id(n)=i.id + RETURN + """ + get_entity_node_return_query(driver.provider) + """ ORDER BY i.score DESC @@ -784,10 +787,11 @@ async def node_similarity_search( else: return [] elif driver.aoss_client: - filters = build_aoss_node_filters(group_ids, search_filter) + route = group_ids[0] if group_ids else None + filters = build_aoss_node_filters(group_ids or [], search_filter) res = driver.aoss_client.search( index='entities', - routing=group_ids[0], + routing=route, _source=['uuid'], knn={ 'field': 'fact_embedding', @@ -810,8 +814,8 @@ async def node_similarity_search( else: query = ( """ - MATCH (n:Entity) - """ + MATCH (n:Entity) + """ + filter_query + """ WITH n, """ @@ -983,9 +987,10 @@ async def episode_fulltext_search( else: return [] elif driver.aoss_client: + route = group_ids[0] if group_ids else None res = driver.aoss_client.search( 'episodes', - routing=group_ids[0], + routing=route, _source=['uuid'], query={ 'bool': { @@ -1142,8 +1147,8 @@ async def community_similarity_search( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - MATCH (n:Community) - """ + MATCH (n:Community) + """ + group_filter_query + """ RETURN DISTINCT id(n) as id, n.name_embedding as embedding @@ -1202,8 +1207,8 @@ async def community_similarity_search( query = ( """ - MATCH (c:Community) - """ + MATCH (c:Community) + """ + group_filter_query + """ WITH c, @@ -1345,9 +1350,9 @@ async def get_relevant_nodes( # FIXME: Kuzu currently does not support using variables such as `node.fulltext_query` as an input to FTS, which means `get_relevant_nodes()` won't work with Kuzu as the graph driver. query = ( """ - UNWIND $nodes AS node - MATCH (n:Entity {group_id: $group_id}) - """ + UNWIND $nodes AS node + MATCH (n:Entity {group_id: $group_id}) + """ + filter_query + """ WITH node, n, """ @@ -1392,9 +1397,9 @@ async def get_relevant_nodes( else: query = ( """ - UNWIND $nodes AS node - MATCH (n:Entity {group_id: $group_id}) - """ + UNWIND $nodes AS node + MATCH (n:Entity {group_id: $group_id}) + """ + filter_query + """ WITH node, n, """ @@ -1483,9 +1488,9 @@ async def get_relevant_edges( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) - """ + UNWIND $edges AS edge + MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) + """ + filter_query + """ WITH e, edge @@ -1555,9 +1560,9 @@ async def get_relevant_edges( query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid}) - """ + UNWIND $edges AS edge + MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid}) + """ + filter_query + """ WITH e, edge, n, m, """ @@ -1593,9 +1598,9 @@ async def get_relevant_edges( else: query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) - """ + UNWIND $edges AS edge + MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) + """ + filter_query + """ WITH e, edge, """ @@ -1668,10 +1673,10 @@ async def get_edge_invalidation_candidates( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) - WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] - """ + UNWIND $edges AS edge + MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) + WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] + """ + filter_query + """ WITH e, edge @@ -1741,10 +1746,10 @@ async def get_edge_invalidation_candidates( query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity) - WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]) - """ + UNWIND $edges AS edge + MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity) + WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]) + """ + filter_query + """ WITH edge, e, n, m, """ @@ -1780,10 +1785,10 @@ async def get_edge_invalidation_candidates( else: query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) - WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] - """ + UNWIND $edges AS edge + MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) + WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] + """ + filter_query + """ WITH edge, e, """