update semaphore gather to use batches (#471)
* update semaphore gather to use batches * batch semaphore update * remove return type
This commit is contained in:
parent
3d22dc16f4
commit
e7ecc71983
1 changed files with 21 additions and 4 deletions
|
|
@ -18,6 +18,7 @@ import asyncio
|
|||
import os
|
||||
from collections.abc import Coroutine
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from dotenv import load_dotenv
|
||||
|
|
@ -91,11 +92,27 @@ def normalize_l2(embedding: list[float]):
|
|||
|
||||
|
||||
# Use this instead of asyncio.gather() to bound coroutines
|
||||
async def semaphore_gather(*coroutines: Coroutine, max_coroutines: int = SEMAPHORE_LIMIT):
|
||||
async def semaphore_gather(
|
||||
*coroutines: Coroutine,
|
||||
max_coroutines: int = SEMAPHORE_LIMIT,
|
||||
):
|
||||
semaphore = asyncio.Semaphore(max_coroutines)
|
||||
|
||||
async def _wrap_coroutine(coroutine):
|
||||
async def _wrap(coro: Coroutine) -> Any:
|
||||
async with semaphore:
|
||||
return await coroutine
|
||||
return await coro
|
||||
|
||||
return await asyncio.gather(*(_wrap_coroutine(coroutine) for coroutine in coroutines))
|
||||
results = []
|
||||
batch = []
|
||||
for coroutine in coroutines:
|
||||
batch.append(_wrap(coroutine))
|
||||
# once we hit max_coroutines, gather and clear the batch
|
||||
if len(batch) >= max_coroutines:
|
||||
results.extend(await asyncio.gather(*batch))
|
||||
batch.clear()
|
||||
|
||||
# gather any remaining coroutines in the final batch
|
||||
if batch:
|
||||
results.extend(await asyncio.gather(*batch))
|
||||
|
||||
return results
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue