update semaphore gather to use batches (#471)

* update semaphore gather to use batches

* batch semaphore update

* remove return type
This commit is contained in:
Preston Rasmussen 2025-05-12 14:00:38 -04:00 committed by GitHub
parent 3d22dc16f4
commit e7ecc71983
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -18,6 +18,7 @@ import asyncio
import os
from collections.abc import Coroutine
from datetime import datetime
from typing import Any
import numpy as np
from dotenv import load_dotenv
@ -91,11 +92,27 @@ def normalize_l2(embedding: list[float]):
# Use this instead of asyncio.gather() to bound coroutines
async def semaphore_gather(*coroutines: Coroutine, max_coroutines: int = SEMAPHORE_LIMIT):
async def semaphore_gather(
*coroutines: Coroutine,
max_coroutines: int = SEMAPHORE_LIMIT,
):
semaphore = asyncio.Semaphore(max_coroutines)
async def _wrap_coroutine(coroutine):
async def _wrap(coro: Coroutine) -> Any:
async with semaphore:
return await coroutine
return await coro
return await asyncio.gather(*(_wrap_coroutine(coroutine) for coroutine in coroutines))
results = []
batch = []
for coroutine in coroutines:
batch.append(_wrap(coroutine))
# once we hit max_coroutines, gather and clear the batch
if len(batch) >= max_coroutines:
results.extend(await asyncio.gather(*batch))
batch.clear()
# gather any remaining coroutines in the final batch
if batch:
results.extend(await asyncio.gather(*batch))
return results