From 836668e9ee63c1a636a6b8c94037b850956c413e Mon Sep 17 00:00:00 2001 From: prestonrasmussen Date: Fri, 12 Sep 2025 12:08:13 -0400 Subject: [PATCH] update --- examples/podcast/podcast_runner.py | 3 - graphiti_core/driver/driver.py | 9 +- graphiti_core/driver/neo4j_driver.py | 2 +- graphiti_core/graphiti.py | 4 +- graphiti_core/search/search_utils.py | 165 +++++++----------- .../utils/maintenance/edge_operations.py | 2 +- uv.lock | 2 +- 7 files changed, 72 insertions(+), 115 deletions(-) diff --git a/examples/podcast/podcast_runner.py b/examples/podcast/podcast_runner.py index 62e3b485..70201b9b 100644 --- a/examples/podcast/podcast_runner.py +++ b/examples/podcast/podcast_runner.py @@ -25,7 +25,6 @@ from pydantic import BaseModel, Field from transcript_parser import parse_podcast_messages from graphiti_core import Graphiti -from graphiti_core.driver.neo4j_driver import Neo4jDriver from graphiti_core.nodes import EpisodeType from graphiti_core.utils.bulk_utils import RawEpisode from graphiti_core.utils.maintenance.graph_data_operations import clear_data @@ -35,8 +34,6 @@ load_dotenv() neo4j_uri = os.environ.get('NEO4J_URI') or 'bolt://localhost:7687' neo4j_user = os.environ.get('NEO4J_USER') or 'neo4j' neo4j_password = os.environ.get('NEO4J_PASSWORD') or 'password' -aoss_host = os.environ.get('AOSS_HOST') or None -aoss_port = os.environ.get('AOSS_PORT') or None def setup_logging(): diff --git a/graphiti_core/driver/driver.py b/graphiti_core/driver/driver.py index 63b77442..157ef4a1 100644 --- a/graphiti_core/driver/driver.py +++ b/graphiti_core/driver/driver.py @@ -17,6 +17,7 @@ limitations under the License. import asyncio import copy import logging +import os from abc import ABC, abstractmethod from collections.abc import Coroutine from datetime import datetime @@ -38,10 +39,10 @@ logger = logging.getLogger(__name__) DEFAULT_SIZE = 10 -EPISODE_INDEX_NAME = 'episodes-test' -ENTTITY_INDEX_NAME = 'entities_test' -COMMUNITY_INDEX_NAME = 'communities-test' -ENTITY_EDGE_INDEX_NAME = 'entity_edges_test' +ENTITY_INDEX_NAME = os.environ.get('ENTITY_INDEX', 'entities') +EPISODE_INDEX_NAME = os.environ.get('EPISODE_INDEX', 'episodes') +COMMUNITY_INDEX_NAME = os.environ.get('COMMUNITY_INDEX_NAME', 'communities') +ENTITY_EDGE_INDEX_NAME = os.environ.get('ENTITY_EDGE_INDEX_NAME', 'entity_edges') class GraphProvider(Enum): diff --git a/graphiti_core/driver/neo4j_driver.py b/graphiti_core/driver/neo4j_driver.py index 7e47e7b3..9cf9041c 100644 --- a/graphiti_core/driver/neo4j_driver.py +++ b/graphiti_core/driver/neo4j_driver.py @@ -73,7 +73,7 @@ class Neo4jDriver(GraphDriver): region = aws_region service = aws_service credentials = boto3.Session(profile_name=aws_profile_name).get_credentials() - auth = AWSV4SignerAuth(credentials, region, service) + auth = AWSV4SignerAuth(credentials, region or '', service or '') self.aoss_client = OpenSearch( hosts=[{'host': aoss_host, 'port': aoss_port}], diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index 55f4fe38..f168deda 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -1035,7 +1035,7 @@ class Graphiti: updated_edge = resolve_edge_pointers([edge], uuid_map)[0] - valid_uuids = await EntityEdge.get_between_nodes( + valid_edges = await EntityEdge.get_between_nodes( self.driver, edge.source_node_uuid, edge.target_node_uuid ) @@ -1045,7 +1045,7 @@ class Graphiti: updated_edge.fact, group_ids=[updated_edge.group_id], config=EDGE_HYBRID_SEARCH_RRF, - search_filter=SearchFilters(uuids=valid_uuids), + search_filter=SearchFilters(uuids=[edge.uuid for edge in valid_edges]), ) ).edges existing_edges = ( diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index efc9f581..9fd66bf2 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -215,11 +215,11 @@ async def edge_fulltext_search( # Match the edge ids and return the values query = ( """ - UNWIND $ids as id - MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) - WHERE e.group_id IN $group_ids - AND id(e)=id - """ + UNWIND $ids as id + MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) + WHERE e.group_id IN $group_ids + AND id(e)=id + """ + filter_query + """ AND id(e)=id @@ -353,8 +353,8 @@ async def edge_similarity_search( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) - """ + MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) + """ + filter_query + """ RETURN DISTINCT id(e) as id, e.fact_embedding as embedding @@ -423,7 +423,11 @@ async def edge_similarity_search( body={ 'query': { 'knn': { - 'fact_embedding': {'vector': list(map(float, search_vector)), 'k': limit} + 'fact_embedding': { + 'vector': list(map(float, search_vector)), + 'k': limit, + 'filter': {'bool': {'filter': filters}}, + } } } }, @@ -633,11 +637,11 @@ async def node_fulltext_search( # Match the edge ides and return the values query = ( """ - UNWIND $ids as i - MATCH (n:Entity) - WHERE n.uuid=i.id - RETURN - """ + UNWIND $ids as i + MATCH (n:Entity) + WHERE n.uuid=i.id + RETURN + """ + get_entity_node_return_query(driver.provider) + """ ORDER BY i.score DESC @@ -747,8 +751,8 @@ async def node_similarity_search( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - MATCH (n:Entity) - """ + MATCH (n:Entity) + """ + filter_query + """ RETURN DISTINCT id(n) as id, n.name_embedding as embedding @@ -777,11 +781,11 @@ async def node_similarity_search( # Match the edge ides and return the values query = ( """ - UNWIND $ids as i - MATCH (n:Entity) - WHERE id(n)=i.id - RETURN - """ + UNWIND $ids as i + MATCH (n:Entity) + WHERE id(n)=i.id + RETURN + """ + get_entity_node_return_query(driver.provider) + """ ORDER BY i.score DESC @@ -810,7 +814,11 @@ async def node_similarity_search( body={ 'query': { 'knn': { - 'name_embedding': {'vector': list(map(float, search_vector)), 'k': limit} + 'name_embedding': { + 'vector': list(map(float, search_vector)), + 'k': limit, + 'filter': {'bool': {'filter': filters}}, + } } } }, @@ -829,8 +837,8 @@ async def node_similarity_search( else: query = ( """ - MATCH (n:Entity) - """ + MATCH (n:Entity) + """ + filter_query + """ WITH n, """ @@ -1162,8 +1170,8 @@ async def community_similarity_search( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - MATCH (n:Community) - """ + MATCH (n:Community) + """ + group_filter_query + """ RETURN DISTINCT id(n) as id, n.name_embedding as embedding @@ -1222,8 +1230,8 @@ async def community_similarity_search( query = ( """ - MATCH (c:Community) - """ + MATCH (c:Community) + """ + group_filter_query + """ WITH c, @@ -1365,9 +1373,9 @@ async def get_relevant_nodes( # FIXME: Kuzu currently does not support using variables such as `node.fulltext_query` as an input to FTS, which means `get_relevant_nodes()` won't work with Kuzu as the graph driver. query = ( """ - UNWIND $nodes AS node - MATCH (n:Entity {group_id: $group_id}) - """ + UNWIND $nodes AS node + MATCH (n:Entity {group_id: $group_id}) + """ + filter_query + """ WITH node, n, """ @@ -1412,9 +1420,9 @@ async def get_relevant_nodes( else: query = ( """ - UNWIND $nodes AS node - MATCH (n:Entity {group_id: $group_id}) - """ + UNWIND $nodes AS node + MATCH (n:Entity {group_id: $group_id}) + """ + filter_query + """ WITH node, n, """ @@ -1503,9 +1511,9 @@ async def get_relevant_edges( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) - """ + UNWIND $edges AS edge + MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) + """ + filter_query + """ WITH e, edge @@ -1575,9 +1583,9 @@ async def get_relevant_edges( query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid}) - """ + UNWIND $edges AS edge + MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid}) + """ + filter_query + """ WITH e, edge, n, m, """ @@ -1610,61 +1618,12 @@ async def get_relevant_edges( }) AS matches """ ) - elif driver.aoss_client: - # First get edge candidates - query = ( - """ - UNWIND $edges AS edge - MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) - """ - + filter_query - + """ - RETURN - e.uuid AS search_edge_uuid, - collect({ - uuid: e.uuid, - source_node_uuid: startNode(e).uuid, - target_node_uuid: endNode(e).uuid, - created_at: e.created_at, - name: e.name, - group_id: e.group_id, - fact: e.fact, - fact_embedding: e.fact_embedding, - episodes: e.episodes, - expired_at: e.expired_at, - valid_at: e.valid_at, - invalid_at: e.invalid_at, - attributes: properties(e) - }) AS matches - """ - ) - - results, _, _ = await driver.execute_query( - query, - edges=[edge.model_dump() for edge in edges], - limit=limit, - min_score=min_score, - routing_='r', - **filter_params, - ) - - relevant_edges_dict: dict[str, list[EntityEdge]] = { - result['search_edge_uuid']: [ - get_entity_edge_from_record(record, driver.provider) - for record in result['matches'] - ] - for result in results - } - - group_id = edges[0].group_id - # semaphore_gather(*[edge_similarity_search(driver, )]) - else: query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) - """ + UNWIND $edges AS edge + MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) + """ + filter_query + """ WITH e, edge, """ @@ -1737,10 +1696,10 @@ async def get_edge_invalidation_candidates( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) - WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] - """ + UNWIND $edges AS edge + MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) + WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] + """ + filter_query + """ WITH e, edge @@ -1810,10 +1769,10 @@ async def get_edge_invalidation_candidates( query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity) - WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]) - """ + UNWIND $edges AS edge + MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity) + WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]) + """ + filter_query + """ WITH edge, e, n, m, """ @@ -1849,10 +1808,10 @@ async def get_edge_invalidation_candidates( else: query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) - WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] - """ + UNWIND $edges AS edge + MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) + WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] + """ + filter_query + """ WITH edge, e, """ diff --git a/graphiti_core/utils/maintenance/edge_operations.py b/graphiti_core/utils/maintenance/edge_operations.py index 14ec4a69..1a38bd55 100644 --- a/graphiti_core/utils/maintenance/edge_operations.py +++ b/graphiti_core/utils/maintenance/edge_operations.py @@ -276,7 +276,7 @@ async def resolve_extracted_edges( config=EDGE_HYBRID_SEARCH_RRF, search_filter=SearchFilters(uuids=valid_uuids), ) - for extracted_edge, valid_uuids in zip(extracted_edges, valid_uuids_list) + for extracted_edge, valid_uuids in zip(extracted_edges, valid_uuids_list, strict=True) ] ) diff --git a/uv.lock b/uv.lock index 8228d4a8..bad253b8 100644 --- a/uv.lock +++ b/uv.lock @@ -783,7 +783,7 @@ wheels = [ [[package]] name = "graphiti-core" -version = "0.21.0rc1" +version = "0.21.0rc2" source = { editable = "." } dependencies = [ { name = "diskcache" },