diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index 4153a684..5417728f 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -41,6 +41,7 @@ from graphiti_core.search.search_config_recipes import ( from graphiti_core.search.search_filters import SearchFilters from graphiti_core.search.search_utils import ( RELEVANT_SCHEMA_LIMIT, + get_edge_invalidation_candidates, get_mentioned_nodes, get_relevant_edges, ) @@ -62,9 +63,8 @@ from graphiti_core.utils.maintenance.community_operations import ( ) from graphiti_core.utils.maintenance.edge_operations import ( build_episodic_edges, - dedupe_extracted_edge, extract_edges, - resolve_edge_contradictions, + resolve_extracted_edge, resolve_extracted_edges, ) from graphiti_core.utils.maintenance.graph_data_operations import ( @@ -77,7 +77,6 @@ from graphiti_core.utils.maintenance.node_operations import ( extract_nodes, resolve_extracted_nodes, ) -from graphiti_core.utils.maintenance.temporal_operations import get_edge_contradictions from graphiti_core.utils.ontology_utils.entity_types_utils import validate_entity_types logger = logging.getLogger(__name__) @@ -681,17 +680,15 @@ class Graphiti: updated_edge = resolve_edge_pointers([edge], uuid_map)[0] - related_edges = await get_relevant_edges(self.driver, [updated_edge], SearchFilters(), 0.8) + related_edges = (await get_relevant_edges(self.driver, [updated_edge], SearchFilters()))[0] + existing_edges = ( + await get_edge_invalidation_candidates(self.driver, [updated_edge], SearchFilters()) + )[0] - resolved_edge = await dedupe_extracted_edge( - self.llm_client, - updated_edge, - related_edges[0], + resolved_edge, invalidated_edges = await resolve_extracted_edge( + self.llm_client, updated_edge, related_edges, existing_edges ) - contradicting_edges = await get_edge_contradictions(self.llm_client, edge, related_edges[0]) - invalidated_edges = resolve_edge_contradictions(resolved_edge, contradicting_edges) - await add_nodes_and_edges_bulk( self.driver, [], [], resolved_nodes, [resolved_edge] + invalidated_edges, self.embedder ) diff --git a/graphiti_core/helpers.py b/graphiti_core/helpers.py index 7c6175d4..21c388c2 100644 --- a/graphiti_core/helpers.py +++ b/graphiti_core/helpers.py @@ -87,8 +87,8 @@ def normalize_l2(embedding: list[float]) -> NDArray: # Use this instead of asyncio.gather() to bound coroutines async def semaphore_gather( - *coroutines: Coroutine, - max_coroutines: int = SEMAPHORE_LIMIT, + *coroutines: Coroutine, + max_coroutines: int = SEMAPHORE_LIMIT, ): semaphore = asyncio.Semaphore(max_coroutines) diff --git a/graphiti_core/prompts/dedupe_edges.py b/graphiti_core/prompts/dedupe_edges.py index 5354f3cc..f63011d4 100644 --- a/graphiti_core/prompts/dedupe_edges.py +++ b/graphiti_core/prompts/dedupe_edges.py @@ -27,6 +27,10 @@ class EdgeDuplicate(BaseModel): ..., description='id of the duplicate fact. If no duplicate facts are found, default to -1.', ) + contradicted_facts: list[int] = Field( + ..., + description='List of ids of facts that should be invalidated. If no facts should be invalidated, the list should be empty.', + ) class UniqueFact(BaseModel): @@ -41,11 +45,13 @@ class UniqueFacts(BaseModel): class Prompt(Protocol): edge: PromptVersion edge_list: PromptVersion + resolve_edge: PromptVersion class Versions(TypedDict): edge: PromptFunction edge_list: PromptFunction + resolve_edge: PromptFunction def edge(context: dict[str, Any]) -> list[Message]: @@ -106,4 +112,41 @@ def edge_list(context: dict[str, Any]) -> list[Message]: ] -versions: Versions = {'edge': edge, 'edge_list': edge_list} +def resolve_edge(context: dict[str, Any]) -> list[Message]: + return [ + Message( + role='system', + content='You are a helpful assistant that de-duplicates facts from fact lists and determines which existing ' + 'facts are contradicted by the new fact.', + ), + Message( + role='user', + content=f""" + + {context['new_edge']} + + + + {context['existing_edges']} + + + {context['edge_invalidation_candidates']} + + + + Task: + If the NEW FACT represents the same factual information as any fact in EXISTING FACTS, return the idx of the duplicate fact. + If the NEW FACT is not a duplicate of any of the EXISTING FACTS, return -1. + + Based on the provided FACT INVALIDATION CANDIDATES and NEW FACT, determine which existing facts the new fact contradicts. + Return a list containing all idx's of the facts that are contradicted by the NEW FACT. + If there are no contradicted facts, return an empty list. + + Guidelines: + 1. The facts do not need to be completely identical to be duplicates, they just need to express the same information. + """, + ), + ] + + +versions: Versions = {'edge': edge, 'edge_list': edge_list, 'resolve_edge': resolve_edge} diff --git a/graphiti_core/prompts/dedupe_nodes.py b/graphiti_core/prompts/dedupe_nodes.py index 1cac6b79..318d4c9f 100644 --- a/graphiti_core/prompts/dedupe_nodes.py +++ b/graphiti_core/prompts/dedupe_nodes.py @@ -23,21 +23,31 @@ from .models import Message, PromptFunction, PromptVersion class NodeDuplicate(BaseModel): - duplicate_node_id: int = Field( + id: int = Field(..., description='integer id of the entity') + duplicate_idx: int = Field( ..., - description='id of the duplicate node. If no duplicate nodes are found, default to -1.', + description='idx of the duplicate node. If no duplicate nodes are found, default to -1.', ) - name: str = Field(..., description='Name of the entity.') + name: str = Field( + ..., + description='Name of the entity. Should be the most complete and descriptive name possible.', + ) + + +class NodeResolutions(BaseModel): + entity_resolutions: list[NodeDuplicate] = Field(..., description='List of resolved nodes') class Prompt(Protocol): node: PromptVersion node_list: PromptVersion + nodes: PromptVersion class Versions(TypedDict): node: PromptFunction node_list: PromptFunction + nodes: PromptFunction def node(context: dict[str, Any]) -> list[Message]: @@ -89,6 +99,67 @@ def node(context: dict[str, Any]) -> list[Message]: ] +def nodes(context: dict[str, Any]) -> list[Message]: + return [ + Message( + role='system', + content='You are a helpful assistant that determines whether or not ENTITIES extracted from a conversation are duplicates' + 'of existing entities.', + ), + Message( + role='user', + content=f""" + + {json.dumps([ep for ep in context['previous_episodes']], indent=2)} + + + {context['episode_content']} + + + + Each of the following ENTITIES were extracted from the CURRENT MESSAGE. + Each entity in ENTITIES is represented as a JSON object with the following structure: + {{ + id: integer id of the entity, + name: "name of the entity", + entity_type: "ontological classification of the entity", + entity_type_description: "Description of what the entity type represents", + duplication_candidates: [ + {{ + idx: integer index of the candidate entity, + name: "name of the candidate entity", + entity_type: "ontological classification of the candidate entity", + ... + }} + ] + }} + + + {json.dumps(context['extracted_nodes'], indent=2)} + + + For each of the above ENTITIES, determine if the entity is a duplicate of any of its duplication candidates. + + Entities should only be considered duplicates if they refer to the *same real-world object or concept*. + + Do NOT mark entities as duplicates if: + - They are related but distinct. + - They have similar names or purposes but refer to separate instances or concepts. + + Task: + Your response will be a list called entity_resolutions which contains one entry for each entity. + + For each entity, return the id of the entity as id, the name of the entity as name, and the duplicate_idx + as an integer. + + - If an entity is a duplicate of one of its duplication_candidates, return the idx of the candidate it is a + duplicate of. + - If an entity is not a duplicate of one of its duplication candidates, return the -1 as the duplication_idx + """, + ), + ] + + def node_list(context: dict[str, Any]) -> list[Message]: return [ Message( @@ -126,4 +197,4 @@ def node_list(context: dict[str, Any]) -> list[Message]: ] -versions: Versions = {'node': node, 'node_list': node_list} +versions: Versions = {'node': node, 'node_list': node_list, 'nodes': nodes} diff --git a/graphiti_core/prompts/invalidate_edges.py b/graphiti_core/prompts/invalidate_edges.py index f30048a5..f5342ed3 100644 --- a/graphiti_core/prompts/invalidate_edges.py +++ b/graphiti_core/prompts/invalidate_edges.py @@ -24,7 +24,7 @@ from .models import Message, PromptFunction, PromptVersion class InvalidatedEdges(BaseModel): contradicted_facts: list[int] = Field( ..., - description='List of ids of facts that be should invalidated. If no facts should be invalidated, the list should be empty.', + description='List of ids of facts that should be invalidated. If no facts should be invalidated, the list should be empty.', ) diff --git a/graphiti_core/utils/maintenance/edge_operations.py b/graphiti_core/utils/maintenance/edge_operations.py index 64973015..d90fba52 100644 --- a/graphiti_core/utils/maintenance/edge_operations.py +++ b/graphiti_core/utils/maintenance/edge_operations.py @@ -35,9 +35,6 @@ from graphiti_core.prompts.extract_edges import ExtractedEdges, MissingFacts from graphiti_core.search.search_filters import SearchFilters from graphiti_core.search.search_utils import get_edge_invalidation_candidates, get_relevant_edges from graphiti_core.utils.datetime_utils import ensure_utc, utc_now -from graphiti_core.utils.maintenance.temporal_operations import ( - get_edge_contradictions, -) logger = logging.getLogger(__name__) @@ -245,7 +242,7 @@ async def resolve_extracted_edges( search_results: tuple[list[list[EntityEdge]], list[list[EntityEdge]]] = await semaphore_gather( get_relevant_edges(driver, extracted_edges, SearchFilters()), - get_edge_invalidation_candidates(driver, extracted_edges, SearchFilters()), + get_edge_invalidation_candidates(driver, extracted_edges, SearchFilters(), 0.2), ) related_edges_lists, edge_invalidation_candidates = search_results @@ -325,11 +322,52 @@ async def resolve_extracted_edge( extracted_edge: EntityEdge, related_edges: list[EntityEdge], existing_edges: list[EntityEdge], - episode: EpisodicNode, + episode: EpisodicNode | None = None, ) -> tuple[EntityEdge, list[EntityEdge]]: - resolved_edge, invalidation_candidates = await semaphore_gather( - dedupe_extracted_edge(llm_client, extracted_edge, related_edges, episode), - get_edge_contradictions(llm_client, extracted_edge, existing_edges), + if len(related_edges) == 0 and len(existing_edges) == 0: + return extracted_edge, [] + + start = time() + + # Prepare context for LLM + related_edges_context = [ + {'id': edge.uuid, 'fact': edge.fact} for i, edge in enumerate(related_edges) + ] + + invalidation_edge_candidates_context = [ + {'id': i, 'fact': existing_edge.fact} for i, existing_edge in enumerate(existing_edges) + ] + + context = { + 'existing_edges': related_edges_context, + 'new_edge': extracted_edge.fact, + 'edge_invalidation_candidates': invalidation_edge_candidates_context, + } + + llm_response = await llm_client.generate_response( + prompt_library.dedupe_edges.resolve_edge(context), + response_model=EdgeDuplicate, + model_size=ModelSize.small, + ) + + duplicate_fact_id: int = llm_response.get('duplicate_fact_id', -1) + + resolved_edge = ( + related_edges[duplicate_fact_id] + if 0 <= duplicate_fact_id < len(related_edges) + else extracted_edge + ) + + if duplicate_fact_id >= 0 and episode is not None: + resolved_edge.episodes.append(episode.uuid) + + contradicted_facts: list[int] = llm_response.get('contradicted_facts', []) + + invalidation_candidates: list[EntityEdge] = [existing_edges[i] for i in contradicted_facts] + + end = time() + logger.debug( + f'Resolved Edge: {extracted_edge.name} is {resolved_edge.name}, in {(end - start) * 1000} ms' ) now = utc_now() diff --git a/graphiti_core/utils/maintenance/node_operations.py b/graphiti_core/utils/maintenance/node_operations.py index f25746bb..2b3de99e 100644 --- a/graphiti_core/utils/maintenance/node_operations.py +++ b/graphiti_core/utils/maintenance/node_operations.py @@ -29,7 +29,7 @@ from graphiti_core.llm_client import LLMClient from graphiti_core.llm_client.config import ModelSize from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings from graphiti_core.prompts import prompt_library -from graphiti_core.prompts.dedupe_nodes import NodeDuplicate +from graphiti_core.prompts.dedupe_nodes import NodeDuplicate, NodeResolutions from graphiti_core.prompts.extract_nodes import ( ExtractedEntities, ExtractedEntity, @@ -243,28 +243,65 @@ async def resolve_extracted_nodes( existing_nodes_lists: list[list[EntityNode]] = [result.nodes for result in search_results] - resolved_nodes: list[EntityNode] = await semaphore_gather( - *[ - resolve_extracted_node( - llm_client, - extracted_node, - existing_nodes, - episode, - previous_episodes, - entity_types.get( - next((item for item in extracted_node.labels if item != 'Entity'), '') - ) - if entity_types is not None - else None, - ) - for extracted_node, existing_nodes in zip( - extracted_nodes, existing_nodes_lists, strict=True - ) - ] + entity_types_dict: dict[str, BaseModel] = entity_types if entity_types is not None else {} + + # Prepare context for LLM + extracted_nodes_context = [ + { + 'id': i, + 'name': node.name, + 'entity_type': node.labels, + 'entity_type_description': entity_types_dict.get( + next((item for item in node.labels if item != 'Entity'), '') + ).__doc__ + or 'Default Entity Type', + 'duplication_candidates': [ + { + **{ + 'idx': j, + 'name': candidate.name, + 'entity_types': candidate.labels, + }, + **candidate.attributes, + } + for j, candidate in enumerate(existing_nodes_lists[i]) + ], + } + for i, node in enumerate(extracted_nodes) + ] + + context = { + 'extracted_nodes': extracted_nodes_context, + 'episode_content': episode.content if episode is not None else '', + 'previous_episodes': [ep.content for ep in previous_episodes] + if previous_episodes is not None + else [], + } + + llm_response = await llm_client.generate_response( + prompt_library.dedupe_nodes.nodes(context), + response_model=NodeResolutions, ) + node_resolutions: list = llm_response.get('entity_resolutions', []) + + resolved_nodes: list[EntityNode] = [] uuid_map: dict[str, str] = {} - for extracted_node, resolved_node in zip(extracted_nodes, resolved_nodes, strict=True): + for resolution in node_resolutions: + resolution_id = resolution.get('id', -1) + duplicate_idx = resolution.get('duplicate_idx', -1) + + extracted_node = extracted_nodes[resolution_id] + + resolved_node = ( + existing_nodes_lists[resolution_id][duplicate_idx] + if 0 <= duplicate_idx < len(existing_nodes_lists[resolution_id]) + else extracted_node + ) + + resolved_node.name = resolution.get('name') + + resolved_nodes.append(resolved_node) uuid_map[extracted_node.uuid] = resolved_node.uuid logger.debug(f'Resolved nodes: {[(n.name, n.uuid) for n in resolved_nodes]}') @@ -410,6 +447,7 @@ async def extract_attributes_from_node( llm_response = await llm_client.generate_response( prompt_library.extract_nodes.extract_attributes(summary_context), response_model=entity_attributes_model, + model_size=ModelSize.small, ) node.summary = llm_response.get('summary', node.summary) diff --git a/pyproject.toml b/pyproject.toml index 2da24413..6623627a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "graphiti-core" description = "A temporal graph building library" -version = "0.11.6pre9" +version = "0.11.6" authors = [ { "name" = "Paul Paliychuk", "email" = "paul@getzep.com" }, { "name" = "Preston Rasmussen", "email" = "preston@getzep.com" }, diff --git a/tests/utils/maintenance/test_edge_operations.py b/tests/utils/maintenance/test_edge_operations.py index 3145b74d..cdb1de9f 100644 --- a/tests/utils/maintenance/test_edge_operations.py +++ b/tests/utils/maintenance/test_edge_operations.py @@ -1,12 +1,10 @@ from datetime import datetime, timedelta, timezone -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import MagicMock import pytest -from pytest import MonkeyPatch from graphiti_core.edges import EntityEdge from graphiti_core.nodes import EpisodicNode -from graphiti_core.utils.maintenance.edge_operations import resolve_extracted_edge @pytest.fixture @@ -91,96 +89,6 @@ def mock_previous_episodes(): ] -@pytest.mark.asyncio -async def test_resolve_extracted_edge_no_changes( - mock_llm_client, - mock_extracted_edge, - mock_related_edges, - mock_existing_edges, - mock_current_episode, - mock_previous_episodes, - monkeypatch: MonkeyPatch, -): - # Mock the function calls - dedupe_mock = AsyncMock(return_value=mock_extracted_edge) - get_contradictions_mock = AsyncMock(return_value=[]) - - # Patch the function calls - monkeypatch.setattr( - 'graphiti_core.utils.maintenance.edge_operations.dedupe_extracted_edge', dedupe_mock - ) - monkeypatch.setattr( - 'graphiti_core.utils.maintenance.edge_operations.get_edge_contradictions', - get_contradictions_mock, - ) - - resolved_edge, invalidated_edges = await resolve_extracted_edge( - mock_llm_client, - mock_extracted_edge, - mock_related_edges, - mock_existing_edges, - mock_current_episode, - ) - - assert resolved_edge.uuid == mock_extracted_edge.uuid - assert invalidated_edges == [] - dedupe_mock.assert_called_once() - get_contradictions_mock.assert_called_once() - - -@pytest.mark.asyncio -async def test_resolve_extracted_edge_with_invalidation( - mock_llm_client, - mock_extracted_edge, - mock_related_edges, - mock_existing_edges, - mock_current_episode, - mock_previous_episodes, - monkeypatch: MonkeyPatch, -): - valid_at = datetime.now(timezone.utc) - timedelta(days=1) - mock_extracted_edge.valid_at = valid_at - - invalidation_candidate = EntityEdge( - source_node_uuid='source_uuid_4', - target_node_uuid='target_uuid_4', - name='invalidation_candidate', - group_id='group_1', - fact='Invalidation candidate fact', - episodes=['episode_4'], - created_at=datetime.now(timezone.utc), - valid_at=datetime.now(timezone.utc) - timedelta(days=2), - invalid_at=None, - ) - - # Mock the function calls - dedupe_mock = AsyncMock(return_value=mock_extracted_edge) - get_contradictions_mock = AsyncMock(return_value=[invalidation_candidate]) - - # Patch the function calls - monkeypatch.setattr( - 'graphiti_core.utils.maintenance.edge_operations.dedupe_extracted_edge', dedupe_mock - ) - monkeypatch.setattr( - 'graphiti_core.utils.maintenance.edge_operations.get_edge_contradictions', - get_contradictions_mock, - ) - - resolved_edge, invalidated_edges = await resolve_extracted_edge( - mock_llm_client, - mock_extracted_edge, - mock_related_edges, - mock_existing_edges, - mock_current_episode, - ) - - assert resolved_edge.uuid == mock_extracted_edge.uuid - assert len(invalidated_edges) == 1 - assert invalidated_edges[0].uuid == invalidation_candidate.uuid - assert invalidated_edges[0].invalid_at == valid_at - assert invalidated_edges[0].expired_at is not None - - # Run the tests if __name__ == '__main__': pytest.main([__file__])