From 37715f6261f3e4d65735d03e71b21176e0dd0057 Mon Sep 17 00:00:00 2001 From: prestonrasmussen Date: Thu, 11 Sep 2025 14:36:23 -0400 Subject: [PATCH] updates --- examples/podcast/podcast_runner.py | 16 +- graphiti_core/driver/driver.py | 31 +++- graphiti_core/driver/neo4j_driver.py | 4 +- graphiti_core/edges.py | 40 +++++ graphiti_core/nodes.py | 82 +++++++++- graphiti_core/search/search_utils.py | 153 ++++++++++++------ .../maintenance/graph_data_operations.py | 8 +- 7 files changed, 268 insertions(+), 66 deletions(-) diff --git a/examples/podcast/podcast_runner.py b/examples/podcast/podcast_runner.py index 70201b9b..a77cc5fc 100644 --- a/examples/podcast/podcast_runner.py +++ b/examples/podcast/podcast_runner.py @@ -25,6 +25,7 @@ from pydantic import BaseModel, Field from transcript_parser import parse_podcast_messages from graphiti_core import Graphiti +from graphiti_core.driver.neo4j_driver import Neo4jDriver from graphiti_core.nodes import EpisodeType from graphiti_core.utils.bulk_utils import RawEpisode from graphiti_core.utils.maintenance.graph_data_operations import clear_data @@ -34,6 +35,8 @@ load_dotenv() neo4j_uri = os.environ.get('NEO4J_URI') or 'bolt://localhost:7687' neo4j_user = os.environ.get('NEO4J_USER') or 'neo4j' neo4j_password = os.environ.get('NEO4J_PASSWORD') or 'password' +aoss_host = os.environ.get('AOSS_HOST') or None +aoss_port = os.environ.get('AOSS_PORT') or None def setup_logging(): @@ -77,11 +80,16 @@ class IsPresidentOf(BaseModel): async def main(use_bulk: bool = False): setup_logging() - client = Graphiti( - neo4j_uri, - neo4j_user, - neo4j_password, + graph_driver = Neo4jDriver( + neo4j_uri, neo4j_user, neo4j_password, aoss_host=aoss_host, aoss_port=aoss_port ) + # client = Graphiti( + # neo4j_uri, + # neo4j_user, + # neo4j_password, + # ) + client = Graphiti(graph_driver=graph_driver) + await client.driver.create_aoss_indices() await clear_data(client.driver) await client.build_indices_and_constraints() messages = parse_podcast_messages() diff --git a/graphiti_core/driver/driver.py b/graphiti_core/driver/driver.py index 40ce2540..25bb71ae 100644 --- a/graphiti_core/driver/driver.py +++ b/graphiti_core/driver/driver.py @@ -48,7 +48,7 @@ class GraphProvider(Enum): aoss_indices = [ { - 'index_name': 'entities', + 'index_name': 'entities_test', 'body': { 'mappings': { 'properties': { @@ -74,7 +74,7 @@ aoss_indices = [ }, }, { - 'index_name': 'communities', + 'index_name': 'communities_test', 'body': { 'mappings': { 'properties': { @@ -102,7 +102,7 @@ aoss_indices = [ }, }, { - 'index_name': 'entity_edges', + 'index_name': 'entity_edges_test', 'body': { 'mappings': { 'properties': { @@ -233,6 +233,31 @@ class GraphDriver(ABC): if client.indices.exists(index=index_name): client.indices.delete(index=index_name) + async def clear_aoss_indices(self): + client = self.aoss_client + + if not client: + logger.warning('No OpenSearch client found') + return + + for index in aoss_indices: + index_name = index['index_name'] + + if client.indices.exists(index=index_name): + try: + # Delete all documents but keep the index + response = client.delete_by_query( + index=index_name, + body={'query': {'match_all': {}}}, + refresh=True, + conflicts='proceed', + ) + logger.info(f"Cleared index '{index_name}': {response}") + except Exception as e: + logger.error(f"Error clearing index '{index_name}': {e}") + else: + logger.warning(f"Index '{index_name}' does not exist") + def save_to_aoss(self, name: str, data: list[dict]) -> int: client = self.aoss_client if not client or not helpers: diff --git a/graphiti_core/driver/neo4j_driver.py b/graphiti_core/driver/neo4j_driver.py index f10d176b..efeadcc3 100644 --- a/graphiti_core/driver/neo4j_driver.py +++ b/graphiti_core/driver/neo4j_driver.py @@ -63,8 +63,8 @@ class Neo4jDriver(GraphDriver): try: session = boto3.Session() self.aoss_client = OpenSearch( # type: ignore - hosts=[{'host': aoss_host, 'port': aoss_port}], - http_auth=Urllib3AWSV4SignerAuth( # type: ignore + hosts=[{'host': aoss_host, 'port': aoss_port, 'scheme': 'https'}], + http_auth=Urllib3AWSV4SignerAuth( session.get_credentials(), session.region_name, 'aoss' ), use_ssl=True, diff --git a/graphiti_core/edges.py b/graphiti_core/edges.py index e1c8e44e..b30d6a32 100644 --- a/graphiti_core/edges.py +++ b/graphiti_core/edges.py @@ -77,6 +77,11 @@ class Edge(BaseModel, ABC): uuid=self.uuid, ) + if driver.aoss_client: + await driver.aoss_client.delete( + index='entity_edges', id=self.uuid, routing=self.group_id + ) + logger.debug(f'Deleted Edge: {self.uuid}') @classmethod @@ -108,6 +113,12 @@ class Edge(BaseModel, ABC): uuids=uuids, ) + if driver.aoss_client: + await driver.aoss_client.delete_by_query( + index='entity_edges', + body={'query': {'terms': {'uuid': uuids}}}, + ) + logger.debug(f'Deleted Edges: {uuids}') def __hash__(self): @@ -351,6 +362,35 @@ class EntityEdge(Edge): raise EdgeNotFoundError(uuid) return edges[0] + @classmethod + async def get_between_nodes( + cls, driver: GraphDriver, source_node_uuid: str, target_node_uuid: str + ): + match_query = """ + MATCH (n:Entity {uuid: $source_node_uuid})-[e:RELATES_TO]->(m:Entity {uuid: $target_node_uuid}) + """ + if driver.provider == GraphProvider.KUZU: + match_query = """ + MATCH (n:Entity {uuid: $source_node_uuid}) + -[:RELATES_TO]->(e:RelatesToNode_) + -[:RELATES_TO]->(m:Entity {uuid: $target_node_uuid}) + """ + + records, _, _ = await driver.execute_query( + match_query + + """ + RETURN + """ + + get_entity_edge_return_query(driver.provider), + source_node_uuid=source_node_uuid, + target_node_uuid=target_node_uuid, + routing_='r', + ) + + edges = [get_entity_edge_from_record(record, driver.provider) for record in records] + + return edges + @classmethod async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]): if len(uuids) == 0: diff --git a/graphiti_core/nodes.py b/graphiti_core/nodes.py index 3adef419..e30a0a66 100644 --- a/graphiti_core/nodes.py +++ b/graphiti_core/nodes.py @@ -94,13 +94,35 @@ class Node(BaseModel, ABC): async def delete(self, driver: GraphDriver): match driver.provider: case GraphProvider.NEO4J: - await driver.execute_query( + result = await driver.execute_query( """ - MATCH (n:Entity|Episodic|Community {uuid: $uuid}) + MATCH (n:Entity|Episodic|Community {uuid: $uuid})-[r]-() + WITH collect(r.uuid) AS edge_uuids, n DETACH DELETE n + RETURN edge_uuids """, uuid=self.uuid, ) + + edge_uuids: list[str] = [] + if result and result[0].get('edge_uuids'): + edge_uuids = result[0]['edge_uuids'] + + if driver.aoss_client: + # Delete the node from OpenSearch indices + for index in ('episodes', 'entities', 'communities'): + await driver.aoss_client.delete( + index=index, id=self.uuid, routing=self.group_id + ) + + # Bulk delete the detached edges + if edge_uuids: + actions = [] + for eid in edge_uuids: + actions.append({'delete': {'_index': 'entity_edges', '_id': eid}}) + + await driver.aoss_client.bulk(body=actions) + case GraphProvider.KUZU: for label in ['Episodic', 'Community']: await driver.execute_query( @@ -162,6 +184,32 @@ class Node(BaseModel, ABC): group_id=group_id, batch_size=batch_size, ) + + if driver.aoss_client: + await driver.aoss_client.delete_by_query( + index='episodes', + body={'query': {'term': {'group_id': group_id}}}, + routing=group_id, + ) + + await driver.aoss_client.delete_by_query( + index='entities', + body={'query': {'term': {'group_id': group_id}}}, + routing=group_id, + ) + + await driver.aoss_client.delete_by_query( + index='communities', + body={'query': {'term': {'group_id': group_id}}}, + routing=group_id, + ) + + await driver.aoss_client.delete_by_query( + index='entity_edges', + body={'query': {'term': {'group_id': group_id}}}, + routing=group_id, + ) + case GraphProvider.KUZU: for label in ['Episodic', 'Community']: await driver.execute_query( @@ -240,6 +288,23 @@ class Node(BaseModel, ABC): ) case _: # Neo4J, Neptune async with driver.session() as session: + # Collect all edge UUIDs before deleting nodes + result = await session.run( + """ + MATCH (n:Entity|Episodic|Community) + WHERE n.uuid IN $uuids + MATCH (n)-[r]-() + RETURN collect(r.uuid) AS edgeUuids + """, + uuids=uuids, + ) + + record = await result.single() + edge_uuids: list[str] = ( + record['edgeUuids'] if record and record['edgeUuids'] else [] + ) + + # Now delete the nodes in batches await session.run( """ MATCH (n:Entity|Episodic|Community) @@ -253,6 +318,19 @@ class Node(BaseModel, ABC): batch_size=batch_size, ) + if driver.aoss_client: + for index in ('episodes', 'entities', 'communities'): + await driver.aoss_client.delete_by_query( + index=index, + body={'query': {'terms': {'uuid': uuids}}}, + ) + + if edge_uuids: + actions = [ + {'delete': {'_index': 'entity_edges', '_id': eid}} for eid in edge_uuids + ] + await driver.aoss_client.bulk(body=actions) + @classmethod async def get_by_uuid(cls, driver: GraphDriver, uuid: str): ... diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index 40eeaa7e..15790f5f 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -209,11 +209,11 @@ async def edge_fulltext_search( # Match the edge ids and return the values query = ( """ - UNWIND $ids as id - MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) - WHERE e.group_id IN $group_ids - AND id(e)=id - """ + UNWIND $ids as id + MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) + WHERE e.group_id IN $group_ids + AND id(e)=id + """ + filter_query + """ AND id(e)=id @@ -344,8 +344,8 @@ async def edge_similarity_search( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) - """ + MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) + """ + filter_query + """ RETURN DISTINCT id(e) as id, e.fact_embedding as embedding @@ -622,11 +622,11 @@ async def node_fulltext_search( # Match the edge ides and return the values query = ( """ - UNWIND $ids as i - MATCH (n:Entity) - WHERE n.uuid=i.id - RETURN - """ + UNWIND $ids as i + MATCH (n:Entity) + WHERE n.uuid=i.id + RETURN + """ + get_entity_node_return_query(driver.provider) + """ ORDER BY i.score DESC @@ -734,8 +734,8 @@ async def node_similarity_search( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - MATCH (n:Entity) - """ + MATCH (n:Entity) + """ + filter_query + """ RETURN DISTINCT id(n) as id, n.name_embedding as embedding @@ -764,11 +764,11 @@ async def node_similarity_search( # Match the edge ides and return the values query = ( """ - UNWIND $ids as i - MATCH (n:Entity) - WHERE id(n)=i.id - RETURN - """ + UNWIND $ids as i + MATCH (n:Entity) + WHERE id(n)=i.id + RETURN + """ + get_entity_node_return_query(driver.provider) + """ ORDER BY i.score DESC @@ -814,8 +814,8 @@ async def node_similarity_search( else: query = ( """ - MATCH (n:Entity) - """ + MATCH (n:Entity) + """ + filter_query + """ WITH n, """ @@ -1147,8 +1147,8 @@ async def community_similarity_search( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - MATCH (n:Community) - """ + MATCH (n:Community) + """ + group_filter_query + """ RETURN DISTINCT id(n) as id, n.name_embedding as embedding @@ -1207,8 +1207,8 @@ async def community_similarity_search( query = ( """ - MATCH (c:Community) - """ + MATCH (c:Community) + """ + group_filter_query + """ WITH c, @@ -1350,9 +1350,9 @@ async def get_relevant_nodes( # FIXME: Kuzu currently does not support using variables such as `node.fulltext_query` as an input to FTS, which means `get_relevant_nodes()` won't work with Kuzu as the graph driver. query = ( """ - UNWIND $nodes AS node - MATCH (n:Entity {group_id: $group_id}) - """ + UNWIND $nodes AS node + MATCH (n:Entity {group_id: $group_id}) + """ + filter_query + """ WITH node, n, """ @@ -1397,9 +1397,9 @@ async def get_relevant_nodes( else: query = ( """ - UNWIND $nodes AS node - MATCH (n:Entity {group_id: $group_id}) - """ + UNWIND $nodes AS node + MATCH (n:Entity {group_id: $group_id}) + """ + filter_query + """ WITH node, n, """ @@ -1488,9 +1488,9 @@ async def get_relevant_edges( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) - """ + UNWIND $edges AS edge + MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) + """ + filter_query + """ WITH e, edge @@ -1560,9 +1560,9 @@ async def get_relevant_edges( query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid}) - """ + UNWIND $edges AS edge + MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid}) + """ + filter_query + """ WITH e, edge, n, m, """ @@ -1595,12 +1595,61 @@ async def get_relevant_edges( }) AS matches """ ) + elif driver.aoss_client: + # First get edge candidates + query = ( + """ + UNWIND $edges AS edge + MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) + """ + + filter_query + + """ + RETURN + e.uuid AS search_edge_uuid, + collect({ + uuid: e.uuid, + source_node_uuid: startNode(e).uuid, + target_node_uuid: endNode(e).uuid, + created_at: e.created_at, + name: e.name, + group_id: e.group_id, + fact: e.fact, + fact_embedding: e.fact_embedding, + episodes: e.episodes, + expired_at: e.expired_at, + valid_at: e.valid_at, + invalid_at: e.invalid_at, + attributes: properties(e) + }) AS matches + """ + ) + + results, _, _ = await driver.execute_query( + query, + edges=[edge.model_dump() for edge in edges], + limit=limit, + min_score=min_score, + routing_='r', + **filter_params, + ) + + relevant_edges_dict: dict[str, list[EntityEdge]] = { + result['search_edge_uuid']: [ + get_entity_edge_from_record(record, driver.provider) + for record in result['matches'] + ] + for result in results + } + + group_id = edges[0].group_id + # semaphore_gather(*[edge_similarity_search(driver, )]) + else: query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) - """ + UNWIND $edges AS edge + MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) + """ + filter_query + """ WITH e, edge, """ @@ -1673,10 +1722,10 @@ async def get_edge_invalidation_candidates( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) - WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] - """ + UNWIND $edges AS edge + MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) + WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] + """ + filter_query + """ WITH e, edge @@ -1746,10 +1795,10 @@ async def get_edge_invalidation_candidates( query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity) - WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]) - """ + UNWIND $edges AS edge + MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity) + WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]) + """ + filter_query + """ WITH edge, e, n, m, """ @@ -1785,10 +1834,10 @@ async def get_edge_invalidation_candidates( else: query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) - WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] - """ + UNWIND $edges AS edge + MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) + WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] + """ + filter_query + """ WITH edge, e, """ diff --git a/graphiti_core/utils/maintenance/graph_data_operations.py b/graphiti_core/utils/maintenance/graph_data_operations.py index e8bbc86c..f62fc6a2 100644 --- a/graphiti_core/utils/maintenance/graph_data_operations.py +++ b/graphiti_core/utils/maintenance/graph_data_operations.py @@ -95,6 +95,8 @@ async def clear_data(driver: GraphDriver, group_ids: list[str] | None = None): async def delete_all(tx): await tx.run('MATCH (n) DETACH DELETE n') + if driver.aoss_client: + await driver.clear_aoss_indices() async def delete_group_ids(tx): labels = ['Entity', 'Episodic', 'Community'] @@ -151,9 +153,9 @@ async def retrieve_episodes( query: LiteralString = ( """ - MATCH (e:Episodic) - WHERE e.valid_at <= $reference_time - """ + MATCH (e:Episodic) + WHERE e.valid_at <= $reference_time + """ + query_filter + """ RETURN