diff --git a/graphiti_core/driver/driver.py b/graphiti_core/driver/driver.py index 157ef4a1..1f47822c 100644 --- a/graphiti_core/driver/driver.py +++ b/graphiti_core/driver/driver.py @@ -54,7 +54,7 @@ class GraphProvider(Enum): aoss_indices = [ { - 'index_name': ENTTITY_INDEX_NAME, + 'index_name': ENTITY_INDEX_NAME, 'body': { 'settings': {'index': {'knn': True}}, 'mappings': { diff --git a/graphiti_core/nodes.py b/graphiti_core/nodes.py index ea40c242..a11d7400 100644 --- a/graphiti_core/nodes.py +++ b/graphiti_core/nodes.py @@ -29,7 +29,7 @@ from typing_extensions import LiteralString from graphiti_core.driver.driver import ( COMMUNITY_INDEX_NAME, ENTITY_EDGE_INDEX_NAME, - ENTTITY_INDEX_NAME, + ENTITY_INDEX_NAME, EPISODE_INDEX_NAME, GraphDriver, GraphProvider, @@ -117,7 +117,7 @@ class Node(BaseModel, ABC): if driver.aoss_client: # Delete the node from OpenSearch indices - for index in (EPISODE_INDEX_NAME, ENTTITY_INDEX_NAME, COMMUNITY_INDEX_NAME): + for index in (EPISODE_INDEX_NAME, ENTITY_INDEX_NAME, COMMUNITY_INDEX_NAME): await driver.aoss_client.delete( index=index, id=self.uuid, routing=self.group_id ) @@ -202,7 +202,7 @@ class Node(BaseModel, ABC): ) await driver.aoss_client.delete_by_query( - index=ENTTITY_INDEX_NAME, + index=ENTITY_INDEX_NAME, body={'query': {'term': {'group_id': group_id}}}, routing=group_id, ) @@ -328,7 +328,7 @@ class Node(BaseModel, ABC): ) if driver.aoss_client: - for index in (EPISODE_INDEX_NAME, ENTTITY_INDEX_NAME, COMMUNITY_INDEX_NAME): + for index in (EPISODE_INDEX_NAME, ENTITY_INDEX_NAME, COMMUNITY_INDEX_NAME): await driver.aoss_client.delete_by_query( index=index, body={'query': {'terms': {'uuid': uuids}}}, @@ -519,7 +519,7 @@ class EntityNode(Node): 'query': {'multi_match': {'query': self.uuid, 'fields': ['uuid']}}, 'size': 1, }, - index=ENTTITY_INDEX_NAME, + index=ENTITY_INDEX_NAME, routing=self.group_id, ) @@ -567,7 +567,7 @@ class EntityNode(Node): labels = ':'.join(self.labels + ['Entity']) if driver.aoss_client: - driver.save_to_aoss(ENTTITY_INDEX_NAME, [entity_data]) # pyright: ignore reportAttributeAccessIssue + driver.save_to_aoss(ENTITY_INDEX_NAME, [entity_data]) # pyright: ignore reportAttributeAccessIssue result = await driver.execute_query( get_entity_node_save_query(driver.provider, labels, bool(driver.aoss_client)), diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index 9fd66bf2..3c475a20 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -25,7 +25,7 @@ from typing_extensions import LiteralString from graphiti_core.driver.driver import ( ENTITY_EDGE_INDEX_NAME, - ENTTITY_INDEX_NAME, + ENTITY_INDEX_NAME, EPISODE_INDEX_NAME, GraphDriver, GraphProvider, @@ -215,11 +215,11 @@ async def edge_fulltext_search( # Match the edge ids and return the values query = ( """ - UNWIND $ids as id - MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) - WHERE e.group_id IN $group_ids - AND id(e)=id - """ + UNWIND $ids as id + MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) + WHERE e.group_id IN $group_ids + AND id(e)=id + """ + filter_query + """ AND id(e)=id @@ -353,8 +353,8 @@ async def edge_similarity_search( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) - """ + MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) + """ + filter_query + """ RETURN DISTINCT id(e) as id, e.fact_embedding as embedding @@ -637,11 +637,11 @@ async def node_fulltext_search( # Match the edge ides and return the values query = ( """ - UNWIND $ids as i - MATCH (n:Entity) - WHERE n.uuid=i.id - RETURN - """ + UNWIND $ids as i + MATCH (n:Entity) + WHERE n.uuid=i.id + RETURN + """ + get_entity_node_return_query(driver.provider) + """ ORDER BY i.score DESC @@ -662,7 +662,7 @@ async def node_fulltext_search( route = group_ids[0] if group_ids else None filters = build_aoss_node_filters(group_ids or [], search_filter) res = driver.aoss_client.search( - index=ENTTITY_INDEX_NAME, + index=ENTITY_INDEX_NAME, routing=route, _source=['uuid'], size=limit, @@ -751,8 +751,8 @@ async def node_similarity_search( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - MATCH (n:Entity) - """ + MATCH (n:Entity) + """ + filter_query + """ RETURN DISTINCT id(n) as id, n.name_embedding as embedding @@ -781,11 +781,11 @@ async def node_similarity_search( # Match the edge ides and return the values query = ( """ - UNWIND $ids as i - MATCH (n:Entity) - WHERE id(n)=i.id - RETURN - """ + UNWIND $ids as i + MATCH (n:Entity) + WHERE id(n)=i.id + RETURN + """ + get_entity_node_return_query(driver.provider) + """ ORDER BY i.score DESC @@ -807,7 +807,7 @@ async def node_similarity_search( route = group_ids[0] if group_ids else None filters = build_aoss_node_filters(group_ids or [], search_filter) res = driver.aoss_client.search( - index=ENTTITY_INDEX_NAME, + index=ENTITY_INDEX_NAME, routing=route, _source=['uuid'], size=limit, @@ -837,8 +837,8 @@ async def node_similarity_search( else: query = ( """ - MATCH (n:Entity) - """ + MATCH (n:Entity) + """ + filter_query + """ WITH n, """ @@ -1170,8 +1170,8 @@ async def community_similarity_search( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - MATCH (n:Community) - """ + MATCH (n:Community) + """ + group_filter_query + """ RETURN DISTINCT id(n) as id, n.name_embedding as embedding @@ -1230,8 +1230,8 @@ async def community_similarity_search( query = ( """ - MATCH (c:Community) - """ + MATCH (c:Community) + """ + group_filter_query + """ WITH c, @@ -1373,9 +1373,9 @@ async def get_relevant_nodes( # FIXME: Kuzu currently does not support using variables such as `node.fulltext_query` as an input to FTS, which means `get_relevant_nodes()` won't work with Kuzu as the graph driver. query = ( """ - UNWIND $nodes AS node - MATCH (n:Entity {group_id: $group_id}) - """ + UNWIND $nodes AS node + MATCH (n:Entity {group_id: $group_id}) + """ + filter_query + """ WITH node, n, """ @@ -1420,9 +1420,9 @@ async def get_relevant_nodes( else: query = ( """ - UNWIND $nodes AS node - MATCH (n:Entity {group_id: $group_id}) - """ + UNWIND $nodes AS node + MATCH (n:Entity {group_id: $group_id}) + """ + filter_query + """ WITH node, n, """ @@ -1511,9 +1511,9 @@ async def get_relevant_edges( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) - """ + UNWIND $edges AS edge + MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) + """ + filter_query + """ WITH e, edge @@ -1583,9 +1583,9 @@ async def get_relevant_edges( query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid}) - """ + UNWIND $edges AS edge + MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid}) + """ + filter_query + """ WITH e, edge, n, m, """ @@ -1621,9 +1621,9 @@ async def get_relevant_edges( else: query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) - """ + UNWIND $edges AS edge + MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) + """ + filter_query + """ WITH e, edge, """ @@ -1696,10 +1696,10 @@ async def get_edge_invalidation_candidates( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) - WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] - """ + UNWIND $edges AS edge + MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) + WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] + """ + filter_query + """ WITH e, edge @@ -1769,10 +1769,10 @@ async def get_edge_invalidation_candidates( query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity) - WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]) - """ + UNWIND $edges AS edge + MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity) + WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]) + """ + filter_query + """ WITH edge, e, n, m, """ @@ -1808,10 +1808,10 @@ async def get_edge_invalidation_candidates( else: query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) - WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] - """ + UNWIND $edges AS edge + MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) + WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] + """ + filter_query + """ WITH edge, e, """ diff --git a/graphiti_core/utils/bulk_utils.py b/graphiti_core/utils/bulk_utils.py index e1b77dd4..e9b0a31d 100644 --- a/graphiti_core/utils/bulk_utils.py +++ b/graphiti_core/utils/bulk_utils.py @@ -25,7 +25,7 @@ from typing_extensions import Any from graphiti_core.driver.driver import ( ENTITY_EDGE_INDEX_NAME, - ENTTITY_INDEX_NAME, + ENTITY_INDEX_NAME, EPISODE_INDEX_NAME, GraphDriver, GraphDriverSession, @@ -211,7 +211,7 @@ async def add_nodes_and_edges_bulk_tx( if driver.aoss_client: driver.save_to_aoss(EPISODE_INDEX_NAME, episodes) - driver.save_to_aoss(ENTTITY_INDEX_NAME, nodes) + driver.save_to_aoss(ENTITY_INDEX_NAME, nodes) driver.save_to_aoss(ENTITY_EDGE_INDEX_NAME, edges)