fix: Add edge type validation based on node labels (#948)
* fix: Add edge type validation based on node labels - Add DEFAULT_EDGE_NAME constant for 'RELATES_TO' - Implement pre-resolution validation to reset invalid edge names - Add post-resolution validation for LLM-returned fact types - Rename parameter from edge_types to edge_type_candidates for clarity - Add comprehensive tests for validation scenarios This ensures edges conform to edge_type_map constraints and prevents misclassification when edge types don't match node label pairs. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * chore: Bump version to 0.30.0pre4 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> --------- Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
parent
ded2bad3f2
commit
3fcd587276
4 changed files with 2260 additions and 2079 deletions
|
|
@ -43,6 +43,8 @@ from graphiti_core.search.search_filters import SearchFilters
|
|||
from graphiti_core.utils.datetime_utils import ensure_utc, utc_now
|
||||
from graphiti_core.utils.maintenance.dedup_helpers import _normalize_string_exact
|
||||
|
||||
DEFAULT_EDGE_NAME = 'RELATES_TO'
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
@ -310,6 +312,15 @@ async def resolve_extracted_edges(
|
|||
|
||||
edge_types_lst.append(extracted_edge_types)
|
||||
|
||||
for extracted_edge, extracted_edge_types in zip(extracted_edges, edge_types_lst, strict=True):
|
||||
allowed_type_names = set(extracted_edge_types)
|
||||
if not allowed_type_names:
|
||||
if extracted_edge.name != DEFAULT_EDGE_NAME:
|
||||
extracted_edge.name = DEFAULT_EDGE_NAME
|
||||
continue
|
||||
if extracted_edge.name not in allowed_type_names:
|
||||
extracted_edge.name = DEFAULT_EDGE_NAME
|
||||
|
||||
# resolve edges with related edges in the graph and find invalidation candidates
|
||||
results: list[tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]] = list(
|
||||
await semaphore_gather(
|
||||
|
|
@ -392,7 +403,7 @@ async def resolve_extracted_edge(
|
|||
related_edges: list[EntityEdge],
|
||||
existing_edges: list[EntityEdge],
|
||||
episode: EpisodicNode,
|
||||
edge_types: dict[str, type[BaseModel]] | None = None,
|
||||
edge_type_candidates: dict[str, type[BaseModel]] | None = None,
|
||||
ensure_ascii: bool = True,
|
||||
) -> tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]:
|
||||
if len(related_edges) == 0 and len(existing_edges) == 0:
|
||||
|
|
@ -429,9 +440,9 @@ async def resolve_extracted_edge(
|
|||
'fact_type_name': type_name,
|
||||
'fact_type_description': type_model.__doc__,
|
||||
}
|
||||
for i, (type_name, type_model) in enumerate(edge_types.items())
|
||||
for i, (type_name, type_model) in enumerate(edge_type_candidates.items())
|
||||
]
|
||||
if edge_types is not None
|
||||
if edge_type_candidates is not None
|
||||
else []
|
||||
)
|
||||
|
||||
|
|
@ -468,7 +479,8 @@ async def resolve_extracted_edge(
|
|||
]
|
||||
|
||||
fact_type: str = response_object.fact_type
|
||||
if fact_type.upper() != 'DEFAULT' and edge_types is not None:
|
||||
candidate_type_names = set(edge_type_candidates or {})
|
||||
if candidate_type_names and fact_type in candidate_type_names:
|
||||
resolved_edge.name = fact_type
|
||||
|
||||
edge_attributes_context = {
|
||||
|
|
@ -478,7 +490,7 @@ async def resolve_extracted_edge(
|
|||
'ensure_ascii': ensure_ascii,
|
||||
}
|
||||
|
||||
edge_model = edge_types.get(fact_type)
|
||||
edge_model = edge_type_candidates.get(fact_type) if edge_type_candidates else None
|
||||
if edge_model is not None and len(edge_model.model_fields) != 0:
|
||||
edge_attributes_response = await llm_client.generate_response(
|
||||
prompt_library.extract_edges.extract_attributes(edge_attributes_context),
|
||||
|
|
@ -487,6 +499,9 @@ async def resolve_extracted_edge(
|
|||
)
|
||||
|
||||
resolved_edge.attributes = edge_attributes_response
|
||||
elif fact_type.upper() != 'DEFAULT':
|
||||
resolved_edge.name = DEFAULT_EDGE_NAME
|
||||
resolved_edge.attributes = {}
|
||||
|
||||
end = time()
|
||||
logger.debug(
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
[project]
|
||||
name = "graphiti-core"
|
||||
description = "A temporal graph building library"
|
||||
version = "0.30.0pre3"
|
||||
version = "0.30.0pre4"
|
||||
authors = [
|
||||
{ name = "Paul Paliychuk", email = "paul@getzep.com" },
|
||||
{ name = "Preston Rasmussen", email = "preston@getzep.com" },
|
||||
|
|
|
|||
|
|
@ -1,16 +1,25 @@
|
|||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from graphiti_core.edges import EntityEdge
|
||||
from graphiti_core.nodes import EpisodicNode
|
||||
from graphiti_core.utils.maintenance.edge_operations import resolve_extracted_edge
|
||||
from graphiti_core.graphiti_types import GraphitiClients
|
||||
from graphiti_core.nodes import EntityNode, EpisodicNode
|
||||
from graphiti_core.search.search_config import SearchResults
|
||||
from graphiti_core.utils.maintenance.edge_operations import (
|
||||
DEFAULT_EDGE_NAME,
|
||||
resolve_extracted_edge,
|
||||
resolve_extracted_edges,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm_client():
|
||||
return MagicMock()
|
||||
client = MagicMock()
|
||||
client.generate_response = AsyncMock()
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -133,7 +142,7 @@ async def test_resolve_extracted_edge_exact_fact_short_circuit(
|
|||
related_edges,
|
||||
mock_existing_edges,
|
||||
mock_current_episode,
|
||||
edge_types=None,
|
||||
edge_type_candidates=None,
|
||||
ensure_ascii=True,
|
||||
)
|
||||
|
||||
|
|
@ -142,3 +151,141 @@ async def test_resolve_extracted_edge_exact_fact_short_circuit(
|
|||
assert duplicate_edges == []
|
||||
assert invalidated == []
|
||||
mock_llm_client.generate_response.assert_not_called()
|
||||
|
||||
|
||||
class OccurredAtEdge(BaseModel):
|
||||
"""Edge model stub for OCCURRED_AT."""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_extracted_edges_resets_unmapped_names(monkeypatch):
|
||||
from graphiti_core.utils.maintenance import edge_operations as edge_ops
|
||||
|
||||
monkeypatch.setattr(edge_ops, 'create_entity_edge_embeddings', AsyncMock(return_value=None))
|
||||
monkeypatch.setattr(EntityEdge, 'get_between_nodes', AsyncMock(return_value=[]))
|
||||
|
||||
async def immediate_gather(*aws, max_coroutines=None):
|
||||
return [await aw for aw in aws]
|
||||
|
||||
monkeypatch.setattr(edge_ops, 'semaphore_gather', immediate_gather)
|
||||
monkeypatch.setattr(edge_ops, 'search', AsyncMock(return_value=SearchResults()))
|
||||
|
||||
llm_client = MagicMock()
|
||||
llm_client.generate_response = AsyncMock(
|
||||
return_value={'duplicate_facts': [], 'contradicted_facts': [], 'fact_type': 'DEFAULT'}
|
||||
)
|
||||
|
||||
clients = GraphitiClients.model_construct(
|
||||
driver=MagicMock(),
|
||||
llm_client=llm_client,
|
||||
embedder=MagicMock(),
|
||||
cross_encoder=MagicMock(),
|
||||
ensure_ascii=True,
|
||||
)
|
||||
|
||||
source_node = EntityNode(
|
||||
uuid='source_uuid',
|
||||
name='Document Node',
|
||||
group_id='group_1',
|
||||
labels=['Document'],
|
||||
)
|
||||
target_node = EntityNode(
|
||||
uuid='target_uuid',
|
||||
name='Topic Node',
|
||||
group_id='group_1',
|
||||
labels=['Topic'],
|
||||
)
|
||||
|
||||
extracted_edge = EntityEdge(
|
||||
source_node_uuid=source_node.uuid,
|
||||
target_node_uuid=target_node.uuid,
|
||||
name='OCCURRED_AT',
|
||||
group_id='group_1',
|
||||
fact='Document occurred at somewhere',
|
||||
episodes=[],
|
||||
created_at=datetime.now(timezone.utc),
|
||||
valid_at=None,
|
||||
invalid_at=None,
|
||||
)
|
||||
|
||||
episode = EpisodicNode(
|
||||
uuid='episode_uuid',
|
||||
name='Episode',
|
||||
group_id='group_1',
|
||||
source='message',
|
||||
source_description='desc',
|
||||
content='Episode content',
|
||||
valid_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
edge_types = {'OCCURRED_AT': OccurredAtEdge}
|
||||
edge_type_map = {('Event', 'Entity'): ['OCCURRED_AT']}
|
||||
|
||||
resolved_edges, invalidated_edges = await resolve_extracted_edges(
|
||||
clients,
|
||||
[extracted_edge],
|
||||
episode,
|
||||
[source_node, target_node],
|
||||
edge_types,
|
||||
edge_type_map,
|
||||
)
|
||||
|
||||
assert resolved_edges[0].name == DEFAULT_EDGE_NAME
|
||||
assert invalidated_edges == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_extracted_edge_rejects_unmapped_fact_type(mock_llm_client):
|
||||
mock_llm_client.generate_response.return_value = {
|
||||
'duplicate_facts': [],
|
||||
'contradicted_facts': [],
|
||||
'fact_type': 'OCCURRED_AT',
|
||||
}
|
||||
|
||||
extracted_edge = EntityEdge(
|
||||
source_node_uuid='source_uuid',
|
||||
target_node_uuid='target_uuid',
|
||||
name='OCCURRED_AT',
|
||||
group_id='group_1',
|
||||
fact='Document occurred at somewhere',
|
||||
episodes=[],
|
||||
created_at=datetime.now(timezone.utc),
|
||||
valid_at=None,
|
||||
invalid_at=None,
|
||||
)
|
||||
|
||||
episode = EpisodicNode(
|
||||
uuid='episode_uuid',
|
||||
name='Episode',
|
||||
group_id='group_1',
|
||||
source='message',
|
||||
source_description='desc',
|
||||
content='Episode content',
|
||||
valid_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
related_edge = EntityEdge(
|
||||
source_node_uuid='alt_source',
|
||||
target_node_uuid='alt_target',
|
||||
name='OTHER',
|
||||
group_id='group_1',
|
||||
fact='Different fact',
|
||||
episodes=[],
|
||||
created_at=datetime.now(timezone.utc),
|
||||
valid_at=None,
|
||||
invalid_at=None,
|
||||
)
|
||||
|
||||
resolved_edge, duplicates, invalidated = await resolve_extracted_edge(
|
||||
mock_llm_client,
|
||||
extracted_edge,
|
||||
[related_edge],
|
||||
[],
|
||||
episode,
|
||||
edge_type_candidates={},
|
||||
ensure_ascii=True,
|
||||
)
|
||||
|
||||
assert resolved_edge.name == DEFAULT_EDGE_NAME
|
||||
assert duplicates == []
|
||||
assert invalidated == []
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue