add uuid filter functionality
This commit is contained in:
parent
37715f6261
commit
427f917614
4 changed files with 123 additions and 65 deletions
|
|
@ -60,9 +60,7 @@ from graphiti_core.search.search_config_recipes import (
|
|||
from graphiti_core.search.search_filters import SearchFilters
|
||||
from graphiti_core.search.search_utils import (
|
||||
RELEVANT_SCHEMA_LIMIT,
|
||||
get_edge_invalidation_candidates,
|
||||
get_mentioned_nodes,
|
||||
get_relevant_edges,
|
||||
)
|
||||
from graphiti_core.telemetry import capture_event
|
||||
from graphiti_core.utils.bulk_utils import (
|
||||
|
|
@ -1037,10 +1035,28 @@ class Graphiti:
|
|||
|
||||
updated_edge = resolve_edge_pointers([edge], uuid_map)[0]
|
||||
|
||||
related_edges = (await get_relevant_edges(self.driver, [updated_edge], SearchFilters()))[0]
|
||||
valid_uuids = await EntityEdge.get_between_nodes(
|
||||
self.driver, edge.source_node_uuid, edge.target_node_uuid
|
||||
)
|
||||
|
||||
related_edges = (
|
||||
await search(
|
||||
self.clients,
|
||||
updated_edge.fact,
|
||||
group_ids=[updated_edge.group_id],
|
||||
config=EDGE_HYBRID_SEARCH_RRF,
|
||||
search_filter=SearchFilters(uuids=valid_uuids),
|
||||
)
|
||||
).edges
|
||||
existing_edges = (
|
||||
await get_edge_invalidation_candidates(self.driver, [updated_edge], SearchFilters())
|
||||
)[0]
|
||||
await search(
|
||||
self.clients,
|
||||
updated_edge.fact,
|
||||
group_ids=[updated_edge.group_id],
|
||||
config=EDGE_HYBRID_SEARCH_RRF,
|
||||
search_filter=SearchFilters(),
|
||||
)
|
||||
).edges
|
||||
|
||||
resolved_edge, invalidated_edges, _ = await resolve_extracted_edge(
|
||||
self.llm_client,
|
||||
|
|
|
|||
|
|
@ -52,6 +52,7 @@ class SearchFilters(BaseModel):
|
|||
invalid_at: list[list[DateFilter]] | None = Field(default=None)
|
||||
created_at: list[list[DateFilter]] | None = Field(default=None)
|
||||
expired_at: list[list[DateFilter]] | None = Field(default=None)
|
||||
edge_uuids: list[str] | None = Field(default=None)
|
||||
|
||||
|
||||
def cypher_to_opensearch_operator(op: ComparisonOperator) -> str:
|
||||
|
|
@ -108,6 +109,10 @@ def edge_search_filter_query_constructor(
|
|||
filter_queries.append('e.name in $edge_types')
|
||||
filter_params['edge_types'] = edge_types
|
||||
|
||||
if filters.edge_uuids is not None:
|
||||
filter_queries.append('e.uuid in $edge_uuids')
|
||||
filter_params['edge_uuids'] = filters.edge_uuids
|
||||
|
||||
if filters.node_labels is not None:
|
||||
if provider == GraphProvider.KUZU:
|
||||
node_label_filter = (
|
||||
|
|
@ -261,6 +266,9 @@ def build_aoss_edge_filters(group_ids: list[str], search_filters: SearchFilters)
|
|||
if search_filters.edge_types:
|
||||
filters.append({'terms': {'edge_types': search_filters.edge_types}})
|
||||
|
||||
if search_filters.edge_uuids:
|
||||
filters.append({'terms': {'uuid': search_filters.edge_uuids}})
|
||||
|
||||
for field in ['valid_at', 'invalid_at', 'created_at', 'expired_at']:
|
||||
ranges = getattr(search_filters, field)
|
||||
if ranges:
|
||||
|
|
|
|||
|
|
@ -209,11 +209,11 @@ async def edge_fulltext_search(
|
|||
# Match the edge ids and return the values
|
||||
query = (
|
||||
"""
|
||||
UNWIND $ids as id
|
||||
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
||||
WHERE e.group_id IN $group_ids
|
||||
AND id(e)=id
|
||||
"""
|
||||
UNWIND $ids as id
|
||||
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
||||
WHERE e.group_id IN $group_ids
|
||||
AND id(e)=id
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
AND id(e)=id
|
||||
|
|
@ -344,8 +344,8 @@ async def edge_similarity_search(
|
|||
if driver.provider == GraphProvider.NEPTUNE:
|
||||
query = (
|
||||
"""
|
||||
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
||||
"""
|
||||
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
RETURN DISTINCT id(e) as id, e.fact_embedding as embedding
|
||||
|
|
@ -622,11 +622,11 @@ async def node_fulltext_search(
|
|||
# Match the edge ides and return the values
|
||||
query = (
|
||||
"""
|
||||
UNWIND $ids as i
|
||||
MATCH (n:Entity)
|
||||
WHERE n.uuid=i.id
|
||||
RETURN
|
||||
"""
|
||||
UNWIND $ids as i
|
||||
MATCH (n:Entity)
|
||||
WHERE n.uuid=i.id
|
||||
RETURN
|
||||
"""
|
||||
+ get_entity_node_return_query(driver.provider)
|
||||
+ """
|
||||
ORDER BY i.score DESC
|
||||
|
|
@ -734,8 +734,8 @@ async def node_similarity_search(
|
|||
if driver.provider == GraphProvider.NEPTUNE:
|
||||
query = (
|
||||
"""
|
||||
MATCH (n:Entity)
|
||||
"""
|
||||
MATCH (n:Entity)
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
RETURN DISTINCT id(n) as id, n.name_embedding as embedding
|
||||
|
|
@ -764,11 +764,11 @@ async def node_similarity_search(
|
|||
# Match the edge ides and return the values
|
||||
query = (
|
||||
"""
|
||||
UNWIND $ids as i
|
||||
MATCH (n:Entity)
|
||||
WHERE id(n)=i.id
|
||||
RETURN
|
||||
"""
|
||||
UNWIND $ids as i
|
||||
MATCH (n:Entity)
|
||||
WHERE id(n)=i.id
|
||||
RETURN
|
||||
"""
|
||||
+ get_entity_node_return_query(driver.provider)
|
||||
+ """
|
||||
ORDER BY i.score DESC
|
||||
|
|
@ -814,8 +814,8 @@ async def node_similarity_search(
|
|||
else:
|
||||
query = (
|
||||
"""
|
||||
MATCH (n:Entity)
|
||||
"""
|
||||
MATCH (n:Entity)
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
WITH n, """
|
||||
|
|
@ -1147,8 +1147,8 @@ async def community_similarity_search(
|
|||
if driver.provider == GraphProvider.NEPTUNE:
|
||||
query = (
|
||||
"""
|
||||
MATCH (n:Community)
|
||||
"""
|
||||
MATCH (n:Community)
|
||||
"""
|
||||
+ group_filter_query
|
||||
+ """
|
||||
RETURN DISTINCT id(n) as id, n.name_embedding as embedding
|
||||
|
|
@ -1207,8 +1207,8 @@ async def community_similarity_search(
|
|||
|
||||
query = (
|
||||
"""
|
||||
MATCH (c:Community)
|
||||
"""
|
||||
MATCH (c:Community)
|
||||
"""
|
||||
+ group_filter_query
|
||||
+ """
|
||||
WITH c,
|
||||
|
|
@ -1350,9 +1350,9 @@ async def get_relevant_nodes(
|
|||
# FIXME: Kuzu currently does not support using variables such as `node.fulltext_query` as an input to FTS, which means `get_relevant_nodes()` won't work with Kuzu as the graph driver.
|
||||
query = (
|
||||
"""
|
||||
UNWIND $nodes AS node
|
||||
MATCH (n:Entity {group_id: $group_id})
|
||||
"""
|
||||
UNWIND $nodes AS node
|
||||
MATCH (n:Entity {group_id: $group_id})
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
WITH node, n, """
|
||||
|
|
@ -1397,9 +1397,9 @@ async def get_relevant_nodes(
|
|||
else:
|
||||
query = (
|
||||
"""
|
||||
UNWIND $nodes AS node
|
||||
MATCH (n:Entity {group_id: $group_id})
|
||||
"""
|
||||
UNWIND $nodes AS node
|
||||
MATCH (n:Entity {group_id: $group_id})
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
WITH node, n, """
|
||||
|
|
@ -1488,9 +1488,9 @@ async def get_relevant_edges(
|
|||
if driver.provider == GraphProvider.NEPTUNE:
|
||||
query = (
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
WITH e, edge
|
||||
|
|
@ -1560,9 +1560,9 @@ async def get_relevant_edges(
|
|||
|
||||
query = (
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid})
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid})
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
WITH e, edge, n, m, """
|
||||
|
|
@ -1599,9 +1599,9 @@ async def get_relevant_edges(
|
|||
# First get edge candidates
|
||||
query = (
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
RETURN
|
||||
|
|
@ -1647,9 +1647,9 @@ async def get_relevant_edges(
|
|||
else:
|
||||
query = (
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
WITH e, edge, """
|
||||
|
|
@ -1722,10 +1722,10 @@ async def get_edge_invalidation_candidates(
|
|||
if driver.provider == GraphProvider.NEPTUNE:
|
||||
query = (
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
|
||||
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
|
||||
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
WITH e, edge
|
||||
|
|
@ -1795,10 +1795,10 @@ async def get_edge_invalidation_candidates(
|
|||
|
||||
query = (
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity)
|
||||
WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid])
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity)
|
||||
WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid])
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
WITH edge, e, n, m, """
|
||||
|
|
@ -1834,10 +1834,10 @@ async def get_edge_invalidation_candidates(
|
|||
else:
|
||||
query = (
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
|
||||
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
|
||||
"""
|
||||
UNWIND $edges AS edge
|
||||
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
|
||||
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
|
||||
"""
|
||||
+ filter_query
|
||||
+ """
|
||||
WITH edge, e, """
|
||||
|
|
|
|||
|
|
@ -36,8 +36,10 @@ from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
|
|||
from graphiti_core.prompts import prompt_library
|
||||
from graphiti_core.prompts.dedupe_edges import EdgeDuplicate
|
||||
from graphiti_core.prompts.extract_edges import ExtractedEdges, MissingFacts
|
||||
from graphiti_core.search.search import search
|
||||
from graphiti_core.search.search_config import SearchResults
|
||||
from graphiti_core.search.search_config_recipes import EDGE_HYBRID_SEARCH_RRF
|
||||
from graphiti_core.search.search_filters import SearchFilters
|
||||
from graphiti_core.search.search_utils import get_edge_invalidation_candidates, get_relevant_edges
|
||||
from graphiti_core.utils.datetime_utils import ensure_utc, utc_now
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -258,12 +260,44 @@ async def resolve_extracted_edges(
|
|||
embedder = clients.embedder
|
||||
await create_entity_edge_embeddings(embedder, extracted_edges)
|
||||
|
||||
search_results = await semaphore_gather(
|
||||
get_relevant_edges(driver, extracted_edges, SearchFilters()),
|
||||
get_edge_invalidation_candidates(driver, extracted_edges, SearchFilters(), 0.2),
|
||||
valid_uuids_list: list[list[str]] = await semaphore_gather(
|
||||
*[
|
||||
EntityEdge.get_between_nodes(driver, edge.source_node_uuid, edge.target_node_uuid)
|
||||
for edge in extracted_edges
|
||||
]
|
||||
)
|
||||
|
||||
related_edges_lists, edge_invalidation_candidates = search_results
|
||||
related_edges_results: list[SearchResults] = await semaphore_gather(
|
||||
*[
|
||||
search(
|
||||
clients,
|
||||
extracted_edge.fact,
|
||||
group_ids=[extracted_edge.group_id],
|
||||
config=EDGE_HYBRID_SEARCH_RRF,
|
||||
search_filter=SearchFilters(uuids=valid_uuids),
|
||||
)
|
||||
for extracted_edge, valid_uuids in zip(extracted_edges, valid_uuids_list)
|
||||
]
|
||||
)
|
||||
|
||||
related_edges_lists: list[list[EntityEdge]] = [result.edges for result in related_edges_results]
|
||||
|
||||
edge_invalidation_candidate_results: list[SearchResults] = await semaphore_gather(
|
||||
*[
|
||||
search(
|
||||
clients,
|
||||
extracted_edge.fact,
|
||||
group_ids=[extracted_edge.group_id],
|
||||
config=EDGE_HYBRID_SEARCH_RRF,
|
||||
search_filter=SearchFilters(),
|
||||
)
|
||||
for extracted_edge in extracted_edges
|
||||
]
|
||||
)
|
||||
|
||||
edge_invalidation_candidates: list[list[EntityEdge]] = [
|
||||
result.edges for result in edge_invalidation_candidate_results
|
||||
]
|
||||
|
||||
logger.debug(
|
||||
f'Related edges lists: {[(e.name, e.uuid) for edges_lst in related_edges_lists for e in edges_lst]}'
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue