Entity classification updates (#285)

* updates

* tested

* remove unused imports

* llm outputs will be dicts rather than pydantic models

* removed unused imports
This commit is contained in:
Preston Rasmussen 2025-03-05 12:08:11 -05:00 committed by GitHub
parent 7f20b21572
commit f73867e0fa
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 48 additions and 23 deletions

View file

@ -18,12 +18,14 @@ import asyncio
import logging
import os
import sys
from typing import ClassVar
from dotenv import load_dotenv
from pydantic import BaseModel, Field
from pydantic import Field
from transcript_parser import parse_podcast_messages
from graphiti_core import Graphiti
from graphiti_core.nodes import EntityType
from graphiti_core.utils.maintenance.graph_data_operations import clear_data
load_dotenv()
@ -54,7 +56,8 @@ def setup_logging():
return logger
class Person(BaseModel):
class Person(EntityType):
type_description: ClassVar[str] = 'A human person, fictional or nonfictional.'
first_name: str | None = Field(..., description='First name')
last_name: str | None = Field(..., description='Last name')
occupation: str | None = Field(..., description="The person's work occupation")

View file

@ -29,7 +29,7 @@ from graphiti_core.edges import EntityEdge, EpisodicEdge
from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
from graphiti_core.llm_client import LLMClient, OpenAIClient
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode
from graphiti_core.nodes import CommunityNode, EntityNode, EntityType, EpisodeType, EpisodicNode
from graphiti_core.search.search import SearchConfig, search
from graphiti_core.search.search_config import DEFAULT_SEARCH_LIMIT, SearchResults
from graphiti_core.search.search_config_recipes import (
@ -262,7 +262,7 @@ class Graphiti:
group_id: str = '',
uuid: str | None = None,
update_communities: bool = False,
entity_types: dict[str, BaseModel] | None = None,
entity_types: dict[str, EntityType] | None = None,
) -> AddEpisodeResults:
"""
Process an episode and update the graph.

View file

@ -19,7 +19,7 @@ from abc import ABC, abstractmethod
from datetime import datetime
from enum import Enum
from time import time
from typing import Any
from typing import Any, ClassVar
from uuid import uuid4
from neo4j import AsyncDriver
@ -39,6 +39,12 @@ from graphiti_core.utils.datetime_utils import utc_now
logger = logging.getLogger(__name__)
class EntityType(BaseModel):
type_description: ClassVar[str] = Field(
default='', description='Description of what the entity type represents'
)
class EpisodeType(Enum):
"""
Enumeration of different types of episodes that can be processed.

View file

@ -30,14 +30,17 @@ class MissedEntities(BaseModel):
missed_entities: list[str] = Field(..., description="Names of entities that weren't extracted")
class EntityClassification(BaseModel):
entities: list[str] = Field(
...,
description='List of entities',
class EntityClassificationTriple(BaseModel):
uuid: str = Field(description='UUID of the entity')
name: str = Field(description='Name of the entity')
entity_type: str | None = Field(
default=None, description='Type of the entity. Must be one of the provided types or None'
)
entity_classifications: list[str | None] = Field(
...,
description='List of entities classifications. The index of the classification should match the index of the entity it corresponds to.',
class EntityClassification(BaseModel):
entity_classifications: list[EntityClassificationTriple] = Field(
..., description='List of entities classification triples.'
)
@ -180,7 +183,7 @@ def classify_nodes(context: dict[str, Any]) -> list[Message]:
{context['entity_types']}
</ENTITY TYPES>
Given the above conversation, extracted entities, and provided entity types, classify the extracted entities.
Given the above conversation, extracted entities, and provided entity types and their descriptions, classify the extracted entities.
Guidelines:
1. Each entity must have exactly one type

View file

@ -85,6 +85,8 @@ def summarize_context(context: dict[str, Any]) -> list[Message]:
provided ENTITY. Summaries must be under 500 words.
In addition, extract any values for the provided entity properties based on their descriptions.
If the value of the entity property cannot be found in the current context, set the value of the property to None.
Do not hallucinate entity property values if they cannot be found in the current context.
<ENTITY>
{context['node_name']}

View file

@ -18,14 +18,17 @@ import logging
from time import time
import pydantic
from pydantic import BaseModel
from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather
from graphiti_core.llm_client import LLMClient
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
from graphiti_core.nodes import EntityNode, EntityType, EpisodeType, EpisodicNode
from graphiti_core.prompts import prompt_library
from graphiti_core.prompts.dedupe_nodes import NodeDuplicate
from graphiti_core.prompts.extract_nodes import EntityClassification, ExtractedNodes, MissedEntities
from graphiti_core.prompts.extract_nodes import (
EntityClassification,
ExtractedNodes,
MissedEntities,
)
from graphiti_core.prompts.summarize_nodes import Summary
from graphiti_core.utils.datetime_utils import utc_now
@ -117,7 +120,7 @@ async def extract_nodes(
llm_client: LLMClient,
episode: EpisodicNode,
previous_episodes: list[EpisodicNode],
entity_types: dict[str, BaseModel] | None = None,
entity_types: dict[str, EntityType] | None = None,
) -> list[EntityNode]:
start = time()
extracted_node_names: list[str] = []
@ -152,7 +155,11 @@ async def extract_nodes(
'episode_content': episode.content,
'previous_episodes': [ep.content for ep in previous_episodes],
'extracted_entities': extracted_node_names,
'entity_types': entity_types.keys() if entity_types is not None else [],
'entity_types': {
type_name: values.type_description for type_name, values in entity_types.items()
}
if entity_types is not None
else {},
}
node_classifications: dict[str, str | None] = {}
@ -163,9 +170,13 @@ async def extract_nodes(
prompt_library.extract_nodes.classify_nodes(node_classification_context),
response_model=EntityClassification,
)
entities = llm_response.get('entities', [])
entity_classifications = llm_response.get('entity_classifications', [])
node_classifications.update(dict(zip(entities, entity_classifications)))
node_classifications.update(
{
entity_classification.get('name'): entity_classification.get('entity_type')
for entity_classification in entity_classifications
}
)
# catch classification errors and continue if we can't classify
except Exception as e:
logger.exception(e)
@ -251,7 +262,7 @@ async def resolve_extracted_nodes(
existing_nodes_lists: list[list[EntityNode]],
episode: EpisodicNode | None = None,
previous_episodes: list[EpisodicNode] | None = None,
entity_types: dict[str, BaseModel] | None = None,
entity_types: dict[str, EntityType] | None = None,
) -> tuple[list[EntityNode], dict[str, str]]:
uuid_map: dict[str, str] = {}
resolved_nodes: list[EntityNode] = []
@ -284,7 +295,7 @@ async def resolve_extracted_node(
existing_nodes: list[EntityNode],
episode: EpisodicNode | None = None,
previous_episodes: list[EpisodicNode] | None = None,
entity_types: dict[str, BaseModel] | None = None,
entity_types: dict[str, EntityType] | None = None,
) -> tuple[EntityNode, dict[str, str]]:
start = time()
@ -319,7 +330,7 @@ async def resolve_extracted_node(
'attributes': [],
}
entity_type_classes: tuple[BaseModel, ...] = tuple()
entity_type_classes: tuple[EntityType, ...] = tuple()
if entity_types is not None: # type: ignore
entity_type_classes = entity_type_classes + tuple(
filter(