Add max_coroutines parameter to Graphiti and update semaphore_gather function (#619)
- Introduced max_coroutines parameter in the Graphiti class to control the maximum number of concurrent operations. - Updated the semaphore_gather function to accept max_coroutines as an optional argument, defaulting to SEMAPHORE_LIMIT if not provided. - Adjusted multiple calls to semaphore_gather throughout the Graphiti class to utilize the new max_coroutines parameter for better concurrency management.
This commit is contained in:
parent
ae7f2234a8
commit
fe870b953f
2 changed files with 58 additions and 13 deletions
|
|
@ -103,6 +103,7 @@ class Graphiti:
|
||||||
cross_encoder: CrossEncoderClient | None = None,
|
cross_encoder: CrossEncoderClient | None = None,
|
||||||
store_raw_episode_content: bool = True,
|
store_raw_episode_content: bool = True,
|
||||||
graph_driver: GraphDriver | None = None,
|
graph_driver: GraphDriver | None = None,
|
||||||
|
max_coroutines: int | None = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize a Graphiti instance.
|
Initialize a Graphiti instance.
|
||||||
|
|
@ -121,6 +122,20 @@ class Graphiti:
|
||||||
llm_client : LLMClient | None, optional
|
llm_client : LLMClient | None, optional
|
||||||
An instance of LLMClient for natural language processing tasks.
|
An instance of LLMClient for natural language processing tasks.
|
||||||
If not provided, a default OpenAIClient will be initialized.
|
If not provided, a default OpenAIClient will be initialized.
|
||||||
|
embedder : EmbedderClient | None, optional
|
||||||
|
An instance of EmbedderClient for embedding tasks.
|
||||||
|
If not provided, a default OpenAIEmbedder will be initialized.
|
||||||
|
cross_encoder : CrossEncoderClient | None, optional
|
||||||
|
An instance of CrossEncoderClient for reranking tasks.
|
||||||
|
If not provided, a default OpenAIRerankerClient will be initialized.
|
||||||
|
store_raw_episode_content : bool, optional
|
||||||
|
Whether to store the raw content of episodes. Defaults to True.
|
||||||
|
graph_driver : GraphDriver | None, optional
|
||||||
|
An instance of GraphDriver for database operations.
|
||||||
|
If not provided, a default Neo4jDriver will be initialized.
|
||||||
|
max_coroutines : int | None, optional
|
||||||
|
The maximum number of concurrent operations allowed. Overrides SEMAPHORE_LIMIT set in the environment.
|
||||||
|
If not set, the Graphiti default is used.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
|
|
@ -145,6 +160,7 @@ class Graphiti:
|
||||||
|
|
||||||
self.database = DEFAULT_DATABASE
|
self.database = DEFAULT_DATABASE
|
||||||
self.store_raw_episode_content = store_raw_episode_content
|
self.store_raw_episode_content = store_raw_episode_content
|
||||||
|
self.max_coroutines = max_coroutines
|
||||||
if llm_client:
|
if llm_client:
|
||||||
self.llm_client = llm_client
|
self.llm_client = llm_client
|
||||||
else:
|
else:
|
||||||
|
|
@ -393,6 +409,7 @@ class Graphiti:
|
||||||
group_id,
|
group_id,
|
||||||
edge_types,
|
edge_types,
|
||||||
),
|
),
|
||||||
|
max_coroutines=self.max_coroutines,
|
||||||
)
|
)
|
||||||
|
|
||||||
edges = resolve_edge_pointers(extracted_edges, uuid_map)
|
edges = resolve_edge_pointers(extracted_edges, uuid_map)
|
||||||
|
|
@ -409,6 +426,7 @@ class Graphiti:
|
||||||
extract_attributes_from_nodes(
|
extract_attributes_from_nodes(
|
||||||
self.clients, nodes, episode, previous_episodes, entity_types
|
self.clients, nodes, episode, previous_episodes, entity_types
|
||||||
),
|
),
|
||||||
|
max_coroutines=self.max_coroutines,
|
||||||
)
|
)
|
||||||
|
|
||||||
duplicate_of_edges = build_duplicate_of_edges(episode, now, node_duplicates)
|
duplicate_of_edges = build_duplicate_of_edges(episode, now, node_duplicates)
|
||||||
|
|
@ -432,7 +450,8 @@ class Graphiti:
|
||||||
*[
|
*[
|
||||||
update_community(self.driver, self.llm_client, self.embedder, node)
|
update_community(self.driver, self.llm_client, self.embedder, node)
|
||||||
for node in nodes
|
for node in nodes
|
||||||
]
|
],
|
||||||
|
max_coroutines=self.max_coroutines,
|
||||||
)
|
)
|
||||||
end = time()
|
end = time()
|
||||||
logger.info(f'Completed add_episode in {(end - start) * 1000} ms')
|
logger.info(f'Completed add_episode in {(end - start) * 1000} ms')
|
||||||
|
|
@ -499,7 +518,10 @@ class Graphiti:
|
||||||
]
|
]
|
||||||
|
|
||||||
# Save all the episodes
|
# Save all the episodes
|
||||||
await semaphore_gather(*[episode.save(self.driver) for episode in episodes])
|
await semaphore_gather(
|
||||||
|
*[episode.save(self.driver) for episode in episodes],
|
||||||
|
max_coroutines=self.max_coroutines,
|
||||||
|
)
|
||||||
|
|
||||||
# Get previous episode context for each episode
|
# Get previous episode context for each episode
|
||||||
episode_pairs = await retrieve_previous_episodes_bulk(self.driver, episodes)
|
episode_pairs = await retrieve_previous_episodes_bulk(self.driver, episodes)
|
||||||
|
|
@ -515,16 +537,21 @@ class Graphiti:
|
||||||
await semaphore_gather(
|
await semaphore_gather(
|
||||||
*[node.generate_name_embedding(self.embedder) for node in extracted_nodes],
|
*[node.generate_name_embedding(self.embedder) for node in extracted_nodes],
|
||||||
*[edge.generate_embedding(self.embedder) for edge in extracted_edges],
|
*[edge.generate_embedding(self.embedder) for edge in extracted_edges],
|
||||||
|
max_coroutines=self.max_coroutines,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Dedupe extracted nodes, compress extracted edges
|
# Dedupe extracted nodes, compress extracted edges
|
||||||
(nodes, uuid_map), extracted_edges_timestamped = await semaphore_gather(
|
(nodes, uuid_map), extracted_edges_timestamped = await semaphore_gather(
|
||||||
dedupe_nodes_bulk(self.driver, self.llm_client, extracted_nodes),
|
dedupe_nodes_bulk(self.driver, self.llm_client, extracted_nodes),
|
||||||
extract_edge_dates_bulk(self.llm_client, extracted_edges, episode_pairs),
|
extract_edge_dates_bulk(self.llm_client, extracted_edges, episode_pairs),
|
||||||
|
max_coroutines=self.max_coroutines,
|
||||||
)
|
)
|
||||||
|
|
||||||
# save nodes to KG
|
# save nodes to KG
|
||||||
await semaphore_gather(*[node.save(self.driver) for node in nodes])
|
await semaphore_gather(
|
||||||
|
*[node.save(self.driver) for node in nodes],
|
||||||
|
max_coroutines=self.max_coroutines,
|
||||||
|
)
|
||||||
|
|
||||||
# re-map edge pointers so that they don't point to discard dupe nodes
|
# re-map edge pointers so that they don't point to discard dupe nodes
|
||||||
extracted_edges_with_resolved_pointers: list[EntityEdge] = resolve_edge_pointers(
|
extracted_edges_with_resolved_pointers: list[EntityEdge] = resolve_edge_pointers(
|
||||||
|
|
@ -536,7 +563,8 @@ class Graphiti:
|
||||||
|
|
||||||
# save episodic edges to KG
|
# save episodic edges to KG
|
||||||
await semaphore_gather(
|
await semaphore_gather(
|
||||||
*[edge.save(self.driver) for edge in episodic_edges_with_resolved_pointers]
|
*[edge.save(self.driver) for edge in episodic_edges_with_resolved_pointers],
|
||||||
|
max_coroutines=self.max_coroutines,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Dedupe extracted edges
|
# Dedupe extracted edges
|
||||||
|
|
@ -548,7 +576,10 @@ class Graphiti:
|
||||||
# invalidate edges
|
# invalidate edges
|
||||||
|
|
||||||
# save edges to KG
|
# save edges to KG
|
||||||
await semaphore_gather(*[edge.save(self.driver) for edge in edges])
|
await semaphore_gather(
|
||||||
|
*[edge.save(self.driver) for edge in edges],
|
||||||
|
max_coroutines=self.max_coroutines,
|
||||||
|
)
|
||||||
|
|
||||||
end = time()
|
end = time()
|
||||||
logger.info(f'Completed add_episode_bulk in {(end - start) * 1000} ms')
|
logger.info(f'Completed add_episode_bulk in {(end - start) * 1000} ms')
|
||||||
|
|
@ -572,11 +603,18 @@ class Graphiti:
|
||||||
)
|
)
|
||||||
|
|
||||||
await semaphore_gather(
|
await semaphore_gather(
|
||||||
*[node.generate_name_embedding(self.embedder) for node in community_nodes]
|
*[node.generate_name_embedding(self.embedder) for node in community_nodes],
|
||||||
|
max_coroutines=self.max_coroutines,
|
||||||
)
|
)
|
||||||
|
|
||||||
await semaphore_gather(*[node.save(self.driver) for node in community_nodes])
|
await semaphore_gather(
|
||||||
await semaphore_gather(*[edge.save(self.driver) for edge in community_edges])
|
*[node.save(self.driver) for node in community_nodes],
|
||||||
|
max_coroutines=self.max_coroutines,
|
||||||
|
)
|
||||||
|
await semaphore_gather(
|
||||||
|
*[edge.save(self.driver) for edge in community_edges],
|
||||||
|
max_coroutines=self.max_coroutines,
|
||||||
|
)
|
||||||
|
|
||||||
return community_nodes
|
return community_nodes
|
||||||
|
|
||||||
|
|
@ -683,7 +721,8 @@ class Graphiti:
|
||||||
episodes = await EpisodicNode.get_by_uuids(self.driver, episode_uuids)
|
episodes = await EpisodicNode.get_by_uuids(self.driver, episode_uuids)
|
||||||
|
|
||||||
edges_list = await semaphore_gather(
|
edges_list = await semaphore_gather(
|
||||||
*[EntityEdge.get_by_uuids(self.driver, episode.entity_edges) for episode in episodes]
|
*[EntityEdge.get_by_uuids(self.driver, episode.entity_edges) for episode in episodes],
|
||||||
|
max_coroutines=self.max_coroutines,
|
||||||
)
|
)
|
||||||
|
|
||||||
edges: list[EntityEdge] = [edge for lst in edges_list for edge in lst]
|
edges: list[EntityEdge] = [edge for lst in edges_list for edge in lst]
|
||||||
|
|
@ -759,6 +798,12 @@ class Graphiti:
|
||||||
if record['episode_count'] == 1:
|
if record['episode_count'] == 1:
|
||||||
nodes_to_delete.append(node)
|
nodes_to_delete.append(node)
|
||||||
|
|
||||||
await semaphore_gather(*[node.delete(self.driver) for node in nodes_to_delete])
|
await semaphore_gather(
|
||||||
await semaphore_gather(*[edge.delete(self.driver) for edge in edges_to_delete])
|
*[node.delete(self.driver) for node in nodes_to_delete],
|
||||||
|
max_coroutines=self.max_coroutines,
|
||||||
|
)
|
||||||
|
await semaphore_gather(
|
||||||
|
*[edge.delete(self.driver) for edge in edges_to_delete],
|
||||||
|
max_coroutines=self.max_coroutines,
|
||||||
|
)
|
||||||
await episode.delete(self.driver)
|
await episode.delete(self.driver)
|
||||||
|
|
|
||||||
|
|
@ -94,9 +94,9 @@ def normalize_l2(embedding: list[float]) -> NDArray:
|
||||||
# Use this instead of asyncio.gather() to bound coroutines
|
# Use this instead of asyncio.gather() to bound coroutines
|
||||||
async def semaphore_gather(
|
async def semaphore_gather(
|
||||||
*coroutines: Coroutine,
|
*coroutines: Coroutine,
|
||||||
max_coroutines: int = SEMAPHORE_LIMIT,
|
max_coroutines: int | None = None,
|
||||||
):
|
):
|
||||||
semaphore = asyncio.Semaphore(max_coroutines)
|
semaphore = asyncio.Semaphore(max_coroutines or SEMAPHORE_LIMIT)
|
||||||
|
|
||||||
async def _wrap_coroutine(coroutine):
|
async def _wrap_coroutine(coroutine):
|
||||||
async with semaphore:
|
async with semaphore:
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue