diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index 68dc1d01..e6fcfe5a 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -113,6 +113,15 @@ class AddEpisodeResults(BaseModel): community_edges: list[CommunityEdge] +class AddBulkEpisodeResults(BaseModel): + episodes: list[EpisodicNode] + episodic_edges: list[EpisodicEdge] + nodes: list[EntityNode] + edges: list[EntityEdge] + communities: list[CommunityNode] + community_edges: list[CommunityEdge] + + class Graphiti: def __init__( self, @@ -580,7 +589,7 @@ class Graphiti: excluded_entity_types: list[str] | None = None, edge_types: dict[str, type[BaseModel]] | None = None, edge_type_map: dict[tuple[str, str], list[str]] | None = None, - ) -> AddEpisodeResults: + ) -> AddBulkEpisodeResults: """ Process multiple episodes in bulk and update the graph. @@ -848,7 +857,7 @@ class Graphiti: end = time() logger.info(f'Completed add_episode_bulk in {(end - start) * 1000} ms') - return AddEpisodeResults( + return AddBulkEpisodeResults( episode=episodes, episodic_edges=resolved_episodic_edges, nodes=final_hydrated_nodes,