From f73867e0fa37c339459ecddd48d3c8b302fc9956 Mon Sep 17 00:00:00 2001 From: Preston Rasmussen <109292228+prasmussen15@users.noreply.github.com> Date: Wed, 5 Mar 2025 12:08:11 -0500 Subject: [PATCH] Entity classification updates (#285) * updates * tested * remove unused imports * llm outputs will be dicts rather than pydantic models * removed unused imports --- examples/podcast/podcast_runner.py | 7 +++-- graphiti_core/graphiti.py | 4 +-- graphiti_core/nodes.py | 8 ++++- graphiti_core/prompts/extract_nodes.py | 19 +++++++----- graphiti_core/prompts/summarize_nodes.py | 2 ++ .../utils/maintenance/node_operations.py | 31 +++++++++++++------ 6 files changed, 48 insertions(+), 23 deletions(-) diff --git a/examples/podcast/podcast_runner.py b/examples/podcast/podcast_runner.py index d5511375..1564cf0a 100644 --- a/examples/podcast/podcast_runner.py +++ b/examples/podcast/podcast_runner.py @@ -18,12 +18,14 @@ import asyncio import logging import os import sys +from typing import ClassVar from dotenv import load_dotenv -from pydantic import BaseModel, Field +from pydantic import Field from transcript_parser import parse_podcast_messages from graphiti_core import Graphiti +from graphiti_core.nodes import EntityType from graphiti_core.utils.maintenance.graph_data_operations import clear_data load_dotenv() @@ -54,7 +56,8 @@ def setup_logging(): return logger -class Person(BaseModel): +class Person(EntityType): + type_description: ClassVar[str] = 'A human person, fictional or nonfictional.' first_name: str | None = Field(..., description='First name') last_name: str | None = Field(..., description='Last name') occupation: str | None = Field(..., description="The person's work occupation") diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index ffed107d..7b5e5dc5 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -29,7 +29,7 @@ from graphiti_core.edges import EntityEdge, EpisodicEdge from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather from graphiti_core.llm_client import LLMClient, OpenAIClient -from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode +from graphiti_core.nodes import CommunityNode, EntityNode, EntityType, EpisodeType, EpisodicNode from graphiti_core.search.search import SearchConfig, search from graphiti_core.search.search_config import DEFAULT_SEARCH_LIMIT, SearchResults from graphiti_core.search.search_config_recipes import ( @@ -262,7 +262,7 @@ class Graphiti: group_id: str = '', uuid: str | None = None, update_communities: bool = False, - entity_types: dict[str, BaseModel] | None = None, + entity_types: dict[str, EntityType] | None = None, ) -> AddEpisodeResults: """ Process an episode and update the graph. diff --git a/graphiti_core/nodes.py b/graphiti_core/nodes.py index 3341f857..e598cd90 100644 --- a/graphiti_core/nodes.py +++ b/graphiti_core/nodes.py @@ -19,7 +19,7 @@ from abc import ABC, abstractmethod from datetime import datetime from enum import Enum from time import time -from typing import Any +from typing import Any, ClassVar from uuid import uuid4 from neo4j import AsyncDriver @@ -39,6 +39,12 @@ from graphiti_core.utils.datetime_utils import utc_now logger = logging.getLogger(__name__) +class EntityType(BaseModel): + type_description: ClassVar[str] = Field( + default='', description='Description of what the entity type represents' + ) + + class EpisodeType(Enum): """ Enumeration of different types of episodes that can be processed. diff --git a/graphiti_core/prompts/extract_nodes.py b/graphiti_core/prompts/extract_nodes.py index 845f1377..7218a173 100644 --- a/graphiti_core/prompts/extract_nodes.py +++ b/graphiti_core/prompts/extract_nodes.py @@ -30,14 +30,17 @@ class MissedEntities(BaseModel): missed_entities: list[str] = Field(..., description="Names of entities that weren't extracted") -class EntityClassification(BaseModel): - entities: list[str] = Field( - ..., - description='List of entities', +class EntityClassificationTriple(BaseModel): + uuid: str = Field(description='UUID of the entity') + name: str = Field(description='Name of the entity') + entity_type: str | None = Field( + default=None, description='Type of the entity. Must be one of the provided types or None' ) - entity_classifications: list[str | None] = Field( - ..., - description='List of entities classifications. The index of the classification should match the index of the entity it corresponds to.', + + +class EntityClassification(BaseModel): + entity_classifications: list[EntityClassificationTriple] = Field( + ..., description='List of entities classification triples.' ) @@ -180,7 +183,7 @@ def classify_nodes(context: dict[str, Any]) -> list[Message]: {context['entity_types']} - Given the above conversation, extracted entities, and provided entity types, classify the extracted entities. + Given the above conversation, extracted entities, and provided entity types and their descriptions, classify the extracted entities. Guidelines: 1. Each entity must have exactly one type diff --git a/graphiti_core/prompts/summarize_nodes.py b/graphiti_core/prompts/summarize_nodes.py index 0a880a82..80c3b0e8 100644 --- a/graphiti_core/prompts/summarize_nodes.py +++ b/graphiti_core/prompts/summarize_nodes.py @@ -85,6 +85,8 @@ def summarize_context(context: dict[str, Any]) -> list[Message]: provided ENTITY. Summaries must be under 500 words. In addition, extract any values for the provided entity properties based on their descriptions. + If the value of the entity property cannot be found in the current context, set the value of the property to None. + Do not hallucinate entity property values if they cannot be found in the current context. {context['node_name']} diff --git a/graphiti_core/utils/maintenance/node_operations.py b/graphiti_core/utils/maintenance/node_operations.py index 85a28a9a..b71f03b3 100644 --- a/graphiti_core/utils/maintenance/node_operations.py +++ b/graphiti_core/utils/maintenance/node_operations.py @@ -18,14 +18,17 @@ import logging from time import time import pydantic -from pydantic import BaseModel from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather from graphiti_core.llm_client import LLMClient -from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode +from graphiti_core.nodes import EntityNode, EntityType, EpisodeType, EpisodicNode from graphiti_core.prompts import prompt_library from graphiti_core.prompts.dedupe_nodes import NodeDuplicate -from graphiti_core.prompts.extract_nodes import EntityClassification, ExtractedNodes, MissedEntities +from graphiti_core.prompts.extract_nodes import ( + EntityClassification, + ExtractedNodes, + MissedEntities, +) from graphiti_core.prompts.summarize_nodes import Summary from graphiti_core.utils.datetime_utils import utc_now @@ -117,7 +120,7 @@ async def extract_nodes( llm_client: LLMClient, episode: EpisodicNode, previous_episodes: list[EpisodicNode], - entity_types: dict[str, BaseModel] | None = None, + entity_types: dict[str, EntityType] | None = None, ) -> list[EntityNode]: start = time() extracted_node_names: list[str] = [] @@ -152,7 +155,11 @@ async def extract_nodes( 'episode_content': episode.content, 'previous_episodes': [ep.content for ep in previous_episodes], 'extracted_entities': extracted_node_names, - 'entity_types': entity_types.keys() if entity_types is not None else [], + 'entity_types': { + type_name: values.type_description for type_name, values in entity_types.items() + } + if entity_types is not None + else {}, } node_classifications: dict[str, str | None] = {} @@ -163,9 +170,13 @@ async def extract_nodes( prompt_library.extract_nodes.classify_nodes(node_classification_context), response_model=EntityClassification, ) - entities = llm_response.get('entities', []) entity_classifications = llm_response.get('entity_classifications', []) - node_classifications.update(dict(zip(entities, entity_classifications))) + node_classifications.update( + { + entity_classification.get('name'): entity_classification.get('entity_type') + for entity_classification in entity_classifications + } + ) # catch classification errors and continue if we can't classify except Exception as e: logger.exception(e) @@ -251,7 +262,7 @@ async def resolve_extracted_nodes( existing_nodes_lists: list[list[EntityNode]], episode: EpisodicNode | None = None, previous_episodes: list[EpisodicNode] | None = None, - entity_types: dict[str, BaseModel] | None = None, + entity_types: dict[str, EntityType] | None = None, ) -> tuple[list[EntityNode], dict[str, str]]: uuid_map: dict[str, str] = {} resolved_nodes: list[EntityNode] = [] @@ -284,7 +295,7 @@ async def resolve_extracted_node( existing_nodes: list[EntityNode], episode: EpisodicNode | None = None, previous_episodes: list[EpisodicNode] | None = None, - entity_types: dict[str, BaseModel] | None = None, + entity_types: dict[str, EntityType] | None = None, ) -> tuple[EntityNode, dict[str, str]]: start = time() @@ -319,7 +330,7 @@ async def resolve_extracted_node( 'attributes': [], } - entity_type_classes: tuple[BaseModel, ...] = tuple() + entity_type_classes: tuple[EntityType, ...] = tuple() if entity_types is not None: # type: ignore entity_type_classes = entity_type_classes + tuple( filter(