diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index 7776a76a..101a74f4 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -122,19 +122,24 @@ class AddBulkEpisodeResults(BaseModel): community_edges: list[CommunityEdge] +class AddTripletResults(BaseModel): + nodes: list[EntityNode] + edges: list[EntityEdge] + + class Graphiti: def __init__( - self, - uri: str | None = None, - user: str | None = None, - password: str | None = None, - llm_client: LLMClient | None = None, - embedder: EmbedderClient | None = None, - cross_encoder: CrossEncoderClient | None = None, - store_raw_episode_content: bool = True, - graph_driver: GraphDriver | None = None, - max_coroutines: int | None = None, - ensure_ascii: bool = False, + self, + uri: str | None = None, + user: str | None = None, + password: str | None = None, + llm_client: LLMClient | None = None, + embedder: EmbedderClient | None = None, + cross_encoder: CrossEncoderClient | None = None, + store_raw_episode_content: bool = True, + graph_driver: GraphDriver | None = None, + max_coroutines: int | None = None, + ensure_ascii: bool = False, ): """ Initialize a Graphiti instance. @@ -345,11 +350,11 @@ class Graphiti: await build_indices_and_constraints(self.driver, delete_existing) async def retrieve_episodes( - self, - reference_time: datetime, - last_n: int = EPISODE_WINDOW_LEN, - group_ids: list[str] | None = None, - source: EpisodeType | None = None, + self, + reference_time: datetime, + last_n: int = EPISODE_WINDOW_LEN, + group_ids: list[str] | None = None, + source: EpisodeType | None = None, ) -> list[EpisodicNode]: """ Retrieve the last n episodic nodes from the graph. @@ -379,20 +384,20 @@ class Graphiti: return await retrieve_episodes(self.driver, reference_time, last_n, group_ids, source) async def add_episode( - self, - name: str, - episode_body: str, - source_description: str, - reference_time: datetime, - source: EpisodeType = EpisodeType.message, - group_id: str | None = None, - uuid: str | None = None, - update_communities: bool = False, - entity_types: dict[str, type[BaseModel]] | None = None, - excluded_entity_types: list[str] | None = None, - previous_episode_uuids: list[str] | None = None, - edge_types: dict[str, type[BaseModel]] | None = None, - edge_type_map: dict[tuple[str, str], list[str]] | None = None, + self, + name: str, + episode_body: str, + source_description: str, + reference_time: datetime, + source: EpisodeType = EpisodeType.message, + group_id: str | None = None, + uuid: str | None = None, + update_communities: bool = False, + entity_types: dict[str, type[BaseModel]] | None = None, + excluded_entity_types: list[str] | None = None, + previous_episode_uuids: list[str] | None = None, + edge_types: dict[str, type[BaseModel]] | None = None, + edge_type_map: dict[tuple[str, str], list[str]] | None = None, ) -> AddEpisodeResults: """ Process an episode and update the graph. @@ -582,13 +587,13 @@ class Graphiti: raise e async def add_episode_bulk( - self, - bulk_episodes: list[RawEpisode], - group_id: str | None = None, - entity_types: dict[str, type[BaseModel]] | None = None, - excluded_entity_types: list[str] | None = None, - edge_types: dict[str, type[BaseModel]] | None = None, - edge_type_map: dict[tuple[str, str], list[str]] | None = None, + self, + bulk_episodes: list[RawEpisode], + group_id: str | None = None, + entity_types: dict[str, type[BaseModel]] | None = None, + excluded_entity_types: list[str] | None = None, + edge_types: dict[str, type[BaseModel]] | None = None, + edge_type_map: dict[tuple[str, str], list[str]] | None = None, ) -> AddBulkEpisodeResults: """ Process multiple episodes in bulk and update the graph. @@ -870,7 +875,7 @@ class Graphiti: raise e async def build_communities( - self, group_ids: list[str] | None = None + self, group_ids: list[str] | None = None ) -> tuple[list[CommunityNode], list[CommunityEdge]]: """ Use a community clustering algorithm to find communities of nodes. Create community nodes summarising @@ -903,12 +908,12 @@ class Graphiti: return community_nodes, community_edges async def search( - self, - query: str, - center_node_uuid: str | None = None, - group_ids: list[str] | None = None, - num_results=DEFAULT_SEARCH_LIMIT, - search_filter: SearchFilters | None = None, + self, + query: str, + center_node_uuid: str | None = None, + group_ids: list[str] | None = None, + num_results=DEFAULT_SEARCH_LIMIT, + search_filter: SearchFilters | None = None, ) -> list[EntityEdge]: """ Perform a hybrid search on the knowledge graph. @@ -962,13 +967,13 @@ class Graphiti: return edges async def _search( - self, - query: str, - config: SearchConfig, - group_ids: list[str] | None = None, - center_node_uuid: str | None = None, - bfs_origin_node_uuids: list[str] | None = None, - search_filter: SearchFilters | None = None, + self, + query: str, + config: SearchConfig, + group_ids: list[str] | None = None, + center_node_uuid: str | None = None, + bfs_origin_node_uuids: list[str] | None = None, + search_filter: SearchFilters | None = None, ) -> SearchResults: """DEPRECATED""" return await self.search_( @@ -976,13 +981,13 @@ class Graphiti: ) async def search_( - self, - query: str, - config: SearchConfig = COMBINED_HYBRID_SEARCH_CROSS_ENCODER, - group_ids: list[str] | None = None, - center_node_uuid: str | None = None, - bfs_origin_node_uuids: list[str] | None = None, - search_filter: SearchFilters | None = None, + self, + query: str, + config: SearchConfig = COMBINED_HYBRID_SEARCH_CROSS_ENCODER, + group_ids: list[str] | None = None, + center_node_uuid: str | None = None, + bfs_origin_node_uuids: list[str] | None = None, + search_filter: SearchFilters | None = None, ) -> SearchResults: """search_ (replaces _search) is our advanced search method that returns Graph objects (nodes and edges) rather than a list of facts. This endpoint allows the end user to utilize more advanced features such as filters and @@ -1015,8 +1020,9 @@ class Graphiti: return SearchResults(edges=edges, nodes=nodes) - async def add_triplet(self, source_node: EntityNode, edge: EntityEdge, - target_node: EntityNode) -> AddEpisodeResults: + async def add_triplet( + self, source_node: EntityNode, edge: EntityEdge, target_node: EntityNode + ) -> AddTripletResults: if source_node.name_embedding is None: await source_node.generate_name_embedding(self.embedder) if target_node.name_embedding is None: @@ -1060,7 +1066,7 @@ class Graphiti: await create_entity_node_embeddings(self.embedder, nodes) await add_nodes_and_edges_bulk(self.driver, [], [], nodes, edges, self.embedder) - return AddEpisodeResults(edges=edges, nodes=nodes) + return AddTripletResults(episode=None, edges=edges, nodes=nodes) async def remove_episode(self, episode_uuid: str): # Find the episode to be deleted diff --git a/uv.lock b/uv.lock index a001c9ea..e2a3b855 100644 --- a/uv.lock +++ b/uv.lock @@ -783,7 +783,7 @@ wheels = [ [[package]] name = "graphiti-core" -version = "0.20.3" +version = "0.20.4" source = { editable = "." } dependencies = [ { name = "diskcache" },