feat: Add RPM control
This commit is contained in:
parent
5d0586da28
commit
0c97a400b0
6 changed files with 91 additions and 42 deletions
|
|
@ -25,6 +25,7 @@ from cognee.infrastructure.llm.tokenizer.Mistral import (
|
|||
from cognee.infrastructure.llm.tokenizer.TikToken import (
|
||||
TikTokenTokenizer,
|
||||
)
|
||||
from cognee.shared.rate_limiting import embedding_rate_limiter_context_manager
|
||||
|
||||
litellm.set_verbose = False
|
||||
logger = get_logger("LiteLLMEmbeddingEngine")
|
||||
|
|
@ -109,13 +110,14 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
|||
response = {"data": [{"embedding": [0.0] * self.dimensions} for _ in text]}
|
||||
return [data["embedding"] for data in response["data"]]
|
||||
else:
|
||||
response = await litellm.aembedding(
|
||||
model=self.model,
|
||||
input=text,
|
||||
api_key=self.api_key,
|
||||
api_base=self.endpoint,
|
||||
api_version=self.api_version,
|
||||
)
|
||||
async with embedding_rate_limiter_context_manager():
|
||||
response = await litellm.aembedding(
|
||||
model=self.model,
|
||||
input=text,
|
||||
api_key=self.api_key,
|
||||
api_base=self.endpoint,
|
||||
api_version=self.api_version,
|
||||
)
|
||||
|
||||
return [data["embedding"] for data in response.data]
|
||||
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from cognee.infrastructure.llm.structured_output_framework.baml.baml_client.type
|
|||
TypeBuilder,
|
||||
)
|
||||
from cognee.infrastructure.llm.structured_output_framework.baml.baml_client import b
|
||||
from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
|
||||
import logging
|
||||
|
||||
logger = get_logger()
|
||||
|
|
@ -58,11 +59,12 @@ async def acreate_structured_output(
|
|||
tb = TypeBuilder()
|
||||
type_builder = create_dynamic_baml_type(tb, tb.ResponseModel, response_model)
|
||||
|
||||
result = await b.AcreateStructuredOutput(
|
||||
text_input=text_input,
|
||||
system_prompt=system_prompt,
|
||||
baml_options={"client_registry": config.baml_registry, "tb": type_builder},
|
||||
)
|
||||
async with llm_rate_limiter_context_manager():
|
||||
result = await b.AcreateStructuredOutput(
|
||||
text_input=text_input,
|
||||
system_prompt=system_prompt,
|
||||
baml_options={"client_registry": config.baml_registry, "tb": type_builder},
|
||||
)
|
||||
|
||||
# Transform BAML response to proper pydantic reponse model
|
||||
if response_model is str:
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.ll
|
|||
from cognee.infrastructure.llm.exceptions import (
|
||||
ContentPolicyFilterError,
|
||||
)
|
||||
from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
|
||||
from cognee.infrastructure.files.utils.open_data_file import open_data_file
|
||||
from cognee.modules.observability.get_observe import get_observe
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
|
@ -135,34 +136,9 @@ class OpenAIAdapter(LLMInterface):
|
|||
"""
|
||||
|
||||
try:
|
||||
return await self.aclient.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""{text_input}""",
|
||||
},
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
],
|
||||
api_key=self.api_key,
|
||||
api_base=self.endpoint,
|
||||
api_version=self.api_version,
|
||||
response_model=response_model,
|
||||
max_retries=self.MAX_RETRIES,
|
||||
)
|
||||
except (
|
||||
ContentFilterFinishReasonError,
|
||||
ContentPolicyViolationError,
|
||||
InstructorRetryException,
|
||||
) as e:
|
||||
if not (self.fallback_model and self.fallback_api_key):
|
||||
raise e
|
||||
try:
|
||||
async with llm_rate_limiter_context_manager():
|
||||
return await self.aclient.chat.completions.create(
|
||||
model=self.fallback_model,
|
||||
model=self.model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
|
|
@ -173,11 +149,38 @@ class OpenAIAdapter(LLMInterface):
|
|||
"content": system_prompt,
|
||||
},
|
||||
],
|
||||
api_key=self.fallback_api_key,
|
||||
# api_base=self.fallback_endpoint,
|
||||
api_key=self.api_key,
|
||||
api_base=self.endpoint,
|
||||
api_version=self.api_version,
|
||||
response_model=response_model,
|
||||
max_retries=self.MAX_RETRIES,
|
||||
)
|
||||
except (
|
||||
ContentFilterFinishReasonError,
|
||||
ContentPolicyViolationError,
|
||||
InstructorRetryException,
|
||||
) as e:
|
||||
if not (self.fallback_model and self.fallback_api_key):
|
||||
raise e
|
||||
try:
|
||||
async with llm_rate_limiter_context_manager():
|
||||
return await self.aclient.chat.completions.create(
|
||||
model=self.fallback_model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""{text_input}""",
|
||||
},
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
],
|
||||
api_key=self.fallback_api_key,
|
||||
# api_base=self.fallback_endpoint,
|
||||
response_model=response_model,
|
||||
max_retries=self.MAX_RETRIES,
|
||||
)
|
||||
except (
|
||||
ContentFilterFinishReasonError,
|
||||
ContentPolicyViolationError,
|
||||
|
|
|
|||
30
cognee/shared/rate_limiting.py
Normal file
30
cognee/shared/rate_limiting.py
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
from aiolimiter import AsyncLimiter
|
||||
from contextlib import nullcontext
|
||||
from cognee.infrastructure.llm.config import get_llm_config
|
||||
|
||||
llm_config = get_llm_config()
|
||||
|
||||
llm_rate_limiter = AsyncLimiter(
|
||||
llm_config.llm_rate_limit_requests, llm_config.embedding_rate_limit_interval
|
||||
)
|
||||
embedding_rate_limiter = AsyncLimiter(
|
||||
llm_config.embedding_rate_limit_requests, llm_config.embedding_rate_limit_interval
|
||||
)
|
||||
|
||||
|
||||
def llm_rate_limiter_context_manager():
|
||||
global llm_rate_limiter
|
||||
if llm_config.llm_rate_limit_enabled:
|
||||
return llm_rate_limiter
|
||||
else:
|
||||
# Return a no-op context manager if rate limiting is disabled
|
||||
return nullcontext()
|
||||
|
||||
|
||||
def embedding_rate_limiter_context_manager():
|
||||
global embedding_rate_limiter
|
||||
if llm_config.embedding_rate_limit_enabled:
|
||||
return embedding_rate_limiter
|
||||
else:
|
||||
# Return a no-op context manager if rate limiting is disabled
|
||||
return nullcontext()
|
||||
|
|
@ -59,6 +59,7 @@ dependencies = [
|
|||
"tenacity>=9.0.0",
|
||||
"fakeredis[lua]>=2.32.0",
|
||||
"diskcache>=5.6.3",
|
||||
"aiolimiter>=1.2.1",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
|
|
|||
13
uv.lock
generated
13
uv.lock
generated
|
|
@ -1,5 +1,5 @@
|
|||
version = 1
|
||||
revision = 3
|
||||
revision = 2
|
||||
requires-python = ">=3.10, <3.14"
|
||||
resolution-markers = [
|
||||
"python_full_version >= '3.13' and platform_python_implementation != 'PyPy' and sys_platform == 'darwin'",
|
||||
|
|
@ -187,6 +187,15 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/10/a1/510b0a7fadc6f43a6ce50152e69dbd86415240835868bb0bd9b5b88b1e06/aioitertools-0.13.0-py3-none-any.whl", hash = "sha256:0be0292b856f08dfac90e31f4739432f4cb6d7520ab9eb73e143f4f2fa5259be", size = 24182, upload-time = "2025-11-06T22:17:06.502Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "aiolimiter"
|
||||
version = "1.2.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f1/23/b52debf471f7a1e42e362d959a3982bdcb4fe13a5d46e63d28868807a79c/aiolimiter-1.2.1.tar.gz", hash = "sha256:e02a37ea1a855d9e832252a105420ad4d15011505512a1a1d814647451b5cca9", size = 7185, upload-time = "2024-12-08T15:31:51.496Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/f3/ba/df6e8e1045aebc4778d19b8a3a9bc1808adb1619ba94ca354d9ba17d86c3/aiolimiter-1.2.1-py3-none-any.whl", hash = "sha256:d3f249e9059a20badcb56b61601a83556133655c11d1eb3dd3e04ff069e5f3c7", size = 6711, upload-time = "2024-12-08T15:31:49.874Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "aiosignal"
|
||||
version = "1.4.0"
|
||||
|
|
@ -942,6 +951,7 @@ source = { editable = "." }
|
|||
dependencies = [
|
||||
{ name = "aiofiles" },
|
||||
{ name = "aiohttp" },
|
||||
{ name = "aiolimiter" },
|
||||
{ name = "aiosqlite" },
|
||||
{ name = "alembic" },
|
||||
{ name = "diskcache" },
|
||||
|
|
@ -1113,6 +1123,7 @@ scraping = [
|
|||
requires-dist = [
|
||||
{ name = "aiofiles", specifier = ">=23.2.1" },
|
||||
{ name = "aiohttp", specifier = ">=3.11.14,<4.0.0" },
|
||||
{ name = "aiolimiter", specifier = ">=1.2.1" },
|
||||
{ name = "aiosqlite", specifier = ">=0.20.0,<1.0.0" },
|
||||
{ name = "alembic", specifier = ">=1.13.3,<2" },
|
||||
{ name = "anthropic", marker = "extra == 'anthropic'", specifier = ">=0.27" },
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue