- Add new dependencies: boto3, opensearch-py, and langchain-aws to pyproject.toml. - Modify Graphiti class to handle additional parameters in edge resolution. - Improve edge type handling in deduplication logic by introducing custom edge type names. - Enhance tests for edge resolution to cover new scenarios and ensure correct behavior. This update improves the flexibility and functionality of edge operations while ensuring compatibility with new libraries.
562 lines
20 KiB
Python
562 lines
20 KiB
Python
"""
|
||
Copyright 2024, Zep Software, Inc.
|
||
|
||
Licensed under the Apache License, Version 2.0 (the "License");
|
||
you may not use this file except in compliance with the License.
|
||
You may obtain a copy of the License at
|
||
|
||
http://www.apache.org/licenses/LICENSE-2.0
|
||
|
||
Unless required by applicable law or agreed to in writing, software
|
||
distributed under the License is distributed on an "AS IS" BASIS,
|
||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
See the License for the specific language governing permissions and
|
||
limitations under the License.
|
||
"""
|
||
|
||
import json
|
||
import logging
|
||
import typing
|
||
from datetime import datetime
|
||
|
||
import numpy as np
|
||
from pydantic import BaseModel, Field
|
||
from typing_extensions import Any
|
||
|
||
from graphiti_core.driver.driver import (
|
||
ENTITY_EDGE_INDEX_NAME,
|
||
ENTITY_INDEX_NAME,
|
||
EPISODE_INDEX_NAME,
|
||
GraphDriver,
|
||
GraphDriverSession,
|
||
GraphProvider,
|
||
)
|
||
from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge, create_entity_edge_embeddings
|
||
from graphiti_core.embedder import EmbedderClient
|
||
from graphiti_core.graphiti_types import GraphitiClients
|
||
from graphiti_core.helpers import normalize_l2, semaphore_gather
|
||
from graphiti_core.models.edges.edge_db_queries import (
|
||
get_entity_edge_save_bulk_query,
|
||
get_episodic_edge_save_bulk_query,
|
||
)
|
||
from graphiti_core.models.nodes.node_db_queries import (
|
||
get_entity_node_save_bulk_query,
|
||
get_episode_node_save_bulk_query,
|
||
)
|
||
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
||
from graphiti_core.utils.datetime_utils import convert_datetimes_to_strings
|
||
from graphiti_core.utils.maintenance.dedup_helpers import (
|
||
DedupResolutionState,
|
||
_build_candidate_indexes,
|
||
_normalize_string_exact,
|
||
_resolve_with_similarity,
|
||
)
|
||
from graphiti_core.utils.maintenance.edge_operations import (
|
||
extract_edges,
|
||
resolve_extracted_edge,
|
||
)
|
||
from graphiti_core.utils.maintenance.graph_data_operations import (
|
||
EPISODE_WINDOW_LEN,
|
||
retrieve_episodes,
|
||
)
|
||
from graphiti_core.utils.maintenance.node_operations import (
|
||
extract_nodes,
|
||
resolve_extracted_nodes,
|
||
)
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
CHUNK_SIZE = 10
|
||
|
||
|
||
def _build_directed_uuid_map(pairs: list[tuple[str, str]]) -> dict[str, str]:
|
||
"""Collapse alias -> canonical chains while preserving direction.
|
||
|
||
The incoming pairs represent directed mappings discovered during node dedupe. We use a simple
|
||
union-find with iterative path compression to ensure every source UUID resolves to its ultimate
|
||
canonical target, even if aliases appear lexicographically smaller than the canonical UUID.
|
||
"""
|
||
|
||
parent: dict[str, str] = {}
|
||
|
||
def find(uuid: str) -> str:
|
||
"""Directed union-find lookup using iterative path compression."""
|
||
parent.setdefault(uuid, uuid)
|
||
root = uuid
|
||
while parent[root] != root:
|
||
root = parent[root]
|
||
|
||
while parent[uuid] != root:
|
||
next_uuid = parent[uuid]
|
||
parent[uuid] = root
|
||
uuid = next_uuid
|
||
|
||
return root
|
||
|
||
for source_uuid, target_uuid in pairs:
|
||
parent.setdefault(source_uuid, source_uuid)
|
||
parent.setdefault(target_uuid, target_uuid)
|
||
parent[find(source_uuid)] = find(target_uuid)
|
||
|
||
return {uuid: find(uuid) for uuid in parent}
|
||
|
||
|
||
class RawEpisode(BaseModel):
|
||
name: str
|
||
uuid: str | None = Field(default=None)
|
||
content: str
|
||
source_description: str
|
||
source: EpisodeType
|
||
reference_time: datetime
|
||
|
||
|
||
async def retrieve_previous_episodes_bulk(
|
||
driver: GraphDriver, episodes: list[EpisodicNode]
|
||
) -> list[tuple[EpisodicNode, list[EpisodicNode]]]:
|
||
previous_episodes_list = await semaphore_gather(
|
||
*[
|
||
retrieve_episodes(
|
||
driver, episode.valid_at, last_n=EPISODE_WINDOW_LEN, group_ids=[episode.group_id]
|
||
)
|
||
for episode in episodes
|
||
]
|
||
)
|
||
episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]] = [
|
||
(episode, previous_episodes_list[i]) for i, episode in enumerate(episodes)
|
||
]
|
||
|
||
return episode_tuples
|
||
|
||
|
||
async def add_nodes_and_edges_bulk(
|
||
driver: GraphDriver,
|
||
episodic_nodes: list[EpisodicNode],
|
||
episodic_edges: list[EpisodicEdge],
|
||
entity_nodes: list[EntityNode],
|
||
entity_edges: list[EntityEdge],
|
||
embedder: EmbedderClient,
|
||
):
|
||
session = driver.session()
|
||
try:
|
||
await session.execute_write(
|
||
add_nodes_and_edges_bulk_tx,
|
||
episodic_nodes,
|
||
episodic_edges,
|
||
entity_nodes,
|
||
entity_edges,
|
||
embedder,
|
||
driver=driver,
|
||
)
|
||
finally:
|
||
await session.close()
|
||
|
||
|
||
async def add_nodes_and_edges_bulk_tx(
|
||
tx: GraphDriverSession,
|
||
episodic_nodes: list[EpisodicNode],
|
||
episodic_edges: list[EpisodicEdge],
|
||
entity_nodes: list[EntityNode],
|
||
entity_edges: list[EntityEdge],
|
||
embedder: EmbedderClient,
|
||
driver: GraphDriver,
|
||
):
|
||
episodes = [dict(episode) for episode in episodic_nodes]
|
||
for episode in episodes:
|
||
episode['source'] = str(episode['source'].value)
|
||
episode.pop('labels', None)
|
||
|
||
nodes = []
|
||
|
||
for node in entity_nodes:
|
||
if node.name_embedding is None:
|
||
await node.generate_name_embedding(embedder)
|
||
|
||
entity_data: dict[str, Any] = {
|
||
'uuid': node.uuid,
|
||
'name': node.name,
|
||
'group_id': node.group_id,
|
||
'summary': node.summary,
|
||
'created_at': node.created_at,
|
||
}
|
||
|
||
if not bool(driver.aoss_client):
|
||
entity_data['name_embedding'] = node.name_embedding
|
||
|
||
entity_data['labels'] = list(set(node.labels + ['Entity']))
|
||
if driver.provider == GraphProvider.KUZU:
|
||
attributes = convert_datetimes_to_strings(node.attributes) if node.attributes else {}
|
||
entity_data['attributes'] = json.dumps(attributes)
|
||
else:
|
||
entity_data.update(node.attributes or {})
|
||
|
||
nodes.append(entity_data)
|
||
|
||
edges = []
|
||
for edge in entity_edges:
|
||
if edge.fact_embedding is None:
|
||
await edge.generate_embedding(embedder)
|
||
edge_data: dict[str, Any] = {
|
||
'uuid': edge.uuid,
|
||
'source_node_uuid': edge.source_node_uuid,
|
||
'target_node_uuid': edge.target_node_uuid,
|
||
'name': edge.name,
|
||
'fact': edge.fact,
|
||
'group_id': edge.group_id,
|
||
'episodes': edge.episodes,
|
||
'created_at': edge.created_at,
|
||
'expired_at': edge.expired_at,
|
||
'valid_at': edge.valid_at,
|
||
'invalid_at': edge.invalid_at,
|
||
}
|
||
|
||
if not bool(driver.aoss_client):
|
||
edge_data['fact_embedding'] = edge.fact_embedding
|
||
|
||
if driver.provider == GraphProvider.KUZU:
|
||
attributes = convert_datetimes_to_strings(edge.attributes) if edge.attributes else {}
|
||
edge_data['attributes'] = json.dumps(attributes)
|
||
else:
|
||
edge_data.update(edge.attributes or {})
|
||
|
||
edges.append(edge_data)
|
||
|
||
if driver.provider == GraphProvider.KUZU:
|
||
# FIXME: Kuzu's UNWIND does not currently support STRUCT[] type properly, so we insert the data one by one instead for now.
|
||
episode_query = get_episode_node_save_bulk_query(driver.provider)
|
||
for episode in episodes:
|
||
await tx.run(episode_query, **episode)
|
||
entity_node_query = get_entity_node_save_bulk_query(driver.provider, nodes)
|
||
for node in nodes:
|
||
await tx.run(entity_node_query, **node)
|
||
entity_edge_query = get_entity_edge_save_bulk_query(driver.provider)
|
||
for edge in edges:
|
||
await tx.run(entity_edge_query, **edge)
|
||
episodic_edge_query = get_episodic_edge_save_bulk_query(driver.provider)
|
||
for edge in episodic_edges:
|
||
await tx.run(episodic_edge_query, **edge.model_dump())
|
||
else:
|
||
await tx.run(get_episode_node_save_bulk_query(driver.provider), episodes=episodes)
|
||
await tx.run(
|
||
get_entity_node_save_bulk_query(
|
||
driver.provider, nodes, has_aoss=bool(driver.aoss_client)
|
||
),
|
||
nodes=nodes,
|
||
)
|
||
await tx.run(
|
||
get_episodic_edge_save_bulk_query(driver.provider),
|
||
episodic_edges=[edge.model_dump() for edge in episodic_edges],
|
||
)
|
||
await tx.run(
|
||
get_entity_edge_save_bulk_query(driver.provider, has_aoss=bool(driver.aoss_client)),
|
||
entity_edges=edges,
|
||
)
|
||
|
||
if bool(driver.aoss_client):
|
||
for node_data, entity_node in zip(nodes, entity_nodes, strict=True):
|
||
if node_data.get('uuid') == entity_node.uuid:
|
||
node_data['name_embedding'] = entity_node.name_embedding
|
||
|
||
for edge_data, entity_edge in zip(edges, entity_edges, strict=True):
|
||
if edge_data.get('uuid') == entity_edge.uuid:
|
||
edge_data['fact_embedding'] = entity_edge.fact_embedding
|
||
|
||
await driver.save_to_aoss(EPISODE_INDEX_NAME, episodes)
|
||
await driver.save_to_aoss(ENTITY_INDEX_NAME, nodes)
|
||
await driver.save_to_aoss(ENTITY_EDGE_INDEX_NAME, edges)
|
||
|
||
|
||
async def extract_nodes_and_edges_bulk(
|
||
clients: GraphitiClients,
|
||
episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
|
||
edge_type_map: dict[tuple[str, str], list[str]],
|
||
entity_types: dict[str, type[BaseModel]] | None = None,
|
||
excluded_entity_types: list[str] | None = None,
|
||
edge_types: dict[str, type[BaseModel]] | None = None,
|
||
) -> tuple[list[list[EntityNode]], list[list[EntityEdge]]]:
|
||
extracted_nodes_bulk: list[list[EntityNode]] = await semaphore_gather(
|
||
*[
|
||
extract_nodes(clients, episode, previous_episodes, entity_types, excluded_entity_types)
|
||
for episode, previous_episodes in episode_tuples
|
||
]
|
||
)
|
||
|
||
extracted_edges_bulk: list[list[EntityEdge]] = await semaphore_gather(
|
||
*[
|
||
extract_edges(
|
||
clients,
|
||
episode,
|
||
extracted_nodes_bulk[i],
|
||
previous_episodes,
|
||
edge_type_map=edge_type_map,
|
||
group_id=episode.group_id,
|
||
edge_types=edge_types,
|
||
)
|
||
for i, (episode, previous_episodes) in enumerate(episode_tuples)
|
||
]
|
||
)
|
||
|
||
return extracted_nodes_bulk, extracted_edges_bulk
|
||
|
||
|
||
async def dedupe_nodes_bulk(
|
||
clients: GraphitiClients,
|
||
extracted_nodes: list[list[EntityNode]],
|
||
episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
|
||
entity_types: dict[str, type[BaseModel]] | None = None,
|
||
) -> tuple[dict[str, list[EntityNode]], dict[str, str]]:
|
||
"""Resolve entity duplicates across an in-memory batch using a two-pass strategy.
|
||
|
||
1. Run :func:`resolve_extracted_nodes` for every episode in parallel so each batch item is
|
||
reconciled against the live graph just like the non-batch flow.
|
||
2. Re-run the deterministic similarity heuristics across the union of resolved nodes to catch
|
||
duplicates that only co-occur inside this batch, emitting a canonical UUID map that callers
|
||
can apply to edges and persistence.
|
||
"""
|
||
|
||
first_pass_results = await semaphore_gather(
|
||
*[
|
||
resolve_extracted_nodes(
|
||
clients,
|
||
nodes,
|
||
episode_tuples[i][0],
|
||
episode_tuples[i][1],
|
||
entity_types,
|
||
)
|
||
for i, nodes in enumerate(extracted_nodes)
|
||
]
|
||
)
|
||
|
||
episode_resolutions: list[tuple[str, list[EntityNode]]] = []
|
||
per_episode_uuid_maps: list[dict[str, str]] = []
|
||
duplicate_pairs: list[tuple[str, str]] = []
|
||
|
||
for (resolved_nodes, uuid_map, duplicates), (episode, _) in zip(
|
||
first_pass_results, episode_tuples, strict=True
|
||
):
|
||
episode_resolutions.append((episode.uuid, resolved_nodes))
|
||
per_episode_uuid_maps.append(uuid_map)
|
||
duplicate_pairs.extend((source.uuid, target.uuid) for source, target in duplicates)
|
||
|
||
canonical_nodes: dict[str, EntityNode] = {}
|
||
for _, resolved_nodes in episode_resolutions:
|
||
for node in resolved_nodes:
|
||
# NOTE: this loop is O(n^2) in the number of nodes inside the batch because we rebuild
|
||
# the MinHash index for the accumulated canonical pool each time. The LRU-backed
|
||
# shingle cache keeps the constant factors low for typical batch sizes (≤ CHUNK_SIZE),
|
||
# but if batches grow significantly we should switch to an incremental index or chunked
|
||
# processing.
|
||
if not canonical_nodes:
|
||
canonical_nodes[node.uuid] = node
|
||
continue
|
||
|
||
existing_candidates = list(canonical_nodes.values())
|
||
normalized = _normalize_string_exact(node.name)
|
||
exact_match = next(
|
||
(
|
||
candidate
|
||
for candidate in existing_candidates
|
||
if _normalize_string_exact(candidate.name) == normalized
|
||
),
|
||
None,
|
||
)
|
||
if exact_match is not None:
|
||
if exact_match.uuid != node.uuid:
|
||
duplicate_pairs.append((node.uuid, exact_match.uuid))
|
||
continue
|
||
|
||
indexes = _build_candidate_indexes(existing_candidates)
|
||
state = DedupResolutionState(
|
||
resolved_nodes=[None],
|
||
uuid_map={},
|
||
unresolved_indices=[],
|
||
)
|
||
_resolve_with_similarity([node], indexes, state)
|
||
|
||
resolved = state.resolved_nodes[0]
|
||
if resolved is None:
|
||
canonical_nodes[node.uuid] = node
|
||
continue
|
||
|
||
canonical_uuid = resolved.uuid
|
||
canonical_nodes.setdefault(canonical_uuid, resolved)
|
||
if canonical_uuid != node.uuid:
|
||
duplicate_pairs.append((node.uuid, canonical_uuid))
|
||
|
||
union_pairs: list[tuple[str, str]] = []
|
||
for uuid_map in per_episode_uuid_maps:
|
||
union_pairs.extend(uuid_map.items())
|
||
union_pairs.extend(duplicate_pairs)
|
||
|
||
compressed_map: dict[str, str] = _build_directed_uuid_map(union_pairs)
|
||
|
||
nodes_by_episode: dict[str, list[EntityNode]] = {}
|
||
for episode_uuid, resolved_nodes in episode_resolutions:
|
||
deduped_nodes: list[EntityNode] = []
|
||
seen: set[str] = set()
|
||
for node in resolved_nodes:
|
||
canonical_uuid = compressed_map.get(node.uuid, node.uuid)
|
||
if canonical_uuid in seen:
|
||
continue
|
||
seen.add(canonical_uuid)
|
||
canonical_node = canonical_nodes.get(canonical_uuid)
|
||
if canonical_node is None:
|
||
logger.error(
|
||
'Canonical node %s missing during batch dedupe; falling back to %s',
|
||
canonical_uuid,
|
||
node.uuid,
|
||
)
|
||
canonical_node = node
|
||
deduped_nodes.append(canonical_node)
|
||
|
||
nodes_by_episode[episode_uuid] = deduped_nodes
|
||
|
||
return nodes_by_episode, compressed_map
|
||
|
||
|
||
async def dedupe_edges_bulk(
|
||
clients: GraphitiClients,
|
||
extracted_edges: list[list[EntityEdge]],
|
||
episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
|
||
_entities: list[EntityNode],
|
||
edge_types: dict[str, type[BaseModel]],
|
||
_edge_type_map: dict[tuple[str, str], list[str]],
|
||
) -> dict[str, list[EntityEdge]]:
|
||
embedder = clients.embedder
|
||
min_score = 0.6
|
||
|
||
# generate embeddings
|
||
await semaphore_gather(
|
||
*[create_entity_edge_embeddings(embedder, edges) for edges in extracted_edges]
|
||
)
|
||
|
||
# Find similar results
|
||
dedupe_tuples: list[tuple[EpisodicNode, EntityEdge, list[EntityEdge]]] = []
|
||
for i, edges_i in enumerate(extracted_edges):
|
||
existing_edges: list[EntityEdge] = []
|
||
for j, edges_j in enumerate(extracted_edges):
|
||
if i == j:
|
||
continue
|
||
existing_edges += edges_j
|
||
|
||
for edge in edges_i:
|
||
candidates: list[EntityEdge] = []
|
||
for existing_edge in existing_edges:
|
||
# Approximate BM25 by checking for word overlaps (this is faster than creating many in-memory indices)
|
||
# This approach will cast a wider net than BM25, which is ideal for this use case
|
||
if (
|
||
edge.source_node_uuid != existing_edge.source_node_uuid
|
||
or edge.target_node_uuid != existing_edge.target_node_uuid
|
||
):
|
||
continue
|
||
|
||
edge_words = set(edge.fact.lower().split())
|
||
existing_edge_words = set(existing_edge.fact.lower().split())
|
||
has_overlap = not edge_words.isdisjoint(existing_edge_words)
|
||
if has_overlap:
|
||
candidates.append(existing_edge)
|
||
continue
|
||
|
||
# Check for semantic similarity even if there is no overlap
|
||
similarity = np.dot(
|
||
normalize_l2(edge.fact_embedding or []),
|
||
normalize_l2(existing_edge.fact_embedding or []),
|
||
)
|
||
if similarity >= min_score:
|
||
candidates.append(existing_edge)
|
||
|
||
dedupe_tuples.append((episode_tuples[i][0], edge, candidates))
|
||
|
||
bulk_edge_resolutions: list[
|
||
tuple[EntityEdge, EntityEdge, list[EntityEdge]]
|
||
] = await semaphore_gather(
|
||
*[
|
||
resolve_extracted_edge(
|
||
clients.llm_client,
|
||
edge,
|
||
candidates,
|
||
candidates,
|
||
episode,
|
||
edge_types,
|
||
set(edge_types),
|
||
clients.ensure_ascii,
|
||
)
|
||
for episode, edge, candidates in dedupe_tuples
|
||
]
|
||
)
|
||
|
||
# For now we won't track edge invalidation
|
||
duplicate_pairs: list[tuple[str, str]] = []
|
||
for i, (_, _, duplicates) in enumerate(bulk_edge_resolutions):
|
||
episode, edge, candidates = dedupe_tuples[i]
|
||
for duplicate in duplicates:
|
||
duplicate_pairs.append((edge.uuid, duplicate.uuid))
|
||
|
||
# Now we compress the duplicate_map, so that 3 -> 2 and 2 -> becomes 3 -> 1 (sorted by uuid)
|
||
compressed_map: dict[str, str] = compress_uuid_map(duplicate_pairs)
|
||
|
||
edge_uuid_map: dict[str, EntityEdge] = {
|
||
edge.uuid: edge for edges in extracted_edges for edge in edges
|
||
}
|
||
|
||
edges_by_episode: dict[str, list[EntityEdge]] = {}
|
||
for i, edges in enumerate(extracted_edges):
|
||
episode = episode_tuples[i][0]
|
||
|
||
edges_by_episode[episode.uuid] = [
|
||
edge_uuid_map[compressed_map.get(edge.uuid, edge.uuid)] for edge in edges
|
||
]
|
||
|
||
return edges_by_episode
|
||
|
||
|
||
class UnionFind:
|
||
def __init__(self, elements):
|
||
# start each element in its own set
|
||
self.parent = {e: e for e in elements}
|
||
|
||
def find(self, x):
|
||
# path‐compression
|
||
if self.parent[x] != x:
|
||
self.parent[x] = self.find(self.parent[x])
|
||
return self.parent[x]
|
||
|
||
def union(self, a, b):
|
||
ra, rb = self.find(a), self.find(b)
|
||
if ra == rb:
|
||
return
|
||
# attach the lexicographically larger root under the smaller
|
||
if ra < rb:
|
||
self.parent[rb] = ra
|
||
else:
|
||
self.parent[ra] = rb
|
||
|
||
|
||
def compress_uuid_map(duplicate_pairs: list[tuple[str, str]]) -> dict[str, str]:
|
||
"""
|
||
all_ids: iterable of all entity IDs (strings)
|
||
duplicate_pairs: iterable of (id1, id2) pairs
|
||
returns: dict mapping each id -> lexicographically smallest id in its duplicate set
|
||
"""
|
||
all_uuids = set()
|
||
for pair in duplicate_pairs:
|
||
all_uuids.add(pair[0])
|
||
all_uuids.add(pair[1])
|
||
|
||
uf = UnionFind(all_uuids)
|
||
for a, b in duplicate_pairs:
|
||
uf.union(a, b)
|
||
# ensure full path‐compression before mapping
|
||
return {uuid: uf.find(uuid) for uuid in all_uuids}
|
||
|
||
|
||
E = typing.TypeVar('E', bound=Edge)
|
||
|
||
|
||
def resolve_edge_pointers(edges: list[E], uuid_map: dict[str, str]):
|
||
for edge in edges:
|
||
source_uuid = edge.source_node_uuid
|
||
target_uuid = edge.target_node_uuid
|
||
edge.source_node_uuid = uuid_map.get(source_uuid, source_uuid)
|
||
edge.target_node_uuid = uuid_map.get(target_uuid, target_uuid)
|
||
|
||
return edges
|