add uuid filter functionality

This commit is contained in:
prestonrasmussen 2025-09-11 15:39:40 -04:00
parent 37715f6261
commit 427f917614
4 changed files with 123 additions and 65 deletions

View file

@ -60,9 +60,7 @@ from graphiti_core.search.search_config_recipes import (
from graphiti_core.search.search_filters import SearchFilters
from graphiti_core.search.search_utils import (
RELEVANT_SCHEMA_LIMIT,
get_edge_invalidation_candidates,
get_mentioned_nodes,
get_relevant_edges,
)
from graphiti_core.telemetry import capture_event
from graphiti_core.utils.bulk_utils import (
@ -1037,10 +1035,28 @@ class Graphiti:
updated_edge = resolve_edge_pointers([edge], uuid_map)[0]
related_edges = (await get_relevant_edges(self.driver, [updated_edge], SearchFilters()))[0]
valid_uuids = await EntityEdge.get_between_nodes(
self.driver, edge.source_node_uuid, edge.target_node_uuid
)
related_edges = (
await search(
self.clients,
updated_edge.fact,
group_ids=[updated_edge.group_id],
config=EDGE_HYBRID_SEARCH_RRF,
search_filter=SearchFilters(uuids=valid_uuids),
)
).edges
existing_edges = (
await get_edge_invalidation_candidates(self.driver, [updated_edge], SearchFilters())
)[0]
await search(
self.clients,
updated_edge.fact,
group_ids=[updated_edge.group_id],
config=EDGE_HYBRID_SEARCH_RRF,
search_filter=SearchFilters(),
)
).edges
resolved_edge, invalidated_edges, _ = await resolve_extracted_edge(
self.llm_client,

View file

@ -52,6 +52,7 @@ class SearchFilters(BaseModel):
invalid_at: list[list[DateFilter]] | None = Field(default=None)
created_at: list[list[DateFilter]] | None = Field(default=None)
expired_at: list[list[DateFilter]] | None = Field(default=None)
edge_uuids: list[str] | None = Field(default=None)
def cypher_to_opensearch_operator(op: ComparisonOperator) -> str:
@ -108,6 +109,10 @@ def edge_search_filter_query_constructor(
filter_queries.append('e.name in $edge_types')
filter_params['edge_types'] = edge_types
if filters.edge_uuids is not None:
filter_queries.append('e.uuid in $edge_uuids')
filter_params['edge_uuids'] = filters.edge_uuids
if filters.node_labels is not None:
if provider == GraphProvider.KUZU:
node_label_filter = (
@ -261,6 +266,9 @@ def build_aoss_edge_filters(group_ids: list[str], search_filters: SearchFilters)
if search_filters.edge_types:
filters.append({'terms': {'edge_types': search_filters.edge_types}})
if search_filters.edge_uuids:
filters.append({'terms': {'uuid': search_filters.edge_uuids}})
for field in ['valid_at', 'invalid_at', 'created_at', 'expired_at']:
ranges = getattr(search_filters, field)
if ranges:

View file

@ -209,11 +209,11 @@ async def edge_fulltext_search(
# Match the edge ids and return the values
query = (
"""
UNWIND $ids as id
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
WHERE e.group_id IN $group_ids
AND id(e)=id
"""
UNWIND $ids as id
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
WHERE e.group_id IN $group_ids
AND id(e)=id
"""
+ filter_query
+ """
AND id(e)=id
@ -344,8 +344,8 @@ async def edge_similarity_search(
if driver.provider == GraphProvider.NEPTUNE:
query = (
"""
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
"""
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
"""
+ filter_query
+ """
RETURN DISTINCT id(e) as id, e.fact_embedding as embedding
@ -622,11 +622,11 @@ async def node_fulltext_search(
# Match the edge ides and return the values
query = (
"""
UNWIND $ids as i
MATCH (n:Entity)
WHERE n.uuid=i.id
RETURN
"""
UNWIND $ids as i
MATCH (n:Entity)
WHERE n.uuid=i.id
RETURN
"""
+ get_entity_node_return_query(driver.provider)
+ """
ORDER BY i.score DESC
@ -734,8 +734,8 @@ async def node_similarity_search(
if driver.provider == GraphProvider.NEPTUNE:
query = (
"""
MATCH (n:Entity)
"""
MATCH (n:Entity)
"""
+ filter_query
+ """
RETURN DISTINCT id(n) as id, n.name_embedding as embedding
@ -764,11 +764,11 @@ async def node_similarity_search(
# Match the edge ides and return the values
query = (
"""
UNWIND $ids as i
MATCH (n:Entity)
WHERE id(n)=i.id
RETURN
"""
UNWIND $ids as i
MATCH (n:Entity)
WHERE id(n)=i.id
RETURN
"""
+ get_entity_node_return_query(driver.provider)
+ """
ORDER BY i.score DESC
@ -814,8 +814,8 @@ async def node_similarity_search(
else:
query = (
"""
MATCH (n:Entity)
"""
MATCH (n:Entity)
"""
+ filter_query
+ """
WITH n, """
@ -1147,8 +1147,8 @@ async def community_similarity_search(
if driver.provider == GraphProvider.NEPTUNE:
query = (
"""
MATCH (n:Community)
"""
MATCH (n:Community)
"""
+ group_filter_query
+ """
RETURN DISTINCT id(n) as id, n.name_embedding as embedding
@ -1207,8 +1207,8 @@ async def community_similarity_search(
query = (
"""
MATCH (c:Community)
"""
MATCH (c:Community)
"""
+ group_filter_query
+ """
WITH c,
@ -1350,9 +1350,9 @@ async def get_relevant_nodes(
# FIXME: Kuzu currently does not support using variables such as `node.fulltext_query` as an input to FTS, which means `get_relevant_nodes()` won't work with Kuzu as the graph driver.
query = (
"""
UNWIND $nodes AS node
MATCH (n:Entity {group_id: $group_id})
"""
UNWIND $nodes AS node
MATCH (n:Entity {group_id: $group_id})
"""
+ filter_query
+ """
WITH node, n, """
@ -1397,9 +1397,9 @@ async def get_relevant_nodes(
else:
query = (
"""
UNWIND $nodes AS node
MATCH (n:Entity {group_id: $group_id})
"""
UNWIND $nodes AS node
MATCH (n:Entity {group_id: $group_id})
"""
+ filter_query
+ """
WITH node, n, """
@ -1488,9 +1488,9 @@ async def get_relevant_edges(
if driver.provider == GraphProvider.NEPTUNE:
query = (
"""
UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
"""
UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
"""
+ filter_query
+ """
WITH e, edge
@ -1560,9 +1560,9 @@ async def get_relevant_edges(
query = (
"""
UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid})
"""
UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid})
"""
+ filter_query
+ """
WITH e, edge, n, m, """
@ -1599,9 +1599,9 @@ async def get_relevant_edges(
# First get edge candidates
query = (
"""
UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
"""
UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
"""
+ filter_query
+ """
RETURN
@ -1647,9 +1647,9 @@ async def get_relevant_edges(
else:
query = (
"""
UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
"""
UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
"""
+ filter_query
+ """
WITH e, edge, """
@ -1722,10 +1722,10 @@ async def get_edge_invalidation_candidates(
if driver.provider == GraphProvider.NEPTUNE:
query = (
"""
UNWIND $edges AS edge
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
"""
UNWIND $edges AS edge
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
"""
+ filter_query
+ """
WITH e, edge
@ -1795,10 +1795,10 @@ async def get_edge_invalidation_candidates(
query = (
"""
UNWIND $edges AS edge
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity)
WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid])
"""
UNWIND $edges AS edge
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity)
WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid])
"""
+ filter_query
+ """
WITH edge, e, n, m, """
@ -1834,10 +1834,10 @@ async def get_edge_invalidation_candidates(
else:
query = (
"""
UNWIND $edges AS edge
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
"""
UNWIND $edges AS edge
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
"""
+ filter_query
+ """
WITH edge, e, """

View file

@ -36,8 +36,10 @@ from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
from graphiti_core.prompts import prompt_library
from graphiti_core.prompts.dedupe_edges import EdgeDuplicate
from graphiti_core.prompts.extract_edges import ExtractedEdges, MissingFacts
from graphiti_core.search.search import search
from graphiti_core.search.search_config import SearchResults
from graphiti_core.search.search_config_recipes import EDGE_HYBRID_SEARCH_RRF
from graphiti_core.search.search_filters import SearchFilters
from graphiti_core.search.search_utils import get_edge_invalidation_candidates, get_relevant_edges
from graphiti_core.utils.datetime_utils import ensure_utc, utc_now
logger = logging.getLogger(__name__)
@ -258,12 +260,44 @@ async def resolve_extracted_edges(
embedder = clients.embedder
await create_entity_edge_embeddings(embedder, extracted_edges)
search_results = await semaphore_gather(
get_relevant_edges(driver, extracted_edges, SearchFilters()),
get_edge_invalidation_candidates(driver, extracted_edges, SearchFilters(), 0.2),
valid_uuids_list: list[list[str]] = await semaphore_gather(
*[
EntityEdge.get_between_nodes(driver, edge.source_node_uuid, edge.target_node_uuid)
for edge in extracted_edges
]
)
related_edges_lists, edge_invalidation_candidates = search_results
related_edges_results: list[SearchResults] = await semaphore_gather(
*[
search(
clients,
extracted_edge.fact,
group_ids=[extracted_edge.group_id],
config=EDGE_HYBRID_SEARCH_RRF,
search_filter=SearchFilters(uuids=valid_uuids),
)
for extracted_edge, valid_uuids in zip(extracted_edges, valid_uuids_list)
]
)
related_edges_lists: list[list[EntityEdge]] = [result.edges for result in related_edges_results]
edge_invalidation_candidate_results: list[SearchResults] = await semaphore_gather(
*[
search(
clients,
extracted_edge.fact,
group_ids=[extracted_edge.group_id],
config=EDGE_HYBRID_SEARCH_RRF,
search_filter=SearchFilters(),
)
for extracted_edge in extracted_edges
]
)
edge_invalidation_candidates: list[list[EntityEdge]] = [
result.edges for result in edge_invalidation_candidate_results
]
logger.debug(
f'Related edges lists: {[(e.name, e.uuid) for edges_lst in related_edges_lists for e in edges_lst]}'