fix: Add edge type validation based on node labels (#948)

* fix: Add edge type validation based on node labels

- Add DEFAULT_EDGE_NAME constant for 'RELATES_TO'
- Implement pre-resolution validation to reset invalid edge names
- Add post-resolution validation for LLM-returned fact types
- Rename parameter from edge_types to edge_type_candidates for clarity
- Add comprehensive tests for validation scenarios

This ensures edges conform to edge_type_map constraints and prevents
misclassification when edge types don't match node label pairs.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* chore: Bump version to 0.30.0pre4

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

---------

Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
Daniel Chalef 2025-09-29 16:35:00 -07:00 committed by GitHub
parent ded2bad3f2
commit 3fcd587276
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 2260 additions and 2079 deletions

View file

@ -43,6 +43,8 @@ from graphiti_core.search.search_filters import SearchFilters
from graphiti_core.utils.datetime_utils import ensure_utc, utc_now
from graphiti_core.utils.maintenance.dedup_helpers import _normalize_string_exact
DEFAULT_EDGE_NAME = 'RELATES_TO'
logger = logging.getLogger(__name__)
@ -310,6 +312,15 @@ async def resolve_extracted_edges(
edge_types_lst.append(extracted_edge_types)
for extracted_edge, extracted_edge_types in zip(extracted_edges, edge_types_lst, strict=True):
allowed_type_names = set(extracted_edge_types)
if not allowed_type_names:
if extracted_edge.name != DEFAULT_EDGE_NAME:
extracted_edge.name = DEFAULT_EDGE_NAME
continue
if extracted_edge.name not in allowed_type_names:
extracted_edge.name = DEFAULT_EDGE_NAME
# resolve edges with related edges in the graph and find invalidation candidates
results: list[tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]] = list(
await semaphore_gather(
@ -392,7 +403,7 @@ async def resolve_extracted_edge(
related_edges: list[EntityEdge],
existing_edges: list[EntityEdge],
episode: EpisodicNode,
edge_types: dict[str, type[BaseModel]] | None = None,
edge_type_candidates: dict[str, type[BaseModel]] | None = None,
ensure_ascii: bool = True,
) -> tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]:
if len(related_edges) == 0 and len(existing_edges) == 0:
@ -429,9 +440,9 @@ async def resolve_extracted_edge(
'fact_type_name': type_name,
'fact_type_description': type_model.__doc__,
}
for i, (type_name, type_model) in enumerate(edge_types.items())
for i, (type_name, type_model) in enumerate(edge_type_candidates.items())
]
if edge_types is not None
if edge_type_candidates is not None
else []
)
@ -468,7 +479,8 @@ async def resolve_extracted_edge(
]
fact_type: str = response_object.fact_type
if fact_type.upper() != 'DEFAULT' and edge_types is not None:
candidate_type_names = set(edge_type_candidates or {})
if candidate_type_names and fact_type in candidate_type_names:
resolved_edge.name = fact_type
edge_attributes_context = {
@ -478,7 +490,7 @@ async def resolve_extracted_edge(
'ensure_ascii': ensure_ascii,
}
edge_model = edge_types.get(fact_type)
edge_model = edge_type_candidates.get(fact_type) if edge_type_candidates else None
if edge_model is not None and len(edge_model.model_fields) != 0:
edge_attributes_response = await llm_client.generate_response(
prompt_library.extract_edges.extract_attributes(edge_attributes_context),
@ -487,6 +499,9 @@ async def resolve_extracted_edge(
)
resolved_edge.attributes = edge_attributes_response
elif fact_type.upper() != 'DEFAULT':
resolved_edge.name = DEFAULT_EDGE_NAME
resolved_edge.attributes = {}
end = time()
logger.debug(

View file

@ -1,7 +1,7 @@
[project]
name = "graphiti-core"
description = "A temporal graph building library"
version = "0.30.0pre3"
version = "0.30.0pre4"
authors = [
{ name = "Paul Paliychuk", email = "paul@getzep.com" },
{ name = "Preston Rasmussen", email = "preston@getzep.com" },

View file

@ -1,16 +1,25 @@
from datetime import datetime, timedelta, timezone
from unittest.mock import MagicMock
from unittest.mock import AsyncMock, MagicMock
import pytest
from pydantic import BaseModel
from graphiti_core.edges import EntityEdge
from graphiti_core.nodes import EpisodicNode
from graphiti_core.utils.maintenance.edge_operations import resolve_extracted_edge
from graphiti_core.graphiti_types import GraphitiClients
from graphiti_core.nodes import EntityNode, EpisodicNode
from graphiti_core.search.search_config import SearchResults
from graphiti_core.utils.maintenance.edge_operations import (
DEFAULT_EDGE_NAME,
resolve_extracted_edge,
resolve_extracted_edges,
)
@pytest.fixture
def mock_llm_client():
return MagicMock()
client = MagicMock()
client.generate_response = AsyncMock()
return client
@pytest.fixture
@ -133,7 +142,7 @@ async def test_resolve_extracted_edge_exact_fact_short_circuit(
related_edges,
mock_existing_edges,
mock_current_episode,
edge_types=None,
edge_type_candidates=None,
ensure_ascii=True,
)
@ -142,3 +151,141 @@ async def test_resolve_extracted_edge_exact_fact_short_circuit(
assert duplicate_edges == []
assert invalidated == []
mock_llm_client.generate_response.assert_not_called()
class OccurredAtEdge(BaseModel):
"""Edge model stub for OCCURRED_AT."""
@pytest.mark.asyncio
async def test_resolve_extracted_edges_resets_unmapped_names(monkeypatch):
from graphiti_core.utils.maintenance import edge_operations as edge_ops
monkeypatch.setattr(edge_ops, 'create_entity_edge_embeddings', AsyncMock(return_value=None))
monkeypatch.setattr(EntityEdge, 'get_between_nodes', AsyncMock(return_value=[]))
async def immediate_gather(*aws, max_coroutines=None):
return [await aw for aw in aws]
monkeypatch.setattr(edge_ops, 'semaphore_gather', immediate_gather)
monkeypatch.setattr(edge_ops, 'search', AsyncMock(return_value=SearchResults()))
llm_client = MagicMock()
llm_client.generate_response = AsyncMock(
return_value={'duplicate_facts': [], 'contradicted_facts': [], 'fact_type': 'DEFAULT'}
)
clients = GraphitiClients.model_construct(
driver=MagicMock(),
llm_client=llm_client,
embedder=MagicMock(),
cross_encoder=MagicMock(),
ensure_ascii=True,
)
source_node = EntityNode(
uuid='source_uuid',
name='Document Node',
group_id='group_1',
labels=['Document'],
)
target_node = EntityNode(
uuid='target_uuid',
name='Topic Node',
group_id='group_1',
labels=['Topic'],
)
extracted_edge = EntityEdge(
source_node_uuid=source_node.uuid,
target_node_uuid=target_node.uuid,
name='OCCURRED_AT',
group_id='group_1',
fact='Document occurred at somewhere',
episodes=[],
created_at=datetime.now(timezone.utc),
valid_at=None,
invalid_at=None,
)
episode = EpisodicNode(
uuid='episode_uuid',
name='Episode',
group_id='group_1',
source='message',
source_description='desc',
content='Episode content',
valid_at=datetime.now(timezone.utc),
)
edge_types = {'OCCURRED_AT': OccurredAtEdge}
edge_type_map = {('Event', 'Entity'): ['OCCURRED_AT']}
resolved_edges, invalidated_edges = await resolve_extracted_edges(
clients,
[extracted_edge],
episode,
[source_node, target_node],
edge_types,
edge_type_map,
)
assert resolved_edges[0].name == DEFAULT_EDGE_NAME
assert invalidated_edges == []
@pytest.mark.asyncio
async def test_resolve_extracted_edge_rejects_unmapped_fact_type(mock_llm_client):
mock_llm_client.generate_response.return_value = {
'duplicate_facts': [],
'contradicted_facts': [],
'fact_type': 'OCCURRED_AT',
}
extracted_edge = EntityEdge(
source_node_uuid='source_uuid',
target_node_uuid='target_uuid',
name='OCCURRED_AT',
group_id='group_1',
fact='Document occurred at somewhere',
episodes=[],
created_at=datetime.now(timezone.utc),
valid_at=None,
invalid_at=None,
)
episode = EpisodicNode(
uuid='episode_uuid',
name='Episode',
group_id='group_1',
source='message',
source_description='desc',
content='Episode content',
valid_at=datetime.now(timezone.utc),
)
related_edge = EntityEdge(
source_node_uuid='alt_source',
target_node_uuid='alt_target',
name='OTHER',
group_id='group_1',
fact='Different fact',
episodes=[],
created_at=datetime.now(timezone.utc),
valid_at=None,
invalid_at=None,
)
resolved_edge, duplicates, invalidated = await resolve_extracted_edge(
mock_llm_client,
extracted_edge,
[related_edge],
[],
episode,
edge_type_candidates={},
ensure_ascii=True,
)
assert resolved_edge.name == DEFAULT_EDGE_NAME
assert duplicates == []
assert invalidated == []

4155
uv.lock generated

File diff suppressed because it is too large Load diff