From ebee09b33502c335f7d2ff3c22ef0b94cd7171ce Mon Sep 17 00:00:00 2001 From: Preston Rasmussen <109292228+prasmussen15@users.noreply.github.com> Date: Fri, 6 Jun 2025 12:28:52 -0400 Subject: [PATCH] Edge extraction and Node Deduplication updates (#564) * update tests * updated fact extraction * optimize node deduplication * linting * Update graphiti_core/utils/maintenance/edge_operations.py Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> --------- Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> --- examples/podcast/podcast_runner.py | 6 ++ graphiti_core/prompts/dedupe_nodes.py | 10 +- graphiti_core/prompts/extract_edges.py | 11 +- .../utils/maintenance/edge_operations.py | 14 +-- .../utils/maintenance/node_operations.py | 102 ++++-------------- pyproject.toml | 2 +- 6 files changed, 50 insertions(+), 95 deletions(-) diff --git a/examples/podcast/podcast_runner.py b/examples/podcast/podcast_runner.py index 2c53889e..146d20e6 100644 --- a/examples/podcast/podcast_runner.py +++ b/examples/podcast/podcast_runner.py @@ -63,6 +63,10 @@ class Person(BaseModel): occupation: str | None = Field(..., description="The person's work occupation") +class IsPresidentOf(BaseModel): + """Relationship between a person and the entity they are a president of""" + + async def main(): setup_logging() client = Graphiti(neo4j_uri, neo4j_user, neo4j_password) @@ -84,6 +88,8 @@ async def main(): source_description='Podcast Transcript', group_id=group_id, entity_types={'Person': Person}, + edge_types={'IS_PRESIDENT_OF': IsPresidentOf}, + edge_type_map={('Person', 'Entity'): ['PRESIDENT_OF']}, previous_episode_uuids=episode_uuids, ) diff --git a/graphiti_core/prompts/dedupe_nodes.py b/graphiti_core/prompts/dedupe_nodes.py index 318d4c9f..16fee8d9 100644 --- a/graphiti_core/prompts/dedupe_nodes.py +++ b/graphiti_core/prompts/dedupe_nodes.py @@ -137,8 +137,12 @@ def nodes(context: dict[str, Any]) -> list[Message]: {json.dumps(context['extracted_nodes'], indent=2)} + + + {json.dumps(context['existing_nodes'], indent=2)} + - For each of the above ENTITIES, determine if the entity is a duplicate of any of its duplication candidates. + For each of the above ENTITIES, determine if the entity is a duplicate of any of the EXISTING ENTITIES. Entities should only be considered duplicates if they refer to the *same real-world object or concept*. @@ -152,9 +156,9 @@ def nodes(context: dict[str, Any]) -> list[Message]: For each entity, return the id of the entity as id, the name of the entity as name, and the duplicate_idx as an integer. - - If an entity is a duplicate of one of its duplication_candidates, return the idx of the candidate it is a + - If an entity is a duplicate of one of the EXISTING ENTITIES, return the idx of the candidate it is a duplicate of. - - If an entity is not a duplicate of one of its duplication candidates, return the -1 as the duplication_idx + - If an entity is not a duplicate of one of the EXISTING ENTITIES, return the -1 as the duplication_idx """, ), ] diff --git a/graphiti_core/prompts/extract_edges.py b/graphiti_core/prompts/extract_edges.py index 37db4699..cd73edfa 100644 --- a/graphiti_core/prompts/extract_edges.py +++ b/graphiti_core/prompts/extract_edges.py @@ -24,8 +24,8 @@ from .models import Message, PromptFunction, PromptVersion class Edge(BaseModel): relation_type: str = Field(..., description='FACT_PREDICATE_IN_SCREAMING_SNAKE_CASE') - source_entity_name: str = Field(..., description='The name of the source entity of the fact.') - target_entity_name: str = Field(..., description='The name of the target entity of the fact.') + source_entity_id: int = Field(..., description='The id of the source entity of the fact.') + target_entity_id: int = Field(..., description='The id of the target entity of the fact.') fact: str = Field(..., description='') valid_at: str | None = Field( None, @@ -77,7 +77,7 @@ def edge(context: dict[str, Any]) -> list[Message]: -{context['nodes']} # Each has: id, label (e.g., Person, Org), name, aliases +{context['nodes']} @@ -94,8 +94,9 @@ Only extract facts that: - involve two DISTINCT ENTITIES from the ENTITIES list, - are clearly stated or unambiguously implied in the CURRENT MESSAGE, and can be represented as edges in a knowledge graph. -- The FACT TYPES provide a list of the most important types of facts, make sure to extract any facts that - could be classified into one of the provided fact types +- The FACT TYPES provide a list of the most important types of facts, make sure to extract facts of these types +- The FACT TYPES are not an exhaustive list, extract all facts from the message even if they do not fit into one + of the FACT TYPES You may use information from the PREVIOUS MESSAGES only to disambiguate references or support continuity. diff --git a/graphiti_core/utils/maintenance/edge_operations.py b/graphiti_core/utils/maintenance/edge_operations.py index 332706b7..9c90a8e9 100644 --- a/graphiti_core/utils/maintenance/edge_operations.py +++ b/graphiti_core/utils/maintenance/edge_operations.py @@ -92,8 +92,6 @@ async def extract_edges( extract_edges_max_tokens = 16384 llm_client = clients.llm_client - node_uuids_by_name_map = {node.name: node.uuid for node in nodes} - edge_types_context = ( [ { @@ -109,7 +107,7 @@ async def extract_edges( # Prepare context for LLM context = { 'episode_content': episode.content, - 'nodes': [node.name for node in nodes], + 'nodes': [{'id': idx, 'name': node.name} for idx, node in enumerate(nodes)], 'previous_episodes': [ep.content for ep in previous_episodes], 'reference_time': episode.valid_at, 'edge_types': edge_types_context, @@ -160,14 +158,16 @@ async def extract_edges( invalid_at = edge_data.get('invalid_at', None) valid_at_datetime = None invalid_at_datetime = None - source_node_uuid = node_uuids_by_name_map.get(edge_data.get('source_entity_name', ''), '') - target_node_uuid = node_uuids_by_name_map.get(edge_data.get('target_entity_name', ''), '') - if source_node_uuid == '' or target_node_uuid == '': + source_node_idx = edge_data.get('source_entity_id', -1) + target_node_idx = edge_data.get('target_entity_id', -1) + if not (-1 < source_node_idx < len(nodes) and -1 < target_node_idx < len(nodes)): logger.warning( - f'WARNING: source or target node not filled {edge_data.get("edge_name")}. source_node_uuid: {source_node_uuid} and target_node_uuid: {target_node_uuid} ' + f'WARNING: source or target node not filled {edge_data.get("edge_name")}. source_node_uuid: {source_node_idx} and target_node_uuid: {target_node_idx} ' ) continue + source_node_uuid = nodes[source_node_idx].uuid + target_node_uuid = nodes[edge_data.get('target_entity_id')].uuid if valid_at: try: diff --git a/graphiti_core/utils/maintenance/node_operations.py b/graphiti_core/utils/maintenance/node_operations.py index 2b3de99e..ac572765 100644 --- a/graphiti_core/utils/maintenance/node_operations.py +++ b/graphiti_core/utils/maintenance/node_operations.py @@ -29,7 +29,7 @@ from graphiti_core.llm_client import LLMClient from graphiti_core.llm_client.config import ModelSize from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings from graphiti_core.prompts import prompt_library -from graphiti_core.prompts.dedupe_nodes import NodeDuplicate, NodeResolutions +from graphiti_core.prompts.dedupe_nodes import NodeResolutions from graphiti_core.prompts.extract_nodes import ( ExtractedEntities, ExtractedEntity, @@ -241,7 +241,25 @@ async def resolve_extracted_nodes( ] ) - existing_nodes_lists: list[list[EntityNode]] = [result.nodes for result in search_results] + existing_nodes_dict: dict[str, EntityNode] = { + node.uuid: node for result in search_results for node in result.nodes + } + + existing_nodes: list[EntityNode] = list(existing_nodes_dict.values()) + + existing_nodes_context = ( + [ + { + **{ + 'idx': i, + 'name': candidate.name, + 'entity_types': candidate.labels, + }, + **candidate.attributes, + } + for i, candidate in enumerate(existing_nodes) + ], + ) entity_types_dict: dict[str, BaseModel] = entity_types if entity_types is not None else {} @@ -255,23 +273,13 @@ async def resolve_extracted_nodes( next((item for item in node.labels if item != 'Entity'), '') ).__doc__ or 'Default Entity Type', - 'duplication_candidates': [ - { - **{ - 'idx': j, - 'name': candidate.name, - 'entity_types': candidate.labels, - }, - **candidate.attributes, - } - for j, candidate in enumerate(existing_nodes_lists[i]) - ], } for i, node in enumerate(extracted_nodes) ] context = { 'extracted_nodes': extracted_nodes_context, + 'existing_nodes': existing_nodes_context, 'episode_content': episode.content if episode is not None else '', 'previous_episodes': [ep.content for ep in previous_episodes] if previous_episodes is not None @@ -294,8 +302,8 @@ async def resolve_extracted_nodes( extracted_node = extracted_nodes[resolution_id] resolved_node = ( - existing_nodes_lists[resolution_id][duplicate_idx] - if 0 <= duplicate_idx < len(existing_nodes_lists[resolution_id]) + existing_nodes[duplicate_idx] + if 0 <= duplicate_idx < len(existing_nodes) else extracted_node ) @@ -309,70 +317,6 @@ async def resolve_extracted_nodes( return resolved_nodes, uuid_map -async def resolve_extracted_node( - llm_client: LLMClient, - extracted_node: EntityNode, - existing_nodes: list[EntityNode], - episode: EpisodicNode | None = None, - previous_episodes: list[EpisodicNode] | None = None, - entity_type: BaseModel | None = None, -) -> EntityNode: - start = time() - if len(existing_nodes) == 0: - return extracted_node - - # Prepare context for LLM - existing_nodes_context = [ - { - **{ - 'id': i, - 'name': node.name, - 'entity_types': node.labels, - }, - **node.attributes, - } - for i, node in enumerate(existing_nodes) - ] - - extracted_node_context = { - 'name': extracted_node.name, - 'entity_type': entity_type.__name__ if entity_type is not None else 'Entity', # type: ignore - } - - context = { - 'existing_nodes': existing_nodes_context, - 'extracted_node': extracted_node_context, - 'entity_type_description': entity_type.__doc__ - if entity_type is not None - else 'Default Entity Type', - 'episode_content': episode.content if episode is not None else '', - 'previous_episodes': [ep.content for ep in previous_episodes] - if previous_episodes is not None - else [], - } - - llm_response = await llm_client.generate_response( - prompt_library.dedupe_nodes.node(context), - response_model=NodeDuplicate, - model_size=ModelSize.small, - ) - - duplicate_id: int = llm_response.get('duplicate_node_id', -1) - - node = ( - existing_nodes[duplicate_id] if 0 <= duplicate_id < len(existing_nodes) else extracted_node - ) - - node.name = llm_response.get('name', '') - - end = time() - logger.debug( - f'Resolved node: {extracted_node.name} is {node.name}, in {(end - start) * 1000} ms' - ) - - return node - - async def extract_attributes_from_nodes( clients: GraphitiClients, nodes: list[EntityNode], diff --git a/pyproject.toml b/pyproject.toml index e94b0cd0..64e46bef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "graphiti-core" description = "A temporal graph building library" -version = "0.12.0pre4" +version = "0.12.0" authors = [ { "name" = "Paul Paliychuk", "email" = "paul@getzep.com" }, { "name" = "Preston Rasmussen", "email" = "preston@getzep.com" },