Add max_coroutines parameter to Graphiti and update semaphore_gather function (#619)

- Introduced max_coroutines parameter in the Graphiti class to control the maximum number of concurrent operations.
- Updated the semaphore_gather function to accept max_coroutines as an optional argument, defaulting to SEMAPHORE_LIMIT if not provided.
- Adjusted multiple calls to semaphore_gather throughout the Graphiti class to utilize the new max_coroutines parameter for better concurrency management.
This commit is contained in:
Daniel Chalef 2025-06-24 09:32:16 -07:00 committed by GitHub
parent ae7f2234a8
commit fe870b953f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 58 additions and 13 deletions

View file

@ -103,6 +103,7 @@ class Graphiti:
cross_encoder: CrossEncoderClient | None = None, cross_encoder: CrossEncoderClient | None = None,
store_raw_episode_content: bool = True, store_raw_episode_content: bool = True,
graph_driver: GraphDriver | None = None, graph_driver: GraphDriver | None = None,
max_coroutines: int | None = None,
): ):
""" """
Initialize a Graphiti instance. Initialize a Graphiti instance.
@ -121,6 +122,20 @@ class Graphiti:
llm_client : LLMClient | None, optional llm_client : LLMClient | None, optional
An instance of LLMClient for natural language processing tasks. An instance of LLMClient for natural language processing tasks.
If not provided, a default OpenAIClient will be initialized. If not provided, a default OpenAIClient will be initialized.
embedder : EmbedderClient | None, optional
An instance of EmbedderClient for embedding tasks.
If not provided, a default OpenAIEmbedder will be initialized.
cross_encoder : CrossEncoderClient | None, optional
An instance of CrossEncoderClient for reranking tasks.
If not provided, a default OpenAIRerankerClient will be initialized.
store_raw_episode_content : bool, optional
Whether to store the raw content of episodes. Defaults to True.
graph_driver : GraphDriver | None, optional
An instance of GraphDriver for database operations.
If not provided, a default Neo4jDriver will be initialized.
max_coroutines : int | None, optional
The maximum number of concurrent operations allowed. Overrides SEMAPHORE_LIMIT set in the environment.
If not set, the Graphiti default is used.
Returns Returns
------- -------
@ -145,6 +160,7 @@ class Graphiti:
self.database = DEFAULT_DATABASE self.database = DEFAULT_DATABASE
self.store_raw_episode_content = store_raw_episode_content self.store_raw_episode_content = store_raw_episode_content
self.max_coroutines = max_coroutines
if llm_client: if llm_client:
self.llm_client = llm_client self.llm_client = llm_client
else: else:
@ -393,6 +409,7 @@ class Graphiti:
group_id, group_id,
edge_types, edge_types,
), ),
max_coroutines=self.max_coroutines,
) )
edges = resolve_edge_pointers(extracted_edges, uuid_map) edges = resolve_edge_pointers(extracted_edges, uuid_map)
@ -409,6 +426,7 @@ class Graphiti:
extract_attributes_from_nodes( extract_attributes_from_nodes(
self.clients, nodes, episode, previous_episodes, entity_types self.clients, nodes, episode, previous_episodes, entity_types
), ),
max_coroutines=self.max_coroutines,
) )
duplicate_of_edges = build_duplicate_of_edges(episode, now, node_duplicates) duplicate_of_edges = build_duplicate_of_edges(episode, now, node_duplicates)
@ -432,7 +450,8 @@ class Graphiti:
*[ *[
update_community(self.driver, self.llm_client, self.embedder, node) update_community(self.driver, self.llm_client, self.embedder, node)
for node in nodes for node in nodes
] ],
max_coroutines=self.max_coroutines,
) )
end = time() end = time()
logger.info(f'Completed add_episode in {(end - start) * 1000} ms') logger.info(f'Completed add_episode in {(end - start) * 1000} ms')
@ -499,7 +518,10 @@ class Graphiti:
] ]
# Save all the episodes # Save all the episodes
await semaphore_gather(*[episode.save(self.driver) for episode in episodes]) await semaphore_gather(
*[episode.save(self.driver) for episode in episodes],
max_coroutines=self.max_coroutines,
)
# Get previous episode context for each episode # Get previous episode context for each episode
episode_pairs = await retrieve_previous_episodes_bulk(self.driver, episodes) episode_pairs = await retrieve_previous_episodes_bulk(self.driver, episodes)
@ -515,16 +537,21 @@ class Graphiti:
await semaphore_gather( await semaphore_gather(
*[node.generate_name_embedding(self.embedder) for node in extracted_nodes], *[node.generate_name_embedding(self.embedder) for node in extracted_nodes],
*[edge.generate_embedding(self.embedder) for edge in extracted_edges], *[edge.generate_embedding(self.embedder) for edge in extracted_edges],
max_coroutines=self.max_coroutines,
) )
# Dedupe extracted nodes, compress extracted edges # Dedupe extracted nodes, compress extracted edges
(nodes, uuid_map), extracted_edges_timestamped = await semaphore_gather( (nodes, uuid_map), extracted_edges_timestamped = await semaphore_gather(
dedupe_nodes_bulk(self.driver, self.llm_client, extracted_nodes), dedupe_nodes_bulk(self.driver, self.llm_client, extracted_nodes),
extract_edge_dates_bulk(self.llm_client, extracted_edges, episode_pairs), extract_edge_dates_bulk(self.llm_client, extracted_edges, episode_pairs),
max_coroutines=self.max_coroutines,
) )
# save nodes to KG # save nodes to KG
await semaphore_gather(*[node.save(self.driver) for node in nodes]) await semaphore_gather(
*[node.save(self.driver) for node in nodes],
max_coroutines=self.max_coroutines,
)
# re-map edge pointers so that they don't point to discard dupe nodes # re-map edge pointers so that they don't point to discard dupe nodes
extracted_edges_with_resolved_pointers: list[EntityEdge] = resolve_edge_pointers( extracted_edges_with_resolved_pointers: list[EntityEdge] = resolve_edge_pointers(
@ -536,7 +563,8 @@ class Graphiti:
# save episodic edges to KG # save episodic edges to KG
await semaphore_gather( await semaphore_gather(
*[edge.save(self.driver) for edge in episodic_edges_with_resolved_pointers] *[edge.save(self.driver) for edge in episodic_edges_with_resolved_pointers],
max_coroutines=self.max_coroutines,
) )
# Dedupe extracted edges # Dedupe extracted edges
@ -548,7 +576,10 @@ class Graphiti:
# invalidate edges # invalidate edges
# save edges to KG # save edges to KG
await semaphore_gather(*[edge.save(self.driver) for edge in edges]) await semaphore_gather(
*[edge.save(self.driver) for edge in edges],
max_coroutines=self.max_coroutines,
)
end = time() end = time()
logger.info(f'Completed add_episode_bulk in {(end - start) * 1000} ms') logger.info(f'Completed add_episode_bulk in {(end - start) * 1000} ms')
@ -572,11 +603,18 @@ class Graphiti:
) )
await semaphore_gather( await semaphore_gather(
*[node.generate_name_embedding(self.embedder) for node in community_nodes] *[node.generate_name_embedding(self.embedder) for node in community_nodes],
max_coroutines=self.max_coroutines,
) )
await semaphore_gather(*[node.save(self.driver) for node in community_nodes]) await semaphore_gather(
await semaphore_gather(*[edge.save(self.driver) for edge in community_edges]) *[node.save(self.driver) for node in community_nodes],
max_coroutines=self.max_coroutines,
)
await semaphore_gather(
*[edge.save(self.driver) for edge in community_edges],
max_coroutines=self.max_coroutines,
)
return community_nodes return community_nodes
@ -683,7 +721,8 @@ class Graphiti:
episodes = await EpisodicNode.get_by_uuids(self.driver, episode_uuids) episodes = await EpisodicNode.get_by_uuids(self.driver, episode_uuids)
edges_list = await semaphore_gather( edges_list = await semaphore_gather(
*[EntityEdge.get_by_uuids(self.driver, episode.entity_edges) for episode in episodes] *[EntityEdge.get_by_uuids(self.driver, episode.entity_edges) for episode in episodes],
max_coroutines=self.max_coroutines,
) )
edges: list[EntityEdge] = [edge for lst in edges_list for edge in lst] edges: list[EntityEdge] = [edge for lst in edges_list for edge in lst]
@ -759,6 +798,12 @@ class Graphiti:
if record['episode_count'] == 1: if record['episode_count'] == 1:
nodes_to_delete.append(node) nodes_to_delete.append(node)
await semaphore_gather(*[node.delete(self.driver) for node in nodes_to_delete]) await semaphore_gather(
await semaphore_gather(*[edge.delete(self.driver) for edge in edges_to_delete]) *[node.delete(self.driver) for node in nodes_to_delete],
max_coroutines=self.max_coroutines,
)
await semaphore_gather(
*[edge.delete(self.driver) for edge in edges_to_delete],
max_coroutines=self.max_coroutines,
)
await episode.delete(self.driver) await episode.delete(self.driver)

View file

@ -94,9 +94,9 @@ def normalize_l2(embedding: list[float]) -> NDArray:
# Use this instead of asyncio.gather() to bound coroutines # Use this instead of asyncio.gather() to bound coroutines
async def semaphore_gather( async def semaphore_gather(
*coroutines: Coroutine, *coroutines: Coroutine,
max_coroutines: int = SEMAPHORE_LIMIT, max_coroutines: int | None = None,
): ):
semaphore = asyncio.Semaphore(max_coroutines) semaphore = asyncio.Semaphore(max_coroutines or SEMAPHORE_LIMIT)
async def _wrap_coroutine(coroutine): async def _wrap_coroutine(coroutine):
async with semaphore: async with semaphore: