Entity classification updates (#285)
* updates * tested * remove unused imports * llm outputs will be dicts rather than pydantic models * removed unused imports
This commit is contained in:
parent
7f20b21572
commit
f73867e0fa
6 changed files with 48 additions and 23 deletions
|
|
@ -18,12 +18,14 @@ import asyncio
|
|||
import logging
|
||||
import os
|
||||
import sys
|
||||
from typing import ClassVar
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import Field
|
||||
from transcript_parser import parse_podcast_messages
|
||||
|
||||
from graphiti_core import Graphiti
|
||||
from graphiti_core.nodes import EntityType
|
||||
from graphiti_core.utils.maintenance.graph_data_operations import clear_data
|
||||
|
||||
load_dotenv()
|
||||
|
|
@ -54,7 +56,8 @@ def setup_logging():
|
|||
return logger
|
||||
|
||||
|
||||
class Person(BaseModel):
|
||||
class Person(EntityType):
|
||||
type_description: ClassVar[str] = 'A human person, fictional or nonfictional.'
|
||||
first_name: str | None = Field(..., description='First name')
|
||||
last_name: str | None = Field(..., description='Last name')
|
||||
occupation: str | None = Field(..., description="The person's work occupation")
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ from graphiti_core.edges import EntityEdge, EpisodicEdge
|
|||
from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
|
||||
from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
|
||||
from graphiti_core.llm_client import LLMClient, OpenAIClient
|
||||
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode
|
||||
from graphiti_core.nodes import CommunityNode, EntityNode, EntityType, EpisodeType, EpisodicNode
|
||||
from graphiti_core.search.search import SearchConfig, search
|
||||
from graphiti_core.search.search_config import DEFAULT_SEARCH_LIMIT, SearchResults
|
||||
from graphiti_core.search.search_config_recipes import (
|
||||
|
|
@ -262,7 +262,7 @@ class Graphiti:
|
|||
group_id: str = '',
|
||||
uuid: str | None = None,
|
||||
update_communities: bool = False,
|
||||
entity_types: dict[str, BaseModel] | None = None,
|
||||
entity_types: dict[str, EntityType] | None = None,
|
||||
) -> AddEpisodeResults:
|
||||
"""
|
||||
Process an episode and update the graph.
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ from abc import ABC, abstractmethod
|
|||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from time import time
|
||||
from typing import Any
|
||||
from typing import Any, ClassVar
|
||||
from uuid import uuid4
|
||||
|
||||
from neo4j import AsyncDriver
|
||||
|
|
@ -39,6 +39,12 @@ from graphiti_core.utils.datetime_utils import utc_now
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EntityType(BaseModel):
|
||||
type_description: ClassVar[str] = Field(
|
||||
default='', description='Description of what the entity type represents'
|
||||
)
|
||||
|
||||
|
||||
class EpisodeType(Enum):
|
||||
"""
|
||||
Enumeration of different types of episodes that can be processed.
|
||||
|
|
|
|||
|
|
@ -30,14 +30,17 @@ class MissedEntities(BaseModel):
|
|||
missed_entities: list[str] = Field(..., description="Names of entities that weren't extracted")
|
||||
|
||||
|
||||
class EntityClassification(BaseModel):
|
||||
entities: list[str] = Field(
|
||||
...,
|
||||
description='List of entities',
|
||||
class EntityClassificationTriple(BaseModel):
|
||||
uuid: str = Field(description='UUID of the entity')
|
||||
name: str = Field(description='Name of the entity')
|
||||
entity_type: str | None = Field(
|
||||
default=None, description='Type of the entity. Must be one of the provided types or None'
|
||||
)
|
||||
entity_classifications: list[str | None] = Field(
|
||||
...,
|
||||
description='List of entities classifications. The index of the classification should match the index of the entity it corresponds to.',
|
||||
|
||||
|
||||
class EntityClassification(BaseModel):
|
||||
entity_classifications: list[EntityClassificationTriple] = Field(
|
||||
..., description='List of entities classification triples.'
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -180,7 +183,7 @@ def classify_nodes(context: dict[str, Any]) -> list[Message]:
|
|||
{context['entity_types']}
|
||||
</ENTITY TYPES>
|
||||
|
||||
Given the above conversation, extracted entities, and provided entity types, classify the extracted entities.
|
||||
Given the above conversation, extracted entities, and provided entity types and their descriptions, classify the extracted entities.
|
||||
|
||||
Guidelines:
|
||||
1. Each entity must have exactly one type
|
||||
|
|
|
|||
|
|
@ -85,6 +85,8 @@ def summarize_context(context: dict[str, Any]) -> list[Message]:
|
|||
provided ENTITY. Summaries must be under 500 words.
|
||||
|
||||
In addition, extract any values for the provided entity properties based on their descriptions.
|
||||
If the value of the entity property cannot be found in the current context, set the value of the property to None.
|
||||
Do not hallucinate entity property values if they cannot be found in the current context.
|
||||
|
||||
<ENTITY>
|
||||
{context['node_name']}
|
||||
|
|
|
|||
|
|
@ -18,14 +18,17 @@ import logging
|
|||
from time import time
|
||||
|
||||
import pydantic
|
||||
from pydantic import BaseModel
|
||||
|
||||
from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather
|
||||
from graphiti_core.llm_client import LLMClient
|
||||
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
||||
from graphiti_core.nodes import EntityNode, EntityType, EpisodeType, EpisodicNode
|
||||
from graphiti_core.prompts import prompt_library
|
||||
from graphiti_core.prompts.dedupe_nodes import NodeDuplicate
|
||||
from graphiti_core.prompts.extract_nodes import EntityClassification, ExtractedNodes, MissedEntities
|
||||
from graphiti_core.prompts.extract_nodes import (
|
||||
EntityClassification,
|
||||
ExtractedNodes,
|
||||
MissedEntities,
|
||||
)
|
||||
from graphiti_core.prompts.summarize_nodes import Summary
|
||||
from graphiti_core.utils.datetime_utils import utc_now
|
||||
|
||||
|
|
@ -117,7 +120,7 @@ async def extract_nodes(
|
|||
llm_client: LLMClient,
|
||||
episode: EpisodicNode,
|
||||
previous_episodes: list[EpisodicNode],
|
||||
entity_types: dict[str, BaseModel] | None = None,
|
||||
entity_types: dict[str, EntityType] | None = None,
|
||||
) -> list[EntityNode]:
|
||||
start = time()
|
||||
extracted_node_names: list[str] = []
|
||||
|
|
@ -152,7 +155,11 @@ async def extract_nodes(
|
|||
'episode_content': episode.content,
|
||||
'previous_episodes': [ep.content for ep in previous_episodes],
|
||||
'extracted_entities': extracted_node_names,
|
||||
'entity_types': entity_types.keys() if entity_types is not None else [],
|
||||
'entity_types': {
|
||||
type_name: values.type_description for type_name, values in entity_types.items()
|
||||
}
|
||||
if entity_types is not None
|
||||
else {},
|
||||
}
|
||||
|
||||
node_classifications: dict[str, str | None] = {}
|
||||
|
|
@ -163,9 +170,13 @@ async def extract_nodes(
|
|||
prompt_library.extract_nodes.classify_nodes(node_classification_context),
|
||||
response_model=EntityClassification,
|
||||
)
|
||||
entities = llm_response.get('entities', [])
|
||||
entity_classifications = llm_response.get('entity_classifications', [])
|
||||
node_classifications.update(dict(zip(entities, entity_classifications)))
|
||||
node_classifications.update(
|
||||
{
|
||||
entity_classification.get('name'): entity_classification.get('entity_type')
|
||||
for entity_classification in entity_classifications
|
||||
}
|
||||
)
|
||||
# catch classification errors and continue if we can't classify
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
|
|
@ -251,7 +262,7 @@ async def resolve_extracted_nodes(
|
|||
existing_nodes_lists: list[list[EntityNode]],
|
||||
episode: EpisodicNode | None = None,
|
||||
previous_episodes: list[EpisodicNode] | None = None,
|
||||
entity_types: dict[str, BaseModel] | None = None,
|
||||
entity_types: dict[str, EntityType] | None = None,
|
||||
) -> tuple[list[EntityNode], dict[str, str]]:
|
||||
uuid_map: dict[str, str] = {}
|
||||
resolved_nodes: list[EntityNode] = []
|
||||
|
|
@ -284,7 +295,7 @@ async def resolve_extracted_node(
|
|||
existing_nodes: list[EntityNode],
|
||||
episode: EpisodicNode | None = None,
|
||||
previous_episodes: list[EpisodicNode] | None = None,
|
||||
entity_types: dict[str, BaseModel] | None = None,
|
||||
entity_types: dict[str, EntityType] | None = None,
|
||||
) -> tuple[EntityNode, dict[str, str]]:
|
||||
start = time()
|
||||
|
||||
|
|
@ -319,7 +330,7 @@ async def resolve_extracted_node(
|
|||
'attributes': [],
|
||||
}
|
||||
|
||||
entity_type_classes: tuple[BaseModel, ...] = tuple()
|
||||
entity_type_classes: tuple[EntityType, ...] = tuple()
|
||||
if entity_types is not None: # type: ignore
|
||||
entity_type_classes = entity_type_classes + tuple(
|
||||
filter(
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue