diff --git a/examples/podcast/podcast_runner.py b/examples/podcast/podcast_runner.py index 1564cf0a..1d556b85 100644 --- a/examples/podcast/podcast_runner.py +++ b/examples/podcast/podcast_runner.py @@ -18,14 +18,12 @@ import asyncio import logging import os import sys -from typing import ClassVar from dotenv import load_dotenv -from pydantic import Field +from pydantic import BaseModel, Field from transcript_parser import parse_podcast_messages from graphiti_core import Graphiti -from graphiti_core.nodes import EntityType from graphiti_core.utils.maintenance.graph_data_operations import clear_data load_dotenv() @@ -56,8 +54,9 @@ def setup_logging(): return logger -class Person(EntityType): - type_description: ClassVar[str] = 'A human person, fictional or nonfictional.' +class Person(BaseModel): + """A human person, fictional or nonfictional.""" + first_name: str | None = Field(..., description='First name') last_name: str | None = Field(..., description='Last name') occupation: str | None = Field(..., description="The person's work occupation") diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index 7b5e5dc5..ffed107d 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -29,7 +29,7 @@ from graphiti_core.edges import EntityEdge, EpisodicEdge from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather from graphiti_core.llm_client import LLMClient, OpenAIClient -from graphiti_core.nodes import CommunityNode, EntityNode, EntityType, EpisodeType, EpisodicNode +from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode from graphiti_core.search.search import SearchConfig, search from graphiti_core.search.search_config import DEFAULT_SEARCH_LIMIT, SearchResults from graphiti_core.search.search_config_recipes import ( @@ -262,7 +262,7 @@ class Graphiti: group_id: str = '', uuid: str | None = None, update_communities: bool = False, - entity_types: dict[str, EntityType] | None = None, + entity_types: dict[str, BaseModel] | None = None, ) -> AddEpisodeResults: """ Process an episode and update the graph. diff --git a/graphiti_core/nodes.py b/graphiti_core/nodes.py index e598cd90..3341f857 100644 --- a/graphiti_core/nodes.py +++ b/graphiti_core/nodes.py @@ -19,7 +19,7 @@ from abc import ABC, abstractmethod from datetime import datetime from enum import Enum from time import time -from typing import Any, ClassVar +from typing import Any from uuid import uuid4 from neo4j import AsyncDriver @@ -39,12 +39,6 @@ from graphiti_core.utils.datetime_utils import utc_now logger = logging.getLogger(__name__) -class EntityType(BaseModel): - type_description: ClassVar[str] = Field( - default='', description='Description of what the entity type represents' - ) - - class EpisodeType(Enum): """ Enumeration of different types of episodes that can be processed. diff --git a/graphiti_core/utils/maintenance/node_operations.py b/graphiti_core/utils/maintenance/node_operations.py index b71f03b3..a84b8b94 100644 --- a/graphiti_core/utils/maintenance/node_operations.py +++ b/graphiti_core/utils/maintenance/node_operations.py @@ -18,10 +18,11 @@ import logging from time import time import pydantic +from pydantic import BaseModel from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather from graphiti_core.llm_client import LLMClient -from graphiti_core.nodes import EntityNode, EntityType, EpisodeType, EpisodicNode +from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode from graphiti_core.prompts import prompt_library from graphiti_core.prompts.dedupe_nodes import NodeDuplicate from graphiti_core.prompts.extract_nodes import ( @@ -120,7 +121,7 @@ async def extract_nodes( llm_client: LLMClient, episode: EpisodicNode, previous_episodes: list[EpisodicNode], - entity_types: dict[str, EntityType] | None = None, + entity_types: dict[str, BaseModel] | None = None, ) -> list[EntityNode]: start = time() extracted_node_names: list[str] = [] @@ -156,7 +157,8 @@ async def extract_nodes( 'previous_episodes': [ep.content for ep in previous_episodes], 'extracted_entities': extracted_node_names, 'entity_types': { - type_name: values.type_description for type_name, values in entity_types.items() + type_name: values.model_json_schema().get('description') + for type_name, values in entity_types.items() } if entity_types is not None else {}, @@ -262,7 +264,7 @@ async def resolve_extracted_nodes( existing_nodes_lists: list[list[EntityNode]], episode: EpisodicNode | None = None, previous_episodes: list[EpisodicNode] | None = None, - entity_types: dict[str, EntityType] | None = None, + entity_types: dict[str, BaseModel] | None = None, ) -> tuple[list[EntityNode], dict[str, str]]: uuid_map: dict[str, str] = {} resolved_nodes: list[EntityNode] = [] @@ -295,7 +297,7 @@ async def resolve_extracted_node( existing_nodes: list[EntityNode], episode: EpisodicNode | None = None, previous_episodes: list[EpisodicNode] | None = None, - entity_types: dict[str, EntityType] | None = None, + entity_types: dict[str, BaseModel] | None = None, ) -> tuple[EntityNode, dict[str, str]]: start = time() @@ -330,7 +332,7 @@ async def resolve_extracted_node( 'attributes': [], } - entity_type_classes: tuple[EntityType, ...] = tuple() + entity_type_classes: tuple[BaseModel, ...] = tuple() if entity_types is not None: # type: ignore entity_type_classes = entity_type_classes + tuple( filter( diff --git a/pyproject.toml b/pyproject.toml index 2768674b..4564d493 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "graphiti-core" -version = "0.7.8" +version = "0.7.9" description = "A temporal graph building library" authors = [ "Paul Paliychuk ",