From 427f917614cb9d46af36b4c4b11236cd3fddae09 Mon Sep 17 00:00:00 2001 From: prestonrasmussen Date: Thu, 11 Sep 2025 15:39:40 -0400 Subject: [PATCH] add uuid filter functionality --- graphiti_core/graphiti.py | 26 ++++- graphiti_core/search/search_filters.py | 8 ++ graphiti_core/search/search_utils.py | 110 +++++++++--------- .../utils/maintenance/edge_operations.py | 44 ++++++- 4 files changed, 123 insertions(+), 65 deletions(-) diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index d217d924..55f4fe38 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -60,9 +60,7 @@ from graphiti_core.search.search_config_recipes import ( from graphiti_core.search.search_filters import SearchFilters from graphiti_core.search.search_utils import ( RELEVANT_SCHEMA_LIMIT, - get_edge_invalidation_candidates, get_mentioned_nodes, - get_relevant_edges, ) from graphiti_core.telemetry import capture_event from graphiti_core.utils.bulk_utils import ( @@ -1037,10 +1035,28 @@ class Graphiti: updated_edge = resolve_edge_pointers([edge], uuid_map)[0] - related_edges = (await get_relevant_edges(self.driver, [updated_edge], SearchFilters()))[0] + valid_uuids = await EntityEdge.get_between_nodes( + self.driver, edge.source_node_uuid, edge.target_node_uuid + ) + + related_edges = ( + await search( + self.clients, + updated_edge.fact, + group_ids=[updated_edge.group_id], + config=EDGE_HYBRID_SEARCH_RRF, + search_filter=SearchFilters(uuids=valid_uuids), + ) + ).edges existing_edges = ( - await get_edge_invalidation_candidates(self.driver, [updated_edge], SearchFilters()) - )[0] + await search( + self.clients, + updated_edge.fact, + group_ids=[updated_edge.group_id], + config=EDGE_HYBRID_SEARCH_RRF, + search_filter=SearchFilters(), + ) + ).edges resolved_edge, invalidated_edges, _ = await resolve_extracted_edge( self.llm_client, diff --git a/graphiti_core/search/search_filters.py b/graphiti_core/search/search_filters.py index f5f2252c..37cf7e82 100644 --- a/graphiti_core/search/search_filters.py +++ b/graphiti_core/search/search_filters.py @@ -52,6 +52,7 @@ class SearchFilters(BaseModel): invalid_at: list[list[DateFilter]] | None = Field(default=None) created_at: list[list[DateFilter]] | None = Field(default=None) expired_at: list[list[DateFilter]] | None = Field(default=None) + edge_uuids: list[str] | None = Field(default=None) def cypher_to_opensearch_operator(op: ComparisonOperator) -> str: @@ -108,6 +109,10 @@ def edge_search_filter_query_constructor( filter_queries.append('e.name in $edge_types') filter_params['edge_types'] = edge_types + if filters.edge_uuids is not None: + filter_queries.append('e.uuid in $edge_uuids') + filter_params['edge_uuids'] = filters.edge_uuids + if filters.node_labels is not None: if provider == GraphProvider.KUZU: node_label_filter = ( @@ -261,6 +266,9 @@ def build_aoss_edge_filters(group_ids: list[str], search_filters: SearchFilters) if search_filters.edge_types: filters.append({'terms': {'edge_types': search_filters.edge_types}}) + if search_filters.edge_uuids: + filters.append({'terms': {'uuid': search_filters.edge_uuids}}) + for field in ['valid_at', 'invalid_at', 'created_at', 'expired_at']: ranges = getattr(search_filters, field) if ranges: diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index 15790f5f..27aefa89 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -209,11 +209,11 @@ async def edge_fulltext_search( # Match the edge ids and return the values query = ( """ - UNWIND $ids as id - MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) - WHERE e.group_id IN $group_ids - AND id(e)=id - """ + UNWIND $ids as id + MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) + WHERE e.group_id IN $group_ids + AND id(e)=id + """ + filter_query + """ AND id(e)=id @@ -344,8 +344,8 @@ async def edge_similarity_search( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) - """ + MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) + """ + filter_query + """ RETURN DISTINCT id(e) as id, e.fact_embedding as embedding @@ -622,11 +622,11 @@ async def node_fulltext_search( # Match the edge ides and return the values query = ( """ - UNWIND $ids as i - MATCH (n:Entity) - WHERE n.uuid=i.id - RETURN - """ + UNWIND $ids as i + MATCH (n:Entity) + WHERE n.uuid=i.id + RETURN + """ + get_entity_node_return_query(driver.provider) + """ ORDER BY i.score DESC @@ -734,8 +734,8 @@ async def node_similarity_search( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - MATCH (n:Entity) - """ + MATCH (n:Entity) + """ + filter_query + """ RETURN DISTINCT id(n) as id, n.name_embedding as embedding @@ -764,11 +764,11 @@ async def node_similarity_search( # Match the edge ides and return the values query = ( """ - UNWIND $ids as i - MATCH (n:Entity) - WHERE id(n)=i.id - RETURN - """ + UNWIND $ids as i + MATCH (n:Entity) + WHERE id(n)=i.id + RETURN + """ + get_entity_node_return_query(driver.provider) + """ ORDER BY i.score DESC @@ -814,8 +814,8 @@ async def node_similarity_search( else: query = ( """ - MATCH (n:Entity) - """ + MATCH (n:Entity) + """ + filter_query + """ WITH n, """ @@ -1147,8 +1147,8 @@ async def community_similarity_search( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - MATCH (n:Community) - """ + MATCH (n:Community) + """ + group_filter_query + """ RETURN DISTINCT id(n) as id, n.name_embedding as embedding @@ -1207,8 +1207,8 @@ async def community_similarity_search( query = ( """ - MATCH (c:Community) - """ + MATCH (c:Community) + """ + group_filter_query + """ WITH c, @@ -1350,9 +1350,9 @@ async def get_relevant_nodes( # FIXME: Kuzu currently does not support using variables such as `node.fulltext_query` as an input to FTS, which means `get_relevant_nodes()` won't work with Kuzu as the graph driver. query = ( """ - UNWIND $nodes AS node - MATCH (n:Entity {group_id: $group_id}) - """ + UNWIND $nodes AS node + MATCH (n:Entity {group_id: $group_id}) + """ + filter_query + """ WITH node, n, """ @@ -1397,9 +1397,9 @@ async def get_relevant_nodes( else: query = ( """ - UNWIND $nodes AS node - MATCH (n:Entity {group_id: $group_id}) - """ + UNWIND $nodes AS node + MATCH (n:Entity {group_id: $group_id}) + """ + filter_query + """ WITH node, n, """ @@ -1488,9 +1488,9 @@ async def get_relevant_edges( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) - """ + UNWIND $edges AS edge + MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) + """ + filter_query + """ WITH e, edge @@ -1560,9 +1560,9 @@ async def get_relevant_edges( query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid}) - """ + UNWIND $edges AS edge + MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid}) + """ + filter_query + """ WITH e, edge, n, m, """ @@ -1599,9 +1599,9 @@ async def get_relevant_edges( # First get edge candidates query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) - """ + UNWIND $edges AS edge + MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) + """ + filter_query + """ RETURN @@ -1647,9 +1647,9 @@ async def get_relevant_edges( else: query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) - """ + UNWIND $edges AS edge + MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid}) + """ + filter_query + """ WITH e, edge, """ @@ -1722,10 +1722,10 @@ async def get_edge_invalidation_candidates( if driver.provider == GraphProvider.NEPTUNE: query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) - WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] - """ + UNWIND $edges AS edge + MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) + WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] + """ + filter_query + """ WITH e, edge @@ -1795,10 +1795,10 @@ async def get_edge_invalidation_candidates( query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity) - WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]) - """ + UNWIND $edges AS edge + MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity) + WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]) + """ + filter_query + """ WITH edge, e, n, m, """ @@ -1834,10 +1834,10 @@ async def get_edge_invalidation_candidates( else: query = ( """ - UNWIND $edges AS edge - MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) - WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] - """ + UNWIND $edges AS edge + MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity) + WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid] + """ + filter_query + """ WITH edge, e, """ diff --git a/graphiti_core/utils/maintenance/edge_operations.py b/graphiti_core/utils/maintenance/edge_operations.py index 55cea243..14ec4a69 100644 --- a/graphiti_core/utils/maintenance/edge_operations.py +++ b/graphiti_core/utils/maintenance/edge_operations.py @@ -36,8 +36,10 @@ from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode from graphiti_core.prompts import prompt_library from graphiti_core.prompts.dedupe_edges import EdgeDuplicate from graphiti_core.prompts.extract_edges import ExtractedEdges, MissingFacts +from graphiti_core.search.search import search +from graphiti_core.search.search_config import SearchResults +from graphiti_core.search.search_config_recipes import EDGE_HYBRID_SEARCH_RRF from graphiti_core.search.search_filters import SearchFilters -from graphiti_core.search.search_utils import get_edge_invalidation_candidates, get_relevant_edges from graphiti_core.utils.datetime_utils import ensure_utc, utc_now logger = logging.getLogger(__name__) @@ -258,12 +260,44 @@ async def resolve_extracted_edges( embedder = clients.embedder await create_entity_edge_embeddings(embedder, extracted_edges) - search_results = await semaphore_gather( - get_relevant_edges(driver, extracted_edges, SearchFilters()), - get_edge_invalidation_candidates(driver, extracted_edges, SearchFilters(), 0.2), + valid_uuids_list: list[list[str]] = await semaphore_gather( + *[ + EntityEdge.get_between_nodes(driver, edge.source_node_uuid, edge.target_node_uuid) + for edge in extracted_edges + ] ) - related_edges_lists, edge_invalidation_candidates = search_results + related_edges_results: list[SearchResults] = await semaphore_gather( + *[ + search( + clients, + extracted_edge.fact, + group_ids=[extracted_edge.group_id], + config=EDGE_HYBRID_SEARCH_RRF, + search_filter=SearchFilters(uuids=valid_uuids), + ) + for extracted_edge, valid_uuids in zip(extracted_edges, valid_uuids_list) + ] + ) + + related_edges_lists: list[list[EntityEdge]] = [result.edges for result in related_edges_results] + + edge_invalidation_candidate_results: list[SearchResults] = await semaphore_gather( + *[ + search( + clients, + extracted_edge.fact, + group_ids=[extracted_edge.group_id], + config=EDGE_HYBRID_SEARCH_RRF, + search_filter=SearchFilters(), + ) + for extracted_edge in extracted_edges + ] + ) + + edge_invalidation_candidates: list[list[EntityEdge]] = [ + result.edges for result in edge_invalidation_candidate_results + ] logger.debug( f'Related edges lists: {[(e.name, e.uuid) for edges_lst in related_edges_lists for e in edges_lst]}'