diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index e27ea9c6..3eed0488 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -1070,6 +1070,7 @@ class Graphiti: group_id=edge.group_id, ), None, + None, self.ensure_ascii, ) diff --git a/graphiti_core/utils/bulk_utils.py b/graphiti_core/utils/bulk_utils.py index 66a15bcb..291c7825 100644 --- a/graphiti_core/utils/bulk_utils.py +++ b/graphiti_core/utils/bulk_utils.py @@ -477,6 +477,7 @@ async def dedupe_edges_bulk( candidates, episode, edge_types, + set(edge_types), clients.ensure_ascii, ) for episode, edge, candidates in dedupe_tuples diff --git a/graphiti_core/utils/maintenance/edge_operations.py b/graphiti_core/utils/maintenance/edge_operations.py index 7f5b2e6e..788acf97 100644 --- a/graphiti_core/utils/maintenance/edge_operations.py +++ b/graphiti_core/utils/maintenance/edge_operations.py @@ -283,8 +283,12 @@ async def resolve_extracted_edges( # Build entity hash table uuid_entity_map: dict[str, EntityNode] = {entity.uuid: entity for entity in entities} - # Determine which edge types are relevant for each edge + # Determine which edge types are relevant for each edge. + # `edge_types_lst` stores the subset of custom edge definitions whose + # node signature matches each extracted edge. Anything outside this subset + # should only stay on the edge if it is a non-custom (LLM generated) label. edge_types_lst: list[dict[str, type[BaseModel]]] = [] + custom_type_names = set(edge_types or {}) for extracted_edge in extracted_edges: source_node = uuid_entity_map.get(extracted_edge.source_node_uuid) target_node = uuid_entity_map.get(extracted_edge.target_node_uuid) @@ -314,11 +318,16 @@ async def resolve_extracted_edges( for extracted_edge, extracted_edge_types in zip(extracted_edges, edge_types_lst, strict=True): allowed_type_names = set(extracted_edge_types) + is_custom_name = extracted_edge.name in custom_type_names if not allowed_type_names: - if extracted_edge.name != DEFAULT_EDGE_NAME: + # No custom types are valid for this node pairing. Keep LLM generated + # labels, but flip disallowed custom names back to the default. + if is_custom_name and extracted_edge.name != DEFAULT_EDGE_NAME: extracted_edge.name = DEFAULT_EDGE_NAME continue - if extracted_edge.name not in allowed_type_names: + if is_custom_name and extracted_edge.name not in allowed_type_names: + # Custom name exists but it is not permitted for this source/target + # signature, so fall back to the default edge label. extracted_edge.name = DEFAULT_EDGE_NAME # resolve edges with related edges in the graph and find invalidation candidates @@ -332,6 +341,7 @@ async def resolve_extracted_edges( existing_edges, episode, extracted_edge_types, + custom_type_names, clients.ensure_ascii, ) for extracted_edge, related_edges, existing_edges, extracted_edge_types in zip( @@ -404,8 +414,37 @@ async def resolve_extracted_edge( existing_edges: list[EntityEdge], episode: EpisodicNode, edge_type_candidates: dict[str, type[BaseModel]] | None = None, + custom_edge_type_names: set[str] | None = None, ensure_ascii: bool = True, ) -> tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]: + """Resolve an extracted edge against existing graph context. + + Parameters + ---------- + llm_client : LLMClient + Client used to invoke the LLM for deduplication and attribute extraction. + extracted_edge : EntityEdge + Newly extracted edge whose canonical representation is being resolved. + related_edges : list[EntityEdge] + Candidate edges with identical endpoints used for duplicate detection. + existing_edges : list[EntityEdge] + Broader set of edges evaluated for contradiction / invalidation. + episode : EpisodicNode + Episode providing content context when extracting edge attributes. + edge_type_candidates : dict[str, type[BaseModel]] | None + Custom edge types permitted for the current source/target signature. + custom_edge_type_names : set[str] | None + Full catalog of registered custom edge names. Used to distinguish + between disallowed custom types (which fall back to the default label) + and ad-hoc labels emitted by the LLM. + ensure_ascii : bool + Whether prompt payloads should coerce ASCII output. + + Returns + ------- + tuple[EntityEdge, list[EntityEdge], list[EntityEdge]] + The resolved edge, any duplicates, and edges to invalidate. + """ if len(related_edges) == 0 and len(existing_edges) == 0: return extracted_edge, [], [] @@ -480,7 +519,15 @@ async def resolve_extracted_edge( fact_type: str = response_object.fact_type candidate_type_names = set(edge_type_candidates or {}) - if candidate_type_names and fact_type in candidate_type_names: + custom_type_names = custom_edge_type_names or set() + + is_default_type = fact_type.upper() == 'DEFAULT' + is_custom_type = fact_type in custom_type_names + is_allowed_custom_type = fact_type in candidate_type_names + + if is_allowed_custom_type: + # The LLM selected a custom type that is allowed for the node pair. + # Adopt the custom type and, if needed, extract its structured attributes. resolved_edge.name = fact_type edge_attributes_context = { @@ -499,9 +546,16 @@ async def resolve_extracted_edge( ) resolved_edge.attributes = edge_attributes_response - elif fact_type.upper() != 'DEFAULT': + elif not is_default_type and is_custom_type: + # The LLM picked a custom type that is not allowed for this signature. + # Reset to the default label and drop any structured attributes. resolved_edge.name = DEFAULT_EDGE_NAME resolved_edge.attributes = {} + elif not is_default_type: + # Non-custom labels are allowed to pass through so long as the LLM does + # not return the sentinel DEFAULT value. + resolved_edge.name = fact_type + resolved_edge.attributes = {} end = time() logger.debug( diff --git a/pyproject.toml b/pyproject.toml index 654c9de4..5d8a879d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "graphiti-core" description = "A temporal graph building library" -version = "0.30.0pre4" +version = "0.30.0pre5" authors = [ { name = "Paul Paliychuk", email = "paul@getzep.com" }, { name = "Preston Rasmussen", email = "preston@getzep.com" }, @@ -42,6 +42,9 @@ dev = [ "google-genai>=1.8.0", "falkordb>=1.1.2,<2.0.0", "kuzu>=0.11.2", + "boto3>=1.39.16", + "opensearch-py>=3.0.0", + "langchain-aws>=0.2.29", "ipykernel>=6.29.5", "jupyterlab>=4.2.4", "diskcache-stubs>=5.6.3.6.20240818", diff --git a/tests/utils/maintenance/test_edge_operations.py b/tests/utils/maintenance/test_edge_operations.py index cfe7eb08..5dc05798 100644 --- a/tests/utils/maintenance/test_edge_operations.py +++ b/tests/utils/maintenance/test_edge_operations.py @@ -1,11 +1,11 @@ from datetime import datetime, timedelta, timezone +from types import SimpleNamespace from unittest.mock import AsyncMock, MagicMock import pytest from pydantic import BaseModel from graphiti_core.edges import EntityEdge -from graphiti_core.graphiti_types import GraphitiClients from graphiti_core.nodes import EntityNode, EpisodicNode from graphiti_core.search.search_config import SearchResults from graphiti_core.utils.maintenance.edge_operations import ( @@ -172,10 +172,14 @@ async def test_resolve_extracted_edges_resets_unmapped_names(monkeypatch): llm_client = MagicMock() llm_client.generate_response = AsyncMock( - return_value={'duplicate_facts': [], 'contradicted_facts': [], 'fact_type': 'DEFAULT'} + return_value={ + 'duplicate_facts': [], + 'contradicted_facts': [], + 'fact_type': 'DEFAULT', + } ) - clients = GraphitiClients.model_construct( + clients = SimpleNamespace( driver=MagicMock(), llm_client=llm_client, embedder=MagicMock(), @@ -234,6 +238,87 @@ async def test_resolve_extracted_edges_resets_unmapped_names(monkeypatch): assert invalidated_edges == [] +@pytest.mark.asyncio +async def test_resolve_extracted_edges_keeps_unknown_names(monkeypatch): + from graphiti_core.utils.maintenance import edge_operations as edge_ops + + monkeypatch.setattr(edge_ops, 'create_entity_edge_embeddings', AsyncMock(return_value=None)) + monkeypatch.setattr(EntityEdge, 'get_between_nodes', AsyncMock(return_value=[])) + + async def immediate_gather(*aws, max_coroutines=None): + return [await aw for aw in aws] + + monkeypatch.setattr(edge_ops, 'semaphore_gather', immediate_gather) + monkeypatch.setattr(edge_ops, 'search', AsyncMock(return_value=SearchResults())) + + llm_client = MagicMock() + llm_client.generate_response = AsyncMock( + return_value={ + 'duplicate_facts': [], + 'contradicted_facts': [], + 'fact_type': 'DEFAULT', + } + ) + + clients = SimpleNamespace( + driver=MagicMock(), + llm_client=llm_client, + embedder=MagicMock(), + cross_encoder=MagicMock(), + ensure_ascii=True, + ) + + source_node = EntityNode( + uuid='source_uuid', + name='User Node', + group_id='group_1', + labels=['User'], + ) + target_node = EntityNode( + uuid='target_uuid', + name='Topic Node', + group_id='group_1', + labels=['Topic'], + ) + + extracted_edge = EntityEdge( + source_node_uuid=source_node.uuid, + target_node_uuid=target_node.uuid, + name='INTERACTED_WITH', + group_id='group_1', + fact='User interacted with topic', + episodes=[], + created_at=datetime.now(timezone.utc), + valid_at=None, + invalid_at=None, + ) + + episode = EpisodicNode( + uuid='episode_uuid', + name='Episode', + group_id='group_1', + source='message', + source_description='desc', + content='Episode content', + valid_at=datetime.now(timezone.utc), + ) + + edge_types = {'OCCURRED_AT': OccurredAtEdge} + edge_type_map = {('Event', 'Entity'): ['OCCURRED_AT']} + + resolved_edges, invalidated_edges = await resolve_extracted_edges( + clients, + [extracted_edge], + episode, + [source_node, target_node], + edge_types, + edge_type_map, + ) + + assert resolved_edges[0].name == 'INTERACTED_WITH' + assert invalidated_edges == [] + + @pytest.mark.asyncio async def test_resolve_extracted_edge_rejects_unmapped_fact_type(mock_llm_client): mock_llm_client.generate_response.return_value = { @@ -283,9 +368,69 @@ async def test_resolve_extracted_edge_rejects_unmapped_fact_type(mock_llm_client [], episode, edge_type_candidates={}, + custom_edge_type_names={'OCCURRED_AT'}, ensure_ascii=True, ) assert resolved_edge.name == DEFAULT_EDGE_NAME assert duplicates == [] assert invalidated == [] + + +@pytest.mark.asyncio +async def test_resolve_extracted_edge_accepts_unknown_fact_type(mock_llm_client): + mock_llm_client.generate_response.return_value = { + 'duplicate_facts': [], + 'contradicted_facts': [], + 'fact_type': 'INTERACTED_WITH', + } + + extracted_edge = EntityEdge( + source_node_uuid='source_uuid', + target_node_uuid='target_uuid', + name='DEFAULT', + group_id='group_1', + fact='User interacted with topic', + episodes=[], + created_at=datetime.now(timezone.utc), + valid_at=None, + invalid_at=None, + ) + + episode = EpisodicNode( + uuid='episode_uuid', + name='Episode', + group_id='group_1', + source='message', + source_description='desc', + content='Episode content', + valid_at=datetime.now(timezone.utc), + ) + + related_edge = EntityEdge( + source_node_uuid='source_uuid', + target_node_uuid='target_uuid', + name='DEFAULT', + group_id='group_1', + fact='User mentioned a topic', + episodes=[], + created_at=datetime.now(timezone.utc), + valid_at=None, + invalid_at=None, + ) + + resolved_edge, duplicates, invalidated = await resolve_extracted_edge( + mock_llm_client, + extracted_edge, + [related_edge], + [], + episode, + edge_type_candidates={'OCCURRED_AT': OccurredAtEdge}, + custom_edge_type_names={'OCCURRED_AT'}, + ensure_ascii=True, + ) + + assert resolved_edge.name == 'INTERACTED_WITH' + assert resolved_edge.attributes == {} + assert duplicates == [] + assert invalidated == []