diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index 1d49ac3a..50eda87a 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -103,6 +103,7 @@ class Graphiti: cross_encoder: CrossEncoderClient | None = None, store_raw_episode_content: bool = True, graph_driver: GraphDriver | None = None, + max_coroutines: int | None = None, ): """ Initialize a Graphiti instance. @@ -121,6 +122,20 @@ class Graphiti: llm_client : LLMClient | None, optional An instance of LLMClient for natural language processing tasks. If not provided, a default OpenAIClient will be initialized. + embedder : EmbedderClient | None, optional + An instance of EmbedderClient for embedding tasks. + If not provided, a default OpenAIEmbedder will be initialized. + cross_encoder : CrossEncoderClient | None, optional + An instance of CrossEncoderClient for reranking tasks. + If not provided, a default OpenAIRerankerClient will be initialized. + store_raw_episode_content : bool, optional + Whether to store the raw content of episodes. Defaults to True. + graph_driver : GraphDriver | None, optional + An instance of GraphDriver for database operations. + If not provided, a default Neo4jDriver will be initialized. + max_coroutines : int | None, optional + The maximum number of concurrent operations allowed. Overrides SEMAPHORE_LIMIT set in the environment. + If not set, the Graphiti default is used. Returns ------- @@ -145,6 +160,7 @@ class Graphiti: self.database = DEFAULT_DATABASE self.store_raw_episode_content = store_raw_episode_content + self.max_coroutines = max_coroutines if llm_client: self.llm_client = llm_client else: @@ -393,6 +409,7 @@ class Graphiti: group_id, edge_types, ), + max_coroutines=self.max_coroutines, ) edges = resolve_edge_pointers(extracted_edges, uuid_map) @@ -409,6 +426,7 @@ class Graphiti: extract_attributes_from_nodes( self.clients, nodes, episode, previous_episodes, entity_types ), + max_coroutines=self.max_coroutines, ) duplicate_of_edges = build_duplicate_of_edges(episode, now, node_duplicates) @@ -432,7 +450,8 @@ class Graphiti: *[ update_community(self.driver, self.llm_client, self.embedder, node) for node in nodes - ] + ], + max_coroutines=self.max_coroutines, ) end = time() logger.info(f'Completed add_episode in {(end - start) * 1000} ms') @@ -499,7 +518,10 @@ class Graphiti: ] # Save all the episodes - await semaphore_gather(*[episode.save(self.driver) for episode in episodes]) + await semaphore_gather( + *[episode.save(self.driver) for episode in episodes], + max_coroutines=self.max_coroutines, + ) # Get previous episode context for each episode episode_pairs = await retrieve_previous_episodes_bulk(self.driver, episodes) @@ -515,16 +537,21 @@ class Graphiti: await semaphore_gather( *[node.generate_name_embedding(self.embedder) for node in extracted_nodes], *[edge.generate_embedding(self.embedder) for edge in extracted_edges], + max_coroutines=self.max_coroutines, ) # Dedupe extracted nodes, compress extracted edges (nodes, uuid_map), extracted_edges_timestamped = await semaphore_gather( dedupe_nodes_bulk(self.driver, self.llm_client, extracted_nodes), extract_edge_dates_bulk(self.llm_client, extracted_edges, episode_pairs), + max_coroutines=self.max_coroutines, ) # save nodes to KG - await semaphore_gather(*[node.save(self.driver) for node in nodes]) + await semaphore_gather( + *[node.save(self.driver) for node in nodes], + max_coroutines=self.max_coroutines, + ) # re-map edge pointers so that they don't point to discard dupe nodes extracted_edges_with_resolved_pointers: list[EntityEdge] = resolve_edge_pointers( @@ -536,7 +563,8 @@ class Graphiti: # save episodic edges to KG await semaphore_gather( - *[edge.save(self.driver) for edge in episodic_edges_with_resolved_pointers] + *[edge.save(self.driver) for edge in episodic_edges_with_resolved_pointers], + max_coroutines=self.max_coroutines, ) # Dedupe extracted edges @@ -548,7 +576,10 @@ class Graphiti: # invalidate edges # save edges to KG - await semaphore_gather(*[edge.save(self.driver) for edge in edges]) + await semaphore_gather( + *[edge.save(self.driver) for edge in edges], + max_coroutines=self.max_coroutines, + ) end = time() logger.info(f'Completed add_episode_bulk in {(end - start) * 1000} ms') @@ -572,11 +603,18 @@ class Graphiti: ) await semaphore_gather( - *[node.generate_name_embedding(self.embedder) for node in community_nodes] + *[node.generate_name_embedding(self.embedder) for node in community_nodes], + max_coroutines=self.max_coroutines, ) - await semaphore_gather(*[node.save(self.driver) for node in community_nodes]) - await semaphore_gather(*[edge.save(self.driver) for edge in community_edges]) + await semaphore_gather( + *[node.save(self.driver) for node in community_nodes], + max_coroutines=self.max_coroutines, + ) + await semaphore_gather( + *[edge.save(self.driver) for edge in community_edges], + max_coroutines=self.max_coroutines, + ) return community_nodes @@ -683,7 +721,8 @@ class Graphiti: episodes = await EpisodicNode.get_by_uuids(self.driver, episode_uuids) edges_list = await semaphore_gather( - *[EntityEdge.get_by_uuids(self.driver, episode.entity_edges) for episode in episodes] + *[EntityEdge.get_by_uuids(self.driver, episode.entity_edges) for episode in episodes], + max_coroutines=self.max_coroutines, ) edges: list[EntityEdge] = [edge for lst in edges_list for edge in lst] @@ -759,6 +798,12 @@ class Graphiti: if record['episode_count'] == 1: nodes_to_delete.append(node) - await semaphore_gather(*[node.delete(self.driver) for node in nodes_to_delete]) - await semaphore_gather(*[edge.delete(self.driver) for edge in edges_to_delete]) + await semaphore_gather( + *[node.delete(self.driver) for node in nodes_to_delete], + max_coroutines=self.max_coroutines, + ) + await semaphore_gather( + *[edge.delete(self.driver) for edge in edges_to_delete], + max_coroutines=self.max_coroutines, + ) await episode.delete(self.driver) diff --git a/graphiti_core/helpers.py b/graphiti_core/helpers.py index de5020a9..47b3a6da 100644 --- a/graphiti_core/helpers.py +++ b/graphiti_core/helpers.py @@ -94,9 +94,9 @@ def normalize_l2(embedding: list[float]) -> NDArray: # Use this instead of asyncio.gather() to bound coroutines async def semaphore_gather( *coroutines: Coroutine, - max_coroutines: int = SEMAPHORE_LIMIT, + max_coroutines: int | None = None, ): - semaphore = asyncio.Semaphore(max_coroutines) + semaphore = asyncio.Semaphore(max_coroutines or SEMAPHORE_LIMIT) async def _wrap_coroutine(coroutine): async with semaphore: