diff --git a/graphiti_core/driver/driver.py b/graphiti_core/driver/driver.py index 89d5403e..d5291e0d 100644 --- a/graphiti_core/driver/driver.py +++ b/graphiti_core/driver/driver.py @@ -195,6 +195,9 @@ class GraphDriver(ABC): async def create_aoss_indices(self): client = self.aoss_client + if not client: + logger.warning('No OpenSearch client found') + return for index in aoss_indices: alias_name = index['index_name'] @@ -220,10 +223,20 @@ class GraphDriver(ABC): for index in aoss_indices: index_name = index['index_name'] client = self.aoss_client + + if not client: + logger.warning('No OpenSearch client found') + return + if client.indices.exists(index=index_name): client.indices.delete(index=index_name) def save_to_aoss(self, name: str, data: list[dict]) -> int: + client = self.aoss_client + if not client: + logger.warning('No OpenSearch client found') + return 0 + for index in aoss_indices: if name.lower() == index['index_name']: to_index = [] @@ -237,7 +250,7 @@ class GraphDriver(ABC): item[p] = d[p] to_index.append(item) - success, failed = helpers.bulk(self.aoss_client, to_index, stats_only=True) + success, failed = helpers.bulk(client, to_index, stats_only=True) return success if failed == 0 else success diff --git a/graphiti_core/driver/falkordb_driver.py b/graphiti_core/driver/falkordb_driver.py index 00c39342..984b070b 100644 --- a/graphiti_core/driver/falkordb_driver.py +++ b/graphiti_core/driver/falkordb_driver.py @@ -39,6 +39,7 @@ logger = logging.getLogger(__name__) class FalkorDriverSession(GraphDriverSession): provider = GraphProvider.FALKORDB + aoss_client: None def __init__(self, graph: FalkorGraph): self.graph = graph diff --git a/graphiti_core/driver/kuzu_driver.py b/graphiti_core/driver/kuzu_driver.py index af371b2f..2e0edbe3 100644 --- a/graphiti_core/driver/kuzu_driver.py +++ b/graphiti_core/driver/kuzu_driver.py @@ -92,6 +92,7 @@ SCHEMA_QUERIES = """ class KuzuDriver(GraphDriver): provider: GraphProvider = GraphProvider.KUZU + aoss_client: None def __init__( self, diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index 8c3745d8..c9da7512 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -209,11 +209,11 @@ async def edge_fulltext_search( # Match the edge ids and return the values query = ( """ - UNWIND $ids as id - MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) - WHERE e.group_id IN $group_ids - AND id(e)=id - """ + UNWIND $ids as id + MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) + WHERE e.group_id IN $group_ids + AND id(e)=id + """ + filter_query + """ AND id(e)=id @@ -343,8 +343,8 @@ async def edge_similarity_search( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) - """ + MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) + """ + filter_query + """ RETURN DISTINCT id(e) as id, e.fact_embedding as embedding @@ -620,11 +620,11 @@ async def node_fulltext_search( # Match the edge ides and return the values query = ( """ - UNWIND $ids as i - MATCH (n:Entity) - WHERE n.uuid=i.id - RETURN - """ + UNWIND $ids as i + MATCH (n:Entity) + WHERE n.uuid=i.id + RETURN + """ + get_entity_node_return_query(driver.provider) + """ ORDER BY i.score DESC @@ -731,8 +731,8 @@ async def node_similarity_search( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - MATCH (n:Entity) - """ + MATCH (n:Entity) + """ + filter_query + """ RETURN DISTINCT id(n) as id, n.name_embedding as embedding @@ -761,11 +761,11 @@ async def node_similarity_search( # Match the edge ides and return the values query = ( """ - UNWIND $ids as i - MATCH (n:Entity) - WHERE id(n)=i.id - RETURN - """ + UNWIND $ids as i + MATCH (n:Entity) + WHERE id(n)=i.id + RETURN + """ + get_entity_node_return_query(driver.provider) + """ ORDER BY i.score DESC @@ -810,8 +810,8 @@ async def node_similarity_search( else: query = ( """ - MATCH (n:Entity) - """ + MATCH (n:Entity) + """ + filter_query + """ WITH n, """ @@ -989,7 +989,7 @@ async def episode_fulltext_search( _source=['uuid'], query={ 'bool': { - 'filter': [{'terms': {'group_id': group_ids}}], + 'filter': {'terms': group_ids}, 'must': [ { 'multi_match': { @@ -1005,37 +1005,14 @@ async def episode_fulltext_search( ) if res['hits']['total']['value'] > 0: - # Calculate Cosine similarity then return the edge ids - input_uuids = [] + input_uuids = {} for r in res['hits']['hits']: - input_uuids.append({'uuid': r['_source']['uuid'], 'score': r['_score']}) + input_uuids[r['_source']['uuid']] = r['_score'] - # Match the edge ides and return the values - query = """ - UNWIND $uuids as i - MATCH (e:Episodic) - WHERE e.uuid=i.uuid - RETURN - e.content AS content, - e.created_at AS created_at, - e.valid_at AS valid_at, - e.uuid AS uuid, - e.name AS name, - e.group_id AS group_id, - e.source_description AS source_description, - e.source AS source, - e.entity_edges AS entity_edges - ORDER BY i.score DESC - LIMIT $limit - """ - records, _, _ = await driver.execute_query( - query, - uuids=input_uuids, - query=fuzzy_query, - limit=limit, - routing_='r', - **filter_params, - ) + # Get nodes + episodes = await EpisodicNode.get_by_uuids(driver, list(input_uuids.keys())) + episodes.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True) + return episodes else: return [] else: @@ -1165,8 +1142,8 @@ async def community_similarity_search( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - MATCH (n:Community) - """ + MATCH (n:Community) + """ + group_filter_query + """ RETURN DISTINCT id(n) as id, n.name_embedding as embedding @@ -1225,8 +1202,8 @@ async def community_similarity_search( query = ( """ - MATCH (c:Community) - """ + MATCH (c:Community) + """ + group_filter_query + """ WITH c, @@ -1368,9 +1345,9 @@ async def get_relevant_nodes( # FIXME: Kuzu currently does not support using variables such as `node.fulltext_query` as an input to FTS, which means `get_relevant_nodes()` won't work with Kuzu as the graph driver. query = ( """ - UNWIND $nodes AS node - MATCH (n:Entity {group_id: $group_id}) - """ + UNWIND $nodes AS node + MATCH (n:Entity {group_id: $group_id}) + """ + filter_query + """ WITH node, n, """ @@ -1415,9 +1392,9 @@ async def get_relevant_nodes( else: query = ( """ - UNWIND $nodes AS node - MATCH (n:Entity {group_id: $group_id}) - """ + UNWIND $nodes AS node + MATCH (n:Entity {group_id: $group_id}) + """ + filter_query + """ WITH node, n, """ @@ -1506,9 +1483,9 @@ async def get_relevant_edges( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) - """ + UNWIND $edges AS edge + MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) + """ + filter_query + """ WITH e, edge @@ -1578,9 +1555,9 @@ async def get_relevant_edges( query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid}) - """ + UNWIND $edges AS edge + MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid}) + """ + filter_query + """ WITH e, edge, n, m, """ @@ -1616,9 +1593,9 @@ async def get_relevant_edges( else: query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) - """ + UNWIND $edges AS edge + MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) + """ + filter_query + """ WITH e, edge, """ @@ -1691,10 +1668,10 @@ async def get_edge_invalidation_candidates( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) - WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] - """ + UNWIND $edges AS edge + MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) + WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] + """ + filter_query + """ WITH e, edge @@ -1764,10 +1741,10 @@ async def get_edge_invalidation_candidates( query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity) - WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]) - """ + UNWIND $edges AS edge + MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity) + WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]) + """ + filter_query + """ WITH edge, e, n, m, """ @@ -1803,10 +1780,10 @@ async def get_edge_invalidation_candidates( else: query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) - WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] - """ + UNWIND $edges AS edge + MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) + WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] + """ + filter_query + """ WITH edge, e, """