chore: Update dependencies and enhance edge resolution logic
- Add new dependencies: boto3, opensearch-py, and langchain-aws to pyproject.toml. - Modify Graphiti class to handle additional parameters in edge resolution. - Improve edge type handling in deduplication logic by introducing custom edge type names. - Enhance tests for edge resolution to cover new scenarios and ensure correct behavior. This update improves the flexibility and functionality of edge operations while ensuring compatibility with new libraries.
This commit is contained in:
parent
3fcd587276
commit
e5337d3504
5 changed files with 177 additions and 6 deletions
|
|
@ -1070,6 +1070,7 @@ class Graphiti:
|
|||
group_id=edge.group_id,
|
||||
),
|
||||
None,
|
||||
None,
|
||||
self.ensure_ascii,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -477,6 +477,7 @@ async def dedupe_edges_bulk(
|
|||
candidates,
|
||||
episode,
|
||||
edge_types,
|
||||
set(edge_types),
|
||||
clients.ensure_ascii,
|
||||
)
|
||||
for episode, edge, candidates in dedupe_tuples
|
||||
|
|
|
|||
|
|
@ -283,8 +283,12 @@ async def resolve_extracted_edges(
|
|||
# Build entity hash table
|
||||
uuid_entity_map: dict[str, EntityNode] = {entity.uuid: entity for entity in entities}
|
||||
|
||||
# Determine which edge types are relevant for each edge
|
||||
# Determine which edge types are relevant for each edge.
|
||||
# `edge_types_lst` stores the subset of custom edge definitions whose
|
||||
# node signature matches each extracted edge. Anything outside this subset
|
||||
# should only stay on the edge if it is a non-custom (LLM generated) label.
|
||||
edge_types_lst: list[dict[str, type[BaseModel]]] = []
|
||||
custom_type_names = set(edge_types or {})
|
||||
for extracted_edge in extracted_edges:
|
||||
source_node = uuid_entity_map.get(extracted_edge.source_node_uuid)
|
||||
target_node = uuid_entity_map.get(extracted_edge.target_node_uuid)
|
||||
|
|
@ -314,11 +318,16 @@ async def resolve_extracted_edges(
|
|||
|
||||
for extracted_edge, extracted_edge_types in zip(extracted_edges, edge_types_lst, strict=True):
|
||||
allowed_type_names = set(extracted_edge_types)
|
||||
is_custom_name = extracted_edge.name in custom_type_names
|
||||
if not allowed_type_names:
|
||||
if extracted_edge.name != DEFAULT_EDGE_NAME:
|
||||
# No custom types are valid for this node pairing. Keep LLM generated
|
||||
# labels, but flip disallowed custom names back to the default.
|
||||
if is_custom_name and extracted_edge.name != DEFAULT_EDGE_NAME:
|
||||
extracted_edge.name = DEFAULT_EDGE_NAME
|
||||
continue
|
||||
if extracted_edge.name not in allowed_type_names:
|
||||
if is_custom_name and extracted_edge.name not in allowed_type_names:
|
||||
# Custom name exists but it is not permitted for this source/target
|
||||
# signature, so fall back to the default edge label.
|
||||
extracted_edge.name = DEFAULT_EDGE_NAME
|
||||
|
||||
# resolve edges with related edges in the graph and find invalidation candidates
|
||||
|
|
@ -332,6 +341,7 @@ async def resolve_extracted_edges(
|
|||
existing_edges,
|
||||
episode,
|
||||
extracted_edge_types,
|
||||
custom_type_names,
|
||||
clients.ensure_ascii,
|
||||
)
|
||||
for extracted_edge, related_edges, existing_edges, extracted_edge_types in zip(
|
||||
|
|
@ -404,6 +414,7 @@ async def resolve_extracted_edge(
|
|||
existing_edges: list[EntityEdge],
|
||||
episode: EpisodicNode,
|
||||
edge_type_candidates: dict[str, type[BaseModel]] | None = None,
|
||||
custom_edge_type_names: set[str] | None = None,
|
||||
ensure_ascii: bool = True,
|
||||
) -> tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]:
|
||||
if len(related_edges) == 0 and len(existing_edges) == 0:
|
||||
|
|
@ -480,7 +491,11 @@ async def resolve_extracted_edge(
|
|||
|
||||
fact_type: str = response_object.fact_type
|
||||
candidate_type_names = set(edge_type_candidates or {})
|
||||
custom_type_names = custom_edge_type_names or set()
|
||||
|
||||
if candidate_type_names and fact_type in candidate_type_names:
|
||||
# The LLM selected a custom type that is allowed for the node pair.
|
||||
# Adopt the custom type and, if needed, extract its structured attributes.
|
||||
resolved_edge.name = fact_type
|
||||
|
||||
edge_attributes_context = {
|
||||
|
|
@ -499,9 +514,16 @@ async def resolve_extracted_edge(
|
|||
)
|
||||
|
||||
resolved_edge.attributes = edge_attributes_response
|
||||
elif fact_type.upper() != 'DEFAULT':
|
||||
elif fact_type.upper() != 'DEFAULT' and fact_type in custom_type_names:
|
||||
# The LLM picked a custom type that is not allowed for this signature.
|
||||
# Reset to the default label and drop any structured attributes.
|
||||
resolved_edge.name = DEFAULT_EDGE_NAME
|
||||
resolved_edge.attributes = {}
|
||||
elif fact_type.upper() != 'DEFAULT':
|
||||
# Non-custom labels are allowed to pass through so long as the LLM does
|
||||
# not return the sentinel DEFAULT value.
|
||||
resolved_edge.name = fact_type
|
||||
resolved_edge.attributes = {}
|
||||
|
||||
end = time()
|
||||
logger.debug(
|
||||
|
|
|
|||
|
|
@ -42,6 +42,9 @@ dev = [
|
|||
"google-genai>=1.8.0",
|
||||
"falkordb>=1.1.2,<2.0.0",
|
||||
"kuzu>=0.11.2",
|
||||
"boto3>=1.39.16",
|
||||
"opensearch-py>=3.0.0",
|
||||
"langchain-aws>=0.2.29",
|
||||
"ipykernel>=6.29.5",
|
||||
"jupyterlab>=4.2.4",
|
||||
"diskcache-stubs>=5.6.3.6.20240818",
|
||||
|
|
|
|||
|
|
@ -1,11 +1,18 @@
|
|||
import sys
|
||||
import types
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
opensearch_stub = types.ModuleType('opensearchpy')
|
||||
opensearch_stub.AsyncOpenSearch = None
|
||||
opensearch_stub.helpers = None
|
||||
sys.modules.setdefault('opensearchpy', opensearch_stub)
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from graphiti_core.edges import EntityEdge
|
||||
from graphiti_core.graphiti_types import GraphitiClients
|
||||
from graphiti_core.nodes import EntityNode, EpisodicNode
|
||||
from graphiti_core.search.search_config import SearchResults
|
||||
from graphiti_core.utils.maintenance.edge_operations import (
|
||||
|
|
@ -175,7 +182,7 @@ async def test_resolve_extracted_edges_resets_unmapped_names(monkeypatch):
|
|||
return_value={'duplicate_facts': [], 'contradicted_facts': [], 'fact_type': 'DEFAULT'}
|
||||
)
|
||||
|
||||
clients = GraphitiClients.model_construct(
|
||||
clients = SimpleNamespace(
|
||||
driver=MagicMock(),
|
||||
llm_client=llm_client,
|
||||
embedder=MagicMock(),
|
||||
|
|
@ -234,6 +241,83 @@ async def test_resolve_extracted_edges_resets_unmapped_names(monkeypatch):
|
|||
assert invalidated_edges == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_extracted_edges_keeps_unknown_names(monkeypatch):
|
||||
from graphiti_core.utils.maintenance import edge_operations as edge_ops
|
||||
|
||||
monkeypatch.setattr(edge_ops, 'create_entity_edge_embeddings', AsyncMock(return_value=None))
|
||||
monkeypatch.setattr(EntityEdge, 'get_between_nodes', AsyncMock(return_value=[]))
|
||||
|
||||
async def immediate_gather(*aws, max_coroutines=None):
|
||||
return [await aw for aw in aws]
|
||||
|
||||
monkeypatch.setattr(edge_ops, 'semaphore_gather', immediate_gather)
|
||||
monkeypatch.setattr(edge_ops, 'search', AsyncMock(return_value=SearchResults()))
|
||||
|
||||
llm_client = MagicMock()
|
||||
llm_client.generate_response = AsyncMock(
|
||||
return_value={'duplicate_facts': [], 'contradicted_facts': [], 'fact_type': 'DEFAULT'}
|
||||
)
|
||||
|
||||
clients = SimpleNamespace(
|
||||
driver=MagicMock(),
|
||||
llm_client=llm_client,
|
||||
embedder=MagicMock(),
|
||||
cross_encoder=MagicMock(),
|
||||
ensure_ascii=True,
|
||||
)
|
||||
|
||||
source_node = EntityNode(
|
||||
uuid='source_uuid',
|
||||
name='User Node',
|
||||
group_id='group_1',
|
||||
labels=['User'],
|
||||
)
|
||||
target_node = EntityNode(
|
||||
uuid='target_uuid',
|
||||
name='Topic Node',
|
||||
group_id='group_1',
|
||||
labels=['Topic'],
|
||||
)
|
||||
|
||||
extracted_edge = EntityEdge(
|
||||
source_node_uuid=source_node.uuid,
|
||||
target_node_uuid=target_node.uuid,
|
||||
name='INTERACTED_WITH',
|
||||
group_id='group_1',
|
||||
fact='User interacted with topic',
|
||||
episodes=[],
|
||||
created_at=datetime.now(timezone.utc),
|
||||
valid_at=None,
|
||||
invalid_at=None,
|
||||
)
|
||||
|
||||
episode = EpisodicNode(
|
||||
uuid='episode_uuid',
|
||||
name='Episode',
|
||||
group_id='group_1',
|
||||
source='message',
|
||||
source_description='desc',
|
||||
content='Episode content',
|
||||
valid_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
edge_types = {'OCCURRED_AT': OccurredAtEdge}
|
||||
edge_type_map = {('Event', 'Entity'): ['OCCURRED_AT']}
|
||||
|
||||
resolved_edges, invalidated_edges = await resolve_extracted_edges(
|
||||
clients,
|
||||
[extracted_edge],
|
||||
episode,
|
||||
[source_node, target_node],
|
||||
edge_types,
|
||||
edge_type_map,
|
||||
)
|
||||
|
||||
assert resolved_edges[0].name == 'INTERACTED_WITH'
|
||||
assert invalidated_edges == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_extracted_edge_rejects_unmapped_fact_type(mock_llm_client):
|
||||
mock_llm_client.generate_response.return_value = {
|
||||
|
|
@ -283,9 +367,69 @@ async def test_resolve_extracted_edge_rejects_unmapped_fact_type(mock_llm_client
|
|||
[],
|
||||
episode,
|
||||
edge_type_candidates={},
|
||||
custom_edge_type_names={'OCCURRED_AT'},
|
||||
ensure_ascii=True,
|
||||
)
|
||||
|
||||
assert resolved_edge.name == DEFAULT_EDGE_NAME
|
||||
assert duplicates == []
|
||||
assert invalidated == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_extracted_edge_accepts_unknown_fact_type(mock_llm_client):
|
||||
mock_llm_client.generate_response.return_value = {
|
||||
'duplicate_facts': [],
|
||||
'contradicted_facts': [],
|
||||
'fact_type': 'INTERACTED_WITH',
|
||||
}
|
||||
|
||||
extracted_edge = EntityEdge(
|
||||
source_node_uuid='source_uuid',
|
||||
target_node_uuid='target_uuid',
|
||||
name='DEFAULT',
|
||||
group_id='group_1',
|
||||
fact='User interacted with topic',
|
||||
episodes=[],
|
||||
created_at=datetime.now(timezone.utc),
|
||||
valid_at=None,
|
||||
invalid_at=None,
|
||||
)
|
||||
|
||||
episode = EpisodicNode(
|
||||
uuid='episode_uuid',
|
||||
name='Episode',
|
||||
group_id='group_1',
|
||||
source='message',
|
||||
source_description='desc',
|
||||
content='Episode content',
|
||||
valid_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
related_edge = EntityEdge(
|
||||
source_node_uuid='source_uuid',
|
||||
target_node_uuid='target_uuid',
|
||||
name='DEFAULT',
|
||||
group_id='group_1',
|
||||
fact='User mentioned a topic',
|
||||
episodes=[],
|
||||
created_at=datetime.now(timezone.utc),
|
||||
valid_at=None,
|
||||
invalid_at=None,
|
||||
)
|
||||
|
||||
resolved_edge, duplicates, invalidated = await resolve_extracted_edge(
|
||||
mock_llm_client,
|
||||
extracted_edge,
|
||||
[related_edge],
|
||||
[],
|
||||
episode,
|
||||
edge_type_candidates={'OCCURRED_AT': OccurredAtEdge},
|
||||
custom_edge_type_names={'OCCURRED_AT'},
|
||||
ensure_ascii=True,
|
||||
)
|
||||
|
||||
assert resolved_edge.name == 'INTERACTED_WITH'
|
||||
assert resolved_edge.attributes == {}
|
||||
assert duplicates == []
|
||||
assert invalidated == []
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue