Allow Edge extraction to keep discovered edge labels (#950)
* chore: Update dependencies and enhance edge resolution logic - Add new dependencies: boto3, opensearch-py, and langchain-aws to pyproject.toml. - Modify Graphiti class to handle additional parameters in edge resolution. - Improve edge type handling in deduplication logic by introducing custom edge type names. - Enhance tests for edge resolution to cover new scenarios and ensure correct behavior. This update improves the flexibility and functionality of edge operations while ensuring compatibility with new libraries. * refactor: Clean up test_edge_operations.py and format response returns - Remove unnecessary stubs for opensearchpy module. - Format return values in llm_client.generate_response for consistency. - Enhance readability by ensuring proper indentation and structure in test cases. This refactor improves the clarity and maintainability of the test suite for edge operations. * bump version to 0.30.0pre5 and enhance docstring for resolve_extracted_edge function - Update version in pyproject.toml to 0.30.0pre5. - Add detailed docstring to resolve_extracted_edge function in edge_operations.py, clarifying parameters and return values. This update improves documentation clarity for the edge resolution process.
This commit is contained in:
parent
3fcd587276
commit
f2c4c97362
5 changed files with 213 additions and 9 deletions
|
|
@ -1070,6 +1070,7 @@ class Graphiti:
|
||||||
group_id=edge.group_id,
|
group_id=edge.group_id,
|
||||||
),
|
),
|
||||||
None,
|
None,
|
||||||
|
None,
|
||||||
self.ensure_ascii,
|
self.ensure_ascii,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -477,6 +477,7 @@ async def dedupe_edges_bulk(
|
||||||
candidates,
|
candidates,
|
||||||
episode,
|
episode,
|
||||||
edge_types,
|
edge_types,
|
||||||
|
set(edge_types),
|
||||||
clients.ensure_ascii,
|
clients.ensure_ascii,
|
||||||
)
|
)
|
||||||
for episode, edge, candidates in dedupe_tuples
|
for episode, edge, candidates in dedupe_tuples
|
||||||
|
|
|
||||||
|
|
@ -283,8 +283,12 @@ async def resolve_extracted_edges(
|
||||||
# Build entity hash table
|
# Build entity hash table
|
||||||
uuid_entity_map: dict[str, EntityNode] = {entity.uuid: entity for entity in entities}
|
uuid_entity_map: dict[str, EntityNode] = {entity.uuid: entity for entity in entities}
|
||||||
|
|
||||||
# Determine which edge types are relevant for each edge
|
# Determine which edge types are relevant for each edge.
|
||||||
|
# `edge_types_lst` stores the subset of custom edge definitions whose
|
||||||
|
# node signature matches each extracted edge. Anything outside this subset
|
||||||
|
# should only stay on the edge if it is a non-custom (LLM generated) label.
|
||||||
edge_types_lst: list[dict[str, type[BaseModel]]] = []
|
edge_types_lst: list[dict[str, type[BaseModel]]] = []
|
||||||
|
custom_type_names = set(edge_types or {})
|
||||||
for extracted_edge in extracted_edges:
|
for extracted_edge in extracted_edges:
|
||||||
source_node = uuid_entity_map.get(extracted_edge.source_node_uuid)
|
source_node = uuid_entity_map.get(extracted_edge.source_node_uuid)
|
||||||
target_node = uuid_entity_map.get(extracted_edge.target_node_uuid)
|
target_node = uuid_entity_map.get(extracted_edge.target_node_uuid)
|
||||||
|
|
@ -314,11 +318,16 @@ async def resolve_extracted_edges(
|
||||||
|
|
||||||
for extracted_edge, extracted_edge_types in zip(extracted_edges, edge_types_lst, strict=True):
|
for extracted_edge, extracted_edge_types in zip(extracted_edges, edge_types_lst, strict=True):
|
||||||
allowed_type_names = set(extracted_edge_types)
|
allowed_type_names = set(extracted_edge_types)
|
||||||
|
is_custom_name = extracted_edge.name in custom_type_names
|
||||||
if not allowed_type_names:
|
if not allowed_type_names:
|
||||||
if extracted_edge.name != DEFAULT_EDGE_NAME:
|
# No custom types are valid for this node pairing. Keep LLM generated
|
||||||
|
# labels, but flip disallowed custom names back to the default.
|
||||||
|
if is_custom_name and extracted_edge.name != DEFAULT_EDGE_NAME:
|
||||||
extracted_edge.name = DEFAULT_EDGE_NAME
|
extracted_edge.name = DEFAULT_EDGE_NAME
|
||||||
continue
|
continue
|
||||||
if extracted_edge.name not in allowed_type_names:
|
if is_custom_name and extracted_edge.name not in allowed_type_names:
|
||||||
|
# Custom name exists but it is not permitted for this source/target
|
||||||
|
# signature, so fall back to the default edge label.
|
||||||
extracted_edge.name = DEFAULT_EDGE_NAME
|
extracted_edge.name = DEFAULT_EDGE_NAME
|
||||||
|
|
||||||
# resolve edges with related edges in the graph and find invalidation candidates
|
# resolve edges with related edges in the graph and find invalidation candidates
|
||||||
|
|
@ -332,6 +341,7 @@ async def resolve_extracted_edges(
|
||||||
existing_edges,
|
existing_edges,
|
||||||
episode,
|
episode,
|
||||||
extracted_edge_types,
|
extracted_edge_types,
|
||||||
|
custom_type_names,
|
||||||
clients.ensure_ascii,
|
clients.ensure_ascii,
|
||||||
)
|
)
|
||||||
for extracted_edge, related_edges, existing_edges, extracted_edge_types in zip(
|
for extracted_edge, related_edges, existing_edges, extracted_edge_types in zip(
|
||||||
|
|
@ -404,8 +414,37 @@ async def resolve_extracted_edge(
|
||||||
existing_edges: list[EntityEdge],
|
existing_edges: list[EntityEdge],
|
||||||
episode: EpisodicNode,
|
episode: EpisodicNode,
|
||||||
edge_type_candidates: dict[str, type[BaseModel]] | None = None,
|
edge_type_candidates: dict[str, type[BaseModel]] | None = None,
|
||||||
|
custom_edge_type_names: set[str] | None = None,
|
||||||
ensure_ascii: bool = True,
|
ensure_ascii: bool = True,
|
||||||
) -> tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]:
|
) -> tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]:
|
||||||
|
"""Resolve an extracted edge against existing graph context.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
llm_client : LLMClient
|
||||||
|
Client used to invoke the LLM for deduplication and attribute extraction.
|
||||||
|
extracted_edge : EntityEdge
|
||||||
|
Newly extracted edge whose canonical representation is being resolved.
|
||||||
|
related_edges : list[EntityEdge]
|
||||||
|
Candidate edges with identical endpoints used for duplicate detection.
|
||||||
|
existing_edges : list[EntityEdge]
|
||||||
|
Broader set of edges evaluated for contradiction / invalidation.
|
||||||
|
episode : EpisodicNode
|
||||||
|
Episode providing content context when extracting edge attributes.
|
||||||
|
edge_type_candidates : dict[str, type[BaseModel]] | None
|
||||||
|
Custom edge types permitted for the current source/target signature.
|
||||||
|
custom_edge_type_names : set[str] | None
|
||||||
|
Full catalog of registered custom edge names. Used to distinguish
|
||||||
|
between disallowed custom types (which fall back to the default label)
|
||||||
|
and ad-hoc labels emitted by the LLM.
|
||||||
|
ensure_ascii : bool
|
||||||
|
Whether prompt payloads should coerce ASCII output.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]
|
||||||
|
The resolved edge, any duplicates, and edges to invalidate.
|
||||||
|
"""
|
||||||
if len(related_edges) == 0 and len(existing_edges) == 0:
|
if len(related_edges) == 0 and len(existing_edges) == 0:
|
||||||
return extracted_edge, [], []
|
return extracted_edge, [], []
|
||||||
|
|
||||||
|
|
@ -480,7 +519,15 @@ async def resolve_extracted_edge(
|
||||||
|
|
||||||
fact_type: str = response_object.fact_type
|
fact_type: str = response_object.fact_type
|
||||||
candidate_type_names = set(edge_type_candidates or {})
|
candidate_type_names = set(edge_type_candidates or {})
|
||||||
if candidate_type_names and fact_type in candidate_type_names:
|
custom_type_names = custom_edge_type_names or set()
|
||||||
|
|
||||||
|
is_default_type = fact_type.upper() == 'DEFAULT'
|
||||||
|
is_custom_type = fact_type in custom_type_names
|
||||||
|
is_allowed_custom_type = fact_type in candidate_type_names
|
||||||
|
|
||||||
|
if is_allowed_custom_type:
|
||||||
|
# The LLM selected a custom type that is allowed for the node pair.
|
||||||
|
# Adopt the custom type and, if needed, extract its structured attributes.
|
||||||
resolved_edge.name = fact_type
|
resolved_edge.name = fact_type
|
||||||
|
|
||||||
edge_attributes_context = {
|
edge_attributes_context = {
|
||||||
|
|
@ -499,9 +546,16 @@ async def resolve_extracted_edge(
|
||||||
)
|
)
|
||||||
|
|
||||||
resolved_edge.attributes = edge_attributes_response
|
resolved_edge.attributes = edge_attributes_response
|
||||||
elif fact_type.upper() != 'DEFAULT':
|
elif not is_default_type and is_custom_type:
|
||||||
|
# The LLM picked a custom type that is not allowed for this signature.
|
||||||
|
# Reset to the default label and drop any structured attributes.
|
||||||
resolved_edge.name = DEFAULT_EDGE_NAME
|
resolved_edge.name = DEFAULT_EDGE_NAME
|
||||||
resolved_edge.attributes = {}
|
resolved_edge.attributes = {}
|
||||||
|
elif not is_default_type:
|
||||||
|
# Non-custom labels are allowed to pass through so long as the LLM does
|
||||||
|
# not return the sentinel DEFAULT value.
|
||||||
|
resolved_edge.name = fact_type
|
||||||
|
resolved_edge.attributes = {}
|
||||||
|
|
||||||
end = time()
|
end = time()
|
||||||
logger.debug(
|
logger.debug(
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
[project]
|
[project]
|
||||||
name = "graphiti-core"
|
name = "graphiti-core"
|
||||||
description = "A temporal graph building library"
|
description = "A temporal graph building library"
|
||||||
version = "0.30.0pre4"
|
version = "0.30.0pre5"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Paul Paliychuk", email = "paul@getzep.com" },
|
{ name = "Paul Paliychuk", email = "paul@getzep.com" },
|
||||||
{ name = "Preston Rasmussen", email = "preston@getzep.com" },
|
{ name = "Preston Rasmussen", email = "preston@getzep.com" },
|
||||||
|
|
@ -42,6 +42,9 @@ dev = [
|
||||||
"google-genai>=1.8.0",
|
"google-genai>=1.8.0",
|
||||||
"falkordb>=1.1.2,<2.0.0",
|
"falkordb>=1.1.2,<2.0.0",
|
||||||
"kuzu>=0.11.2",
|
"kuzu>=0.11.2",
|
||||||
|
"boto3>=1.39.16",
|
||||||
|
"opensearch-py>=3.0.0",
|
||||||
|
"langchain-aws>=0.2.29",
|
||||||
"ipykernel>=6.29.5",
|
"ipykernel>=6.29.5",
|
||||||
"jupyterlab>=4.2.4",
|
"jupyterlab>=4.2.4",
|
||||||
"diskcache-stubs>=5.6.3.6.20240818",
|
"diskcache-stubs>=5.6.3.6.20240818",
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,11 @@
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from types import SimpleNamespace
|
||||||
from unittest.mock import AsyncMock, MagicMock
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from graphiti_core.edges import EntityEdge
|
from graphiti_core.edges import EntityEdge
|
||||||
from graphiti_core.graphiti_types import GraphitiClients
|
|
||||||
from graphiti_core.nodes import EntityNode, EpisodicNode
|
from graphiti_core.nodes import EntityNode, EpisodicNode
|
||||||
from graphiti_core.search.search_config import SearchResults
|
from graphiti_core.search.search_config import SearchResults
|
||||||
from graphiti_core.utils.maintenance.edge_operations import (
|
from graphiti_core.utils.maintenance.edge_operations import (
|
||||||
|
|
@ -172,10 +172,14 @@ async def test_resolve_extracted_edges_resets_unmapped_names(monkeypatch):
|
||||||
|
|
||||||
llm_client = MagicMock()
|
llm_client = MagicMock()
|
||||||
llm_client.generate_response = AsyncMock(
|
llm_client.generate_response = AsyncMock(
|
||||||
return_value={'duplicate_facts': [], 'contradicted_facts': [], 'fact_type': 'DEFAULT'}
|
return_value={
|
||||||
|
'duplicate_facts': [],
|
||||||
|
'contradicted_facts': [],
|
||||||
|
'fact_type': 'DEFAULT',
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
clients = GraphitiClients.model_construct(
|
clients = SimpleNamespace(
|
||||||
driver=MagicMock(),
|
driver=MagicMock(),
|
||||||
llm_client=llm_client,
|
llm_client=llm_client,
|
||||||
embedder=MagicMock(),
|
embedder=MagicMock(),
|
||||||
|
|
@ -234,6 +238,87 @@ async def test_resolve_extracted_edges_resets_unmapped_names(monkeypatch):
|
||||||
assert invalidated_edges == []
|
assert invalidated_edges == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_resolve_extracted_edges_keeps_unknown_names(monkeypatch):
|
||||||
|
from graphiti_core.utils.maintenance import edge_operations as edge_ops
|
||||||
|
|
||||||
|
monkeypatch.setattr(edge_ops, 'create_entity_edge_embeddings', AsyncMock(return_value=None))
|
||||||
|
monkeypatch.setattr(EntityEdge, 'get_between_nodes', AsyncMock(return_value=[]))
|
||||||
|
|
||||||
|
async def immediate_gather(*aws, max_coroutines=None):
|
||||||
|
return [await aw for aw in aws]
|
||||||
|
|
||||||
|
monkeypatch.setattr(edge_ops, 'semaphore_gather', immediate_gather)
|
||||||
|
monkeypatch.setattr(edge_ops, 'search', AsyncMock(return_value=SearchResults()))
|
||||||
|
|
||||||
|
llm_client = MagicMock()
|
||||||
|
llm_client.generate_response = AsyncMock(
|
||||||
|
return_value={
|
||||||
|
'duplicate_facts': [],
|
||||||
|
'contradicted_facts': [],
|
||||||
|
'fact_type': 'DEFAULT',
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
clients = SimpleNamespace(
|
||||||
|
driver=MagicMock(),
|
||||||
|
llm_client=llm_client,
|
||||||
|
embedder=MagicMock(),
|
||||||
|
cross_encoder=MagicMock(),
|
||||||
|
ensure_ascii=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
source_node = EntityNode(
|
||||||
|
uuid='source_uuid',
|
||||||
|
name='User Node',
|
||||||
|
group_id='group_1',
|
||||||
|
labels=['User'],
|
||||||
|
)
|
||||||
|
target_node = EntityNode(
|
||||||
|
uuid='target_uuid',
|
||||||
|
name='Topic Node',
|
||||||
|
group_id='group_1',
|
||||||
|
labels=['Topic'],
|
||||||
|
)
|
||||||
|
|
||||||
|
extracted_edge = EntityEdge(
|
||||||
|
source_node_uuid=source_node.uuid,
|
||||||
|
target_node_uuid=target_node.uuid,
|
||||||
|
name='INTERACTED_WITH',
|
||||||
|
group_id='group_1',
|
||||||
|
fact='User interacted with topic',
|
||||||
|
episodes=[],
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
valid_at=None,
|
||||||
|
invalid_at=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
episode = EpisodicNode(
|
||||||
|
uuid='episode_uuid',
|
||||||
|
name='Episode',
|
||||||
|
group_id='group_1',
|
||||||
|
source='message',
|
||||||
|
source_description='desc',
|
||||||
|
content='Episode content',
|
||||||
|
valid_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
|
||||||
|
edge_types = {'OCCURRED_AT': OccurredAtEdge}
|
||||||
|
edge_type_map = {('Event', 'Entity'): ['OCCURRED_AT']}
|
||||||
|
|
||||||
|
resolved_edges, invalidated_edges = await resolve_extracted_edges(
|
||||||
|
clients,
|
||||||
|
[extracted_edge],
|
||||||
|
episode,
|
||||||
|
[source_node, target_node],
|
||||||
|
edge_types,
|
||||||
|
edge_type_map,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert resolved_edges[0].name == 'INTERACTED_WITH'
|
||||||
|
assert invalidated_edges == []
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_resolve_extracted_edge_rejects_unmapped_fact_type(mock_llm_client):
|
async def test_resolve_extracted_edge_rejects_unmapped_fact_type(mock_llm_client):
|
||||||
mock_llm_client.generate_response.return_value = {
|
mock_llm_client.generate_response.return_value = {
|
||||||
|
|
@ -283,9 +368,69 @@ async def test_resolve_extracted_edge_rejects_unmapped_fact_type(mock_llm_client
|
||||||
[],
|
[],
|
||||||
episode,
|
episode,
|
||||||
edge_type_candidates={},
|
edge_type_candidates={},
|
||||||
|
custom_edge_type_names={'OCCURRED_AT'},
|
||||||
ensure_ascii=True,
|
ensure_ascii=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert resolved_edge.name == DEFAULT_EDGE_NAME
|
assert resolved_edge.name == DEFAULT_EDGE_NAME
|
||||||
assert duplicates == []
|
assert duplicates == []
|
||||||
assert invalidated == []
|
assert invalidated == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_resolve_extracted_edge_accepts_unknown_fact_type(mock_llm_client):
|
||||||
|
mock_llm_client.generate_response.return_value = {
|
||||||
|
'duplicate_facts': [],
|
||||||
|
'contradicted_facts': [],
|
||||||
|
'fact_type': 'INTERACTED_WITH',
|
||||||
|
}
|
||||||
|
|
||||||
|
extracted_edge = EntityEdge(
|
||||||
|
source_node_uuid='source_uuid',
|
||||||
|
target_node_uuid='target_uuid',
|
||||||
|
name='DEFAULT',
|
||||||
|
group_id='group_1',
|
||||||
|
fact='User interacted with topic',
|
||||||
|
episodes=[],
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
valid_at=None,
|
||||||
|
invalid_at=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
episode = EpisodicNode(
|
||||||
|
uuid='episode_uuid',
|
||||||
|
name='Episode',
|
||||||
|
group_id='group_1',
|
||||||
|
source='message',
|
||||||
|
source_description='desc',
|
||||||
|
content='Episode content',
|
||||||
|
valid_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
|
||||||
|
related_edge = EntityEdge(
|
||||||
|
source_node_uuid='source_uuid',
|
||||||
|
target_node_uuid='target_uuid',
|
||||||
|
name='DEFAULT',
|
||||||
|
group_id='group_1',
|
||||||
|
fact='User mentioned a topic',
|
||||||
|
episodes=[],
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
valid_at=None,
|
||||||
|
invalid_at=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
resolved_edge, duplicates, invalidated = await resolve_extracted_edge(
|
||||||
|
mock_llm_client,
|
||||||
|
extracted_edge,
|
||||||
|
[related_edge],
|
||||||
|
[],
|
||||||
|
episode,
|
||||||
|
edge_type_candidates={'OCCURRED_AT': OccurredAtEdge},
|
||||||
|
custom_edge_type_names={'OCCURRED_AT'},
|
||||||
|
ensure_ascii=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert resolved_edge.name == 'INTERACTED_WITH'
|
||||||
|
assert resolved_edge.attributes == {}
|
||||||
|
assert duplicates == []
|
||||||
|
assert invalidated == []
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue