fix: Add edge type validation based on node labels (#948)

* fix: Add edge type validation based on node labels

- Add DEFAULT_EDGE_NAME constant for 'RELATES_TO'
- Implement pre-resolution validation to reset invalid edge names
- Add post-resolution validation for LLM-returned fact types
- Rename parameter from edge_types to edge_type_candidates for clarity
- Add comprehensive tests for validation scenarios

This ensures edges conform to edge_type_map constraints and prevents
misclassification when edge types don't match node label pairs.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* chore: Bump version to 0.30.0pre4

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

---------

Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
Daniel Chalef 2025-09-29 16:35:00 -07:00 committed by GitHub
parent ded2bad3f2
commit 3fcd587276
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 2260 additions and 2079 deletions

View file

@ -43,6 +43,8 @@ from graphiti_core.search.search_filters import SearchFilters
from graphiti_core.utils.datetime_utils import ensure_utc, utc_now from graphiti_core.utils.datetime_utils import ensure_utc, utc_now
from graphiti_core.utils.maintenance.dedup_helpers import _normalize_string_exact from graphiti_core.utils.maintenance.dedup_helpers import _normalize_string_exact
DEFAULT_EDGE_NAME = 'RELATES_TO'
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -310,6 +312,15 @@ async def resolve_extracted_edges(
edge_types_lst.append(extracted_edge_types) edge_types_lst.append(extracted_edge_types)
for extracted_edge, extracted_edge_types in zip(extracted_edges, edge_types_lst, strict=True):
allowed_type_names = set(extracted_edge_types)
if not allowed_type_names:
if extracted_edge.name != DEFAULT_EDGE_NAME:
extracted_edge.name = DEFAULT_EDGE_NAME
continue
if extracted_edge.name not in allowed_type_names:
extracted_edge.name = DEFAULT_EDGE_NAME
# resolve edges with related edges in the graph and find invalidation candidates # resolve edges with related edges in the graph and find invalidation candidates
results: list[tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]] = list( results: list[tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]] = list(
await semaphore_gather( await semaphore_gather(
@ -392,7 +403,7 @@ async def resolve_extracted_edge(
related_edges: list[EntityEdge], related_edges: list[EntityEdge],
existing_edges: list[EntityEdge], existing_edges: list[EntityEdge],
episode: EpisodicNode, episode: EpisodicNode,
edge_types: dict[str, type[BaseModel]] | None = None, edge_type_candidates: dict[str, type[BaseModel]] | None = None,
ensure_ascii: bool = True, ensure_ascii: bool = True,
) -> tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]: ) -> tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]:
if len(related_edges) == 0 and len(existing_edges) == 0: if len(related_edges) == 0 and len(existing_edges) == 0:
@ -429,9 +440,9 @@ async def resolve_extracted_edge(
'fact_type_name': type_name, 'fact_type_name': type_name,
'fact_type_description': type_model.__doc__, 'fact_type_description': type_model.__doc__,
} }
for i, (type_name, type_model) in enumerate(edge_types.items()) for i, (type_name, type_model) in enumerate(edge_type_candidates.items())
] ]
if edge_types is not None if edge_type_candidates is not None
else [] else []
) )
@ -468,7 +479,8 @@ async def resolve_extracted_edge(
] ]
fact_type: str = response_object.fact_type fact_type: str = response_object.fact_type
if fact_type.upper() != 'DEFAULT' and edge_types is not None: candidate_type_names = set(edge_type_candidates or {})
if candidate_type_names and fact_type in candidate_type_names:
resolved_edge.name = fact_type resolved_edge.name = fact_type
edge_attributes_context = { edge_attributes_context = {
@ -478,7 +490,7 @@ async def resolve_extracted_edge(
'ensure_ascii': ensure_ascii, 'ensure_ascii': ensure_ascii,
} }
edge_model = edge_types.get(fact_type) edge_model = edge_type_candidates.get(fact_type) if edge_type_candidates else None
if edge_model is not None and len(edge_model.model_fields) != 0: if edge_model is not None and len(edge_model.model_fields) != 0:
edge_attributes_response = await llm_client.generate_response( edge_attributes_response = await llm_client.generate_response(
prompt_library.extract_edges.extract_attributes(edge_attributes_context), prompt_library.extract_edges.extract_attributes(edge_attributes_context),
@ -487,6 +499,9 @@ async def resolve_extracted_edge(
) )
resolved_edge.attributes = edge_attributes_response resolved_edge.attributes = edge_attributes_response
elif fact_type.upper() != 'DEFAULT':
resolved_edge.name = DEFAULT_EDGE_NAME
resolved_edge.attributes = {}
end = time() end = time()
logger.debug( logger.debug(

View file

@ -1,7 +1,7 @@
[project] [project]
name = "graphiti-core" name = "graphiti-core"
description = "A temporal graph building library" description = "A temporal graph building library"
version = "0.30.0pre3" version = "0.30.0pre4"
authors = [ authors = [
{ name = "Paul Paliychuk", email = "paul@getzep.com" }, { name = "Paul Paliychuk", email = "paul@getzep.com" },
{ name = "Preston Rasmussen", email = "preston@getzep.com" }, { name = "Preston Rasmussen", email = "preston@getzep.com" },

View file

@ -1,16 +1,25 @@
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from unittest.mock import MagicMock from unittest.mock import AsyncMock, MagicMock
import pytest import pytest
from pydantic import BaseModel
from graphiti_core.edges import EntityEdge from graphiti_core.edges import EntityEdge
from graphiti_core.nodes import EpisodicNode from graphiti_core.graphiti_types import GraphitiClients
from graphiti_core.utils.maintenance.edge_operations import resolve_extracted_edge from graphiti_core.nodes import EntityNode, EpisodicNode
from graphiti_core.search.search_config import SearchResults
from graphiti_core.utils.maintenance.edge_operations import (
DEFAULT_EDGE_NAME,
resolve_extracted_edge,
resolve_extracted_edges,
)
@pytest.fixture @pytest.fixture
def mock_llm_client(): def mock_llm_client():
return MagicMock() client = MagicMock()
client.generate_response = AsyncMock()
return client
@pytest.fixture @pytest.fixture
@ -133,7 +142,7 @@ async def test_resolve_extracted_edge_exact_fact_short_circuit(
related_edges, related_edges,
mock_existing_edges, mock_existing_edges,
mock_current_episode, mock_current_episode,
edge_types=None, edge_type_candidates=None,
ensure_ascii=True, ensure_ascii=True,
) )
@ -142,3 +151,141 @@ async def test_resolve_extracted_edge_exact_fact_short_circuit(
assert duplicate_edges == [] assert duplicate_edges == []
assert invalidated == [] assert invalidated == []
mock_llm_client.generate_response.assert_not_called() mock_llm_client.generate_response.assert_not_called()
class OccurredAtEdge(BaseModel):
"""Edge model stub for OCCURRED_AT."""
@pytest.mark.asyncio
async def test_resolve_extracted_edges_resets_unmapped_names(monkeypatch):
from graphiti_core.utils.maintenance import edge_operations as edge_ops
monkeypatch.setattr(edge_ops, 'create_entity_edge_embeddings', AsyncMock(return_value=None))
monkeypatch.setattr(EntityEdge, 'get_between_nodes', AsyncMock(return_value=[]))
async def immediate_gather(*aws, max_coroutines=None):
return [await aw for aw in aws]
monkeypatch.setattr(edge_ops, 'semaphore_gather', immediate_gather)
monkeypatch.setattr(edge_ops, 'search', AsyncMock(return_value=SearchResults()))
llm_client = MagicMock()
llm_client.generate_response = AsyncMock(
return_value={'duplicate_facts': [], 'contradicted_facts': [], 'fact_type': 'DEFAULT'}
)
clients = GraphitiClients.model_construct(
driver=MagicMock(),
llm_client=llm_client,
embedder=MagicMock(),
cross_encoder=MagicMock(),
ensure_ascii=True,
)
source_node = EntityNode(
uuid='source_uuid',
name='Document Node',
group_id='group_1',
labels=['Document'],
)
target_node = EntityNode(
uuid='target_uuid',
name='Topic Node',
group_id='group_1',
labels=['Topic'],
)
extracted_edge = EntityEdge(
source_node_uuid=source_node.uuid,
target_node_uuid=target_node.uuid,
name='OCCURRED_AT',
group_id='group_1',
fact='Document occurred at somewhere',
episodes=[],
created_at=datetime.now(timezone.utc),
valid_at=None,
invalid_at=None,
)
episode = EpisodicNode(
uuid='episode_uuid',
name='Episode',
group_id='group_1',
source='message',
source_description='desc',
content='Episode content',
valid_at=datetime.now(timezone.utc),
)
edge_types = {'OCCURRED_AT': OccurredAtEdge}
edge_type_map = {('Event', 'Entity'): ['OCCURRED_AT']}
resolved_edges, invalidated_edges = await resolve_extracted_edges(
clients,
[extracted_edge],
episode,
[source_node, target_node],
edge_types,
edge_type_map,
)
assert resolved_edges[0].name == DEFAULT_EDGE_NAME
assert invalidated_edges == []
@pytest.mark.asyncio
async def test_resolve_extracted_edge_rejects_unmapped_fact_type(mock_llm_client):
mock_llm_client.generate_response.return_value = {
'duplicate_facts': [],
'contradicted_facts': [],
'fact_type': 'OCCURRED_AT',
}
extracted_edge = EntityEdge(
source_node_uuid='source_uuid',
target_node_uuid='target_uuid',
name='OCCURRED_AT',
group_id='group_1',
fact='Document occurred at somewhere',
episodes=[],
created_at=datetime.now(timezone.utc),
valid_at=None,
invalid_at=None,
)
episode = EpisodicNode(
uuid='episode_uuid',
name='Episode',
group_id='group_1',
source='message',
source_description='desc',
content='Episode content',
valid_at=datetime.now(timezone.utc),
)
related_edge = EntityEdge(
source_node_uuid='alt_source',
target_node_uuid='alt_target',
name='OTHER',
group_id='group_1',
fact='Different fact',
episodes=[],
created_at=datetime.now(timezone.utc),
valid_at=None,
invalid_at=None,
)
resolved_edge, duplicates, invalidated = await resolve_extracted_edge(
mock_llm_client,
extracted_edge,
[related_edge],
[],
episode,
edge_type_candidates={},
ensure_ascii=True,
)
assert resolved_edge.name == DEFAULT_EDGE_NAME
assert duplicates == []
assert invalidated == []

4155
uv.lock generated

File diff suppressed because it is too large Load diff