add uuid filter functionality
This commit is contained in:
parent
37715f6261
commit
427f917614
4 changed files with 123 additions and 65 deletions
|
|
@ -60,9 +60,7 @@ from graphiti_core.search.search_config_recipes import (
|
||||||
from graphiti_core.search.search_filters import SearchFilters
|
from graphiti_core.search.search_filters import SearchFilters
|
||||||
from graphiti_core.search.search_utils import (
|
from graphiti_core.search.search_utils import (
|
||||||
RELEVANT_SCHEMA_LIMIT,
|
RELEVANT_SCHEMA_LIMIT,
|
||||||
get_edge_invalidation_candidates,
|
|
||||||
get_mentioned_nodes,
|
get_mentioned_nodes,
|
||||||
get_relevant_edges,
|
|
||||||
)
|
)
|
||||||
from graphiti_core.telemetry import capture_event
|
from graphiti_core.telemetry import capture_event
|
||||||
from graphiti_core.utils.bulk_utils import (
|
from graphiti_core.utils.bulk_utils import (
|
||||||
|
|
@ -1037,10 +1035,28 @@ class Graphiti:
|
||||||
|
|
||||||
updated_edge = resolve_edge_pointers([edge], uuid_map)[0]
|
updated_edge = resolve_edge_pointers([edge], uuid_map)[0]
|
||||||
|
|
||||||
related_edges = (await get_relevant_edges(self.driver, [updated_edge], SearchFilters()))[0]
|
valid_uuids = await EntityEdge.get_between_nodes(
|
||||||
|
self.driver, edge.source_node_uuid, edge.target_node_uuid
|
||||||
|
)
|
||||||
|
|
||||||
|
related_edges = (
|
||||||
|
await search(
|
||||||
|
self.clients,
|
||||||
|
updated_edge.fact,
|
||||||
|
group_ids=[updated_edge.group_id],
|
||||||
|
config=EDGE_HYBRID_SEARCH_RRF,
|
||||||
|
search_filter=SearchFilters(uuids=valid_uuids),
|
||||||
|
)
|
||||||
|
).edges
|
||||||
existing_edges = (
|
existing_edges = (
|
||||||
await get_edge_invalidation_candidates(self.driver, [updated_edge], SearchFilters())
|
await search(
|
||||||
)[0]
|
self.clients,
|
||||||
|
updated_edge.fact,
|
||||||
|
group_ids=[updated_edge.group_id],
|
||||||
|
config=EDGE_HYBRID_SEARCH_RRF,
|
||||||
|
search_filter=SearchFilters(),
|
||||||
|
)
|
||||||
|
).edges
|
||||||
|
|
||||||
resolved_edge, invalidated_edges, _ = await resolve_extracted_edge(
|
resolved_edge, invalidated_edges, _ = await resolve_extracted_edge(
|
||||||
self.llm_client,
|
self.llm_client,
|
||||||
|
|
|
||||||
|
|
@ -52,6 +52,7 @@ class SearchFilters(BaseModel):
|
||||||
invalid_at: list[list[DateFilter]] | None = Field(default=None)
|
invalid_at: list[list[DateFilter]] | None = Field(default=None)
|
||||||
created_at: list[list[DateFilter]] | None = Field(default=None)
|
created_at: list[list[DateFilter]] | None = Field(default=None)
|
||||||
expired_at: list[list[DateFilter]] | None = Field(default=None)
|
expired_at: list[list[DateFilter]] | None = Field(default=None)
|
||||||
|
edge_uuids: list[str] | None = Field(default=None)
|
||||||
|
|
||||||
|
|
||||||
def cypher_to_opensearch_operator(op: ComparisonOperator) -> str:
|
def cypher_to_opensearch_operator(op: ComparisonOperator) -> str:
|
||||||
|
|
@ -108,6 +109,10 @@ def edge_search_filter_query_constructor(
|
||||||
filter_queries.append('e.name in $edge_types')
|
filter_queries.append('e.name in $edge_types')
|
||||||
filter_params['edge_types'] = edge_types
|
filter_params['edge_types'] = edge_types
|
||||||
|
|
||||||
|
if filters.edge_uuids is not None:
|
||||||
|
filter_queries.append('e.uuid in $edge_uuids')
|
||||||
|
filter_params['edge_uuids'] = filters.edge_uuids
|
||||||
|
|
||||||
if filters.node_labels is not None:
|
if filters.node_labels is not None:
|
||||||
if provider == GraphProvider.KUZU:
|
if provider == GraphProvider.KUZU:
|
||||||
node_label_filter = (
|
node_label_filter = (
|
||||||
|
|
@ -261,6 +266,9 @@ def build_aoss_edge_filters(group_ids: list[str], search_filters: SearchFilters)
|
||||||
if search_filters.edge_types:
|
if search_filters.edge_types:
|
||||||
filters.append({'terms': {'edge_types': search_filters.edge_types}})
|
filters.append({'terms': {'edge_types': search_filters.edge_types}})
|
||||||
|
|
||||||
|
if search_filters.edge_uuids:
|
||||||
|
filters.append({'terms': {'uuid': search_filters.edge_uuids}})
|
||||||
|
|
||||||
for field in ['valid_at', 'invalid_at', 'created_at', 'expired_at']:
|
for field in ['valid_at', 'invalid_at', 'created_at', 'expired_at']:
|
||||||
ranges = getattr(search_filters, field)
|
ranges = getattr(search_filters, field)
|
||||||
if ranges:
|
if ranges:
|
||||||
|
|
|
||||||
|
|
@ -36,8 +36,10 @@ from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
|
||||||
from graphiti_core.prompts import prompt_library
|
from graphiti_core.prompts import prompt_library
|
||||||
from graphiti_core.prompts.dedupe_edges import EdgeDuplicate
|
from graphiti_core.prompts.dedupe_edges import EdgeDuplicate
|
||||||
from graphiti_core.prompts.extract_edges import ExtractedEdges, MissingFacts
|
from graphiti_core.prompts.extract_edges import ExtractedEdges, MissingFacts
|
||||||
|
from graphiti_core.search.search import search
|
||||||
|
from graphiti_core.search.search_config import SearchResults
|
||||||
|
from graphiti_core.search.search_config_recipes import EDGE_HYBRID_SEARCH_RRF
|
||||||
from graphiti_core.search.search_filters import SearchFilters
|
from graphiti_core.search.search_filters import SearchFilters
|
||||||
from graphiti_core.search.search_utils import get_edge_invalidation_candidates, get_relevant_edges
|
|
||||||
from graphiti_core.utils.datetime_utils import ensure_utc, utc_now
|
from graphiti_core.utils.datetime_utils import ensure_utc, utc_now
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -258,12 +260,44 @@ async def resolve_extracted_edges(
|
||||||
embedder = clients.embedder
|
embedder = clients.embedder
|
||||||
await create_entity_edge_embeddings(embedder, extracted_edges)
|
await create_entity_edge_embeddings(embedder, extracted_edges)
|
||||||
|
|
||||||
search_results = await semaphore_gather(
|
valid_uuids_list: list[list[str]] = await semaphore_gather(
|
||||||
get_relevant_edges(driver, extracted_edges, SearchFilters()),
|
*[
|
||||||
get_edge_invalidation_candidates(driver, extracted_edges, SearchFilters(), 0.2),
|
EntityEdge.get_between_nodes(driver, edge.source_node_uuid, edge.target_node_uuid)
|
||||||
|
for edge in extracted_edges
|
||||||
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
related_edges_lists, edge_invalidation_candidates = search_results
|
related_edges_results: list[SearchResults] = await semaphore_gather(
|
||||||
|
*[
|
||||||
|
search(
|
||||||
|
clients,
|
||||||
|
extracted_edge.fact,
|
||||||
|
group_ids=[extracted_edge.group_id],
|
||||||
|
config=EDGE_HYBRID_SEARCH_RRF,
|
||||||
|
search_filter=SearchFilters(uuids=valid_uuids),
|
||||||
|
)
|
||||||
|
for extracted_edge, valid_uuids in zip(extracted_edges, valid_uuids_list)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
related_edges_lists: list[list[EntityEdge]] = [result.edges for result in related_edges_results]
|
||||||
|
|
||||||
|
edge_invalidation_candidate_results: list[SearchResults] = await semaphore_gather(
|
||||||
|
*[
|
||||||
|
search(
|
||||||
|
clients,
|
||||||
|
extracted_edge.fact,
|
||||||
|
group_ids=[extracted_edge.group_id],
|
||||||
|
config=EDGE_HYBRID_SEARCH_RRF,
|
||||||
|
search_filter=SearchFilters(),
|
||||||
|
)
|
||||||
|
for extracted_edge in extracted_edges
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
edge_invalidation_candidates: list[list[EntityEdge]] = [
|
||||||
|
result.edges for result in edge_invalidation_candidate_results
|
||||||
|
]
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f'Related edges lists: {[(e.name, e.uuid) for edges_lst in related_edges_lists for e in edges_lst]}'
|
f'Related edges lists: {[(e.name, e.uuid) for edges_lst in related_edges_lists for e in edges_lst]}'
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue