Bounded semaphore - limiting concurrency (#244)

* WIP

* add semaphore

* remove unused imports

* remove unused imports

* lower concurrency limit
This commit is contained in:
Preston Rasmussen 2024-12-17 13:08:18 -05:00 committed by GitHub
parent 0186ac920c
commit 00fe87679e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 87 additions and 64 deletions

View file

@ -25,6 +25,7 @@ from dotenv import load_dotenv
from examples.multi_session_conversation_memory.parse_msc_messages import conversation_q_and_a from examples.multi_session_conversation_memory.parse_msc_messages import conversation_q_and_a
from graphiti_core import Graphiti from graphiti_core import Graphiti
from graphiti_core.helpers import semaphore_gather
from graphiti_core.prompts import prompt_library from graphiti_core.prompts import prompt_library
from graphiti_core.search.search_config_recipes import COMBINED_HYBRID_SEARCH_RRF from graphiti_core.search.search_config_recipes import COMBINED_HYBRID_SEARCH_RRF
@ -122,7 +123,7 @@ async def main():
qa_chunk = qa[i : i + 20] qa_chunk = qa[i : i + 20]
group_ids = range(len(qa))[i : i + 20] group_ids = range(len(qa))[i : i + 20]
results = list( results = list(
await asyncio.gather( await semaphore_gather(
*[ *[
evaluate_qa(graphiti, str(group_id), query, answer) evaluate_qa(graphiti, str(group_id), query, answer)
for group_id, (query, answer) in zip(group_ids, qa_chunk) for group_id, (query, answer) in zip(group_ids, qa_chunk)

View file

@ -26,6 +26,7 @@ from examples.multi_session_conversation_memory.parse_msc_messages import (
parse_msc_messages, parse_msc_messages,
) )
from graphiti_core import Graphiti from graphiti_core import Graphiti
from graphiti_core.helpers import semaphore_gather
load_dotenv() load_dotenv()
@ -75,7 +76,7 @@ async def main():
msc_message_slice = msc_messages[i : i + 10] msc_message_slice = msc_messages[i : i + 10]
group_ids = range(len(msc_messages))[i : i + 10] group_ids = range(len(msc_messages))[i : i + 10]
await asyncio.gather( await semaphore_gather(
*[ *[
add_conversation(graphiti, str(group_id), messages) add_conversation(graphiti, str(group_id), messages)
for group_id, messages in zip(group_ids, msc_message_slice) for group_id, messages in zip(group_ids, msc_message_slice)

View file

@ -14,7 +14,6 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
import asyncio
import logging import logging
from typing import Any from typing import Any
@ -22,6 +21,7 @@ import openai
from openai import AsyncOpenAI from openai import AsyncOpenAI
from pydantic import BaseModel from pydantic import BaseModel
from ..helpers import semaphore_gather
from ..llm_client import LLMConfig, RateLimitError from ..llm_client import LLMConfig, RateLimitError
from ..prompts import Message from ..prompts import Message
from .client import CrossEncoderClient from .client import CrossEncoderClient
@ -75,7 +75,7 @@ class OpenAIRerankerClient(CrossEncoderClient):
for passage in passages for passage in passages
] ]
try: try:
responses = await asyncio.gather( responses = await semaphore_gather(
*[ *[
self.client.chat.completions.create( self.client.chat.completions.create(
model=DEFAULT_MODEL, model=DEFAULT_MODEL,

View file

@ -14,7 +14,6 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
import asyncio
import logging import logging
from datetime import datetime from datetime import datetime
from time import time from time import time
@ -27,7 +26,7 @@ from graphiti_core.cross_encoder.client import CrossEncoderClient
from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
from graphiti_core.edges import EntityEdge, EpisodicEdge from graphiti_core.edges import EntityEdge, EpisodicEdge
from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
from graphiti_core.helpers import DEFAULT_DATABASE from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
from graphiti_core.llm_client import LLMClient, OpenAIClient from graphiti_core.llm_client import LLMClient, OpenAIClient
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode
from graphiti_core.search.search import SearchConfig, search from graphiti_core.search.search import SearchConfig, search
@ -340,13 +339,13 @@ class Graphiti:
# Calculate Embeddings # Calculate Embeddings
await asyncio.gather( await semaphore_gather(
*[node.generate_name_embedding(self.embedder) for node in extracted_nodes] *[node.generate_name_embedding(self.embedder) for node in extracted_nodes]
) )
# Find relevant nodes already in the graph # Find relevant nodes already in the graph
existing_nodes_lists: list[list[EntityNode]] = list( existing_nodes_lists: list[list[EntityNode]] = list(
await asyncio.gather( await semaphore_gather(
*[get_relevant_nodes(self.driver, [node]) for node in extracted_nodes] *[get_relevant_nodes(self.driver, [node]) for node in extracted_nodes]
) )
) )
@ -354,7 +353,7 @@ class Graphiti:
# Resolve extracted nodes with nodes already in the graph and extract facts # Resolve extracted nodes with nodes already in the graph and extract facts
logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}') logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
(mentioned_nodes, uuid_map), extracted_edges = await asyncio.gather( (mentioned_nodes, uuid_map), extracted_edges = await semaphore_gather(
resolve_extracted_nodes( resolve_extracted_nodes(
self.llm_client, self.llm_client,
extracted_nodes, extracted_nodes,
@ -374,7 +373,7 @@ class Graphiti:
) )
# calculate embeddings # calculate embeddings
await asyncio.gather( await semaphore_gather(
*[ *[
edge.generate_embedding(self.embedder) edge.generate_embedding(self.embedder)
for edge in extracted_edges_with_resolved_pointers for edge in extracted_edges_with_resolved_pointers
@ -383,7 +382,7 @@ class Graphiti:
# Resolve extracted edges with related edges already in the graph # Resolve extracted edges with related edges already in the graph
related_edges_list: list[list[EntityEdge]] = list( related_edges_list: list[list[EntityEdge]] = list(
await asyncio.gather( await semaphore_gather(
*[ *[
get_relevant_edges( get_relevant_edges(
self.driver, self.driver,
@ -404,7 +403,7 @@ class Graphiti:
) )
existing_source_edges_list: list[list[EntityEdge]] = list( existing_source_edges_list: list[list[EntityEdge]] = list(
await asyncio.gather( await semaphore_gather(
*[ *[
get_relevant_edges( get_relevant_edges(
self.driver, self.driver,
@ -419,7 +418,7 @@ class Graphiti:
) )
existing_target_edges_list: list[list[EntityEdge]] = list( existing_target_edges_list: list[list[EntityEdge]] = list(
await asyncio.gather( await semaphore_gather(
*[ *[
get_relevant_edges( get_relevant_edges(
self.driver, self.driver,
@ -468,7 +467,7 @@ class Graphiti:
# Update any communities # Update any communities
if update_communities: if update_communities:
await asyncio.gather( await semaphore_gather(
*[ *[
update_community(self.driver, self.llm_client, self.embedder, node) update_community(self.driver, self.llm_client, self.embedder, node)
for node in nodes for node in nodes
@ -538,7 +537,7 @@ class Graphiti:
] ]
# Save all the episodes # Save all the episodes
await asyncio.gather(*[episode.save(self.driver) for episode in episodes]) await semaphore_gather(*[episode.save(self.driver) for episode in episodes])
# Get previous episode context for each episode # Get previous episode context for each episode
episode_pairs = await retrieve_previous_episodes_bulk(self.driver, episodes) episode_pairs = await retrieve_previous_episodes_bulk(self.driver, episodes)
@ -551,19 +550,19 @@ class Graphiti:
) = await extract_nodes_and_edges_bulk(self.llm_client, episode_pairs) ) = await extract_nodes_and_edges_bulk(self.llm_client, episode_pairs)
# Generate embeddings # Generate embeddings
await asyncio.gather( await semaphore_gather(
*[node.generate_name_embedding(self.embedder) for node in extracted_nodes], *[node.generate_name_embedding(self.embedder) for node in extracted_nodes],
*[edge.generate_embedding(self.embedder) for edge in extracted_edges], *[edge.generate_embedding(self.embedder) for edge in extracted_edges],
) )
# Dedupe extracted nodes, compress extracted edges # Dedupe extracted nodes, compress extracted edges
(nodes, uuid_map), extracted_edges_timestamped = await asyncio.gather( (nodes, uuid_map), extracted_edges_timestamped = await semaphore_gather(
dedupe_nodes_bulk(self.driver, self.llm_client, extracted_nodes), dedupe_nodes_bulk(self.driver, self.llm_client, extracted_nodes),
extract_edge_dates_bulk(self.llm_client, extracted_edges, episode_pairs), extract_edge_dates_bulk(self.llm_client, extracted_edges, episode_pairs),
) )
# save nodes to KG # save nodes to KG
await asyncio.gather(*[node.save(self.driver) for node in nodes]) await semaphore_gather(*[node.save(self.driver) for node in nodes])
# re-map edge pointers so that they don't point to discard dupe nodes # re-map edge pointers so that they don't point to discard dupe nodes
extracted_edges_with_resolved_pointers: list[EntityEdge] = resolve_edge_pointers( extracted_edges_with_resolved_pointers: list[EntityEdge] = resolve_edge_pointers(
@ -574,7 +573,7 @@ class Graphiti:
) )
# save episodic edges to KG # save episodic edges to KG
await asyncio.gather( await semaphore_gather(
*[edge.save(self.driver) for edge in episodic_edges_with_resolved_pointers] *[edge.save(self.driver) for edge in episodic_edges_with_resolved_pointers]
) )
@ -587,7 +586,7 @@ class Graphiti:
# invalidate edges # invalidate edges
# save edges to KG # save edges to KG
await asyncio.gather(*[edge.save(self.driver) for edge in edges]) await semaphore_gather(*[edge.save(self.driver) for edge in edges])
end = time() end = time()
logger.info(f'Completed add_episode_bulk in {(end - start) * 1000} ms') logger.info(f'Completed add_episode_bulk in {(end - start) * 1000} ms')
@ -610,12 +609,12 @@ class Graphiti:
self.driver, self.llm_client, group_ids self.driver, self.llm_client, group_ids
) )
await asyncio.gather( await semaphore_gather(
*[node.generate_name_embedding(self.embedder) for node in community_nodes] *[node.generate_name_embedding(self.embedder) for node in community_nodes]
) )
await asyncio.gather(*[node.save(self.driver) for node in community_nodes]) await semaphore_gather(*[node.save(self.driver) for node in community_nodes])
await asyncio.gather(*[edge.save(self.driver) for edge in community_edges]) await semaphore_gather(*[edge.save(self.driver) for edge in community_edges])
return community_nodes return community_nodes
@ -698,7 +697,7 @@ class Graphiti:
async def get_episode_mentions(self, episode_uuids: list[str]) -> SearchResults: async def get_episode_mentions(self, episode_uuids: list[str]) -> SearchResults:
episodes = await EpisodicNode.get_by_uuids(self.driver, episode_uuids) episodes = await EpisodicNode.get_by_uuids(self.driver, episode_uuids)
edges_list = await asyncio.gather( edges_list = await semaphore_gather(
*[EntityEdge.get_by_uuids(self.driver, episode.entity_edges) for episode in episodes] *[EntityEdge.get_by_uuids(self.driver, episode.entity_edges) for episode in episodes]
) )

View file

@ -14,7 +14,9 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
import asyncio
import os import os
from collections.abc import Coroutine
from datetime import datetime from datetime import datetime
import numpy as np import numpy as np
@ -25,6 +27,7 @@ load_dotenv()
DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', None) DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', None)
USE_PARALLEL_RUNTIME = bool(os.getenv('USE_PARALLEL_RUNTIME', False)) USE_PARALLEL_RUNTIME = bool(os.getenv('USE_PARALLEL_RUNTIME', False))
SEMAPHORE_LIMIT = int(os.getenv('SEMAPHORE_LIMIT', 20))
MAX_REFLEXION_ITERATIONS = 2 MAX_REFLEXION_ITERATIONS = 2
DEFAULT_PAGE_LIMIT = 20 DEFAULT_PAGE_LIMIT = 20
@ -80,3 +83,19 @@ def normalize_l2(embedding: list[float]) -> list[float]:
else: else:
norm = np.linalg.norm(embedding_array, 2, axis=1, keepdims=True) norm = np.linalg.norm(embedding_array, 2, axis=1, keepdims=True)
return (np.where(norm == 0, embedding_array, embedding_array / norm)).tolist() return (np.where(norm == 0, embedding_array, embedding_array / norm)).tolist()
# Use this instead of asyncio.gather() to bound coroutines
async def semaphore_gather(
*coroutines: Coroutine, max_coroutines: int = SEMAPHORE_LIMIT, return_exceptions=True
):
semaphore = asyncio.Semaphore(max_coroutines)
async def _wrap_coroutine(coroutine):
async with semaphore:
return await coroutine
return await asyncio.gather(
*(_wrap_coroutine(coroutine) for coroutine in coroutines),
return_exceptions=return_exceptions,
)

View file

@ -14,7 +14,6 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
import asyncio
import logging import logging
from collections import defaultdict from collections import defaultdict
from time import time from time import time
@ -25,6 +24,7 @@ from graphiti_core.cross_encoder.client import CrossEncoderClient
from graphiti_core.edges import EntityEdge from graphiti_core.edges import EntityEdge
from graphiti_core.embedder import EmbedderClient from graphiti_core.embedder import EmbedderClient
from graphiti_core.errors import SearchRerankerError from graphiti_core.errors import SearchRerankerError
from graphiti_core.helpers import semaphore_gather
from graphiti_core.nodes import CommunityNode, EntityNode from graphiti_core.nodes import CommunityNode, EntityNode
from graphiti_core.search.search_config import ( from graphiti_core.search.search_config import (
DEFAULT_SEARCH_LIMIT, DEFAULT_SEARCH_LIMIT,
@ -78,7 +78,7 @@ async def search(
# if group_ids is empty, set it to None # if group_ids is empty, set it to None
group_ids = group_ids if group_ids else None group_ids = group_ids if group_ids else None
edges, nodes, communities = await asyncio.gather( edges, nodes, communities = await semaphore_gather(
edge_search( edge_search(
driver, driver,
cross_encoder, cross_encoder,
@ -141,7 +141,7 @@ async def edge_search(
return [] return []
search_results: list[list[EntityEdge]] = list( search_results: list[list[EntityEdge]] = list(
await asyncio.gather( await semaphore_gather(
*[ *[
edge_fulltext_search(driver, query, group_ids, 2 * limit), edge_fulltext_search(driver, query, group_ids, 2 * limit),
edge_similarity_search( edge_similarity_search(
@ -226,7 +226,7 @@ async def node_search(
return [] return []
search_results: list[list[EntityNode]] = list( search_results: list[list[EntityNode]] = list(
await asyncio.gather( await semaphore_gather(
*[ *[
node_fulltext_search(driver, query, group_ids, 2 * limit), node_fulltext_search(driver, query, group_ids, 2 * limit),
node_similarity_search( node_similarity_search(
@ -295,7 +295,7 @@ async def community_search(
return [] return []
search_results: list[list[CommunityNode]] = list( search_results: list[list[CommunityNode]] = list(
await asyncio.gather( await semaphore_gather(
*[ *[
community_fulltext_search(driver, query, group_ids, 2 * limit), community_fulltext_search(driver, query, group_ids, 2 * limit),
community_similarity_search( community_similarity_search(

View file

@ -14,7 +14,6 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
import asyncio
import logging import logging
from collections import defaultdict from collections import defaultdict
from time import time from time import time
@ -30,6 +29,7 @@ from graphiti_core.helpers import (
USE_PARALLEL_RUNTIME, USE_PARALLEL_RUNTIME,
lucene_sanitize, lucene_sanitize,
normalize_l2, normalize_l2,
semaphore_gather,
) )
from graphiti_core.nodes import ( from graphiti_core.nodes import (
CommunityNode, CommunityNode,
@ -549,7 +549,7 @@ async def hybrid_node_search(
start = time() start = time()
results: list[list[EntityNode]] = list( results: list[list[EntityNode]] = list(
await asyncio.gather( await semaphore_gather(
*[node_fulltext_search(driver, q, group_ids, 2 * limit) for q in queries], *[node_fulltext_search(driver, q, group_ids, 2 * limit) for q in queries],
*[node_similarity_search(driver, e, group_ids, 2 * limit) for e in embeddings], *[node_similarity_search(driver, e, group_ids, 2 * limit) for e in embeddings],
) )
@ -619,7 +619,7 @@ async def get_relevant_edges(
relevant_edges: list[EntityEdge] = [] relevant_edges: list[EntityEdge] = []
relevant_edge_uuids = set() relevant_edge_uuids = set()
results = await asyncio.gather( results = await semaphore_gather(
*[ *[
edge_similarity_search( edge_similarity_search(
driver, driver,

View file

@ -14,7 +14,6 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
import asyncio
import logging import logging
import typing import typing
from collections import defaultdict from collections import defaultdict
@ -26,6 +25,7 @@ from numpy import dot, sqrt
from pydantic import BaseModel from pydantic import BaseModel
from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge
from graphiti_core.helpers import semaphore_gather
from graphiti_core.llm_client import LLMClient from graphiti_core.llm_client import LLMClient
from graphiti_core.models.edges.edge_db_queries import ( from graphiti_core.models.edges.edge_db_queries import (
ENTITY_EDGE_SAVE_BULK, ENTITY_EDGE_SAVE_BULK,
@ -71,7 +71,7 @@ class RawEpisode(BaseModel):
async def retrieve_previous_episodes_bulk( async def retrieve_previous_episodes_bulk(
driver: AsyncDriver, episodes: list[EpisodicNode] driver: AsyncDriver, episodes: list[EpisodicNode]
) -> list[tuple[EpisodicNode, list[EpisodicNode]]]: ) -> list[tuple[EpisodicNode, list[EpisodicNode]]]:
previous_episodes_list = await asyncio.gather( previous_episodes_list = await semaphore_gather(
*[ *[
retrieve_episodes( retrieve_episodes(
driver, episode.valid_at, last_n=EPISODE_WINDOW_LEN, group_ids=[episode.group_id] driver, episode.valid_at, last_n=EPISODE_WINDOW_LEN, group_ids=[episode.group_id]
@ -118,7 +118,7 @@ async def add_nodes_and_edges_bulk_tx(
async def extract_nodes_and_edges_bulk( async def extract_nodes_and_edges_bulk(
llm_client: LLMClient, episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]] llm_client: LLMClient, episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]]
) -> tuple[list[EntityNode], list[EntityEdge], list[EpisodicEdge]]: ) -> tuple[list[EntityNode], list[EntityEdge], list[EpisodicEdge]]:
extracted_nodes_bulk = await asyncio.gather( extracted_nodes_bulk = await semaphore_gather(
*[ *[
extract_nodes(llm_client, episode, previous_episodes) extract_nodes(llm_client, episode, previous_episodes)
for episode, previous_episodes in episode_tuples for episode, previous_episodes in episode_tuples
@ -130,7 +130,7 @@ async def extract_nodes_and_edges_bulk(
[episode[1] for episode in episode_tuples], [episode[1] for episode in episode_tuples],
) )
extracted_edges_bulk = await asyncio.gather( extracted_edges_bulk = await semaphore_gather(
*[ *[
extract_edges( extract_edges(
llm_client, llm_client,
@ -171,13 +171,13 @@ async def dedupe_nodes_bulk(
node_chunks = [nodes[i : i + CHUNK_SIZE] for i in range(0, len(nodes), CHUNK_SIZE)] node_chunks = [nodes[i : i + CHUNK_SIZE] for i in range(0, len(nodes), CHUNK_SIZE)]
existing_nodes_chunks: list[list[EntityNode]] = list( existing_nodes_chunks: list[list[EntityNode]] = list(
await asyncio.gather( await semaphore_gather(
*[get_relevant_nodes(driver, node_chunk) for node_chunk in node_chunks] *[get_relevant_nodes(driver, node_chunk) for node_chunk in node_chunks]
) )
) )
results: list[tuple[list[EntityNode], dict[str, str]]] = list( results: list[tuple[list[EntityNode], dict[str, str]]] = list(
await asyncio.gather( await semaphore_gather(
*[ *[
dedupe_extracted_nodes(llm_client, node_chunk, existing_nodes_chunks[i]) dedupe_extracted_nodes(llm_client, node_chunk, existing_nodes_chunks[i])
for i, node_chunk in enumerate(node_chunks) for i, node_chunk in enumerate(node_chunks)
@ -205,13 +205,13 @@ async def dedupe_edges_bulk(
] ]
relevant_edges_chunks: list[list[EntityEdge]] = list( relevant_edges_chunks: list[list[EntityEdge]] = list(
await asyncio.gather( await semaphore_gather(
*[get_relevant_edges(driver, edge_chunk, None, None) for edge_chunk in edge_chunks] *[get_relevant_edges(driver, edge_chunk, None, None) for edge_chunk in edge_chunks]
) )
) )
resolved_edge_chunks: list[list[EntityEdge]] = list( resolved_edge_chunks: list[list[EntityEdge]] = list(
await asyncio.gather( await semaphore_gather(
*[ *[
dedupe_extracted_edges(llm_client, edge_chunk, relevant_edges_chunks[i]) dedupe_extracted_edges(llm_client, edge_chunk, relevant_edges_chunks[i])
for i, edge_chunk in enumerate(edge_chunks) for i, edge_chunk in enumerate(edge_chunks)
@ -292,7 +292,9 @@ async def compress_nodes(
# add both nodes to the shortest chunk # add both nodes to the shortest chunk
node_chunks[-1].extend([n, m]) node_chunks[-1].extend([n, m])
results = await asyncio.gather(*[dedupe_node_list(llm_client, chunk) for chunk in node_chunks]) results = await semaphore_gather(
*[dedupe_node_list(llm_client, chunk) for chunk in node_chunks]
)
extended_map = dict(uuid_map) extended_map = dict(uuid_map)
compressed_nodes: list[EntityNode] = [] compressed_nodes: list[EntityNode] = []
@ -315,7 +317,9 @@ async def compress_edges(llm_client: LLMClient, edges: list[EntityEdge]) -> list
# We build a map of the edges based on their source and target nodes. # We build a map of the edges based on their source and target nodes.
edge_chunks = chunk_edges_by_nodes(edges) edge_chunks = chunk_edges_by_nodes(edges)
results = await asyncio.gather(*[dedupe_edge_list(llm_client, chunk) for chunk in edge_chunks]) results = await semaphore_gather(
*[dedupe_edge_list(llm_client, chunk) for chunk in edge_chunks]
)
compressed_edges: list[EntityEdge] = [] compressed_edges: list[EntityEdge] = []
for edge_chunk in results: for edge_chunk in results:
@ -368,7 +372,7 @@ async def extract_edge_dates_bulk(
episode.uuid: (episode, previous_episodes) for episode, previous_episodes in episode_pairs episode.uuid: (episode, previous_episodes) for episode, previous_episodes in episode_pairs
} }
results = await asyncio.gather( results = await semaphore_gather(
*[ *[
extract_edge_dates( extract_edge_dates(
llm_client, llm_client,

View file

@ -7,7 +7,7 @@ from pydantic import BaseModel
from graphiti_core.edges import CommunityEdge from graphiti_core.edges import CommunityEdge
from graphiti_core.embedder import EmbedderClient from graphiti_core.embedder import EmbedderClient
from graphiti_core.helpers import DEFAULT_DATABASE from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
from graphiti_core.llm_client import LLMClient from graphiti_core.llm_client import LLMClient
from graphiti_core.nodes import ( from graphiti_core.nodes import (
CommunityNode, CommunityNode,
@ -71,7 +71,7 @@ async def get_community_clusters(
community_clusters.extend( community_clusters.extend(
list( list(
await asyncio.gather( await semaphore_gather(
*[EntityNode.get_by_uuids(driver, cluster) for cluster in cluster_uuids] *[EntityNode.get_by_uuids(driver, cluster) for cluster in cluster_uuids]
) )
) )
@ -164,7 +164,7 @@ async def build_community(
odd_one_out = summaries.pop() odd_one_out = summaries.pop()
length -= 1 length -= 1
new_summaries: list[str] = list( new_summaries: list[str] = list(
await asyncio.gather( await semaphore_gather(
*[ *[
summarize_pair(llm_client, (str(left_summary), str(right_summary))) summarize_pair(llm_client, (str(left_summary), str(right_summary)))
for left_summary, right_summary in zip( for left_summary, right_summary in zip(
@ -207,7 +207,9 @@ async def build_communities(
return await build_community(llm_client, cluster) return await build_community(llm_client, cluster)
communities: list[tuple[CommunityNode, list[CommunityEdge]]] = list( communities: list[tuple[CommunityNode, list[CommunityEdge]]] = list(
await asyncio.gather(*[limited_build_community(cluster) for cluster in community_clusters]) await semaphore_gather(
*[limited_build_community(cluster) for cluster in community_clusters]
)
) )
community_nodes: list[CommunityNode] = [] community_nodes: list[CommunityNode] = []

View file

@ -14,13 +14,12 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
import asyncio
import logging import logging
from datetime import datetime from datetime import datetime
from time import time from time import time
from graphiti_core.edges import CommunityEdge, EntityEdge, EpisodicEdge from graphiti_core.edges import CommunityEdge, EntityEdge, EpisodicEdge
from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather
from graphiti_core.llm_client import LLMClient from graphiti_core.llm_client import LLMClient
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
from graphiti_core.prompts import prompt_library from graphiti_core.prompts import prompt_library
@ -199,7 +198,7 @@ async def resolve_extracted_edges(
) -> tuple[list[EntityEdge], list[EntityEdge]]: ) -> tuple[list[EntityEdge], list[EntityEdge]]:
# resolve edges with related edges in the graph, extract temporal information, and find invalidation candidates # resolve edges with related edges in the graph, extract temporal information, and find invalidation candidates
results: list[tuple[EntityEdge, list[EntityEdge]]] = list( results: list[tuple[EntityEdge, list[EntityEdge]]] = list(
await asyncio.gather( await semaphore_gather(
*[ *[
resolve_extracted_edge( resolve_extracted_edge(
llm_client, llm_client,
@ -266,7 +265,7 @@ async def resolve_extracted_edge(
current_episode: EpisodicNode, current_episode: EpisodicNode,
previous_episodes: list[EpisodicNode], previous_episodes: list[EpisodicNode],
) -> tuple[EntityEdge, list[EntityEdge]]: ) -> tuple[EntityEdge, list[EntityEdge]]:
resolved_edge, (valid_at, invalid_at), invalidation_candidates = await asyncio.gather( resolved_edge, (valid_at, invalid_at), invalidation_candidates = await semaphore_gather(
dedupe_extracted_edge(llm_client, extracted_edge, related_edges), dedupe_extracted_edge(llm_client, extracted_edge, related_edges),
extract_edge_dates(llm_client, extracted_edge, current_episode, previous_episodes), extract_edge_dates(llm_client, extracted_edge, current_episode, previous_episodes),
get_edge_contradictions(llm_client, extracted_edge, existing_edges), get_edge_contradictions(llm_client, extracted_edge, existing_edges),

View file

@ -14,14 +14,13 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
import asyncio
import logging import logging
from datetime import datetime, timezone from datetime import datetime, timezone
from neo4j import AsyncDriver from neo4j import AsyncDriver
from typing_extensions import LiteralString from typing_extensions import LiteralString
from graphiti_core.helpers import DEFAULT_DATABASE from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
from graphiti_core.nodes import EpisodeType, EpisodicNode from graphiti_core.nodes import EpisodeType, EpisodicNode
EPISODE_WINDOW_LEN = 3 EPISODE_WINDOW_LEN = 3
@ -38,7 +37,7 @@ async def build_indices_and_constraints(driver: AsyncDriver, delete_existing: bo
database_=DEFAULT_DATABASE, database_=DEFAULT_DATABASE,
) )
index_names = [record['name'] for record in records] index_names = [record['name'] for record in records]
await asyncio.gather( await semaphore_gather(
*[ *[
driver.execute_query( driver.execute_query(
"""DROP INDEX $name""", """DROP INDEX $name""",
@ -82,7 +81,7 @@ async def build_indices_and_constraints(driver: AsyncDriver, delete_existing: bo
index_queries: list[LiteralString] = range_indices + fulltext_indices index_queries: list[LiteralString] = range_indices + fulltext_indices
await asyncio.gather( await semaphore_gather(
*[ *[
driver.execute_query( driver.execute_query(
query, query,

View file

@ -14,11 +14,10 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
import asyncio
import logging import logging
from time import time from time import time
from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather
from graphiti_core.llm_client import LLMClient from graphiti_core.llm_client import LLMClient
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
from graphiti_core.prompts import prompt_library from graphiti_core.prompts import prompt_library
@ -223,7 +222,7 @@ async def resolve_extracted_nodes(
uuid_map: dict[str, str] = {} uuid_map: dict[str, str] = {}
resolved_nodes: list[EntityNode] = [] resolved_nodes: list[EntityNode] = []
results: list[tuple[EntityNode, dict[str, str]]] = list( results: list[tuple[EntityNode, dict[str, str]]] = list(
await asyncio.gather( await semaphore_gather(
*[ *[
resolve_extracted_node( resolve_extracted_node(
llm_client, extracted_node, existing_nodes, episode, previous_episodes llm_client, extracted_node, existing_nodes, episode, previous_episodes
@ -275,7 +274,7 @@ async def resolve_extracted_node(
else [], else [],
} }
llm_response, node_summary_response = await asyncio.gather( llm_response, node_summary_response = await semaphore_gather(
llm_client.generate_response( llm_client.generate_response(
prompt_library.dedupe_nodes.node(context), response_model=NodeDuplicate prompt_library.dedupe_nodes.node(context), response_model=NodeDuplicate
), ),

View file

@ -14,7 +14,6 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
import asyncio
import logging import logging
import os import os
import sys import sys
@ -25,6 +24,7 @@ from dotenv import load_dotenv
from graphiti_core.edges import EntityEdge, EpisodicEdge from graphiti_core.edges import EntityEdge, EpisodicEdge
from graphiti_core.graphiti import Graphiti from graphiti_core.graphiti import Graphiti
from graphiti_core.helpers import semaphore_gather
from graphiti_core.nodes import EntityNode, EpisodicNode from graphiti_core.nodes import EntityNode, EpisodicNode
from graphiti_core.search.search_config_recipes import ( from graphiti_core.search.search_config_recipes import (
COMBINED_HYBRID_SEARCH_CROSS_ENCODER, COMBINED_HYBRID_SEARCH_CROSS_ENCODER,
@ -137,8 +137,8 @@ async def test_graph_integration():
edges = [episodic_edge_1, episodic_edge_2, entity_edge] edges = [episodic_edge_1, episodic_edge_2, entity_edge]
# test save # test save
await asyncio.gather(*[node.save(driver) for node in nodes]) await semaphore_gather(*[node.save(driver) for node in nodes])
await asyncio.gather(*[edge.save(driver) for edge in edges]) await semaphore_gather(*[edge.save(driver) for edge in edges])
# test get # test get
assert await EpisodicNode.get_by_uuid(driver, episode.uuid) is not None assert await EpisodicNode.get_by_uuid(driver, episode.uuid) is not None
@ -147,5 +147,5 @@ async def test_graph_integration():
assert await EntityEdge.get_by_uuid(driver, entity_edge.uuid) is not None assert await EntityEdge.get_by_uuid(driver, entity_edge.uuid) is not None
# test delete # test delete
await asyncio.gather(*[node.delete(driver) for node in nodes]) await semaphore_gather(*[node.delete(driver) for node in nodes])
await asyncio.gather(*[edge.delete(driver) for edge in edges]) await semaphore_gather(*[edge.delete(driver) for edge in edges])