diff --git a/graphiti_core/helpers.py b/graphiti_core/helpers.py index aac3228b..d5032807 100644 --- a/graphiti_core/helpers.py +++ b/graphiti_core/helpers.py @@ -18,6 +18,7 @@ import asyncio import os from collections.abc import Coroutine from datetime import datetime +from typing import Any import numpy as np from dotenv import load_dotenv @@ -91,11 +92,27 @@ def normalize_l2(embedding: list[float]): # Use this instead of asyncio.gather() to bound coroutines -async def semaphore_gather(*coroutines: Coroutine, max_coroutines: int = SEMAPHORE_LIMIT): +async def semaphore_gather( + *coroutines: Coroutine, + max_coroutines: int = SEMAPHORE_LIMIT, +): semaphore = asyncio.Semaphore(max_coroutines) - async def _wrap_coroutine(coroutine): + async def _wrap(coro: Coroutine) -> Any: async with semaphore: - return await coroutine + return await coro - return await asyncio.gather(*(_wrap_coroutine(coroutine) for coroutine in coroutines)) + results = [] + batch = [] + for coroutine in coroutines: + batch.append(_wrap(coroutine)) + # once we hit max_coroutines, gather and clear the batch + if len(batch) >= max_coroutines: + results.extend(await asyncio.gather(*batch)) + batch.clear() + + # gather any remaining coroutines in the final batch + if batch: + results.extend(await asyncio.gather(*batch)) + + return results