diff --git a/examples/podcast/podcast_runner.py b/examples/podcast/podcast_runner.py index e3448682..62e3b485 100644 --- a/examples/podcast/podcast_runner.py +++ b/examples/podcast/podcast_runner.py @@ -80,22 +80,11 @@ class IsPresidentOf(BaseModel): async def main(use_bulk: bool = False): setup_logging() - graph_driver = Neo4jDriver( + client = Graphiti( neo4j_uri, neo4j_user, neo4j_password, - aoss_host=aoss_host, - aoss_port=aoss_port, - region='us-west-2', - service='es', ) - # client = Graphiti( - # neo4j_uri, - # neo4j_user, - # neo4j_password, - # ) - client = Graphiti(graph_driver=graph_driver) - await client.driver.create_aoss_indices() await clear_data(client.driver) await client.build_indices_and_constraints() messages = parse_podcast_messages() diff --git a/graphiti_core/driver/driver.py b/graphiti_core/driver/driver.py index 25bb71ae..63b77442 100644 --- a/graphiti_core/driver/driver.py +++ b/graphiti_core/driver/driver.py @@ -38,6 +38,11 @@ logger = logging.getLogger(__name__) DEFAULT_SIZE = 10 +EPISODE_INDEX_NAME = 'episodes-test' +ENTTITY_INDEX_NAME = 'entities_test' +COMMUNITY_INDEX_NAME = 'communities-test' +ENTITY_EDGE_INDEX_NAME = 'entity_edges_test' + class GraphProvider(Enum): NEO4J = 'neo4j' @@ -48,20 +53,19 @@ class GraphProvider(Enum): aoss_indices = [ { - 'index_name': 'entities_test', + 'index_name': ENTTITY_INDEX_NAME, 'body': { + 'settings': {'index': {'knn': True}}, 'mappings': { 'properties': { 'uuid': {'type': 'keyword'}, 'name': {'type': 'text'}, 'summary': {'type': 'text'}, 'group_id': {'type': 'text'}, - 'created_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"}, + 'created_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'}, 'name_embedding': { 'type': 'knn_vector', - 'dims': EMBEDDING_DIM, - 'index': True, - 'similarity': 'cosine', + 'dimension': EMBEDDING_DIM, 'method': { 'engine': 'faiss', 'space_type': 'cosinesimil', @@ -70,11 +74,11 @@ aoss_indices = [ }, }, } - } + }, }, }, { - 'index_name': 'communities_test', + 'index_name': COMMUNITY_INDEX_NAME, 'body': { 'mappings': { 'properties': { @@ -86,7 +90,7 @@ aoss_indices = [ }, }, { - 'index_name': 'episodes', + 'index_name': EPISODE_INDEX_NAME, 'body': { 'mappings': { 'properties': { @@ -95,30 +99,29 @@ aoss_indices = [ 'source': {'type': 'text'}, 'source_description': {'type': 'text'}, 'group_id': {'type': 'text'}, - 'created_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"}, - 'valid_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"}, + 'created_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'}, + 'valid_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'}, } } }, }, { - 'index_name': 'entity_edges_test', + 'index_name': ENTITY_EDGE_INDEX_NAME, 'body': { + 'settings': {'index': {'knn': True}}, 'mappings': { 'properties': { 'uuid': {'type': 'keyword'}, 'name': {'type': 'text'}, 'fact': {'type': 'text'}, 'group_id': {'type': 'text'}, - 'created_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"}, - 'valid_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"}, - 'expired_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"}, - 'invalid_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"}, + 'created_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'}, + 'valid_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'}, + 'expired_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'}, + 'invalid_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'}, 'fact_embedding': { 'type': 'knn_vector', - 'dims': EMBEDDING_DIM, - 'index': True, - 'similarity': 'cosine', + 'dimension': EMBEDDING_DIM, 'method': { 'engine': 'faiss', 'space_type': 'cosinesimil', @@ -127,7 +130,7 @@ aoss_indices = [ }, }, } - } + }, }, }, ] @@ -219,19 +222,36 @@ class GraphDriver(ABC): client.indices.put_alias(index=physical_index_name, name=alias_name) # Allow some time for index creation - await asyncio.sleep(60) + await asyncio.sleep(1) async def delete_aoss_indices(self): - for index in aoss_indices: - index_name = index['index_name'] - client = self.aoss_client + client = self.aoss_client - if not client: - logger.warning('No OpenSearch client found') - return + if not client: + logger.warning('No OpenSearch client found') + return - if client.indices.exists(index=index_name): - client.indices.delete(index=index_name) + for entry in aoss_indices: + alias_name = entry['index_name'] + + try: + # Resolve alias → indices + alias_info = client.indices.get_alias(name=alias_name) + indices = list(alias_info.keys()) + + if not indices: + logger.info(f"No indices found for alias '{alias_name}'") + continue + + for index in indices: + if client.indices.exists(index=index): + client.indices.delete(index=index) + logger.info(f"Deleted index '{index}' (alias: {alias_name})") + else: + logger.warning(f"Index '{index}' not found for alias '{alias_name}'") + + except Exception as e: + logger.error(f"Error deleting indices for alias '{alias_name}': {e}") async def clear_aoss_indices(self): client = self.aoss_client @@ -277,7 +297,9 @@ class GraphDriver(ABC): item[p] = d[p] to_index.append(item) - success, failed = helpers.bulk(client, to_index, stats_only=True) + success, failed = helpers.bulk( + client, to_index, stats_only=True, request_timeout=60 + ) return success if failed == 0 else success diff --git a/graphiti_core/driver/neo4j_driver.py b/graphiti_core/driver/neo4j_driver.py index c8be811d..7e47e7b3 100644 --- a/graphiti_core/driver/neo4j_driver.py +++ b/graphiti_core/driver/neo4j_driver.py @@ -56,8 +56,9 @@ class Neo4jDriver(GraphDriver): database: str = 'neo4j', aoss_host: str | None = None, aoss_port: int | None = None, - region: str | None = None, - service: str | None = None, + aws_profile_name: str | None = None, + aws_region: str | None = None, + aws_service: str | None = None, ): super().__init__() self.client = AsyncGraphDatabase.driver( @@ -69,9 +70,9 @@ class Neo4jDriver(GraphDriver): self.aoss_client = None if aoss_host and aoss_port and boto3 is not None: try: - region = region - service = service - credentials = boto3.Session(profile_name='zep-development').get_credentials() + region = aws_region + service = aws_service + credentials = boto3.Session(profile_name=aws_profile_name).get_credentials() auth = AWSV4SignerAuth(credentials, region, service) self.aoss_client = OpenSearch( diff --git a/graphiti_core/edges.py b/graphiti_core/edges.py index b30d6a32..cda88595 100644 --- a/graphiti_core/edges.py +++ b/graphiti_core/edges.py @@ -25,7 +25,7 @@ from uuid import uuid4 from pydantic import BaseModel, Field from typing_extensions import LiteralString -from graphiti_core.driver.driver import GraphDriver, GraphProvider +from graphiti_core.driver.driver import ENTITY_EDGE_INDEX_NAME, GraphDriver, GraphProvider from graphiti_core.embedder import EmbedderClient from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError from graphiti_core.helpers import parse_db_date @@ -79,7 +79,7 @@ class Edge(BaseModel, ABC): if driver.aoss_client: await driver.aoss_client.delete( - index='entity_edges', id=self.uuid, routing=self.group_id + index=ENTITY_EDGE_INDEX_NAME, id=self.uuid, routing=self.group_id ) logger.debug(f'Deleted Edge: {self.uuid}') @@ -115,7 +115,7 @@ class Edge(BaseModel, ABC): if driver.aoss_client: await driver.aoss_client.delete_by_query( - index='entity_edges', + index=ENTITY_EDGE_INDEX_NAME, body={'query': {'terms': {'uuid': uuids}}}, ) @@ -272,7 +272,7 @@ class EntityEdge(Edge): 'query': {'multi_match': {'query': self.uuid, 'fields': ['uuid']}}, 'size': 1, }, - index='entity_edges', + index=ENTITY_EDGE_INDEX_NAME, routing=self.group_id, ) @@ -325,7 +325,7 @@ class EntityEdge(Edge): edge_data.update(self.attributes or {}) if driver.aoss_client: - driver.save_to_aoss('entity_edges', [edge_data]) # pyright: ignore reportAttributeAccessIssue + driver.save_to_aoss(ENTITY_EDGE_INDEX_NAME, [edge_data]) # pyright: ignore reportAttributeAccessIssue result = await driver.execute_query( get_entity_edge_save_query(driver.provider), diff --git a/graphiti_core/nodes.py b/graphiti_core/nodes.py index e30a0a66..ea40c242 100644 --- a/graphiti_core/nodes.py +++ b/graphiti_core/nodes.py @@ -26,7 +26,14 @@ from uuid import uuid4 from pydantic import BaseModel, Field from typing_extensions import LiteralString -from graphiti_core.driver.driver import GraphDriver, GraphProvider +from graphiti_core.driver.driver import ( + COMMUNITY_INDEX_NAME, + ENTITY_EDGE_INDEX_NAME, + ENTTITY_INDEX_NAME, + EPISODE_INDEX_NAME, + GraphDriver, + GraphProvider, +) from graphiti_core.embedder import EmbedderClient from graphiti_core.errors import NodeNotFoundError from graphiti_core.helpers import parse_db_date @@ -110,7 +117,7 @@ class Node(BaseModel, ABC): if driver.aoss_client: # Delete the node from OpenSearch indices - for index in ('episodes', 'entities', 'communities'): + for index in (EPISODE_INDEX_NAME, ENTTITY_INDEX_NAME, COMMUNITY_INDEX_NAME): await driver.aoss_client.delete( index=index, id=self.uuid, routing=self.group_id ) @@ -119,7 +126,9 @@ class Node(BaseModel, ABC): if edge_uuids: actions = [] for eid in edge_uuids: - actions.append({'delete': {'_index': 'entity_edges', '_id': eid}}) + actions.append( + {'delete': {'_index': ENTITY_EDGE_INDEX_NAME, '_id': eid}} + ) await driver.aoss_client.bulk(body=actions) @@ -187,25 +196,25 @@ class Node(BaseModel, ABC): if driver.aoss_client: await driver.aoss_client.delete_by_query( - index='episodes', + index=EPISODE_INDEX_NAME, body={'query': {'term': {'group_id': group_id}}}, routing=group_id, ) await driver.aoss_client.delete_by_query( - index='entities', + index=ENTTITY_INDEX_NAME, body={'query': {'term': {'group_id': group_id}}}, routing=group_id, ) await driver.aoss_client.delete_by_query( - index='communities', + index=COMMUNITY_INDEX_NAME, body={'query': {'term': {'group_id': group_id}}}, routing=group_id, ) await driver.aoss_client.delete_by_query( - index='entity_edges', + index=ENTITY_EDGE_INDEX_NAME, body={'query': {'term': {'group_id': group_id}}}, routing=group_id, ) @@ -319,7 +328,7 @@ class Node(BaseModel, ABC): ) if driver.aoss_client: - for index in ('episodes', 'entities', 'communities'): + for index in (EPISODE_INDEX_NAME, ENTTITY_INDEX_NAME, COMMUNITY_INDEX_NAME): await driver.aoss_client.delete_by_query( index=index, body={'query': {'terms': {'uuid': uuids}}}, @@ -327,7 +336,8 @@ class Node(BaseModel, ABC): if edge_uuids: actions = [ - {'delete': {'_index': 'entity_edges', '_id': eid}} for eid in edge_uuids + {'delete': {'_index': ENTITY_EDGE_INDEX_NAME, '_id': eid}} + for eid in edge_uuids ] await driver.aoss_client.bulk(body=actions) @@ -509,7 +519,7 @@ class EntityNode(Node): 'query': {'multi_match': {'query': self.uuid, 'fields': ['uuid']}}, 'size': 1, }, - index='entities', + index=ENTTITY_INDEX_NAME, routing=self.group_id, ) @@ -557,7 +567,7 @@ class EntityNode(Node): labels = ':'.join(self.labels + ['Entity']) if driver.aoss_client: - driver.save_to_aoss('entities', [entity_data]) # pyright: ignore reportAttributeAccessIssue + driver.save_to_aoss(ENTTITY_INDEX_NAME, [entity_data]) # pyright: ignore reportAttributeAccessIssue result = await driver.execute_query( get_entity_node_save_query(driver.provider, labels, bool(driver.aoss_client)), diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index 27aefa89..efc9f581 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -23,7 +23,13 @@ import numpy as np from numpy._typing import NDArray from typing_extensions import LiteralString -from graphiti_core.driver.driver import GraphDriver, GraphProvider +from graphiti_core.driver.driver import ( + ENTITY_EDGE_INDEX_NAME, + ENTTITY_INDEX_NAME, + EPISODE_INDEX_NAME, + GraphDriver, + GraphProvider, +) from graphiti_core.edges import EntityEdge, get_entity_edge_from_record from graphiti_core.graph_queries import ( get_nodes_query, @@ -209,11 +215,11 @@ async def edge_fulltext_search( # Match the edge ids and return the values query = ( """ - UNWIND $ids as id - MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) - WHERE e.group_id IN $group_ids - AND id(e)=id - """ + UNWIND $ids as id + MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) + WHERE e.group_id IN $group_ids + AND id(e)=id + """ + filter_query + """ AND id(e)=id @@ -249,16 +255,19 @@ async def edge_fulltext_search( route = group_ids[0] if group_ids else None filters = build_aoss_edge_filters(group_ids or [], search_filter) res = driver.aoss_client.search( - index='entity_edges', + index=ENTITY_EDGE_INDEX_NAME, routing=route, _source=['uuid'], - query={ - 'bool': { - 'filter': filters, - 'must': [{'match': {'fact': {'query': query, 'operator': 'or'}}}], + body={ + 'query': { + 'bool': { + 'filter': filters, + 'must': [{'match': {'fact': {'query': query, 'operator': 'or'}}}], + } } }, ) + if res['hits']['total']['value'] > 0: input_uuids = {} for r in res['hits']['hits']: @@ -344,8 +353,8 @@ async def edge_similarity_search( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) - """ + MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) + """ + filter_query + """ RETURN DISTINCT id(e) as id, e.fact_embedding as embedding @@ -407,16 +416,17 @@ async def edge_similarity_search( route = group_ids[0] if group_ids else None filters = build_aoss_edge_filters(group_ids or [], search_filter) res = driver.aoss_client.search( - index='entity_edges', + index=ENTITY_EDGE_INDEX_NAME, routing=route, _source=['uuid'], - knn={ - 'field': 'fact_embedding', - 'query_vector': search_vector, - 'k': limit, - 'num_candidates': 1000, + size=limit, + body={ + 'query': { + 'knn': { + 'fact_embedding': {'vector': list(map(float, search_vector)), 'k': limit} + } + } }, - query={'bool': {'filter': filters}}, ) if res['hits']['total']['value'] > 0: @@ -428,6 +438,7 @@ async def edge_similarity_search( entity_edges = await EntityEdge.get_by_uuids(driver, list(input_uuids.keys())) entity_edges.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True) return entity_edges + return [] else: query = ( @@ -622,11 +633,11 @@ async def node_fulltext_search( # Match the edge ides and return the values query = ( """ - UNWIND $ids as i - MATCH (n:Entity) - WHERE n.uuid=i.id - RETURN - """ + UNWIND $ids as i + MATCH (n:Entity) + WHERE n.uuid=i.id + RETURN + """ + get_entity_node_return_query(driver.provider) + """ ORDER BY i.score DESC @@ -647,24 +658,26 @@ async def node_fulltext_search( route = group_ids[0] if group_ids else None filters = build_aoss_node_filters(group_ids or [], search_filter) res = driver.aoss_client.search( - 'entities', + index=ENTTITY_INDEX_NAME, routing=route, _source=['uuid'], - query={ - 'bool': { - 'filter': filters, - 'must': [ - { - 'multi_match': { - 'query': query, - 'field': ['name', 'summary'], - 'operator': 'or', + size=limit, + body={ + 'query': { + 'bool': { + 'filter': filters, + 'must': [ + { + 'multi_match': { + 'query': query, + 'fields': ['name', 'summary'], # ✅ fixed key + 'operator': 'or', + } } - } - ], + ], + } } }, - limit=limit, ) if res['hits']['total']['value'] > 0: @@ -734,8 +747,8 @@ async def node_similarity_search( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - MATCH (n:Entity) - """ + MATCH (n:Entity) + """ + filter_query + """ RETURN DISTINCT id(n) as id, n.name_embedding as embedding @@ -764,11 +777,11 @@ async def node_similarity_search( # Match the edge ides and return the values query = ( """ - UNWIND $ids as i - MATCH (n:Entity) - WHERE id(n)=i.id - RETURN - """ + UNWIND $ids as i + MATCH (n:Entity) + WHERE id(n)=i.id + RETURN + """ + get_entity_node_return_query(driver.provider) + """ ORDER BY i.score DESC @@ -790,16 +803,17 @@ async def node_similarity_search( route = group_ids[0] if group_ids else None filters = build_aoss_node_filters(group_ids or [], search_filter) res = driver.aoss_client.search( - index='entities', + index=ENTTITY_INDEX_NAME, routing=route, _source=['uuid'], - knn={ - 'field': 'fact_embedding', - 'query_vector': search_vector, - 'k': limit, - 'num_candidates': 1000, + size=limit, + body={ + 'query': { + 'knn': { + 'name_embedding': {'vector': list(map(float, search_vector)), 'k': limit} + } + } }, - query={'bool': {'filter': filters}}, ) if res['hits']['total']['value'] > 0: @@ -811,11 +825,12 @@ async def node_similarity_search( entity_nodes = await EntityNode.get_by_uuids(driver, list(input_uuids.keys())) entity_nodes.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True) return entity_nodes + return [] else: query = ( """ - MATCH (n:Entity) - """ + MATCH (n:Entity) + """ + filter_query + """ WITH n, """ @@ -989,7 +1004,7 @@ async def episode_fulltext_search( elif driver.aoss_client: route = group_ids[0] if group_ids else None res = driver.aoss_client.search( - 'episodes', + EPISODE_INDEX_NAME, routing=route, _source=['uuid'], query={ @@ -1147,8 +1162,8 @@ async def community_similarity_search( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - MATCH (n:Community) - """ + MATCH (n:Community) + """ + group_filter_query + """ RETURN DISTINCT id(n) as id, n.name_embedding as embedding @@ -1207,8 +1222,8 @@ async def community_similarity_search( query = ( """ - MATCH (c:Community) - """ + MATCH (c:Community) + """ + group_filter_query + """ WITH c, @@ -1350,9 +1365,9 @@ async def get_relevant_nodes( # FIXME: Kuzu currently does not support using variables such as `node.fulltext_query` as an input to FTS, which means `get_relevant_nodes()` won't work with Kuzu as the graph driver. query = ( """ - UNWIND $nodes AS node - MATCH (n:Entity {group_id: $group_id}) - """ + UNWIND $nodes AS node + MATCH (n:Entity {group_id: $group_id}) + """ + filter_query + """ WITH node, n, """ @@ -1397,9 +1412,9 @@ async def get_relevant_nodes( else: query = ( """ - UNWIND $nodes AS node - MATCH (n:Entity {group_id: $group_id}) - """ + UNWIND $nodes AS node + MATCH (n:Entity {group_id: $group_id}) + """ + filter_query + """ WITH node, n, """ @@ -1488,9 +1503,9 @@ async def get_relevant_edges( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) - """ + UNWIND $edges AS edge + MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) + """ + filter_query + """ WITH e, edge @@ -1560,9 +1575,9 @@ async def get_relevant_edges( query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid}) - """ + UNWIND $edges AS edge + MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid}) + """ + filter_query + """ WITH e, edge, n, m, """ @@ -1599,9 +1614,9 @@ async def get_relevant_edges( # First get edge candidates query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) - """ + UNWIND $edges AS edge + MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) + """ + filter_query + """ RETURN @@ -1647,9 +1662,9 @@ async def get_relevant_edges( else: query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) - """ + UNWIND $edges AS edge + MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) + """ + filter_query + """ WITH e, edge, """ @@ -1722,10 +1737,10 @@ async def get_edge_invalidation_candidates( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) - WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] - """ + UNWIND $edges AS edge + MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) + WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] + """ + filter_query + """ WITH e, edge @@ -1795,10 +1810,10 @@ async def get_edge_invalidation_candidates( query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity) - WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]) - """ + UNWIND $edges AS edge + MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity) + WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]) + """ + filter_query + """ WITH edge, e, n, m, """ @@ -1834,10 +1849,10 @@ async def get_edge_invalidation_candidates( else: query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) - WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] - """ + UNWIND $edges AS edge + MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) + WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] + """ + filter_query + """ WITH edge, e, """ diff --git a/graphiti_core/utils/bulk_utils.py b/graphiti_core/utils/bulk_utils.py index ff76a930..e1b77dd4 100644 --- a/graphiti_core/utils/bulk_utils.py +++ b/graphiti_core/utils/bulk_utils.py @@ -23,7 +23,14 @@ import numpy as np from pydantic import BaseModel, Field from typing_extensions import Any -from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider +from graphiti_core.driver.driver import ( + ENTITY_EDGE_INDEX_NAME, + ENTTITY_INDEX_NAME, + EPISODE_INDEX_NAME, + GraphDriver, + GraphDriverSession, + GraphProvider, +) from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge, create_entity_edge_embeddings from graphiti_core.embedder import EmbedderClient from graphiti_core.graphiti_types import GraphitiClients @@ -203,9 +210,9 @@ async def add_nodes_and_edges_bulk_tx( ) if driver.aoss_client: - driver.save_to_aoss('episodes', episodes) - driver.save_to_aoss('entities', nodes) - driver.save_to_aoss('entity_edges', edges) + driver.save_to_aoss(EPISODE_INDEX_NAME, episodes) + driver.save_to_aoss(ENTTITY_INDEX_NAME, nodes) + driver.save_to_aoss(ENTITY_EDGE_INDEX_NAME, edges) async def extract_nodes_and_edges_bulk(