Allow Edge extraction to keep discovered edge labels (#950)
* chore: Update dependencies and enhance edge resolution logic - Add new dependencies: boto3, opensearch-py, and langchain-aws to pyproject.toml. - Modify Graphiti class to handle additional parameters in edge resolution. - Improve edge type handling in deduplication logic by introducing custom edge type names. - Enhance tests for edge resolution to cover new scenarios and ensure correct behavior. This update improves the flexibility and functionality of edge operations while ensuring compatibility with new libraries. * refactor: Clean up test_edge_operations.py and format response returns - Remove unnecessary stubs for opensearchpy module. - Format return values in llm_client.generate_response for consistency. - Enhance readability by ensuring proper indentation and structure in test cases. This refactor improves the clarity and maintainability of the test suite for edge operations. * bump version to 0.30.0pre5 and enhance docstring for resolve_extracted_edge function - Update version in pyproject.toml to 0.30.0pre5. - Add detailed docstring to resolve_extracted_edge function in edge_operations.py, clarifying parameters and return values. This update improves documentation clarity for the edge resolution process.
This commit is contained in:
parent
3fcd587276
commit
f2c4c97362
5 changed files with 213 additions and 9 deletions
|
|
@ -1070,6 +1070,7 @@ class Graphiti:
|
|||
group_id=edge.group_id,
|
||||
),
|
||||
None,
|
||||
None,
|
||||
self.ensure_ascii,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -477,6 +477,7 @@ async def dedupe_edges_bulk(
|
|||
candidates,
|
||||
episode,
|
||||
edge_types,
|
||||
set(edge_types),
|
||||
clients.ensure_ascii,
|
||||
)
|
||||
for episode, edge, candidates in dedupe_tuples
|
||||
|
|
|
|||
|
|
@ -283,8 +283,12 @@ async def resolve_extracted_edges(
|
|||
# Build entity hash table
|
||||
uuid_entity_map: dict[str, EntityNode] = {entity.uuid: entity for entity in entities}
|
||||
|
||||
# Determine which edge types are relevant for each edge
|
||||
# Determine which edge types are relevant for each edge.
|
||||
# `edge_types_lst` stores the subset of custom edge definitions whose
|
||||
# node signature matches each extracted edge. Anything outside this subset
|
||||
# should only stay on the edge if it is a non-custom (LLM generated) label.
|
||||
edge_types_lst: list[dict[str, type[BaseModel]]] = []
|
||||
custom_type_names = set(edge_types or {})
|
||||
for extracted_edge in extracted_edges:
|
||||
source_node = uuid_entity_map.get(extracted_edge.source_node_uuid)
|
||||
target_node = uuid_entity_map.get(extracted_edge.target_node_uuid)
|
||||
|
|
@ -314,11 +318,16 @@ async def resolve_extracted_edges(
|
|||
|
||||
for extracted_edge, extracted_edge_types in zip(extracted_edges, edge_types_lst, strict=True):
|
||||
allowed_type_names = set(extracted_edge_types)
|
||||
is_custom_name = extracted_edge.name in custom_type_names
|
||||
if not allowed_type_names:
|
||||
if extracted_edge.name != DEFAULT_EDGE_NAME:
|
||||
# No custom types are valid for this node pairing. Keep LLM generated
|
||||
# labels, but flip disallowed custom names back to the default.
|
||||
if is_custom_name and extracted_edge.name != DEFAULT_EDGE_NAME:
|
||||
extracted_edge.name = DEFAULT_EDGE_NAME
|
||||
continue
|
||||
if extracted_edge.name not in allowed_type_names:
|
||||
if is_custom_name and extracted_edge.name not in allowed_type_names:
|
||||
# Custom name exists but it is not permitted for this source/target
|
||||
# signature, so fall back to the default edge label.
|
||||
extracted_edge.name = DEFAULT_EDGE_NAME
|
||||
|
||||
# resolve edges with related edges in the graph and find invalidation candidates
|
||||
|
|
@ -332,6 +341,7 @@ async def resolve_extracted_edges(
|
|||
existing_edges,
|
||||
episode,
|
||||
extracted_edge_types,
|
||||
custom_type_names,
|
||||
clients.ensure_ascii,
|
||||
)
|
||||
for extracted_edge, related_edges, existing_edges, extracted_edge_types in zip(
|
||||
|
|
@ -404,8 +414,37 @@ async def resolve_extracted_edge(
|
|||
existing_edges: list[EntityEdge],
|
||||
episode: EpisodicNode,
|
||||
edge_type_candidates: dict[str, type[BaseModel]] | None = None,
|
||||
custom_edge_type_names: set[str] | None = None,
|
||||
ensure_ascii: bool = True,
|
||||
) -> tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]:
|
||||
"""Resolve an extracted edge against existing graph context.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
llm_client : LLMClient
|
||||
Client used to invoke the LLM for deduplication and attribute extraction.
|
||||
extracted_edge : EntityEdge
|
||||
Newly extracted edge whose canonical representation is being resolved.
|
||||
related_edges : list[EntityEdge]
|
||||
Candidate edges with identical endpoints used for duplicate detection.
|
||||
existing_edges : list[EntityEdge]
|
||||
Broader set of edges evaluated for contradiction / invalidation.
|
||||
episode : EpisodicNode
|
||||
Episode providing content context when extracting edge attributes.
|
||||
edge_type_candidates : dict[str, type[BaseModel]] | None
|
||||
Custom edge types permitted for the current source/target signature.
|
||||
custom_edge_type_names : set[str] | None
|
||||
Full catalog of registered custom edge names. Used to distinguish
|
||||
between disallowed custom types (which fall back to the default label)
|
||||
and ad-hoc labels emitted by the LLM.
|
||||
ensure_ascii : bool
|
||||
Whether prompt payloads should coerce ASCII output.
|
||||
|
||||
Returns
|
||||
-------
|
||||
tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]
|
||||
The resolved edge, any duplicates, and edges to invalidate.
|
||||
"""
|
||||
if len(related_edges) == 0 and len(existing_edges) == 0:
|
||||
return extracted_edge, [], []
|
||||
|
||||
|
|
@ -480,7 +519,15 @@ async def resolve_extracted_edge(
|
|||
|
||||
fact_type: str = response_object.fact_type
|
||||
candidate_type_names = set(edge_type_candidates or {})
|
||||
if candidate_type_names and fact_type in candidate_type_names:
|
||||
custom_type_names = custom_edge_type_names or set()
|
||||
|
||||
is_default_type = fact_type.upper() == 'DEFAULT'
|
||||
is_custom_type = fact_type in custom_type_names
|
||||
is_allowed_custom_type = fact_type in candidate_type_names
|
||||
|
||||
if is_allowed_custom_type:
|
||||
# The LLM selected a custom type that is allowed for the node pair.
|
||||
# Adopt the custom type and, if needed, extract its structured attributes.
|
||||
resolved_edge.name = fact_type
|
||||
|
||||
edge_attributes_context = {
|
||||
|
|
@ -499,9 +546,16 @@ async def resolve_extracted_edge(
|
|||
)
|
||||
|
||||
resolved_edge.attributes = edge_attributes_response
|
||||
elif fact_type.upper() != 'DEFAULT':
|
||||
elif not is_default_type and is_custom_type:
|
||||
# The LLM picked a custom type that is not allowed for this signature.
|
||||
# Reset to the default label and drop any structured attributes.
|
||||
resolved_edge.name = DEFAULT_EDGE_NAME
|
||||
resolved_edge.attributes = {}
|
||||
elif not is_default_type:
|
||||
# Non-custom labels are allowed to pass through so long as the LLM does
|
||||
# not return the sentinel DEFAULT value.
|
||||
resolved_edge.name = fact_type
|
||||
resolved_edge.attributes = {}
|
||||
|
||||
end = time()
|
||||
logger.debug(
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
[project]
|
||||
name = "graphiti-core"
|
||||
description = "A temporal graph building library"
|
||||
version = "0.30.0pre4"
|
||||
version = "0.30.0pre5"
|
||||
authors = [
|
||||
{ name = "Paul Paliychuk", email = "paul@getzep.com" },
|
||||
{ name = "Preston Rasmussen", email = "preston@getzep.com" },
|
||||
|
|
@ -42,6 +42,9 @@ dev = [
|
|||
"google-genai>=1.8.0",
|
||||
"falkordb>=1.1.2,<2.0.0",
|
||||
"kuzu>=0.11.2",
|
||||
"boto3>=1.39.16",
|
||||
"opensearch-py>=3.0.0",
|
||||
"langchain-aws>=0.2.29",
|
||||
"ipykernel>=6.29.5",
|
||||
"jupyterlab>=4.2.4",
|
||||
"diskcache-stubs>=5.6.3.6.20240818",
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
from datetime import datetime, timedelta, timezone
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from graphiti_core.edges import EntityEdge
|
||||
from graphiti_core.graphiti_types import GraphitiClients
|
||||
from graphiti_core.nodes import EntityNode, EpisodicNode
|
||||
from graphiti_core.search.search_config import SearchResults
|
||||
from graphiti_core.utils.maintenance.edge_operations import (
|
||||
|
|
@ -172,10 +172,14 @@ async def test_resolve_extracted_edges_resets_unmapped_names(monkeypatch):
|
|||
|
||||
llm_client = MagicMock()
|
||||
llm_client.generate_response = AsyncMock(
|
||||
return_value={'duplicate_facts': [], 'contradicted_facts': [], 'fact_type': 'DEFAULT'}
|
||||
return_value={
|
||||
'duplicate_facts': [],
|
||||
'contradicted_facts': [],
|
||||
'fact_type': 'DEFAULT',
|
||||
}
|
||||
)
|
||||
|
||||
clients = GraphitiClients.model_construct(
|
||||
clients = SimpleNamespace(
|
||||
driver=MagicMock(),
|
||||
llm_client=llm_client,
|
||||
embedder=MagicMock(),
|
||||
|
|
@ -234,6 +238,87 @@ async def test_resolve_extracted_edges_resets_unmapped_names(monkeypatch):
|
|||
assert invalidated_edges == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_extracted_edges_keeps_unknown_names(monkeypatch):
|
||||
from graphiti_core.utils.maintenance import edge_operations as edge_ops
|
||||
|
||||
monkeypatch.setattr(edge_ops, 'create_entity_edge_embeddings', AsyncMock(return_value=None))
|
||||
monkeypatch.setattr(EntityEdge, 'get_between_nodes', AsyncMock(return_value=[]))
|
||||
|
||||
async def immediate_gather(*aws, max_coroutines=None):
|
||||
return [await aw for aw in aws]
|
||||
|
||||
monkeypatch.setattr(edge_ops, 'semaphore_gather', immediate_gather)
|
||||
monkeypatch.setattr(edge_ops, 'search', AsyncMock(return_value=SearchResults()))
|
||||
|
||||
llm_client = MagicMock()
|
||||
llm_client.generate_response = AsyncMock(
|
||||
return_value={
|
||||
'duplicate_facts': [],
|
||||
'contradicted_facts': [],
|
||||
'fact_type': 'DEFAULT',
|
||||
}
|
||||
)
|
||||
|
||||
clients = SimpleNamespace(
|
||||
driver=MagicMock(),
|
||||
llm_client=llm_client,
|
||||
embedder=MagicMock(),
|
||||
cross_encoder=MagicMock(),
|
||||
ensure_ascii=True,
|
||||
)
|
||||
|
||||
source_node = EntityNode(
|
||||
uuid='source_uuid',
|
||||
name='User Node',
|
||||
group_id='group_1',
|
||||
labels=['User'],
|
||||
)
|
||||
target_node = EntityNode(
|
||||
uuid='target_uuid',
|
||||
name='Topic Node',
|
||||
group_id='group_1',
|
||||
labels=['Topic'],
|
||||
)
|
||||
|
||||
extracted_edge = EntityEdge(
|
||||
source_node_uuid=source_node.uuid,
|
||||
target_node_uuid=target_node.uuid,
|
||||
name='INTERACTED_WITH',
|
||||
group_id='group_1',
|
||||
fact='User interacted with topic',
|
||||
episodes=[],
|
||||
created_at=datetime.now(timezone.utc),
|
||||
valid_at=None,
|
||||
invalid_at=None,
|
||||
)
|
||||
|
||||
episode = EpisodicNode(
|
||||
uuid='episode_uuid',
|
||||
name='Episode',
|
||||
group_id='group_1',
|
||||
source='message',
|
||||
source_description='desc',
|
||||
content='Episode content',
|
||||
valid_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
edge_types = {'OCCURRED_AT': OccurredAtEdge}
|
||||
edge_type_map = {('Event', 'Entity'): ['OCCURRED_AT']}
|
||||
|
||||
resolved_edges, invalidated_edges = await resolve_extracted_edges(
|
||||
clients,
|
||||
[extracted_edge],
|
||||
episode,
|
||||
[source_node, target_node],
|
||||
edge_types,
|
||||
edge_type_map,
|
||||
)
|
||||
|
||||
assert resolved_edges[0].name == 'INTERACTED_WITH'
|
||||
assert invalidated_edges == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_extracted_edge_rejects_unmapped_fact_type(mock_llm_client):
|
||||
mock_llm_client.generate_response.return_value = {
|
||||
|
|
@ -283,9 +368,69 @@ async def test_resolve_extracted_edge_rejects_unmapped_fact_type(mock_llm_client
|
|||
[],
|
||||
episode,
|
||||
edge_type_candidates={},
|
||||
custom_edge_type_names={'OCCURRED_AT'},
|
||||
ensure_ascii=True,
|
||||
)
|
||||
|
||||
assert resolved_edge.name == DEFAULT_EDGE_NAME
|
||||
assert duplicates == []
|
||||
assert invalidated == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_extracted_edge_accepts_unknown_fact_type(mock_llm_client):
|
||||
mock_llm_client.generate_response.return_value = {
|
||||
'duplicate_facts': [],
|
||||
'contradicted_facts': [],
|
||||
'fact_type': 'INTERACTED_WITH',
|
||||
}
|
||||
|
||||
extracted_edge = EntityEdge(
|
||||
source_node_uuid='source_uuid',
|
||||
target_node_uuid='target_uuid',
|
||||
name='DEFAULT',
|
||||
group_id='group_1',
|
||||
fact='User interacted with topic',
|
||||
episodes=[],
|
||||
created_at=datetime.now(timezone.utc),
|
||||
valid_at=None,
|
||||
invalid_at=None,
|
||||
)
|
||||
|
||||
episode = EpisodicNode(
|
||||
uuid='episode_uuid',
|
||||
name='Episode',
|
||||
group_id='group_1',
|
||||
source='message',
|
||||
source_description='desc',
|
||||
content='Episode content',
|
||||
valid_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
related_edge = EntityEdge(
|
||||
source_node_uuid='source_uuid',
|
||||
target_node_uuid='target_uuid',
|
||||
name='DEFAULT',
|
||||
group_id='group_1',
|
||||
fact='User mentioned a topic',
|
||||
episodes=[],
|
||||
created_at=datetime.now(timezone.utc),
|
||||
valid_at=None,
|
||||
invalid_at=None,
|
||||
)
|
||||
|
||||
resolved_edge, duplicates, invalidated = await resolve_extracted_edge(
|
||||
mock_llm_client,
|
||||
extracted_edge,
|
||||
[related_edge],
|
||||
[],
|
||||
episode,
|
||||
edge_type_candidates={'OCCURRED_AT': OccurredAtEdge},
|
||||
custom_edge_type_names={'OCCURRED_AT'},
|
||||
ensure_ascii=True,
|
||||
)
|
||||
|
||||
assert resolved_edge.name == 'INTERACTED_WITH'
|
||||
assert resolved_edge.attributes == {}
|
||||
assert duplicates == []
|
||||
assert invalidated == []
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue