swap type_description to docstring (#287)
* swap type_description to docstring * remove unused imports * bump version * removed unused imports
This commit is contained in:
parent
5ef849cac9
commit
e83bcbb435
5 changed files with 16 additions and 21 deletions
|
|
@ -18,14 +18,12 @@ import asyncio
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from typing import ClassVar
|
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from pydantic import Field
|
from pydantic import BaseModel, Field
|
||||||
from transcript_parser import parse_podcast_messages
|
from transcript_parser import parse_podcast_messages
|
||||||
|
|
||||||
from graphiti_core import Graphiti
|
from graphiti_core import Graphiti
|
||||||
from graphiti_core.nodes import EntityType
|
|
||||||
from graphiti_core.utils.maintenance.graph_data_operations import clear_data
|
from graphiti_core.utils.maintenance.graph_data_operations import clear_data
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
@ -56,8 +54,9 @@ def setup_logging():
|
||||||
return logger
|
return logger
|
||||||
|
|
||||||
|
|
||||||
class Person(EntityType):
|
class Person(BaseModel):
|
||||||
type_description: ClassVar[str] = 'A human person, fictional or nonfictional.'
|
"""A human person, fictional or nonfictional."""
|
||||||
|
|
||||||
first_name: str | None = Field(..., description='First name')
|
first_name: str | None = Field(..., description='First name')
|
||||||
last_name: str | None = Field(..., description='Last name')
|
last_name: str | None = Field(..., description='Last name')
|
||||||
occupation: str | None = Field(..., description="The person's work occupation")
|
occupation: str | None = Field(..., description="The person's work occupation")
|
||||||
|
|
|
||||||
|
|
@ -29,7 +29,7 @@ from graphiti_core.edges import EntityEdge, EpisodicEdge
|
||||||
from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
|
from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
|
||||||
from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
|
from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
|
||||||
from graphiti_core.llm_client import LLMClient, OpenAIClient
|
from graphiti_core.llm_client import LLMClient, OpenAIClient
|
||||||
from graphiti_core.nodes import CommunityNode, EntityNode, EntityType, EpisodeType, EpisodicNode
|
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode
|
||||||
from graphiti_core.search.search import SearchConfig, search
|
from graphiti_core.search.search import SearchConfig, search
|
||||||
from graphiti_core.search.search_config import DEFAULT_SEARCH_LIMIT, SearchResults
|
from graphiti_core.search.search_config import DEFAULT_SEARCH_LIMIT, SearchResults
|
||||||
from graphiti_core.search.search_config_recipes import (
|
from graphiti_core.search.search_config_recipes import (
|
||||||
|
|
@ -262,7 +262,7 @@ class Graphiti:
|
||||||
group_id: str = '',
|
group_id: str = '',
|
||||||
uuid: str | None = None,
|
uuid: str | None = None,
|
||||||
update_communities: bool = False,
|
update_communities: bool = False,
|
||||||
entity_types: dict[str, EntityType] | None = None,
|
entity_types: dict[str, BaseModel] | None = None,
|
||||||
) -> AddEpisodeResults:
|
) -> AddEpisodeResults:
|
||||||
"""
|
"""
|
||||||
Process an episode and update the graph.
|
Process an episode and update the graph.
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,7 @@ from abc import ABC, abstractmethod
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from time import time
|
from time import time
|
||||||
from typing import Any, ClassVar
|
from typing import Any
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from neo4j import AsyncDriver
|
from neo4j import AsyncDriver
|
||||||
|
|
@ -39,12 +39,6 @@ from graphiti_core.utils.datetime_utils import utc_now
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class EntityType(BaseModel):
|
|
||||||
type_description: ClassVar[str] = Field(
|
|
||||||
default='', description='Description of what the entity type represents'
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class EpisodeType(Enum):
|
class EpisodeType(Enum):
|
||||||
"""
|
"""
|
||||||
Enumeration of different types of episodes that can be processed.
|
Enumeration of different types of episodes that can be processed.
|
||||||
|
|
|
||||||
|
|
@ -18,10 +18,11 @@ import logging
|
||||||
from time import time
|
from time import time
|
||||||
|
|
||||||
import pydantic
|
import pydantic
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather
|
from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather
|
||||||
from graphiti_core.llm_client import LLMClient
|
from graphiti_core.llm_client import LLMClient
|
||||||
from graphiti_core.nodes import EntityNode, EntityType, EpisodeType, EpisodicNode
|
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
||||||
from graphiti_core.prompts import prompt_library
|
from graphiti_core.prompts import prompt_library
|
||||||
from graphiti_core.prompts.dedupe_nodes import NodeDuplicate
|
from graphiti_core.prompts.dedupe_nodes import NodeDuplicate
|
||||||
from graphiti_core.prompts.extract_nodes import (
|
from graphiti_core.prompts.extract_nodes import (
|
||||||
|
|
@ -120,7 +121,7 @@ async def extract_nodes(
|
||||||
llm_client: LLMClient,
|
llm_client: LLMClient,
|
||||||
episode: EpisodicNode,
|
episode: EpisodicNode,
|
||||||
previous_episodes: list[EpisodicNode],
|
previous_episodes: list[EpisodicNode],
|
||||||
entity_types: dict[str, EntityType] | None = None,
|
entity_types: dict[str, BaseModel] | None = None,
|
||||||
) -> list[EntityNode]:
|
) -> list[EntityNode]:
|
||||||
start = time()
|
start = time()
|
||||||
extracted_node_names: list[str] = []
|
extracted_node_names: list[str] = []
|
||||||
|
|
@ -156,7 +157,8 @@ async def extract_nodes(
|
||||||
'previous_episodes': [ep.content for ep in previous_episodes],
|
'previous_episodes': [ep.content for ep in previous_episodes],
|
||||||
'extracted_entities': extracted_node_names,
|
'extracted_entities': extracted_node_names,
|
||||||
'entity_types': {
|
'entity_types': {
|
||||||
type_name: values.type_description for type_name, values in entity_types.items()
|
type_name: values.model_json_schema().get('description')
|
||||||
|
for type_name, values in entity_types.items()
|
||||||
}
|
}
|
||||||
if entity_types is not None
|
if entity_types is not None
|
||||||
else {},
|
else {},
|
||||||
|
|
@ -262,7 +264,7 @@ async def resolve_extracted_nodes(
|
||||||
existing_nodes_lists: list[list[EntityNode]],
|
existing_nodes_lists: list[list[EntityNode]],
|
||||||
episode: EpisodicNode | None = None,
|
episode: EpisodicNode | None = None,
|
||||||
previous_episodes: list[EpisodicNode] | None = None,
|
previous_episodes: list[EpisodicNode] | None = None,
|
||||||
entity_types: dict[str, EntityType] | None = None,
|
entity_types: dict[str, BaseModel] | None = None,
|
||||||
) -> tuple[list[EntityNode], dict[str, str]]:
|
) -> tuple[list[EntityNode], dict[str, str]]:
|
||||||
uuid_map: dict[str, str] = {}
|
uuid_map: dict[str, str] = {}
|
||||||
resolved_nodes: list[EntityNode] = []
|
resolved_nodes: list[EntityNode] = []
|
||||||
|
|
@ -295,7 +297,7 @@ async def resolve_extracted_node(
|
||||||
existing_nodes: list[EntityNode],
|
existing_nodes: list[EntityNode],
|
||||||
episode: EpisodicNode | None = None,
|
episode: EpisodicNode | None = None,
|
||||||
previous_episodes: list[EpisodicNode] | None = None,
|
previous_episodes: list[EpisodicNode] | None = None,
|
||||||
entity_types: dict[str, EntityType] | None = None,
|
entity_types: dict[str, BaseModel] | None = None,
|
||||||
) -> tuple[EntityNode, dict[str, str]]:
|
) -> tuple[EntityNode, dict[str, str]]:
|
||||||
start = time()
|
start = time()
|
||||||
|
|
||||||
|
|
@ -330,7 +332,7 @@ async def resolve_extracted_node(
|
||||||
'attributes': [],
|
'attributes': [],
|
||||||
}
|
}
|
||||||
|
|
||||||
entity_type_classes: tuple[EntityType, ...] = tuple()
|
entity_type_classes: tuple[BaseModel, ...] = tuple()
|
||||||
if entity_types is not None: # type: ignore
|
if entity_types is not None: # type: ignore
|
||||||
entity_type_classes = entity_type_classes + tuple(
|
entity_type_classes = entity_type_classes + tuple(
|
||||||
filter(
|
filter(
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "graphiti-core"
|
name = "graphiti-core"
|
||||||
version = "0.7.8"
|
version = "0.7.9"
|
||||||
description = "A temporal graph building library"
|
description = "A temporal graph building library"
|
||||||
authors = [
|
authors = [
|
||||||
"Paul Paliychuk <paul@getzep.com>",
|
"Paul Paliychuk <paul@getzep.com>",
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue