From 377225eec5d7c1fdfc55b4b77ee62e303cb799ca Mon Sep 17 00:00:00 2001 From: Preston Rasmussen <109292228+prasmussen15@users.noreply.github.com> Date: Thu, 3 Oct 2024 15:39:57 -0400 Subject: [PATCH] add addepisode return object (#172) * add addepisode return * format --- graphiti_core/graphiti.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index fff2aa91..07507e0d 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -21,6 +21,7 @@ from time import time from dotenv import load_dotenv from neo4j import AsyncGraphDatabase +from pydantic import BaseModel from graphiti_core.edges import EntityEdge, EpisodicEdge from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder @@ -77,6 +78,12 @@ logger = logging.getLogger(__name__) load_dotenv() +class AddEpisodeResults(BaseModel): + episode: EpisodicNode + nodes: list[EntityNode] + edges: list[EntityEdge] + + class Graphiti: def __init__( self, @@ -245,7 +252,7 @@ class Graphiti: group_id: str = '', uuid: str | None = None, update_communities: bool = False, - ): + ) -> AddEpisodeResults: """ Process an episode and update the graph. @@ -451,6 +458,8 @@ class Graphiti: end = time() logger.info(f'Completed add_episode in {(end - start) * 1000} ms') + return AddEpisodeResults(episode=episode, nodes=nodes, edges=entity_edges) + except Exception as e: raise e