chore: Update dependencies and enhance edge resolution logic

- Add new dependencies: boto3, opensearch-py, and langchain-aws to pyproject.toml.
- Modify Graphiti class to handle additional parameters in edge resolution.
- Improve edge type handling in deduplication logic by introducing custom edge type names.
- Enhance tests for edge resolution to cover new scenarios and ensure correct behavior.

This update improves the flexibility and functionality of edge operations while ensuring compatibility with new libraries.
This commit is contained in:
Daniel Chalef 2025-09-29 21:10:15 -07:00
parent 3fcd587276
commit e5337d3504
5 changed files with 177 additions and 6 deletions

View file

@ -1070,6 +1070,7 @@ class Graphiti:
group_id=edge.group_id,
),
None,
None,
self.ensure_ascii,
)

View file

@ -477,6 +477,7 @@ async def dedupe_edges_bulk(
candidates,
episode,
edge_types,
set(edge_types),
clients.ensure_ascii,
)
for episode, edge, candidates in dedupe_tuples

View file

@ -283,8 +283,12 @@ async def resolve_extracted_edges(
# Build entity hash table
uuid_entity_map: dict[str, EntityNode] = {entity.uuid: entity for entity in entities}
# Determine which edge types are relevant for each edge
# Determine which edge types are relevant for each edge.
# `edge_types_lst` stores the subset of custom edge definitions whose
# node signature matches each extracted edge. Anything outside this subset
# should only stay on the edge if it is a non-custom (LLM generated) label.
edge_types_lst: list[dict[str, type[BaseModel]]] = []
custom_type_names = set(edge_types or {})
for extracted_edge in extracted_edges:
source_node = uuid_entity_map.get(extracted_edge.source_node_uuid)
target_node = uuid_entity_map.get(extracted_edge.target_node_uuid)
@ -314,11 +318,16 @@ async def resolve_extracted_edges(
for extracted_edge, extracted_edge_types in zip(extracted_edges, edge_types_lst, strict=True):
allowed_type_names = set(extracted_edge_types)
is_custom_name = extracted_edge.name in custom_type_names
if not allowed_type_names:
if extracted_edge.name != DEFAULT_EDGE_NAME:
# No custom types are valid for this node pairing. Keep LLM generated
# labels, but flip disallowed custom names back to the default.
if is_custom_name and extracted_edge.name != DEFAULT_EDGE_NAME:
extracted_edge.name = DEFAULT_EDGE_NAME
continue
if extracted_edge.name not in allowed_type_names:
if is_custom_name and extracted_edge.name not in allowed_type_names:
# Custom name exists but it is not permitted for this source/target
# signature, so fall back to the default edge label.
extracted_edge.name = DEFAULT_EDGE_NAME
# resolve edges with related edges in the graph and find invalidation candidates
@ -332,6 +341,7 @@ async def resolve_extracted_edges(
existing_edges,
episode,
extracted_edge_types,
custom_type_names,
clients.ensure_ascii,
)
for extracted_edge, related_edges, existing_edges, extracted_edge_types in zip(
@ -404,6 +414,7 @@ async def resolve_extracted_edge(
existing_edges: list[EntityEdge],
episode: EpisodicNode,
edge_type_candidates: dict[str, type[BaseModel]] | None = None,
custom_edge_type_names: set[str] | None = None,
ensure_ascii: bool = True,
) -> tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]:
if len(related_edges) == 0 and len(existing_edges) == 0:
@ -480,7 +491,11 @@ async def resolve_extracted_edge(
fact_type: str = response_object.fact_type
candidate_type_names = set(edge_type_candidates or {})
custom_type_names = custom_edge_type_names or set()
if candidate_type_names and fact_type in candidate_type_names:
# The LLM selected a custom type that is allowed for the node pair.
# Adopt the custom type and, if needed, extract its structured attributes.
resolved_edge.name = fact_type
edge_attributes_context = {
@ -499,9 +514,16 @@ async def resolve_extracted_edge(
)
resolved_edge.attributes = edge_attributes_response
elif fact_type.upper() != 'DEFAULT':
elif fact_type.upper() != 'DEFAULT' and fact_type in custom_type_names:
# The LLM picked a custom type that is not allowed for this signature.
# Reset to the default label and drop any structured attributes.
resolved_edge.name = DEFAULT_EDGE_NAME
resolved_edge.attributes = {}
elif fact_type.upper() != 'DEFAULT':
# Non-custom labels are allowed to pass through so long as the LLM does
# not return the sentinel DEFAULT value.
resolved_edge.name = fact_type
resolved_edge.attributes = {}
end = time()
logger.debug(

View file

@ -42,6 +42,9 @@ dev = [
"google-genai>=1.8.0",
"falkordb>=1.1.2,<2.0.0",
"kuzu>=0.11.2",
"boto3>=1.39.16",
"opensearch-py>=3.0.0",
"langchain-aws>=0.2.29",
"ipykernel>=6.29.5",
"jupyterlab>=4.2.4",
"diskcache-stubs>=5.6.3.6.20240818",

View file

@ -1,11 +1,18 @@
import sys
import types
from datetime import datetime, timedelta, timezone
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock
opensearch_stub = types.ModuleType('opensearchpy')
opensearch_stub.AsyncOpenSearch = None
opensearch_stub.helpers = None
sys.modules.setdefault('opensearchpy', opensearch_stub)
import pytest
from pydantic import BaseModel
from graphiti_core.edges import EntityEdge
from graphiti_core.graphiti_types import GraphitiClients
from graphiti_core.nodes import EntityNode, EpisodicNode
from graphiti_core.search.search_config import SearchResults
from graphiti_core.utils.maintenance.edge_operations import (
@ -175,7 +182,7 @@ async def test_resolve_extracted_edges_resets_unmapped_names(monkeypatch):
return_value={'duplicate_facts': [], 'contradicted_facts': [], 'fact_type': 'DEFAULT'}
)
clients = GraphitiClients.model_construct(
clients = SimpleNamespace(
driver=MagicMock(),
llm_client=llm_client,
embedder=MagicMock(),
@ -234,6 +241,83 @@ async def test_resolve_extracted_edges_resets_unmapped_names(monkeypatch):
assert invalidated_edges == []
@pytest.mark.asyncio
async def test_resolve_extracted_edges_keeps_unknown_names(monkeypatch):
from graphiti_core.utils.maintenance import edge_operations as edge_ops
monkeypatch.setattr(edge_ops, 'create_entity_edge_embeddings', AsyncMock(return_value=None))
monkeypatch.setattr(EntityEdge, 'get_between_nodes', AsyncMock(return_value=[]))
async def immediate_gather(*aws, max_coroutines=None):
return [await aw for aw in aws]
monkeypatch.setattr(edge_ops, 'semaphore_gather', immediate_gather)
monkeypatch.setattr(edge_ops, 'search', AsyncMock(return_value=SearchResults()))
llm_client = MagicMock()
llm_client.generate_response = AsyncMock(
return_value={'duplicate_facts': [], 'contradicted_facts': [], 'fact_type': 'DEFAULT'}
)
clients = SimpleNamespace(
driver=MagicMock(),
llm_client=llm_client,
embedder=MagicMock(),
cross_encoder=MagicMock(),
ensure_ascii=True,
)
source_node = EntityNode(
uuid='source_uuid',
name='User Node',
group_id='group_1',
labels=['User'],
)
target_node = EntityNode(
uuid='target_uuid',
name='Topic Node',
group_id='group_1',
labels=['Topic'],
)
extracted_edge = EntityEdge(
source_node_uuid=source_node.uuid,
target_node_uuid=target_node.uuid,
name='INTERACTED_WITH',
group_id='group_1',
fact='User interacted with topic',
episodes=[],
created_at=datetime.now(timezone.utc),
valid_at=None,
invalid_at=None,
)
episode = EpisodicNode(
uuid='episode_uuid',
name='Episode',
group_id='group_1',
source='message',
source_description='desc',
content='Episode content',
valid_at=datetime.now(timezone.utc),
)
edge_types = {'OCCURRED_AT': OccurredAtEdge}
edge_type_map = {('Event', 'Entity'): ['OCCURRED_AT']}
resolved_edges, invalidated_edges = await resolve_extracted_edges(
clients,
[extracted_edge],
episode,
[source_node, target_node],
edge_types,
edge_type_map,
)
assert resolved_edges[0].name == 'INTERACTED_WITH'
assert invalidated_edges == []
@pytest.mark.asyncio
async def test_resolve_extracted_edge_rejects_unmapped_fact_type(mock_llm_client):
mock_llm_client.generate_response.return_value = {
@ -283,9 +367,69 @@ async def test_resolve_extracted_edge_rejects_unmapped_fact_type(mock_llm_client
[],
episode,
edge_type_candidates={},
custom_edge_type_names={'OCCURRED_AT'},
ensure_ascii=True,
)
assert resolved_edge.name == DEFAULT_EDGE_NAME
assert duplicates == []
assert invalidated == []
@pytest.mark.asyncio
async def test_resolve_extracted_edge_accepts_unknown_fact_type(mock_llm_client):
mock_llm_client.generate_response.return_value = {
'duplicate_facts': [],
'contradicted_facts': [],
'fact_type': 'INTERACTED_WITH',
}
extracted_edge = EntityEdge(
source_node_uuid='source_uuid',
target_node_uuid='target_uuid',
name='DEFAULT',
group_id='group_1',
fact='User interacted with topic',
episodes=[],
created_at=datetime.now(timezone.utc),
valid_at=None,
invalid_at=None,
)
episode = EpisodicNode(
uuid='episode_uuid',
name='Episode',
group_id='group_1',
source='message',
source_description='desc',
content='Episode content',
valid_at=datetime.now(timezone.utc),
)
related_edge = EntityEdge(
source_node_uuid='source_uuid',
target_node_uuid='target_uuid',
name='DEFAULT',
group_id='group_1',
fact='User mentioned a topic',
episodes=[],
created_at=datetime.now(timezone.utc),
valid_at=None,
invalid_at=None,
)
resolved_edge, duplicates, invalidated = await resolve_extracted_edge(
mock_llm_client,
extracted_edge,
[related_edge],
[],
episode,
edge_type_candidates={'OCCURRED_AT': OccurredAtEdge},
custom_edge_type_names={'OCCURRED_AT'},
ensure_ascii=True,
)
assert resolved_edge.name == 'INTERACTED_WITH'
assert resolved_edge.attributes == {}
assert duplicates == []
assert invalidated == []