update semaphore gather to use batches (#471)

* update semaphore gather to use batches

* batch semaphore update

* remove return type
This commit is contained in:
Preston Rasmussen 2025-05-12 14:00:38 -04:00 committed by GitHub
parent 3d22dc16f4
commit e7ecc71983
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -18,6 +18,7 @@ import asyncio
import os import os
from collections.abc import Coroutine from collections.abc import Coroutine
from datetime import datetime from datetime import datetime
from typing import Any
import numpy as np import numpy as np
from dotenv import load_dotenv from dotenv import load_dotenv
@ -91,11 +92,27 @@ def normalize_l2(embedding: list[float]):
# Use this instead of asyncio.gather() to bound coroutines # Use this instead of asyncio.gather() to bound coroutines
async def semaphore_gather(*coroutines: Coroutine, max_coroutines: int = SEMAPHORE_LIMIT): async def semaphore_gather(
*coroutines: Coroutine,
max_coroutines: int = SEMAPHORE_LIMIT,
):
semaphore = asyncio.Semaphore(max_coroutines) semaphore = asyncio.Semaphore(max_coroutines)
async def _wrap_coroutine(coroutine): async def _wrap(coro: Coroutine) -> Any:
async with semaphore: async with semaphore:
return await coroutine return await coro
return await asyncio.gather(*(_wrap_coroutine(coroutine) for coroutine in coroutines)) results = []
batch = []
for coroutine in coroutines:
batch.append(_wrap(coroutine))
# once we hit max_coroutines, gather and clear the batch
if len(batch) >= max_coroutines:
results.extend(await asyncio.gather(*batch))
batch.clear()
# gather any remaining coroutines in the final batch
if batch:
results.extend(await asyncio.gather(*batch))
return results