diff --git a/graphiti_core/driver/driver.py b/graphiti_core/driver/driver.py index 6ece60bb..4c926d08 100644 --- a/graphiti_core/driver/driver.py +++ b/graphiti_core/driver/driver.py @@ -46,12 +46,25 @@ aoss_indices = [ 'name': {'type': 'text'}, 'summary': {'type': 'text'}, 'group_id': {'type': 'text'}, + 'created_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"}, + 'name_embedding': { + 'type': 'dense_vector', + 'dims': 1024, + 'index': True, + 'similarity': 'cosine', + }, } } }, 'query': { 'query': {'multi_match': {'query': '', 'fields': ['name', 'summary', 'group_id']}}, 'size': DEFAULT_SIZE, + 'knn': { + 'field': 'name_embedding', + 'query_vector': [], + 'k': DEFAULT_SIZE, + 'num_candidates': 100, + }, }, }, { @@ -80,6 +93,8 @@ aoss_indices = [ 'source': {'type': 'text'}, 'source_description': {'type': 'text'}, 'group_id': {'type': 'text'}, + 'created_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"}, + 'valid_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"}, } } }, @@ -102,12 +117,28 @@ aoss_indices = [ 'name': {'type': 'text'}, 'fact': {'type': 'text'}, 'group_id': {'type': 'text'}, + 'created_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"}, + 'valid_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"}, + 'expired_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"}, + 'invalid_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"}, + 'fact_embedding': { + 'type': 'dense_vector', + 'dims': 1024, + 'index': True, + 'similarity': 'cosine', + }, } } }, 'query': { 'query': {'multi_match': {'query': '', 'fields': ['name', 'fact', 'group_id']}}, 'size': DEFAULT_SIZE, + 'knn': { + 'field': 'fact_embedding', + 'query_vector': [], # supply vector at runtime + 'k': DEFAULT_SIZE, + 'num_candidates': 100, + }, }, }, ] diff --git a/graphiti_core/edges.py b/graphiti_core/edges.py index a427d65e..a8394bb2 100644 --- a/graphiti_core/edges.py +++ b/graphiti_core/edges.py @@ -298,7 +298,7 @@ class EntityEdge(Edge): else: edge_data.update(self.attributes or {}) - if driver.provider == GraphProvider.NEPTUNE: + if driver.aoss_client: driver.save_to_aoss('edge_name_and_fact', [edge_data]) # pyright: ignore reportAttributeAccessIssue result = await driver.execute_query( diff --git a/graphiti_core/nodes.py b/graphiti_core/nodes.py index 4c2bbf36..b15977e7 100644 --- a/graphiti_core/nodes.py +++ b/graphiti_core/nodes.py @@ -273,7 +273,7 @@ class EpisodicNode(Node): ) async def save(self, driver: GraphDriver): - if driver.provider == GraphProvider.NEPTUNE: + if driver.aoss_client: driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue 'episode_content', [ @@ -470,7 +470,7 @@ class EntityNode(Node): entity_data.update(self.attributes or {}) labels = ':'.join(self.labels + ['Entity']) - if driver.provider == GraphProvider.NEPTUNE: + if driver.aoss_client: driver.save_to_aoss('node_name_and_summary', [entity_data]) # pyright: ignore reportAttributeAccessIssue result = await driver.execute_query( diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index 379662d5..611fb6d4 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -208,11 +208,11 @@ async def edge_fulltext_search( # Match the edge ids and return the values query = ( """ - UNWIND $ids as id - MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) - WHERE e.group_id IN $group_ids - AND id(e)=id - """ + UNWIND $ids as id + MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) + WHERE e.group_id IN $group_ids + AND id(e)=id + """ + filter_query + """ AND id(e)=id @@ -244,6 +244,41 @@ async def edge_fulltext_search( ) else: return [] + elif driver.aoss_client: + res = driver.run_aoss_query('edge_name_and_fact', query) # pyright: ignore reportAttributeAccessIssue + if res['hits']['total']['value'] > 0: + # Calculate Cosine similarity then return the edge ids + input_uuids = [] + for r in res['hits']['hits']: + input_uuids.append({'uuid': r['_source']['uuid'], 'score': r['_score']}) + + # Match the edge ids and return the values + query = ( + """ + UNWIND $uuids as uuid + MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) + WHERE e.group_id IN $group_ids + AND e.uuid=uuid + """ + + filter_query + + """ + AND e.uuid=uuid + WITH e, id.score as score, startNode(e) AS n, endNode(e) AS m""" + + get_entity_edge_return_query(driver.provider) + + """ORDER BY score DESC LIMIT $limit + """ + ) + + records, _, _ = await driver.execute_query( + query, + query=fuzzy_query, + uuids=input_uuids, + limit=limit, + routing_='r', + **filter_params, + ) + else: + return [] else: query = ( get_relationships_query('edge_name_and_fact', limit=limit, provider=driver.provider) @@ -318,8 +353,8 @@ async def edge_similarity_search( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) - """ + MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) + """ + filter_query + """ RETURN DISTINCT id(e) as id, e.fact_embedding as embedding @@ -571,11 +606,11 @@ async def node_fulltext_search( # Match the edge ides and return the values query = ( """ - UNWIND $ids as i - MATCH (n:Entity) - WHERE n.uuid=i.id - RETURN - """ + UNWIND $ids as i + MATCH (n:Entity) + WHERE n.uuid=i.id + RETURN + """ + get_entity_node_return_query(driver.provider) + """ ORDER BY i.score DESC @@ -592,6 +627,38 @@ async def node_fulltext_search( ) else: return [] + elif driver.aoss_client: + res = driver.run_aoss_query('node_name_and_summary', query, limit=limit) # pyright: ignore reportAttributeAccessIssue + if res['hits']['total']['value'] > 0: + # Calculate Cosine similarity then return the edge ids + input_uuids = [] + for r in res['hits']['hits']: + input_uuids.append({'uuid': r['_source']['uuid'], 'score': r['_score']}) + + # Match the edge ides and return the values + query = ( + """ + UNWIND $uuids as i + MATCH (n:Entity) + WHERE n.uuid=i.uuid + RETURN + """ + + get_entity_node_return_query(driver.provider) + + """ + ORDER BY i.score DESC + LIMIT $limit + """ + ) + records, _, _ = await driver.execute_query( + query, + uuids=input_uuids, + query=fuzzy_query, + limit=limit, + routing_='r', + **filter_params, + ) + else: + return [] else: query = ( get_nodes_query( @@ -648,8 +715,8 @@ async def node_similarity_search( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - MATCH (n:Entity) - """ + MATCH (n:Entity) + """ + filter_query + """ RETURN DISTINCT id(n) as id, n.name_embedding as embedding @@ -678,11 +745,11 @@ async def node_similarity_search( # Match the edge ides and return the values query = ( """ - UNWIND $ids as i - MATCH (n:Entity) - WHERE id(n)=i.id - RETURN - """ + UNWIND $ids as i + MATCH (n:Entity) + WHERE id(n)=i.id + RETURN + """ + get_entity_node_return_query(driver.provider) + """ ORDER BY i.score DESC @@ -703,8 +770,8 @@ async def node_similarity_search( else: query = ( """ - MATCH (n:Entity) - """ + MATCH (n:Entity) + """ + filter_query + """ WITH n, """ @@ -852,7 +919,7 @@ async def episode_fulltext_search( query = """ UNWIND $ids as i MATCH (e:Episodic) - WHERE e.uuid=i.id + WHERE e.uuid=i.uuid RETURN e.content AS content, e.created_at AS created_at, @@ -876,6 +943,42 @@ async def episode_fulltext_search( ) else: return [] + elif driver.aoss_client: + res = driver.run_aoss_query('episode_content', query, limit=limit) # pyright: ignore reportAttributeAccessIssue + if res['hits']['total']['value'] > 0: + # Calculate Cosine similarity then return the edge ids + input_uuids = [] + for r in res['hits']['hits']: + input_uuids.append({'uuid': r['_source']['uuid'], 'score': r['_score']}) + + # Match the edge ides and return the values + query = """ + UNWIND $uuids as i + MATCH (e:Episodic) + WHERE e.uuid=i.uuid + RETURN + e.content AS content, + e.created_at AS created_at, + e.valid_at AS valid_at, + e.uuid AS uuid, + e.name AS name, + e.group_id AS group_id, + e.source_description AS source_description, + e.source AS source, + e.entity_edges AS entity_edges + ORDER BY i.score DESC + LIMIT $limit + """ + records, _, _ = await driver.execute_query( + query, + uuids=input_uuids, + query=fuzzy_query, + limit=limit, + routing_='r', + **filter_params, + ) + else: + return [] else: query = ( get_nodes_query('episode_content', '$query', limit=limit, provider=driver.provider) @@ -1003,8 +1106,8 @@ async def community_similarity_search( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - MATCH (n:Community) - """ + MATCH (n:Community) + """ + group_filter_query + """ RETURN DISTINCT id(n) as id, n.name_embedding as embedding @@ -1063,8 +1166,8 @@ async def community_similarity_search( query = ( """ - MATCH (c:Community) - """ + MATCH (c:Community) + """ + group_filter_query + """ WITH c, @@ -1206,9 +1309,9 @@ async def get_relevant_nodes( # FIXME: Kuzu currently does not support using variables such as `node.fulltext_query` as an input to FTS, which means `get_relevant_nodes()` won't work with Kuzu as the graph driver. query = ( """ - UNWIND $nodes AS node - MATCH (n:Entity {group_id: $group_id}) - """ + UNWIND $nodes AS node + MATCH (n:Entity {group_id: $group_id}) + """ + filter_query + """ WITH node, n, """ @@ -1253,9 +1356,9 @@ async def get_relevant_nodes( else: query = ( """ - UNWIND $nodes AS node - MATCH (n:Entity {group_id: $group_id}) - """ + UNWIND $nodes AS node + MATCH (n:Entity {group_id: $group_id}) + """ + filter_query + """ WITH node, n, """ @@ -1344,9 +1447,9 @@ async def get_relevant_edges( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) - """ + UNWIND $edges AS edge + MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) + """ + filter_query + """ WITH e, edge @@ -1416,9 +1519,9 @@ async def get_relevant_edges( query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid}) - """ + UNWIND $edges AS edge + MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid}) + """ + filter_query + """ WITH e, edge, n, m, """ @@ -1454,9 +1557,9 @@ async def get_relevant_edges( else: query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) - """ + UNWIND $edges AS edge + MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) + """ + filter_query + """ WITH e, edge, """ @@ -1529,10 +1632,10 @@ async def get_edge_invalidation_candidates( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) - WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] - """ + UNWIND $edges AS edge + MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) + WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] + """ + filter_query + """ WITH e, edge @@ -1602,10 +1705,10 @@ async def get_edge_invalidation_candidates( query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity) - WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]) - """ + UNWIND $edges AS edge + MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity) + WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]) + """ + filter_query + """ WITH edge, e, n, m, """ @@ -1641,10 +1744,10 @@ async def get_edge_invalidation_candidates( else: query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) - WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] - """ + UNWIND $edges AS edge + MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) + WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] + """ + filter_query + """ WITH edge, e, """ diff --git a/graphiti_core/utils/bulk_utils.py b/graphiti_core/utils/bulk_utils.py index 76494800..57bbdd45 100644 --- a/graphiti_core/utils/bulk_utils.py +++ b/graphiti_core/utils/bulk_utils.py @@ -194,6 +194,11 @@ async def add_nodes_and_edges_bulk_tx( ) await tx.run(get_entity_edge_save_bulk_query(driver.provider), entity_edges=edges) + if driver.aoss_client: + driver.save_to_aoss('episode_content', episodes) + driver.save_to_aoss('node_name_and_summary', nodes) + driver.save_to_aoss('edge_name_and_summary', edges) + async def extract_nodes_and_edges_bulk( clients: GraphitiClients,