add uuid filter functionality

This commit is contained in:
prestonrasmussen 2025-09-11 15:39:40 -04:00
parent 37715f6261
commit 427f917614
4 changed files with 123 additions and 65 deletions

View file

@ -60,9 +60,7 @@ from graphiti_core.search.search_config_recipes import (
from graphiti_core.search.search_filters import SearchFilters from graphiti_core.search.search_filters import SearchFilters
from graphiti_core.search.search_utils import ( from graphiti_core.search.search_utils import (
RELEVANT_SCHEMA_LIMIT, RELEVANT_SCHEMA_LIMIT,
get_edge_invalidation_candidates,
get_mentioned_nodes, get_mentioned_nodes,
get_relevant_edges,
) )
from graphiti_core.telemetry import capture_event from graphiti_core.telemetry import capture_event
from graphiti_core.utils.bulk_utils import ( from graphiti_core.utils.bulk_utils import (
@ -1037,10 +1035,28 @@ class Graphiti:
updated_edge = resolve_edge_pointers([edge], uuid_map)[0] updated_edge = resolve_edge_pointers([edge], uuid_map)[0]
related_edges = (await get_relevant_edges(self.driver, [updated_edge], SearchFilters()))[0] valid_uuids = await EntityEdge.get_between_nodes(
self.driver, edge.source_node_uuid, edge.target_node_uuid
)
related_edges = (
await search(
self.clients,
updated_edge.fact,
group_ids=[updated_edge.group_id],
config=EDGE_HYBRID_SEARCH_RRF,
search_filter=SearchFilters(uuids=valid_uuids),
)
).edges
existing_edges = ( existing_edges = (
await get_edge_invalidation_candidates(self.driver, [updated_edge], SearchFilters()) await search(
)[0] self.clients,
updated_edge.fact,
group_ids=[updated_edge.group_id],
config=EDGE_HYBRID_SEARCH_RRF,
search_filter=SearchFilters(),
)
).edges
resolved_edge, invalidated_edges, _ = await resolve_extracted_edge( resolved_edge, invalidated_edges, _ = await resolve_extracted_edge(
self.llm_client, self.llm_client,

View file

@ -52,6 +52,7 @@ class SearchFilters(BaseModel):
invalid_at: list[list[DateFilter]] | None = Field(default=None) invalid_at: list[list[DateFilter]] | None = Field(default=None)
created_at: list[list[DateFilter]] | None = Field(default=None) created_at: list[list[DateFilter]] | None = Field(default=None)
expired_at: list[list[DateFilter]] | None = Field(default=None) expired_at: list[list[DateFilter]] | None = Field(default=None)
edge_uuids: list[str] | None = Field(default=None)
def cypher_to_opensearch_operator(op: ComparisonOperator) -> str: def cypher_to_opensearch_operator(op: ComparisonOperator) -> str:
@ -108,6 +109,10 @@ def edge_search_filter_query_constructor(
filter_queries.append('e.name in $edge_types') filter_queries.append('e.name in $edge_types')
filter_params['edge_types'] = edge_types filter_params['edge_types'] = edge_types
if filters.edge_uuids is not None:
filter_queries.append('e.uuid in $edge_uuids')
filter_params['edge_uuids'] = filters.edge_uuids
if filters.node_labels is not None: if filters.node_labels is not None:
if provider == GraphProvider.KUZU: if provider == GraphProvider.KUZU:
node_label_filter = ( node_label_filter = (
@ -261,6 +266,9 @@ def build_aoss_edge_filters(group_ids: list[str], search_filters: SearchFilters)
if search_filters.edge_types: if search_filters.edge_types:
filters.append({'terms': {'edge_types': search_filters.edge_types}}) filters.append({'terms': {'edge_types': search_filters.edge_types}})
if search_filters.edge_uuids:
filters.append({'terms': {'uuid': search_filters.edge_uuids}})
for field in ['valid_at', 'invalid_at', 'created_at', 'expired_at']: for field in ['valid_at', 'invalid_at', 'created_at', 'expired_at']:
ranges = getattr(search_filters, field) ranges = getattr(search_filters, field)
if ranges: if ranges:

View file

@ -36,8 +36,10 @@ from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
from graphiti_core.prompts import prompt_library from graphiti_core.prompts import prompt_library
from graphiti_core.prompts.dedupe_edges import EdgeDuplicate from graphiti_core.prompts.dedupe_edges import EdgeDuplicate
from graphiti_core.prompts.extract_edges import ExtractedEdges, MissingFacts from graphiti_core.prompts.extract_edges import ExtractedEdges, MissingFacts
from graphiti_core.search.search import search
from graphiti_core.search.search_config import SearchResults
from graphiti_core.search.search_config_recipes import EDGE_HYBRID_SEARCH_RRF
from graphiti_core.search.search_filters import SearchFilters from graphiti_core.search.search_filters import SearchFilters
from graphiti_core.search.search_utils import get_edge_invalidation_candidates, get_relevant_edges
from graphiti_core.utils.datetime_utils import ensure_utc, utc_now from graphiti_core.utils.datetime_utils import ensure_utc, utc_now
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -258,12 +260,44 @@ async def resolve_extracted_edges(
embedder = clients.embedder embedder = clients.embedder
await create_entity_edge_embeddings(embedder, extracted_edges) await create_entity_edge_embeddings(embedder, extracted_edges)
search_results = await semaphore_gather( valid_uuids_list: list[list[str]] = await semaphore_gather(
get_relevant_edges(driver, extracted_edges, SearchFilters()), *[
get_edge_invalidation_candidates(driver, extracted_edges, SearchFilters(), 0.2), EntityEdge.get_between_nodes(driver, edge.source_node_uuid, edge.target_node_uuid)
for edge in extracted_edges
]
) )
related_edges_lists, edge_invalidation_candidates = search_results related_edges_results: list[SearchResults] = await semaphore_gather(
*[
search(
clients,
extracted_edge.fact,
group_ids=[extracted_edge.group_id],
config=EDGE_HYBRID_SEARCH_RRF,
search_filter=SearchFilters(uuids=valid_uuids),
)
for extracted_edge, valid_uuids in zip(extracted_edges, valid_uuids_list)
]
)
related_edges_lists: list[list[EntityEdge]] = [result.edges for result in related_edges_results]
edge_invalidation_candidate_results: list[SearchResults] = await semaphore_gather(
*[
search(
clients,
extracted_edge.fact,
group_ids=[extracted_edge.group_id],
config=EDGE_HYBRID_SEARCH_RRF,
search_filter=SearchFilters(),
)
for extracted_edge in extracted_edges
]
)
edge_invalidation_candidates: list[list[EntityEdge]] = [
result.edges for result in edge_invalidation_candidate_results
]
logger.debug( logger.debug(
f'Related edges lists: {[(e.name, e.uuid) for edges_lst in related_edges_lists for e in edges_lst]}' f'Related edges lists: {[(e.name, e.uuid) for edges_lst in related_edges_lists for e in edges_lst]}'