swap type_description to docstring (#287)

* swap type_description to docstring

* remove unused imports

* bump version

* removed unused imports
This commit is contained in:
Preston Rasmussen 2025-03-05 15:27:03 -05:00 committed by GitHub
parent 5ef849cac9
commit e83bcbb435
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 16 additions and 21 deletions

View file

@ -18,14 +18,12 @@ import asyncio
import logging
import os
import sys
from typing import ClassVar
from dotenv import load_dotenv
from pydantic import Field
from pydantic import BaseModel, Field
from transcript_parser import parse_podcast_messages
from graphiti_core import Graphiti
from graphiti_core.nodes import EntityType
from graphiti_core.utils.maintenance.graph_data_operations import clear_data
load_dotenv()
@ -56,8 +54,9 @@ def setup_logging():
return logger
class Person(EntityType):
type_description: ClassVar[str] = 'A human person, fictional or nonfictional.'
class Person(BaseModel):
"""A human person, fictional or nonfictional."""
first_name: str | None = Field(..., description='First name')
last_name: str | None = Field(..., description='Last name')
occupation: str | None = Field(..., description="The person's work occupation")

View file

@ -29,7 +29,7 @@ from graphiti_core.edges import EntityEdge, EpisodicEdge
from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
from graphiti_core.llm_client import LLMClient, OpenAIClient
from graphiti_core.nodes import CommunityNode, EntityNode, EntityType, EpisodeType, EpisodicNode
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode
from graphiti_core.search.search import SearchConfig, search
from graphiti_core.search.search_config import DEFAULT_SEARCH_LIMIT, SearchResults
from graphiti_core.search.search_config_recipes import (
@ -262,7 +262,7 @@ class Graphiti:
group_id: str = '',
uuid: str | None = None,
update_communities: bool = False,
entity_types: dict[str, EntityType] | None = None,
entity_types: dict[str, BaseModel] | None = None,
) -> AddEpisodeResults:
"""
Process an episode and update the graph.

View file

@ -19,7 +19,7 @@ from abc import ABC, abstractmethod
from datetime import datetime
from enum import Enum
from time import time
from typing import Any, ClassVar
from typing import Any
from uuid import uuid4
from neo4j import AsyncDriver
@ -39,12 +39,6 @@ from graphiti_core.utils.datetime_utils import utc_now
logger = logging.getLogger(__name__)
class EntityType(BaseModel):
type_description: ClassVar[str] = Field(
default='', description='Description of what the entity type represents'
)
class EpisodeType(Enum):
"""
Enumeration of different types of episodes that can be processed.

View file

@ -18,10 +18,11 @@ import logging
from time import time
import pydantic
from pydantic import BaseModel
from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather
from graphiti_core.llm_client import LLMClient
from graphiti_core.nodes import EntityNode, EntityType, EpisodeType, EpisodicNode
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
from graphiti_core.prompts import prompt_library
from graphiti_core.prompts.dedupe_nodes import NodeDuplicate
from graphiti_core.prompts.extract_nodes import (
@ -120,7 +121,7 @@ async def extract_nodes(
llm_client: LLMClient,
episode: EpisodicNode,
previous_episodes: list[EpisodicNode],
entity_types: dict[str, EntityType] | None = None,
entity_types: dict[str, BaseModel] | None = None,
) -> list[EntityNode]:
start = time()
extracted_node_names: list[str] = []
@ -156,7 +157,8 @@ async def extract_nodes(
'previous_episodes': [ep.content for ep in previous_episodes],
'extracted_entities': extracted_node_names,
'entity_types': {
type_name: values.type_description for type_name, values in entity_types.items()
type_name: values.model_json_schema().get('description')
for type_name, values in entity_types.items()
}
if entity_types is not None
else {},
@ -262,7 +264,7 @@ async def resolve_extracted_nodes(
existing_nodes_lists: list[list[EntityNode]],
episode: EpisodicNode | None = None,
previous_episodes: list[EpisodicNode] | None = None,
entity_types: dict[str, EntityType] | None = None,
entity_types: dict[str, BaseModel] | None = None,
) -> tuple[list[EntityNode], dict[str, str]]:
uuid_map: dict[str, str] = {}
resolved_nodes: list[EntityNode] = []
@ -295,7 +297,7 @@ async def resolve_extracted_node(
existing_nodes: list[EntityNode],
episode: EpisodicNode | None = None,
previous_episodes: list[EpisodicNode] | None = None,
entity_types: dict[str, EntityType] | None = None,
entity_types: dict[str, BaseModel] | None = None,
) -> tuple[EntityNode, dict[str, str]]:
start = time()
@ -330,7 +332,7 @@ async def resolve_extracted_node(
'attributes': [],
}
entity_type_classes: tuple[EntityType, ...] = tuple()
entity_type_classes: tuple[BaseModel, ...] = tuple()
if entity_types is not None: # type: ignore
entity_type_classes = entity_type_classes + tuple(
filter(

View file

@ -1,6 +1,6 @@
[tool.poetry]
name = "graphiti-core"
version = "0.7.8"
version = "0.7.9"
description = "A temporal graph building library"
authors = [
"Paul Paliychuk <paul@getzep.com>",