add addepisode return object (#172)

* add addepisode return

* format
This commit is contained in:
Preston Rasmussen 2024-10-03 15:39:57 -04:00 committed by GitHub
parent c8ff5be8ce
commit 377225eec5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -21,6 +21,7 @@ from time import time
from dotenv import load_dotenv from dotenv import load_dotenv
from neo4j import AsyncGraphDatabase from neo4j import AsyncGraphDatabase
from pydantic import BaseModel
from graphiti_core.edges import EntityEdge, EpisodicEdge from graphiti_core.edges import EntityEdge, EpisodicEdge
from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
@ -77,6 +78,12 @@ logger = logging.getLogger(__name__)
load_dotenv() load_dotenv()
class AddEpisodeResults(BaseModel):
episode: EpisodicNode
nodes: list[EntityNode]
edges: list[EntityEdge]
class Graphiti: class Graphiti:
def __init__( def __init__(
self, self,
@ -245,7 +252,7 @@ class Graphiti:
group_id: str = '', group_id: str = '',
uuid: str | None = None, uuid: str | None = None,
update_communities: bool = False, update_communities: bool = False,
): ) -> AddEpisodeResults:
""" """
Process an episode and update the graph. Process an episode and update the graph.
@ -451,6 +458,8 @@ class Graphiti:
end = time() end = time()
logger.info(f'Completed add_episode in {(end - start) * 1000} ms') logger.info(f'Completed add_episode in {(end - start) * 1000} ms')
return AddEpisodeResults(episode=episode, nodes=nodes, edges=entity_edges)
except Exception as e: except Exception as e:
raise e raise e