Allow Edge extraction to keep discovered edge labels (#950)

* chore: Update dependencies and enhance edge resolution logic

- Add new dependencies: boto3, opensearch-py, and langchain-aws to pyproject.toml.
- Modify Graphiti class to handle additional parameters in edge resolution.
- Improve edge type handling in deduplication logic by introducing custom edge type names.
- Enhance tests for edge resolution to cover new scenarios and ensure correct behavior.

This update improves the flexibility and functionality of edge operations while ensuring compatibility with new libraries.

* refactor: Clean up test_edge_operations.py and format response returns

- Remove unnecessary stubs for opensearchpy module.
- Format return values in llm_client.generate_response for consistency.
- Enhance readability by ensuring proper indentation and structure in test cases.

This refactor improves the clarity and maintainability of the test suite for edge operations.

* bump version to 0.30.0pre5 and enhance docstring for resolve_extracted_edge function

- Update version in pyproject.toml to 0.30.0pre5.
- Add detailed docstring to resolve_extracted_edge function in edge_operations.py, clarifying parameters and return values.

This update improves documentation clarity for the edge resolution process.
This commit is contained in:
Daniel Chalef 2025-09-29 21:32:47 -07:00 committed by GitHub
parent 3fcd587276
commit f2c4c97362
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 213 additions and 9 deletions

View file

@ -1070,6 +1070,7 @@ class Graphiti:
group_id=edge.group_id,
),
None,
None,
self.ensure_ascii,
)

View file

@ -477,6 +477,7 @@ async def dedupe_edges_bulk(
candidates,
episode,
edge_types,
set(edge_types),
clients.ensure_ascii,
)
for episode, edge, candidates in dedupe_tuples

View file

@ -283,8 +283,12 @@ async def resolve_extracted_edges(
# Build entity hash table
uuid_entity_map: dict[str, EntityNode] = {entity.uuid: entity for entity in entities}
# Determine which edge types are relevant for each edge
# Determine which edge types are relevant for each edge.
# `edge_types_lst` stores the subset of custom edge definitions whose
# node signature matches each extracted edge. Anything outside this subset
# should only stay on the edge if it is a non-custom (LLM generated) label.
edge_types_lst: list[dict[str, type[BaseModel]]] = []
custom_type_names = set(edge_types or {})
for extracted_edge in extracted_edges:
source_node = uuid_entity_map.get(extracted_edge.source_node_uuid)
target_node = uuid_entity_map.get(extracted_edge.target_node_uuid)
@ -314,11 +318,16 @@ async def resolve_extracted_edges(
for extracted_edge, extracted_edge_types in zip(extracted_edges, edge_types_lst, strict=True):
allowed_type_names = set(extracted_edge_types)
is_custom_name = extracted_edge.name in custom_type_names
if not allowed_type_names:
if extracted_edge.name != DEFAULT_EDGE_NAME:
# No custom types are valid for this node pairing. Keep LLM generated
# labels, but flip disallowed custom names back to the default.
if is_custom_name and extracted_edge.name != DEFAULT_EDGE_NAME:
extracted_edge.name = DEFAULT_EDGE_NAME
continue
if extracted_edge.name not in allowed_type_names:
if is_custom_name and extracted_edge.name not in allowed_type_names:
# Custom name exists but it is not permitted for this source/target
# signature, so fall back to the default edge label.
extracted_edge.name = DEFAULT_EDGE_NAME
# resolve edges with related edges in the graph and find invalidation candidates
@ -332,6 +341,7 @@ async def resolve_extracted_edges(
existing_edges,
episode,
extracted_edge_types,
custom_type_names,
clients.ensure_ascii,
)
for extracted_edge, related_edges, existing_edges, extracted_edge_types in zip(
@ -404,8 +414,37 @@ async def resolve_extracted_edge(
existing_edges: list[EntityEdge],
episode: EpisodicNode,
edge_type_candidates: dict[str, type[BaseModel]] | None = None,
custom_edge_type_names: set[str] | None = None,
ensure_ascii: bool = True,
) -> tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]:
"""Resolve an extracted edge against existing graph context.
Parameters
----------
llm_client : LLMClient
Client used to invoke the LLM for deduplication and attribute extraction.
extracted_edge : EntityEdge
Newly extracted edge whose canonical representation is being resolved.
related_edges : list[EntityEdge]
Candidate edges with identical endpoints used for duplicate detection.
existing_edges : list[EntityEdge]
Broader set of edges evaluated for contradiction / invalidation.
episode : EpisodicNode
Episode providing content context when extracting edge attributes.
edge_type_candidates : dict[str, type[BaseModel]] | None
Custom edge types permitted for the current source/target signature.
custom_edge_type_names : set[str] | None
Full catalog of registered custom edge names. Used to distinguish
between disallowed custom types (which fall back to the default label)
and ad-hoc labels emitted by the LLM.
ensure_ascii : bool
Whether prompt payloads should coerce ASCII output.
Returns
-------
tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]
The resolved edge, any duplicates, and edges to invalidate.
"""
if len(related_edges) == 0 and len(existing_edges) == 0:
return extracted_edge, [], []
@ -480,7 +519,15 @@ async def resolve_extracted_edge(
fact_type: str = response_object.fact_type
candidate_type_names = set(edge_type_candidates or {})
if candidate_type_names and fact_type in candidate_type_names:
custom_type_names = custom_edge_type_names or set()
is_default_type = fact_type.upper() == 'DEFAULT'
is_custom_type = fact_type in custom_type_names
is_allowed_custom_type = fact_type in candidate_type_names
if is_allowed_custom_type:
# The LLM selected a custom type that is allowed for the node pair.
# Adopt the custom type and, if needed, extract its structured attributes.
resolved_edge.name = fact_type
edge_attributes_context = {
@ -499,9 +546,16 @@ async def resolve_extracted_edge(
)
resolved_edge.attributes = edge_attributes_response
elif fact_type.upper() != 'DEFAULT':
elif not is_default_type and is_custom_type:
# The LLM picked a custom type that is not allowed for this signature.
# Reset to the default label and drop any structured attributes.
resolved_edge.name = DEFAULT_EDGE_NAME
resolved_edge.attributes = {}
elif not is_default_type:
# Non-custom labels are allowed to pass through so long as the LLM does
# not return the sentinel DEFAULT value.
resolved_edge.name = fact_type
resolved_edge.attributes = {}
end = time()
logger.debug(

View file

@ -1,7 +1,7 @@
[project]
name = "graphiti-core"
description = "A temporal graph building library"
version = "0.30.0pre4"
version = "0.30.0pre5"
authors = [
{ name = "Paul Paliychuk", email = "paul@getzep.com" },
{ name = "Preston Rasmussen", email = "preston@getzep.com" },
@ -42,6 +42,9 @@ dev = [
"google-genai>=1.8.0",
"falkordb>=1.1.2,<2.0.0",
"kuzu>=0.11.2",
"boto3>=1.39.16",
"opensearch-py>=3.0.0",
"langchain-aws>=0.2.29",
"ipykernel>=6.29.5",
"jupyterlab>=4.2.4",
"diskcache-stubs>=5.6.3.6.20240818",

View file

@ -1,11 +1,11 @@
from datetime import datetime, timedelta, timezone
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock
import pytest
from pydantic import BaseModel
from graphiti_core.edges import EntityEdge
from graphiti_core.graphiti_types import GraphitiClients
from graphiti_core.nodes import EntityNode, EpisodicNode
from graphiti_core.search.search_config import SearchResults
from graphiti_core.utils.maintenance.edge_operations import (
@ -172,10 +172,14 @@ async def test_resolve_extracted_edges_resets_unmapped_names(monkeypatch):
llm_client = MagicMock()
llm_client.generate_response = AsyncMock(
return_value={'duplicate_facts': [], 'contradicted_facts': [], 'fact_type': 'DEFAULT'}
return_value={
'duplicate_facts': [],
'contradicted_facts': [],
'fact_type': 'DEFAULT',
}
)
clients = GraphitiClients.model_construct(
clients = SimpleNamespace(
driver=MagicMock(),
llm_client=llm_client,
embedder=MagicMock(),
@ -234,6 +238,87 @@ async def test_resolve_extracted_edges_resets_unmapped_names(monkeypatch):
assert invalidated_edges == []
@pytest.mark.asyncio
async def test_resolve_extracted_edges_keeps_unknown_names(monkeypatch):
from graphiti_core.utils.maintenance import edge_operations as edge_ops
monkeypatch.setattr(edge_ops, 'create_entity_edge_embeddings', AsyncMock(return_value=None))
monkeypatch.setattr(EntityEdge, 'get_between_nodes', AsyncMock(return_value=[]))
async def immediate_gather(*aws, max_coroutines=None):
return [await aw for aw in aws]
monkeypatch.setattr(edge_ops, 'semaphore_gather', immediate_gather)
monkeypatch.setattr(edge_ops, 'search', AsyncMock(return_value=SearchResults()))
llm_client = MagicMock()
llm_client.generate_response = AsyncMock(
return_value={
'duplicate_facts': [],
'contradicted_facts': [],
'fact_type': 'DEFAULT',
}
)
clients = SimpleNamespace(
driver=MagicMock(),
llm_client=llm_client,
embedder=MagicMock(),
cross_encoder=MagicMock(),
ensure_ascii=True,
)
source_node = EntityNode(
uuid='source_uuid',
name='User Node',
group_id='group_1',
labels=['User'],
)
target_node = EntityNode(
uuid='target_uuid',
name='Topic Node',
group_id='group_1',
labels=['Topic'],
)
extracted_edge = EntityEdge(
source_node_uuid=source_node.uuid,
target_node_uuid=target_node.uuid,
name='INTERACTED_WITH',
group_id='group_1',
fact='User interacted with topic',
episodes=[],
created_at=datetime.now(timezone.utc),
valid_at=None,
invalid_at=None,
)
episode = EpisodicNode(
uuid='episode_uuid',
name='Episode',
group_id='group_1',
source='message',
source_description='desc',
content='Episode content',
valid_at=datetime.now(timezone.utc),
)
edge_types = {'OCCURRED_AT': OccurredAtEdge}
edge_type_map = {('Event', 'Entity'): ['OCCURRED_AT']}
resolved_edges, invalidated_edges = await resolve_extracted_edges(
clients,
[extracted_edge],
episode,
[source_node, target_node],
edge_types,
edge_type_map,
)
assert resolved_edges[0].name == 'INTERACTED_WITH'
assert invalidated_edges == []
@pytest.mark.asyncio
async def test_resolve_extracted_edge_rejects_unmapped_fact_type(mock_llm_client):
mock_llm_client.generate_response.return_value = {
@ -283,9 +368,69 @@ async def test_resolve_extracted_edge_rejects_unmapped_fact_type(mock_llm_client
[],
episode,
edge_type_candidates={},
custom_edge_type_names={'OCCURRED_AT'},
ensure_ascii=True,
)
assert resolved_edge.name == DEFAULT_EDGE_NAME
assert duplicates == []
assert invalidated == []
@pytest.mark.asyncio
async def test_resolve_extracted_edge_accepts_unknown_fact_type(mock_llm_client):
mock_llm_client.generate_response.return_value = {
'duplicate_facts': [],
'contradicted_facts': [],
'fact_type': 'INTERACTED_WITH',
}
extracted_edge = EntityEdge(
source_node_uuid='source_uuid',
target_node_uuid='target_uuid',
name='DEFAULT',
group_id='group_1',
fact='User interacted with topic',
episodes=[],
created_at=datetime.now(timezone.utc),
valid_at=None,
invalid_at=None,
)
episode = EpisodicNode(
uuid='episode_uuid',
name='Episode',
group_id='group_1',
source='message',
source_description='desc',
content='Episode content',
valid_at=datetime.now(timezone.utc),
)
related_edge = EntityEdge(
source_node_uuid='source_uuid',
target_node_uuid='target_uuid',
name='DEFAULT',
group_id='group_1',
fact='User mentioned a topic',
episodes=[],
created_at=datetime.now(timezone.utc),
valid_at=None,
invalid_at=None,
)
resolved_edge, duplicates, invalidated = await resolve_extracted_edge(
mock_llm_client,
extracted_edge,
[related_edge],
[],
episode,
edge_type_candidates={'OCCURRED_AT': OccurredAtEdge},
custom_edge_type_names={'OCCURRED_AT'},
ensure_ascii=True,
)
assert resolved_edge.name == 'INTERACTED_WITH'
assert resolved_edge.attributes == {}
assert duplicates == []
assert invalidated == []