From fd1c360e8c7766781faeccdcb37c19a72b409fdd Mon Sep 17 00:00:00 2001 From: prestonrasmussen Date: Sun, 14 Sep 2025 01:23:13 -0400 Subject: [PATCH] update --- examples/podcast/podcast_runner.py | 18 +-- graphiti_core/driver/driver.py | 7 +- graphiti_core/driver/neo4j_driver.py | 1 - graphiti_core/driver/neptune_driver.py | 4 +- graphiti_core/nodes.py | 3 +- graphiti_core/search/search_utils.py | 106 +++++++++--------- .../utils/maintenance/edge_operations.py | 1 - 7 files changed, 60 insertions(+), 80 deletions(-) diff --git a/examples/podcast/podcast_runner.py b/examples/podcast/podcast_runner.py index 6a4d5468..70201b9b 100644 --- a/examples/podcast/podcast_runner.py +++ b/examples/podcast/podcast_runner.py @@ -25,7 +25,6 @@ from pydantic import BaseModel, Field from transcript_parser import parse_podcast_messages from graphiti_core import Graphiti -from graphiti_core.driver.neo4j_driver import Neo4jDriver from graphiti_core.nodes import EpisodeType from graphiti_core.utils.bulk_utils import RawEpisode from graphiti_core.utils.maintenance.graph_data_operations import clear_data @@ -35,8 +34,6 @@ load_dotenv() neo4j_uri = os.environ.get('NEO4J_URI') or 'bolt://localhost:7687' neo4j_user = os.environ.get('NEO4J_USER') or 'neo4j' neo4j_password = os.environ.get('NEO4J_PASSWORD') or 'password' -aoss_host = os.environ.get('AOSS_HOST') or None -aoss_port = os.environ.get('AOSS_PORT') or None def setup_logging(): @@ -80,25 +77,12 @@ class IsPresidentOf(BaseModel): async def main(use_bulk: bool = False): setup_logging() - graph_driver = Neo4jDriver( + client = Graphiti( neo4j_uri, neo4j_user, neo4j_password, - aoss_host=aoss_host, - aoss_port=int(aoss_port), - aws_profile_name='zep-development', - aws_region='us-west-2', - aws_service='es', ) - # client = Graphiti( - # neo4j_uri, - # neo4j_user, - # neo4j_password, - # ) - client = Graphiti(graph_driver=graph_driver) await clear_data(client.driver) - await client.driver.delete_aoss_indices() - await client.driver.create_aoss_indices() await client.build_indices_and_constraints() messages = parse_podcast_messages() group_id = str(uuid4()) diff --git a/graphiti_core/driver/driver.py b/graphiti_core/driver/driver.py index ba99513f..d359f11f 100644 --- a/graphiti_core/driver/driver.py +++ b/graphiti_core/driver/driver.py @@ -274,10 +274,7 @@ class GraphDriver(ABC): response = await client.delete_by_query( index=index_name, body={'query': {'match_all': {}}}, - refresh=True, - conflicts='proceed', - wait_for_completion=True, - slices='auto', # improves coverage/concurrency + slices='auto', ) logger.info(f"Cleared index '{index_name}': {response}") except Exception as e: @@ -287,7 +284,7 @@ class GraphDriver(ABC): async def save_to_aoss(self, name: str, data: list[dict]) -> int: client = self.aoss_client - if not client: + if not client or not helpers: logger.warning('No OpenSearch client found') return 0 diff --git a/graphiti_core/driver/neo4j_driver.py b/graphiti_core/driver/neo4j_driver.py index 2365100e..e88e1c8c 100644 --- a/graphiti_core/driver/neo4j_driver.py +++ b/graphiti_core/driver/neo4j_driver.py @@ -32,7 +32,6 @@ try: AIOHttpConnection, AsyncOpenSearch, AWSV4SignerAuth, - RequestsHttpConnection, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, ) diff --git a/graphiti_core/driver/neptune_driver.py b/graphiti_core/driver/neptune_driver.py index cb343163..355fded1 100644 --- a/graphiti_core/driver/neptune_driver.py +++ b/graphiti_core/driver/neptune_driver.py @@ -237,12 +237,12 @@ class NeptuneDriver(GraphDriver): 'You must provide an AOSS endpoint to create an OpenSearch driver.' ) if not client.indices.exists(index=index_name): - client.indices.create(index=index_name, body=index['body']) + await client.indices.create(index=index_name, body=index['body']) alias_name = index.get('alias_name', index_name) if not client.indices.exists_alias(name=alias_name, index=index_name): - client.indices.put_alias(index=index_name, name=alias_name) + await client.indices.put_alias(index=index_name, name=alias_name) # Sleep for 1 minute to let the index creation complete await asyncio.sleep(60) diff --git a/graphiti_core/nodes.py b/graphiti_core/nodes.py index 524dbd0f..45599b01 100644 --- a/graphiti_core/nodes.py +++ b/graphiti_core/nodes.py @@ -103,8 +103,9 @@ class Node(BaseModel, ABC): case GraphProvider.NEO4J: records, _, _ = await driver.execute_query( """ - MATCH (n {uuid: $uuid})-[r]-() + MATCH (n {uuid: $uuid}) WHERE n:Entity OR n:Episodic OR n:Community + OPTIONAL MATCH (n)-[r]-() WITH collect(r.uuid) AS edge_uuids, n DETACH DELETE n RETURN edge_uuids diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index f35261bc..2af8dd10 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -215,11 +215,11 @@ async def edge_fulltext_search( # Match the edge ids and return the values query = ( """ - UNWIND $ids as id - MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) - WHERE e.group_id IN $group_ids - AND id(e)=id - """ + UNWIND $ids as id + MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) + WHERE e.group_id IN $group_ids + AND id(e)=id + """ + filter_query + """ AND id(e)=id @@ -353,8 +353,8 @@ async def edge_similarity_search( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) - """ + MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) + """ + filter_query + """ RETURN DISTINCT id(e) as id, e.fact_embedding as embedding @@ -637,11 +637,11 @@ async def node_fulltext_search( # Match the edge ides and return the values query = ( """ - UNWIND $ids as i - MATCH (n:Entity) - WHERE n.uuid=i.id - RETURN - """ + UNWIND $ids as i + MATCH (n:Entity) + WHERE n.uuid=i.id + RETURN + """ + get_entity_node_return_query(driver.provider) + """ ORDER BY i.score DESC @@ -751,8 +751,8 @@ async def node_similarity_search( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - MATCH (n:Entity) - """ + MATCH (n:Entity) + """ + filter_query + """ RETURN DISTINCT id(n) as id, n.name_embedding as embedding @@ -781,11 +781,11 @@ async def node_similarity_search( # Match the edge ides and return the values query = ( """ - UNWIND $ids as i - MATCH (n:Entity) - WHERE id(n)=i.id - RETURN - """ + UNWIND $ids as i + MATCH (n:Entity) + WHERE id(n)=i.id + RETURN + """ + get_entity_node_return_query(driver.provider) + """ ORDER BY i.score DESC @@ -808,7 +808,7 @@ async def node_similarity_search( filters = build_aoss_node_filters(group_ids or [], search_filter) res = await driver.aoss_client.search( index=ENTITY_INDEX_NAME, - routing=route, + params={'routing': route}, _source=['uuid'], size=limit, body={ @@ -837,8 +837,8 @@ async def node_similarity_search( else: query = ( """ - MATCH (n:Entity) - """ + MATCH (n:Entity) + """ + filter_query + """ WITH n, """ @@ -1170,8 +1170,8 @@ async def community_similarity_search( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - MATCH (n:Community) - """ + MATCH (n:Community) + """ + group_filter_query + """ RETURN DISTINCT id(n) as id, n.name_embedding as embedding @@ -1230,8 +1230,8 @@ async def community_similarity_search( query = ( """ - MATCH (c:Community) - """ + MATCH (c:Community) + """ + group_filter_query + """ WITH c, @@ -1373,9 +1373,9 @@ async def get_relevant_nodes( # FIXME: Kuzu currently does not support using variables such as `node.fulltext_query` as an input to FTS, which means `get_relevant_nodes()` won't work with Kuzu as the graph driver. query = ( """ - UNWIND $nodes AS node - MATCH (n:Entity {group_id: $group_id}) - """ + UNWIND $nodes AS node + MATCH (n:Entity {group_id: $group_id}) + """ + filter_query + """ WITH node, n, """ @@ -1420,9 +1420,9 @@ async def get_relevant_nodes( else: query = ( """ - UNWIND $nodes AS node - MATCH (n:Entity {group_id: $group_id}) - """ + UNWIND $nodes AS node + MATCH (n:Entity {group_id: $group_id}) + """ + filter_query + """ WITH node, n, """ @@ -1511,9 +1511,9 @@ async def get_relevant_edges( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) - """ + UNWIND $edges AS edge + MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) + """ + filter_query + """ WITH e, edge @@ -1583,9 +1583,9 @@ async def get_relevant_edges( query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid}) - """ + UNWIND $edges AS edge + MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid}) + """ + filter_query + """ WITH e, edge, n, m, """ @@ -1621,9 +1621,9 @@ async def get_relevant_edges( else: query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) - """ + UNWIND $edges AS edge + MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) + """ + filter_query + """ WITH e, edge, """ @@ -1696,10 +1696,10 @@ async def get_edge_invalidation_candidates( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) - WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] - """ + UNWIND $edges AS edge + MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) + WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] + """ + filter_query + """ WITH e, edge @@ -1769,10 +1769,10 @@ async def get_edge_invalidation_candidates( query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity) - WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]) - """ + UNWIND $edges AS edge + MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity) + WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]) + """ + filter_query + """ WITH edge, e, n, m, """ @@ -1808,10 +1808,10 @@ async def get_edge_invalidation_candidates( else: query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) - WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] - """ + UNWIND $edges AS edge + MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) + WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] + """ + filter_query + """ WITH edge, e, """ diff --git a/graphiti_core/utils/maintenance/edge_operations.py b/graphiti_core/utils/maintenance/edge_operations.py index 19ad339c..5de8d22a 100644 --- a/graphiti_core/utils/maintenance/edge_operations.py +++ b/graphiti_core/utils/maintenance/edge_operations.py @@ -17,7 +17,6 @@ limitations under the License. import logging from datetime import datetime from time import time -from xml.dom.minidom import Entity from pydantic import BaseModel from typing_extensions import LiteralString