add concurrent embedding limit
This commit is contained in:
parent
d0a4ef252e
commit
f6eeedb050
1 changed files with 20 additions and 1 deletions
|
|
@ -17,6 +17,17 @@ import tiktoken
|
|||
|
||||
from lightrag.prompt import PROMPTS
|
||||
|
||||
|
||||
class UnlimitedSemaphore:
|
||||
"""A context manager that allows unlimited access."""
|
||||
|
||||
async def __aenter__(self):
|
||||
pass
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
pass
|
||||
|
||||
|
||||
ENCODER = None
|
||||
|
||||
logger = logging.getLogger("lightrag")
|
||||
|
|
@ -42,9 +53,17 @@ class EmbeddingFunc:
|
|||
embedding_dim: int
|
||||
max_token_size: int
|
||||
func: callable
|
||||
concurrent_limit: int = 16
|
||||
|
||||
def __post_init__(self):
|
||||
if self.concurrent_limit != 0:
|
||||
self._semaphore = asyncio.Semaphore(self.concurrent_limit)
|
||||
else:
|
||||
self._semaphore = UnlimitedSemaphore()
|
||||
|
||||
async def __call__(self, *args, **kwargs) -> np.ndarray:
|
||||
return await self.func(*args, **kwargs)
|
||||
async with self._semaphore:
|
||||
return await self.func(*args, **kwargs)
|
||||
|
||||
|
||||
def locate_json_string_body_from_string(content: str) -> Union[str, None]:
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue