Entity classification updates (#285)
* updates * tested * remove unused imports * llm outputs will be dicts rather than pydantic models * removed unused imports
This commit is contained in:
parent
7f20b21572
commit
f73867e0fa
6 changed files with 48 additions and 23 deletions
|
|
@ -18,12 +18,14 @@ import asyncio
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
from typing import ClassVar
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import Field
|
||||||
from transcript_parser import parse_podcast_messages
|
from transcript_parser import parse_podcast_messages
|
||||||
|
|
||||||
from graphiti_core import Graphiti
|
from graphiti_core import Graphiti
|
||||||
|
from graphiti_core.nodes import EntityType
|
||||||
from graphiti_core.utils.maintenance.graph_data_operations import clear_data
|
from graphiti_core.utils.maintenance.graph_data_operations import clear_data
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
@ -54,7 +56,8 @@ def setup_logging():
|
||||||
return logger
|
return logger
|
||||||
|
|
||||||
|
|
||||||
class Person(BaseModel):
|
class Person(EntityType):
|
||||||
|
type_description: ClassVar[str] = 'A human person, fictional or nonfictional.'
|
||||||
first_name: str | None = Field(..., description='First name')
|
first_name: str | None = Field(..., description='First name')
|
||||||
last_name: str | None = Field(..., description='Last name')
|
last_name: str | None = Field(..., description='Last name')
|
||||||
occupation: str | None = Field(..., description="The person's work occupation")
|
occupation: str | None = Field(..., description="The person's work occupation")
|
||||||
|
|
|
||||||
|
|
@ -29,7 +29,7 @@ from graphiti_core.edges import EntityEdge, EpisodicEdge
|
||||||
from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
|
from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
|
||||||
from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
|
from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
|
||||||
from graphiti_core.llm_client import LLMClient, OpenAIClient
|
from graphiti_core.llm_client import LLMClient, OpenAIClient
|
||||||
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode
|
from graphiti_core.nodes import CommunityNode, EntityNode, EntityType, EpisodeType, EpisodicNode
|
||||||
from graphiti_core.search.search import SearchConfig, search
|
from graphiti_core.search.search import SearchConfig, search
|
||||||
from graphiti_core.search.search_config import DEFAULT_SEARCH_LIMIT, SearchResults
|
from graphiti_core.search.search_config import DEFAULT_SEARCH_LIMIT, SearchResults
|
||||||
from graphiti_core.search.search_config_recipes import (
|
from graphiti_core.search.search_config_recipes import (
|
||||||
|
|
@ -262,7 +262,7 @@ class Graphiti:
|
||||||
group_id: str = '',
|
group_id: str = '',
|
||||||
uuid: str | None = None,
|
uuid: str | None = None,
|
||||||
update_communities: bool = False,
|
update_communities: bool = False,
|
||||||
entity_types: dict[str, BaseModel] | None = None,
|
entity_types: dict[str, EntityType] | None = None,
|
||||||
) -> AddEpisodeResults:
|
) -> AddEpisodeResults:
|
||||||
"""
|
"""
|
||||||
Process an episode and update the graph.
|
Process an episode and update the graph.
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,7 @@ from abc import ABC, abstractmethod
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from time import time
|
from time import time
|
||||||
from typing import Any
|
from typing import Any, ClassVar
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from neo4j import AsyncDriver
|
from neo4j import AsyncDriver
|
||||||
|
|
@ -39,6 +39,12 @@ from graphiti_core.utils.datetime_utils import utc_now
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class EntityType(BaseModel):
|
||||||
|
type_description: ClassVar[str] = Field(
|
||||||
|
default='', description='Description of what the entity type represents'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class EpisodeType(Enum):
|
class EpisodeType(Enum):
|
||||||
"""
|
"""
|
||||||
Enumeration of different types of episodes that can be processed.
|
Enumeration of different types of episodes that can be processed.
|
||||||
|
|
|
||||||
|
|
@ -30,14 +30,17 @@ class MissedEntities(BaseModel):
|
||||||
missed_entities: list[str] = Field(..., description="Names of entities that weren't extracted")
|
missed_entities: list[str] = Field(..., description="Names of entities that weren't extracted")
|
||||||
|
|
||||||
|
|
||||||
class EntityClassification(BaseModel):
|
class EntityClassificationTriple(BaseModel):
|
||||||
entities: list[str] = Field(
|
uuid: str = Field(description='UUID of the entity')
|
||||||
...,
|
name: str = Field(description='Name of the entity')
|
||||||
description='List of entities',
|
entity_type: str | None = Field(
|
||||||
|
default=None, description='Type of the entity. Must be one of the provided types or None'
|
||||||
)
|
)
|
||||||
entity_classifications: list[str | None] = Field(
|
|
||||||
...,
|
|
||||||
description='List of entities classifications. The index of the classification should match the index of the entity it corresponds to.',
|
class EntityClassification(BaseModel):
|
||||||
|
entity_classifications: list[EntityClassificationTriple] = Field(
|
||||||
|
..., description='List of entities classification triples.'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -180,7 +183,7 @@ def classify_nodes(context: dict[str, Any]) -> list[Message]:
|
||||||
{context['entity_types']}
|
{context['entity_types']}
|
||||||
</ENTITY TYPES>
|
</ENTITY TYPES>
|
||||||
|
|
||||||
Given the above conversation, extracted entities, and provided entity types, classify the extracted entities.
|
Given the above conversation, extracted entities, and provided entity types and their descriptions, classify the extracted entities.
|
||||||
|
|
||||||
Guidelines:
|
Guidelines:
|
||||||
1. Each entity must have exactly one type
|
1. Each entity must have exactly one type
|
||||||
|
|
|
||||||
|
|
@ -85,6 +85,8 @@ def summarize_context(context: dict[str, Any]) -> list[Message]:
|
||||||
provided ENTITY. Summaries must be under 500 words.
|
provided ENTITY. Summaries must be under 500 words.
|
||||||
|
|
||||||
In addition, extract any values for the provided entity properties based on their descriptions.
|
In addition, extract any values for the provided entity properties based on their descriptions.
|
||||||
|
If the value of the entity property cannot be found in the current context, set the value of the property to None.
|
||||||
|
Do not hallucinate entity property values if they cannot be found in the current context.
|
||||||
|
|
||||||
<ENTITY>
|
<ENTITY>
|
||||||
{context['node_name']}
|
{context['node_name']}
|
||||||
|
|
|
||||||
|
|
@ -18,14 +18,17 @@ import logging
|
||||||
from time import time
|
from time import time
|
||||||
|
|
||||||
import pydantic
|
import pydantic
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather
|
from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather
|
||||||
from graphiti_core.llm_client import LLMClient
|
from graphiti_core.llm_client import LLMClient
|
||||||
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
from graphiti_core.nodes import EntityNode, EntityType, EpisodeType, EpisodicNode
|
||||||
from graphiti_core.prompts import prompt_library
|
from graphiti_core.prompts import prompt_library
|
||||||
from graphiti_core.prompts.dedupe_nodes import NodeDuplicate
|
from graphiti_core.prompts.dedupe_nodes import NodeDuplicate
|
||||||
from graphiti_core.prompts.extract_nodes import EntityClassification, ExtractedNodes, MissedEntities
|
from graphiti_core.prompts.extract_nodes import (
|
||||||
|
EntityClassification,
|
||||||
|
ExtractedNodes,
|
||||||
|
MissedEntities,
|
||||||
|
)
|
||||||
from graphiti_core.prompts.summarize_nodes import Summary
|
from graphiti_core.prompts.summarize_nodes import Summary
|
||||||
from graphiti_core.utils.datetime_utils import utc_now
|
from graphiti_core.utils.datetime_utils import utc_now
|
||||||
|
|
||||||
|
|
@ -117,7 +120,7 @@ async def extract_nodes(
|
||||||
llm_client: LLMClient,
|
llm_client: LLMClient,
|
||||||
episode: EpisodicNode,
|
episode: EpisodicNode,
|
||||||
previous_episodes: list[EpisodicNode],
|
previous_episodes: list[EpisodicNode],
|
||||||
entity_types: dict[str, BaseModel] | None = None,
|
entity_types: dict[str, EntityType] | None = None,
|
||||||
) -> list[EntityNode]:
|
) -> list[EntityNode]:
|
||||||
start = time()
|
start = time()
|
||||||
extracted_node_names: list[str] = []
|
extracted_node_names: list[str] = []
|
||||||
|
|
@ -152,7 +155,11 @@ async def extract_nodes(
|
||||||
'episode_content': episode.content,
|
'episode_content': episode.content,
|
||||||
'previous_episodes': [ep.content for ep in previous_episodes],
|
'previous_episodes': [ep.content for ep in previous_episodes],
|
||||||
'extracted_entities': extracted_node_names,
|
'extracted_entities': extracted_node_names,
|
||||||
'entity_types': entity_types.keys() if entity_types is not None else [],
|
'entity_types': {
|
||||||
|
type_name: values.type_description for type_name, values in entity_types.items()
|
||||||
|
}
|
||||||
|
if entity_types is not None
|
||||||
|
else {},
|
||||||
}
|
}
|
||||||
|
|
||||||
node_classifications: dict[str, str | None] = {}
|
node_classifications: dict[str, str | None] = {}
|
||||||
|
|
@ -163,9 +170,13 @@ async def extract_nodes(
|
||||||
prompt_library.extract_nodes.classify_nodes(node_classification_context),
|
prompt_library.extract_nodes.classify_nodes(node_classification_context),
|
||||||
response_model=EntityClassification,
|
response_model=EntityClassification,
|
||||||
)
|
)
|
||||||
entities = llm_response.get('entities', [])
|
|
||||||
entity_classifications = llm_response.get('entity_classifications', [])
|
entity_classifications = llm_response.get('entity_classifications', [])
|
||||||
node_classifications.update(dict(zip(entities, entity_classifications)))
|
node_classifications.update(
|
||||||
|
{
|
||||||
|
entity_classification.get('name'): entity_classification.get('entity_type')
|
||||||
|
for entity_classification in entity_classifications
|
||||||
|
}
|
||||||
|
)
|
||||||
# catch classification errors and continue if we can't classify
|
# catch classification errors and continue if we can't classify
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(e)
|
logger.exception(e)
|
||||||
|
|
@ -251,7 +262,7 @@ async def resolve_extracted_nodes(
|
||||||
existing_nodes_lists: list[list[EntityNode]],
|
existing_nodes_lists: list[list[EntityNode]],
|
||||||
episode: EpisodicNode | None = None,
|
episode: EpisodicNode | None = None,
|
||||||
previous_episodes: list[EpisodicNode] | None = None,
|
previous_episodes: list[EpisodicNode] | None = None,
|
||||||
entity_types: dict[str, BaseModel] | None = None,
|
entity_types: dict[str, EntityType] | None = None,
|
||||||
) -> tuple[list[EntityNode], dict[str, str]]:
|
) -> tuple[list[EntityNode], dict[str, str]]:
|
||||||
uuid_map: dict[str, str] = {}
|
uuid_map: dict[str, str] = {}
|
||||||
resolved_nodes: list[EntityNode] = []
|
resolved_nodes: list[EntityNode] = []
|
||||||
|
|
@ -284,7 +295,7 @@ async def resolve_extracted_node(
|
||||||
existing_nodes: list[EntityNode],
|
existing_nodes: list[EntityNode],
|
||||||
episode: EpisodicNode | None = None,
|
episode: EpisodicNode | None = None,
|
||||||
previous_episodes: list[EpisodicNode] | None = None,
|
previous_episodes: list[EpisodicNode] | None = None,
|
||||||
entity_types: dict[str, BaseModel] | None = None,
|
entity_types: dict[str, EntityType] | None = None,
|
||||||
) -> tuple[EntityNode, dict[str, str]]:
|
) -> tuple[EntityNode, dict[str, str]]:
|
||||||
start = time()
|
start = time()
|
||||||
|
|
||||||
|
|
@ -319,7 +330,7 @@ async def resolve_extracted_node(
|
||||||
'attributes': [],
|
'attributes': [],
|
||||||
}
|
}
|
||||||
|
|
||||||
entity_type_classes: tuple[BaseModel, ...] = tuple()
|
entity_type_classes: tuple[EntityType, ...] = tuple()
|
||||||
if entity_types is not None: # type: ignore
|
if entity_types is not None: # type: ignore
|
||||||
entity_type_classes = entity_type_classes + tuple(
|
entity_type_classes = entity_type_classes + tuple(
|
||||||
filter(
|
filter(
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue