feat: Add RPM limiting to Cognee

This commit is contained in:
Igor Ilic 2025-12-05 18:56:34 +01:00
parent 0c97a400b0
commit 7deaa6e8e9
10 changed files with 146 additions and 683 deletions

View file

@ -17,6 +17,7 @@ from cognee.infrastructure.databases.exceptions import EmbeddingException
from cognee.infrastructure.llm.tokenizer.TikToken import (
TikTokenTokenizer,
)
from cognee.shared.rate_limiting import embedding_rate_limiter_context_manager
litellm.set_verbose = False
logger = get_logger("FastembedEmbeddingEngine")
@ -68,7 +69,7 @@ class FastembedEmbeddingEngine(EmbeddingEngine):
@retry(
stop=stop_after_delay(128),
wait=wait_exponential_jitter(2, 128),
wait=wait_exponential_jitter(8, 128),
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
@ -96,11 +97,12 @@ class FastembedEmbeddingEngine(EmbeddingEngine):
if self.mock:
return [[0.0] * self.dimensions for _ in text]
else:
embeddings = self.embedding_model.embed(
text,
batch_size=len(text),
parallel=None,
)
async with embedding_rate_limiter_context_manager():
embeddings = self.embedding_model.embed(
text,
batch_size=len(text),
parallel=None,
)
return list(embeddings)

View file

@ -18,10 +18,7 @@ from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import Em
from cognee.infrastructure.llm.tokenizer.HuggingFace import (
HuggingFaceTokenizer,
)
from cognee.infrastructure.databases.vector.embeddings.embedding_rate_limiter import (
embedding_rate_limit_async,
embedding_sleep_and_retry_async,
)
from cognee.shared.rate_limiting import embedding_rate_limiter_context_manager
from cognee.shared.utils import create_secure_ssl_context
logger = get_logger("OllamaEmbeddingEngine")
@ -101,7 +98,7 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
@retry(
stop=stop_after_delay(128),
wait=wait_exponential_jitter(2, 128),
wait=wait_exponential_jitter(8, 128),
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
@ -120,14 +117,15 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
ssl_context = create_secure_ssl_context()
connector = aiohttp.TCPConnector(ssl=ssl_context)
async with aiohttp.ClientSession(connector=connector) as session:
async with session.post(
self.endpoint, json=payload, headers=headers, timeout=60.0
) as response:
data = await response.json()
if "embeddings" in data:
return data["embeddings"][0]
else:
return data["data"][0]["embedding"]
async with embedding_rate_limiter_context_manager():
async with session.post(
self.endpoint, json=payload, headers=headers, timeout=60.0
) as response:
data = await response.json()
if "embeddings" in data:
return data["embeddings"][0]
else:
return data["data"][0]["embedding"]
def get_vector_size(self) -> int:
"""

View file

@ -1,544 +0,0 @@
import threading
import logging
import functools
import os
import time
import asyncio
import random
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.llm.config import get_llm_config
logger = get_logger()
# Common error patterns that indicate rate limiting
RATE_LIMIT_ERROR_PATTERNS = [
"rate limit",
"rate_limit",
"ratelimit",
"too many requests",
"retry after",
"capacity",
"quota",
"limit exceeded",
"tps limit exceeded",
"request limit exceeded",
"maximum requests",
"exceeded your current quota",
"throttled",
"throttling",
]
# Default retry settings
DEFAULT_MAX_RETRIES = 5
DEFAULT_INITIAL_BACKOFF = 1.0 # seconds
DEFAULT_BACKOFF_FACTOR = 2.0 # exponential backoff multiplier
DEFAULT_JITTER = 0.1 # 10% jitter to avoid thundering herd
class EmbeddingRateLimiter:
"""
Rate limiter for embedding API calls.
This class implements a singleton pattern to ensure that rate limiting
is consistent across all embedding requests. It uses the limits
library with a moving window strategy to control request rates.
The rate limiter uses the same configuration as the LLM API rate limiter
but uses a separate key to track embedding API calls independently.
Public Methods:
- get_instance
- reset_instance
- hit_limit
- wait_if_needed
- async_wait_if_needed
Instance Variables:
- enabled
- requests_limit
- interval_seconds
- request_times
- lock
"""
_instance = None
lock = threading.Lock()
@classmethod
def get_instance(cls):
"""
Retrieve the singleton instance of the EmbeddingRateLimiter.
This method ensures that only one instance of the class exists and
is thread-safe. It lazily initializes the instance if it doesn't
already exist.
Returns:
--------
The singleton instance of the EmbeddingRateLimiter class.
"""
if cls._instance is None:
with cls.lock:
if cls._instance is None:
cls._instance = cls()
return cls._instance
@classmethod
def reset_instance(cls):
"""
Reset the singleton instance of the EmbeddingRateLimiter.
This method is thread-safe and sets the instance to None, allowing
for a new instance to be created when requested again.
"""
with cls.lock:
cls._instance = None
def __init__(self):
config = get_llm_config()
self.enabled = config.embedding_rate_limit_enabled
self.requests_limit = config.embedding_rate_limit_requests
self.interval_seconds = config.embedding_rate_limit_interval
self.request_times = []
self.lock = threading.Lock()
logging.info(
f"EmbeddingRateLimiter initialized: enabled={self.enabled}, "
f"requests_limit={self.requests_limit}, interval_seconds={self.interval_seconds}"
)
def hit_limit(self) -> bool:
"""
Check if the current request would exceed the rate limit.
This method checks if the rate limiter is enabled and evaluates
the number of requests made in the elapsed interval.
Returns:
- bool: True if the rate limit would be exceeded, False otherwise.
Returns:
--------
- bool: True if the rate limit would be exceeded, otherwise False.
"""
if not self.enabled:
return False
current_time = time.time()
with self.lock:
# Remove expired request times
cutoff_time = current_time - self.interval_seconds
self.request_times = [t for t in self.request_times if t > cutoff_time]
# Check if adding a new request would exceed the limit
if len(self.request_times) >= self.requests_limit:
logger.info(
f"Rate limit hit: {len(self.request_times)} requests in the last {self.interval_seconds} seconds"
)
return True
# Otherwise, we're under the limit
return False
def wait_if_needed(self) -> float:
"""
Block until a request can be made without exceeding the rate limit.
This method will wait if the current request would exceed the
rate limit and returns the time waited in seconds.
Returns:
- float: Time waited in seconds before a request is allowed.
Returns:
--------
- float: Time waited in seconds before proceeding.
"""
if not self.enabled:
return 0
wait_time = 0
start_time = time.time()
while self.hit_limit():
time.sleep(0.5) # Poll every 0.5 seconds
wait_time = time.time() - start_time
# Record this request
with self.lock:
self.request_times.append(time.time())
return wait_time
async def async_wait_if_needed(self) -> float:
"""
Asynchronously wait until a request can be made without exceeding the rate limit.
This method will wait if the current request would exceed the
rate limit and returns the time waited in seconds.
Returns:
- float: Time waited in seconds before a request is allowed.
Returns:
--------
- float: Time waited in seconds before proceeding.
"""
if not self.enabled:
return 0
wait_time = 0
start_time = time.time()
while self.hit_limit():
await asyncio.sleep(0.5) # Poll every 0.5 seconds
wait_time = time.time() - start_time
# Record this request
with self.lock:
self.request_times.append(time.time())
return wait_time
def embedding_rate_limit_sync(func):
"""
Apply rate limiting to a synchronous embedding function.
Parameters:
-----------
- func: Function to decorate with rate limiting logic.
Returns:
--------
Returns the decorated function that applies rate limiting.
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
"""
Wrap the given function with rate limiting logic to control the embedding API usage.
Checks if the rate limit has been exceeded before allowing the function to execute. If
the limit is hit, it logs a warning and raises an EmbeddingException. Otherwise, it
updates the request count and proceeds to call the original function.
Parameters:
-----------
- *args: Variable length argument list for the wrapped function.
- **kwargs: Keyword arguments for the wrapped function.
Returns:
--------
Returns the result of the wrapped function if rate limiting conditions are met.
"""
limiter = EmbeddingRateLimiter.get_instance()
# Check if rate limiting is enabled and if we're at the limit
if limiter.hit_limit():
error_msg = "Embedding API rate limit exceeded"
logger.warning(error_msg)
# Create a custom embedding rate limit exception
from cognee.infrastructure.databases.exceptions import EmbeddingException
raise EmbeddingException(error_msg)
# Add this request to the counter and proceed
limiter.wait_if_needed()
return func(*args, **kwargs)
return wrapper
def embedding_rate_limit_async(func):
"""
Decorator that applies rate limiting to an asynchronous embedding function.
Parameters:
-----------
- func: Async function to decorate.
Returns:
--------
Returns the decorated async function that applies rate limiting.
"""
@functools.wraps(func)
async def wrapper(*args, **kwargs):
"""
Handle function calls with embedding rate limiting.
This asynchronous wrapper checks if the embedding API rate limit is exceeded before
allowing the function to execute. If the limit is exceeded, it logs a warning and raises
an EmbeddingException. If not, it waits as necessary and proceeds with the function
call.
Parameters:
-----------
- *args: Positional arguments passed to the wrapped function.
- **kwargs: Keyword arguments passed to the wrapped function.
Returns:
--------
Returns the result of the wrapped function after handling rate limiting.
"""
limiter = EmbeddingRateLimiter.get_instance()
# Check if rate limiting is enabled and if we're at the limit
if limiter.hit_limit():
error_msg = "Embedding API rate limit exceeded"
logger.warning(error_msg)
# Create a custom embedding rate limit exception
from cognee.infrastructure.databases.exceptions import EmbeddingException
raise EmbeddingException(error_msg)
# Add this request to the counter and proceed
await limiter.async_wait_if_needed()
return await func(*args, **kwargs)
return wrapper
def embedding_sleep_and_retry_sync(max_retries=5, base_backoff=1.0, jitter=0.5):
"""
Add retry with exponential backoff for synchronous embedding functions.
Parameters:
-----------
- max_retries: Maximum number of retries before giving up. (default 5)
- base_backoff: Base backoff time in seconds for retry intervals. (default 1.0)
- jitter: Jitter factor to randomize the backoff time to avoid collision. (default
0.5)
Returns:
--------
A decorator that retries the wrapped function on rate limit errors, applying
exponential backoff with jitter.
"""
def decorator(func):
"""
Wraps a function to apply retry logic on rate limit errors.
Parameters:
-----------
- func: The function to be wrapped with retry logic.
Returns:
--------
Returns the wrapped function with retry logic applied.
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
"""
Retry the execution of a function with backoff on failure due to rate limit errors.
This wrapper function will call the specified function and if it raises an exception, it
will handle retries according to defined conditions. It will check the environment for a
DISABLE_RETRIES flag to determine whether to retry or propagate errors immediately
during tests. If the error is identified as a rate limit error, it will apply an
exponential backoff strategy with jitter before retrying, up to a maximum number of
retries. If the retries are exhausted, it raises the last encountered error.
Parameters:
-----------
- *args: Positional arguments passed to the wrapped function.
- **kwargs: Keyword arguments passed to the wrapped function.
Returns:
--------
Returns the result of the wrapped function if successful; otherwise, raises the last
error encountered after maximum retries are exhausted.
"""
# If DISABLE_RETRIES is set, don't retry for testing purposes
disable_retries = os.environ.get("DISABLE_RETRIES", "false").lower() in (
"true",
"1",
"yes",
)
retries = 0
last_error = None
while retries <= max_retries:
try:
return func(*args, **kwargs)
except Exception as e:
# Check if this is a rate limit error
error_str = str(e).lower()
error_type = type(e).__name__
is_rate_limit = any(
pattern in error_str.lower() for pattern in RATE_LIMIT_ERROR_PATTERNS
)
if disable_retries:
# For testing, propagate the exception immediately
raise
if is_rate_limit and retries < max_retries:
# Calculate backoff with jitter
backoff = (
base_backoff * (2**retries) * (1 + random.uniform(-jitter, jitter))
)
logger.warning(
f"Embedding rate limit hit, retrying in {backoff:.2f}s "
f"(attempt {retries + 1}/{max_retries}): "
f"({error_str!r}, {error_type!r})"
)
time.sleep(backoff)
retries += 1
last_error = e
else:
# Not a rate limit error or max retries reached, raise
raise
# If we exit the loop due to max retries, raise the last error
if last_error:
raise last_error
return wrapper
return decorator
def embedding_sleep_and_retry_async(max_retries=5, base_backoff=1.0, jitter=0.5):
"""
Add retry logic with exponential backoff for asynchronous embedding functions.
This decorator retries the wrapped asynchronous function upon encountering rate limit
errors, utilizing exponential backoff with optional jitter to space out retry attempts.
It allows for a maximum number of retries before giving up and raising the last error
encountered.
Parameters:
-----------
- max_retries: Maximum number of retries allowed before giving up. (default 5)
- base_backoff: Base amount of time in seconds to wait before retrying after a rate
limit error. (default 1.0)
- jitter: Amount of randomness to add to the backoff duration to help mitigate burst
issues on retries. (default 0.5)
Returns:
--------
Returns a decorated asynchronous function that implements the retry logic on rate
limit errors.
"""
def decorator(func):
"""
Handle retries for an async function with exponential backoff and jitter.
Parameters:
-----------
- func: An asynchronous function to be wrapped with retry logic.
Returns:
--------
Returns the wrapper function that manages the retry behavior for the wrapped async
function.
"""
@functools.wraps(func)
async def wrapper(*args, **kwargs):
"""
Handle retries for an async function with exponential backoff and jitter.
If the environment variable DISABLE_RETRIES is set to true, 1, or yes, the function will
not retry on errors.
It attempts to call the wrapped function until it succeeds or the maximum number of
retries is reached. If an exception occurs, it checks if it's a rate limit error to
determine if a retry is needed.
Parameters:
-----------
- *args: Positional arguments passed to the wrapped function.
- **kwargs: Keyword arguments passed to the wrapped function.
Returns:
--------
Returns the result of the wrapped async function if successful; raises the last
encountered error if all retries fail.
"""
# If DISABLE_RETRIES is set, don't retry for testing purposes
disable_retries = os.environ.get("DISABLE_RETRIES", "false").lower() in (
"true",
"1",
"yes",
)
retries = 0
last_error = None
while retries <= max_retries:
try:
return await func(*args, **kwargs)
except Exception as e:
# Check if this is a rate limit error
error_str = str(e).lower()
error_type = type(e).__name__
is_rate_limit = any(
pattern in error_str.lower() for pattern in RATE_LIMIT_ERROR_PATTERNS
)
if disable_retries:
# For testing, propagate the exception immediately
raise
if is_rate_limit and retries < max_retries:
# Calculate backoff with jitter
backoff = (
base_backoff * (2**retries) * (1 + random.uniform(-jitter, jitter))
)
logger.warning(
f"Embedding rate limit hit, retrying in {backoff:.2f}s "
f"(attempt {retries + 1}/{max_retries}): "
f"({error_str!r}, {error_type!r})"
)
await asyncio.sleep(backoff)
retries += 1
last_error = e
else:
# Not a rate limit error or max retries reached, raise
raise
# If we exit the loop due to max retries, raise the last error
if last_error:
raise last_error
return wrapper
return decorator

View file

@ -15,6 +15,7 @@ from tenacity import (
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
LLMInterface,
)
from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
from cognee.infrastructure.llm.config import get_llm_config
logger = get_logger()
@ -45,7 +46,7 @@ class AnthropicAdapter(LLMInterface):
@retry(
stop=stop_after_delay(128),
wait=wait_exponential_jitter(2, 128),
wait=wait_exponential_jitter(8, 128),
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
@ -69,17 +70,17 @@ class AnthropicAdapter(LLMInterface):
- BaseModel: An instance of BaseModel containing the structured response.
"""
return await self.aclient(
model=self.model,
max_tokens=4096,
max_retries=5,
messages=[
{
"role": "user",
"content": f"""Use the given format to extract information
from the following input: {text_input}. {system_prompt}""",
}
],
response_model=response_model,
)
async with llm_rate_limiter_context_manager():
return await self.aclient(
model=self.model,
max_tokens=4096,
max_retries=2,
messages=[
{
"role": "user",
"content": f"""Use the given format to extract information
from the following input: {text_input}. {system_prompt}""",
}
],
response_model=response_model,
)

View file

@ -13,6 +13,7 @@ from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.ll
LLMInterface,
)
import logging
from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
from cognee.shared.logging_utils import get_logger
from tenacity import (
retry,
@ -73,7 +74,7 @@ class GeminiAdapter(LLMInterface):
@retry(
stop=stop_after_delay(128),
wait=wait_exponential_jitter(2, 128),
wait=wait_exponential_jitter(8, 128),
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
@ -105,24 +106,25 @@ class GeminiAdapter(LLMInterface):
"""
try:
return await self.aclient.chat.completions.create(
model=self.model,
messages=[
{
"role": "user",
"content": f"""{text_input}""",
},
{
"role": "system",
"content": system_prompt,
},
],
api_key=self.api_key,
max_retries=5,
api_base=self.endpoint,
api_version=self.api_version,
response_model=response_model,
)
async with llm_rate_limiter_context_manager():
return await self.aclient.chat.completions.create(
model=self.model,
messages=[
{
"role": "user",
"content": f"""{text_input}""",
},
{
"role": "system",
"content": system_prompt,
},
],
api_key=self.api_key,
max_retries=2,
api_base=self.endpoint,
api_version=self.api_version,
response_model=response_model,
)
except (
ContentFilterFinishReasonError,
ContentPolicyViolationError,
@ -140,23 +142,24 @@ class GeminiAdapter(LLMInterface):
)
try:
return await self.aclient.chat.completions.create(
model=self.fallback_model,
messages=[
{
"role": "user",
"content": f"""{text_input}""",
},
{
"role": "system",
"content": system_prompt,
},
],
max_retries=5,
api_key=self.fallback_api_key,
api_base=self.fallback_endpoint,
response_model=response_model,
)
async with llm_rate_limiter_context_manager():
return await self.aclient.chat.completions.create(
model=self.fallback_model,
messages=[
{
"role": "user",
"content": f"""{text_input}""",
},
{
"role": "system",
"content": system_prompt,
},
],
max_retries=2,
api_key=self.fallback_api_key,
api_base=self.fallback_endpoint,
response_model=response_model,
)
except (
ContentFilterFinishReasonError,
ContentPolicyViolationError,

View file

@ -13,6 +13,7 @@ from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.ll
LLMInterface,
)
import logging
from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
from cognee.shared.logging_utils import get_logger
from tenacity import (
retry,
@ -73,7 +74,7 @@ class GenericAPIAdapter(LLMInterface):
@retry(
stop=stop_after_delay(128),
wait=wait_exponential_jitter(2, 128),
wait=wait_exponential_jitter(8, 128),
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
@ -105,23 +106,24 @@ class GenericAPIAdapter(LLMInterface):
"""
try:
return await self.aclient.chat.completions.create(
model=self.model,
messages=[
{
"role": "user",
"content": f"""{text_input}""",
},
{
"role": "system",
"content": system_prompt,
},
],
max_retries=5,
api_key=self.api_key,
api_base=self.endpoint,
response_model=response_model,
)
async with llm_rate_limiter_context_manager():
return await self.aclient.chat.completions.create(
model=self.model,
messages=[
{
"role": "user",
"content": f"""{text_input}""",
},
{
"role": "system",
"content": system_prompt,
},
],
max_retries=2,
api_key=self.api_key,
api_base=self.endpoint,
response_model=response_model,
)
except (
ContentFilterFinishReasonError,
ContentPolicyViolationError,
@ -139,23 +141,24 @@ class GenericAPIAdapter(LLMInterface):
) from error
try:
return await self.aclient.chat.completions.create(
model=self.fallback_model,
messages=[
{
"role": "user",
"content": f"""{text_input}""",
},
{
"role": "system",
"content": system_prompt,
},
],
max_retries=5,
api_key=self.fallback_api_key,
api_base=self.fallback_endpoint,
response_model=response_model,
)
async with llm_rate_limiter_context_manager():
return await self.aclient.chat.completions.create(
model=self.fallback_model,
messages=[
{
"role": "user",
"content": f"""{text_input}""",
},
{
"role": "system",
"content": system_prompt,
},
],
max_retries=2,
api_key=self.fallback_api_key,
api_base=self.fallback_endpoint,
response_model=response_model,
)
except (
ContentFilterFinishReasonError,
ContentPolicyViolationError,

View file

@ -10,6 +10,7 @@ from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.ll
LLMInterface,
)
from cognee.infrastructure.llm.config import get_llm_config
from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
import logging
from tenacity import (
@ -62,7 +63,7 @@ class MistralAdapter(LLMInterface):
@retry(
stop=stop_after_delay(128),
wait=wait_exponential_jitter(2, 128),
wait=wait_exponential_jitter(8, 128),
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
@ -97,13 +98,14 @@ class MistralAdapter(LLMInterface):
},
]
try:
response = await self.aclient.chat.completions.create(
model=self.model,
max_tokens=self.max_completion_tokens,
max_retries=5,
messages=messages,
response_model=response_model,
)
async with llm_rate_limiter_context_manager():
response = await self.aclient.chat.completions.create(
model=self.model,
max_tokens=self.max_completion_tokens,
max_retries=2,
messages=messages,
response_model=response_model,
)
if response.choices and response.choices[0].message.content:
content = response.choices[0].message.content
return response_model.model_validate_json(content)

View file

@ -11,6 +11,8 @@ from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.ll
)
from cognee.infrastructure.files.utils.open_data_file import open_data_file
from cognee.shared.logging_utils import get_logger
from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
from tenacity import (
retry,
stop_after_delay,
@ -68,7 +70,7 @@ class OllamaAPIAdapter(LLMInterface):
@retry(
stop=stop_after_delay(128),
wait=wait_exponential_jitter(2, 128),
wait=wait_exponential_jitter(8, 128),
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
@ -95,28 +97,28 @@ class OllamaAPIAdapter(LLMInterface):
- BaseModel: A structured output that conforms to the specified response model.
"""
response = self.aclient.chat.completions.create(
model=self.model,
messages=[
{
"role": "user",
"content": f"{text_input}",
},
{
"role": "system",
"content": system_prompt,
},
],
max_retries=5,
response_model=response_model,
)
async with llm_rate_limiter_context_manager():
response = self.aclient.chat.completions.create(
model=self.model,
messages=[
{
"role": "user",
"content": f"{text_input}",
},
{
"role": "system",
"content": system_prompt,
},
],
max_retries=2,
response_model=response_model,
)
return response
@retry(
stop=stop_after_delay(128),
wait=wait_exponential_jitter(2, 128),
wait=wait_exponential_jitter(8, 128),
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,

View file

@ -106,7 +106,7 @@ class OpenAIAdapter(LLMInterface):
@observe(as_type="generation")
@retry(
stop=stop_after_delay(128),
wait=wait_exponential_jitter(2, 128),
wait=wait_exponential_jitter(8, 128),
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,

View file

@ -4,10 +4,7 @@ from typing import List
from cognee.infrastructure.databases.vector.embeddings.LiteLLMEmbeddingEngine import (
LiteLLMEmbeddingEngine,
)
from cognee.infrastructure.databases.vector.embeddings.embedding_rate_limiter import (
embedding_rate_limit_async,
embedding_sleep_and_retry_async,
)
from cognee.shared.rate_limiting import embedding_rate_limiter_context_manager
class MockEmbeddingEngine(LiteLLMEmbeddingEngine):
@ -34,8 +31,6 @@ class MockEmbeddingEngine(LiteLLMEmbeddingEngine):
self.fail_every_n_requests = fail_every_n_requests
self.add_delay = add_delay
@embedding_sleep_and_retry_async()
@embedding_rate_limit_async
async def embed_text(self, text: List[str]) -> List[List[float]]:
"""
Mock implementation that returns fixed embeddings and can
@ -52,4 +47,5 @@ class MockEmbeddingEngine(LiteLLMEmbeddingEngine):
raise Exception(f"Mock failure on request #{self.request_count}")
# Return mock embeddings of the correct dimension
return [[0.1] * self.dimensions for _ in text]
async with embedding_rate_limiter_context_manager():
return [[0.1] * self.dimensions for _ in text]