feat: Add RPM limiting to Cognee
This commit is contained in:
parent
0c97a400b0
commit
7deaa6e8e9
10 changed files with 146 additions and 683 deletions
|
|
@ -17,6 +17,7 @@ from cognee.infrastructure.databases.exceptions import EmbeddingException
|
||||||
from cognee.infrastructure.llm.tokenizer.TikToken import (
|
from cognee.infrastructure.llm.tokenizer.TikToken import (
|
||||||
TikTokenTokenizer,
|
TikTokenTokenizer,
|
||||||
)
|
)
|
||||||
|
from cognee.shared.rate_limiting import embedding_rate_limiter_context_manager
|
||||||
|
|
||||||
litellm.set_verbose = False
|
litellm.set_verbose = False
|
||||||
logger = get_logger("FastembedEmbeddingEngine")
|
logger = get_logger("FastembedEmbeddingEngine")
|
||||||
|
|
@ -68,7 +69,7 @@ class FastembedEmbeddingEngine(EmbeddingEngine):
|
||||||
|
|
||||||
@retry(
|
@retry(
|
||||||
stop=stop_after_delay(128),
|
stop=stop_after_delay(128),
|
||||||
wait=wait_exponential_jitter(2, 128),
|
wait=wait_exponential_jitter(8, 128),
|
||||||
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
||||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||||
reraise=True,
|
reraise=True,
|
||||||
|
|
@ -96,11 +97,12 @@ class FastembedEmbeddingEngine(EmbeddingEngine):
|
||||||
if self.mock:
|
if self.mock:
|
||||||
return [[0.0] * self.dimensions for _ in text]
|
return [[0.0] * self.dimensions for _ in text]
|
||||||
else:
|
else:
|
||||||
embeddings = self.embedding_model.embed(
|
async with embedding_rate_limiter_context_manager():
|
||||||
text,
|
embeddings = self.embedding_model.embed(
|
||||||
batch_size=len(text),
|
text,
|
||||||
parallel=None,
|
batch_size=len(text),
|
||||||
)
|
parallel=None,
|
||||||
|
)
|
||||||
|
|
||||||
return list(embeddings)
|
return list(embeddings)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -18,10 +18,7 @@ from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import Em
|
||||||
from cognee.infrastructure.llm.tokenizer.HuggingFace import (
|
from cognee.infrastructure.llm.tokenizer.HuggingFace import (
|
||||||
HuggingFaceTokenizer,
|
HuggingFaceTokenizer,
|
||||||
)
|
)
|
||||||
from cognee.infrastructure.databases.vector.embeddings.embedding_rate_limiter import (
|
from cognee.shared.rate_limiting import embedding_rate_limiter_context_manager
|
||||||
embedding_rate_limit_async,
|
|
||||||
embedding_sleep_and_retry_async,
|
|
||||||
)
|
|
||||||
from cognee.shared.utils import create_secure_ssl_context
|
from cognee.shared.utils import create_secure_ssl_context
|
||||||
|
|
||||||
logger = get_logger("OllamaEmbeddingEngine")
|
logger = get_logger("OllamaEmbeddingEngine")
|
||||||
|
|
@ -101,7 +98,7 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
|
||||||
|
|
||||||
@retry(
|
@retry(
|
||||||
stop=stop_after_delay(128),
|
stop=stop_after_delay(128),
|
||||||
wait=wait_exponential_jitter(2, 128),
|
wait=wait_exponential_jitter(8, 128),
|
||||||
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
||||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||||
reraise=True,
|
reraise=True,
|
||||||
|
|
@ -120,14 +117,15 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
|
||||||
ssl_context = create_secure_ssl_context()
|
ssl_context = create_secure_ssl_context()
|
||||||
connector = aiohttp.TCPConnector(ssl=ssl_context)
|
connector = aiohttp.TCPConnector(ssl=ssl_context)
|
||||||
async with aiohttp.ClientSession(connector=connector) as session:
|
async with aiohttp.ClientSession(connector=connector) as session:
|
||||||
async with session.post(
|
async with embedding_rate_limiter_context_manager():
|
||||||
self.endpoint, json=payload, headers=headers, timeout=60.0
|
async with session.post(
|
||||||
) as response:
|
self.endpoint, json=payload, headers=headers, timeout=60.0
|
||||||
data = await response.json()
|
) as response:
|
||||||
if "embeddings" in data:
|
data = await response.json()
|
||||||
return data["embeddings"][0]
|
if "embeddings" in data:
|
||||||
else:
|
return data["embeddings"][0]
|
||||||
return data["data"][0]["embedding"]
|
else:
|
||||||
|
return data["data"][0]["embedding"]
|
||||||
|
|
||||||
def get_vector_size(self) -> int:
|
def get_vector_size(self) -> int:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -1,544 +0,0 @@
|
||||||
import threading
|
|
||||||
import logging
|
|
||||||
import functools
|
|
||||||
import os
|
|
||||||
import time
|
|
||||||
import asyncio
|
|
||||||
import random
|
|
||||||
from cognee.shared.logging_utils import get_logger
|
|
||||||
from cognee.infrastructure.llm.config import get_llm_config
|
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger()
|
|
||||||
|
|
||||||
# Common error patterns that indicate rate limiting
|
|
||||||
RATE_LIMIT_ERROR_PATTERNS = [
|
|
||||||
"rate limit",
|
|
||||||
"rate_limit",
|
|
||||||
"ratelimit",
|
|
||||||
"too many requests",
|
|
||||||
"retry after",
|
|
||||||
"capacity",
|
|
||||||
"quota",
|
|
||||||
"limit exceeded",
|
|
||||||
"tps limit exceeded",
|
|
||||||
"request limit exceeded",
|
|
||||||
"maximum requests",
|
|
||||||
"exceeded your current quota",
|
|
||||||
"throttled",
|
|
||||||
"throttling",
|
|
||||||
]
|
|
||||||
|
|
||||||
# Default retry settings
|
|
||||||
DEFAULT_MAX_RETRIES = 5
|
|
||||||
DEFAULT_INITIAL_BACKOFF = 1.0 # seconds
|
|
||||||
DEFAULT_BACKOFF_FACTOR = 2.0 # exponential backoff multiplier
|
|
||||||
DEFAULT_JITTER = 0.1 # 10% jitter to avoid thundering herd
|
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingRateLimiter:
|
|
||||||
"""
|
|
||||||
Rate limiter for embedding API calls.
|
|
||||||
|
|
||||||
This class implements a singleton pattern to ensure that rate limiting
|
|
||||||
is consistent across all embedding requests. It uses the limits
|
|
||||||
library with a moving window strategy to control request rates.
|
|
||||||
|
|
||||||
The rate limiter uses the same configuration as the LLM API rate limiter
|
|
||||||
but uses a separate key to track embedding API calls independently.
|
|
||||||
|
|
||||||
Public Methods:
|
|
||||||
- get_instance
|
|
||||||
- reset_instance
|
|
||||||
- hit_limit
|
|
||||||
- wait_if_needed
|
|
||||||
- async_wait_if_needed
|
|
||||||
|
|
||||||
Instance Variables:
|
|
||||||
- enabled
|
|
||||||
- requests_limit
|
|
||||||
- interval_seconds
|
|
||||||
- request_times
|
|
||||||
- lock
|
|
||||||
"""
|
|
||||||
|
|
||||||
_instance = None
|
|
||||||
lock = threading.Lock()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_instance(cls):
|
|
||||||
"""
|
|
||||||
Retrieve the singleton instance of the EmbeddingRateLimiter.
|
|
||||||
|
|
||||||
This method ensures that only one instance of the class exists and
|
|
||||||
is thread-safe. It lazily initializes the instance if it doesn't
|
|
||||||
already exist.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
--------
|
|
||||||
|
|
||||||
The singleton instance of the EmbeddingRateLimiter class.
|
|
||||||
"""
|
|
||||||
if cls._instance is None:
|
|
||||||
with cls.lock:
|
|
||||||
if cls._instance is None:
|
|
||||||
cls._instance = cls()
|
|
||||||
return cls._instance
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def reset_instance(cls):
|
|
||||||
"""
|
|
||||||
Reset the singleton instance of the EmbeddingRateLimiter.
|
|
||||||
|
|
||||||
This method is thread-safe and sets the instance to None, allowing
|
|
||||||
for a new instance to be created when requested again.
|
|
||||||
"""
|
|
||||||
with cls.lock:
|
|
||||||
cls._instance = None
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
config = get_llm_config()
|
|
||||||
self.enabled = config.embedding_rate_limit_enabled
|
|
||||||
self.requests_limit = config.embedding_rate_limit_requests
|
|
||||||
self.interval_seconds = config.embedding_rate_limit_interval
|
|
||||||
self.request_times = []
|
|
||||||
self.lock = threading.Lock()
|
|
||||||
|
|
||||||
logging.info(
|
|
||||||
f"EmbeddingRateLimiter initialized: enabled={self.enabled}, "
|
|
||||||
f"requests_limit={self.requests_limit}, interval_seconds={self.interval_seconds}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def hit_limit(self) -> bool:
|
|
||||||
"""
|
|
||||||
Check if the current request would exceed the rate limit.
|
|
||||||
|
|
||||||
This method checks if the rate limiter is enabled and evaluates
|
|
||||||
the number of requests made in the elapsed interval.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- bool: True if the rate limit would be exceeded, False otherwise.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
--------
|
|
||||||
|
|
||||||
- bool: True if the rate limit would be exceeded, otherwise False.
|
|
||||||
"""
|
|
||||||
if not self.enabled:
|
|
||||||
return False
|
|
||||||
|
|
||||||
current_time = time.time()
|
|
||||||
|
|
||||||
with self.lock:
|
|
||||||
# Remove expired request times
|
|
||||||
cutoff_time = current_time - self.interval_seconds
|
|
||||||
self.request_times = [t for t in self.request_times if t > cutoff_time]
|
|
||||||
|
|
||||||
# Check if adding a new request would exceed the limit
|
|
||||||
if len(self.request_times) >= self.requests_limit:
|
|
||||||
logger.info(
|
|
||||||
f"Rate limit hit: {len(self.request_times)} requests in the last {self.interval_seconds} seconds"
|
|
||||||
)
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Otherwise, we're under the limit
|
|
||||||
return False
|
|
||||||
|
|
||||||
def wait_if_needed(self) -> float:
|
|
||||||
"""
|
|
||||||
Block until a request can be made without exceeding the rate limit.
|
|
||||||
|
|
||||||
This method will wait if the current request would exceed the
|
|
||||||
rate limit and returns the time waited in seconds.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- float: Time waited in seconds before a request is allowed.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
--------
|
|
||||||
|
|
||||||
- float: Time waited in seconds before proceeding.
|
|
||||||
"""
|
|
||||||
if not self.enabled:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
wait_time = 0
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
while self.hit_limit():
|
|
||||||
time.sleep(0.5) # Poll every 0.5 seconds
|
|
||||||
wait_time = time.time() - start_time
|
|
||||||
|
|
||||||
# Record this request
|
|
||||||
with self.lock:
|
|
||||||
self.request_times.append(time.time())
|
|
||||||
|
|
||||||
return wait_time
|
|
||||||
|
|
||||||
async def async_wait_if_needed(self) -> float:
|
|
||||||
"""
|
|
||||||
Asynchronously wait until a request can be made without exceeding the rate limit.
|
|
||||||
|
|
||||||
This method will wait if the current request would exceed the
|
|
||||||
rate limit and returns the time waited in seconds.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- float: Time waited in seconds before a request is allowed.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
--------
|
|
||||||
|
|
||||||
- float: Time waited in seconds before proceeding.
|
|
||||||
"""
|
|
||||||
if not self.enabled:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
wait_time = 0
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
while self.hit_limit():
|
|
||||||
await asyncio.sleep(0.5) # Poll every 0.5 seconds
|
|
||||||
wait_time = time.time() - start_time
|
|
||||||
|
|
||||||
# Record this request
|
|
||||||
with self.lock:
|
|
||||||
self.request_times.append(time.time())
|
|
||||||
|
|
||||||
return wait_time
|
|
||||||
|
|
||||||
|
|
||||||
def embedding_rate_limit_sync(func):
|
|
||||||
"""
|
|
||||||
Apply rate limiting to a synchronous embedding function.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
-----------
|
|
||||||
|
|
||||||
- func: Function to decorate with rate limiting logic.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
--------
|
|
||||||
|
|
||||||
Returns the decorated function that applies rate limiting.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@functools.wraps(func)
|
|
||||||
def wrapper(*args, **kwargs):
|
|
||||||
"""
|
|
||||||
Wrap the given function with rate limiting logic to control the embedding API usage.
|
|
||||||
|
|
||||||
Checks if the rate limit has been exceeded before allowing the function to execute. If
|
|
||||||
the limit is hit, it logs a warning and raises an EmbeddingException. Otherwise, it
|
|
||||||
updates the request count and proceeds to call the original function.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
-----------
|
|
||||||
|
|
||||||
- *args: Variable length argument list for the wrapped function.
|
|
||||||
- **kwargs: Keyword arguments for the wrapped function.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
--------
|
|
||||||
|
|
||||||
Returns the result of the wrapped function if rate limiting conditions are met.
|
|
||||||
"""
|
|
||||||
limiter = EmbeddingRateLimiter.get_instance()
|
|
||||||
|
|
||||||
# Check if rate limiting is enabled and if we're at the limit
|
|
||||||
if limiter.hit_limit():
|
|
||||||
error_msg = "Embedding API rate limit exceeded"
|
|
||||||
logger.warning(error_msg)
|
|
||||||
|
|
||||||
# Create a custom embedding rate limit exception
|
|
||||||
from cognee.infrastructure.databases.exceptions import EmbeddingException
|
|
||||||
|
|
||||||
raise EmbeddingException(error_msg)
|
|
||||||
|
|
||||||
# Add this request to the counter and proceed
|
|
||||||
limiter.wait_if_needed()
|
|
||||||
return func(*args, **kwargs)
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
|
|
||||||
def embedding_rate_limit_async(func):
|
|
||||||
"""
|
|
||||||
Decorator that applies rate limiting to an asynchronous embedding function.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
-----------
|
|
||||||
|
|
||||||
- func: Async function to decorate.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
--------
|
|
||||||
|
|
||||||
Returns the decorated async function that applies rate limiting.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@functools.wraps(func)
|
|
||||||
async def wrapper(*args, **kwargs):
|
|
||||||
"""
|
|
||||||
Handle function calls with embedding rate limiting.
|
|
||||||
|
|
||||||
This asynchronous wrapper checks if the embedding API rate limit is exceeded before
|
|
||||||
allowing the function to execute. If the limit is exceeded, it logs a warning and raises
|
|
||||||
an EmbeddingException. If not, it waits as necessary and proceeds with the function
|
|
||||||
call.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
-----------
|
|
||||||
|
|
||||||
- *args: Positional arguments passed to the wrapped function.
|
|
||||||
- **kwargs: Keyword arguments passed to the wrapped function.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
--------
|
|
||||||
|
|
||||||
Returns the result of the wrapped function after handling rate limiting.
|
|
||||||
"""
|
|
||||||
limiter = EmbeddingRateLimiter.get_instance()
|
|
||||||
|
|
||||||
# Check if rate limiting is enabled and if we're at the limit
|
|
||||||
if limiter.hit_limit():
|
|
||||||
error_msg = "Embedding API rate limit exceeded"
|
|
||||||
logger.warning(error_msg)
|
|
||||||
|
|
||||||
# Create a custom embedding rate limit exception
|
|
||||||
from cognee.infrastructure.databases.exceptions import EmbeddingException
|
|
||||||
|
|
||||||
raise EmbeddingException(error_msg)
|
|
||||||
|
|
||||||
# Add this request to the counter and proceed
|
|
||||||
await limiter.async_wait_if_needed()
|
|
||||||
return await func(*args, **kwargs)
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
|
|
||||||
def embedding_sleep_and_retry_sync(max_retries=5, base_backoff=1.0, jitter=0.5):
|
|
||||||
"""
|
|
||||||
Add retry with exponential backoff for synchronous embedding functions.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
-----------
|
|
||||||
|
|
||||||
- max_retries: Maximum number of retries before giving up. (default 5)
|
|
||||||
- base_backoff: Base backoff time in seconds for retry intervals. (default 1.0)
|
|
||||||
- jitter: Jitter factor to randomize the backoff time to avoid collision. (default
|
|
||||||
0.5)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
--------
|
|
||||||
|
|
||||||
A decorator that retries the wrapped function on rate limit errors, applying
|
|
||||||
exponential backoff with jitter.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def decorator(func):
|
|
||||||
"""
|
|
||||||
Wraps a function to apply retry logic on rate limit errors.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
-----------
|
|
||||||
|
|
||||||
- func: The function to be wrapped with retry logic.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
--------
|
|
||||||
|
|
||||||
Returns the wrapped function with retry logic applied.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@functools.wraps(func)
|
|
||||||
def wrapper(*args, **kwargs):
|
|
||||||
"""
|
|
||||||
Retry the execution of a function with backoff on failure due to rate limit errors.
|
|
||||||
|
|
||||||
This wrapper function will call the specified function and if it raises an exception, it
|
|
||||||
will handle retries according to defined conditions. It will check the environment for a
|
|
||||||
DISABLE_RETRIES flag to determine whether to retry or propagate errors immediately
|
|
||||||
during tests. If the error is identified as a rate limit error, it will apply an
|
|
||||||
exponential backoff strategy with jitter before retrying, up to a maximum number of
|
|
||||||
retries. If the retries are exhausted, it raises the last encountered error.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
-----------
|
|
||||||
|
|
||||||
- *args: Positional arguments passed to the wrapped function.
|
|
||||||
- **kwargs: Keyword arguments passed to the wrapped function.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
--------
|
|
||||||
|
|
||||||
Returns the result of the wrapped function if successful; otherwise, raises the last
|
|
||||||
error encountered after maximum retries are exhausted.
|
|
||||||
"""
|
|
||||||
# If DISABLE_RETRIES is set, don't retry for testing purposes
|
|
||||||
disable_retries = os.environ.get("DISABLE_RETRIES", "false").lower() in (
|
|
||||||
"true",
|
|
||||||
"1",
|
|
||||||
"yes",
|
|
||||||
)
|
|
||||||
|
|
||||||
retries = 0
|
|
||||||
last_error = None
|
|
||||||
|
|
||||||
while retries <= max_retries:
|
|
||||||
try:
|
|
||||||
return func(*args, **kwargs)
|
|
||||||
except Exception as e:
|
|
||||||
# Check if this is a rate limit error
|
|
||||||
error_str = str(e).lower()
|
|
||||||
error_type = type(e).__name__
|
|
||||||
is_rate_limit = any(
|
|
||||||
pattern in error_str.lower() for pattern in RATE_LIMIT_ERROR_PATTERNS
|
|
||||||
)
|
|
||||||
|
|
||||||
if disable_retries:
|
|
||||||
# For testing, propagate the exception immediately
|
|
||||||
raise
|
|
||||||
|
|
||||||
if is_rate_limit and retries < max_retries:
|
|
||||||
# Calculate backoff with jitter
|
|
||||||
backoff = (
|
|
||||||
base_backoff * (2**retries) * (1 + random.uniform(-jitter, jitter))
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.warning(
|
|
||||||
f"Embedding rate limit hit, retrying in {backoff:.2f}s "
|
|
||||||
f"(attempt {retries + 1}/{max_retries}): "
|
|
||||||
f"({error_str!r}, {error_type!r})"
|
|
||||||
)
|
|
||||||
|
|
||||||
time.sleep(backoff)
|
|
||||||
retries += 1
|
|
||||||
last_error = e
|
|
||||||
else:
|
|
||||||
# Not a rate limit error or max retries reached, raise
|
|
||||||
raise
|
|
||||||
|
|
||||||
# If we exit the loop due to max retries, raise the last error
|
|
||||||
if last_error:
|
|
||||||
raise last_error
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
|
|
||||||
def embedding_sleep_and_retry_async(max_retries=5, base_backoff=1.0, jitter=0.5):
|
|
||||||
"""
|
|
||||||
Add retry logic with exponential backoff for asynchronous embedding functions.
|
|
||||||
|
|
||||||
This decorator retries the wrapped asynchronous function upon encountering rate limit
|
|
||||||
errors, utilizing exponential backoff with optional jitter to space out retry attempts.
|
|
||||||
It allows for a maximum number of retries before giving up and raising the last error
|
|
||||||
encountered.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
-----------
|
|
||||||
|
|
||||||
- max_retries: Maximum number of retries allowed before giving up. (default 5)
|
|
||||||
- base_backoff: Base amount of time in seconds to wait before retrying after a rate
|
|
||||||
limit error. (default 1.0)
|
|
||||||
- jitter: Amount of randomness to add to the backoff duration to help mitigate burst
|
|
||||||
issues on retries. (default 0.5)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
--------
|
|
||||||
|
|
||||||
Returns a decorated asynchronous function that implements the retry logic on rate
|
|
||||||
limit errors.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def decorator(func):
|
|
||||||
"""
|
|
||||||
Handle retries for an async function with exponential backoff and jitter.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
-----------
|
|
||||||
|
|
||||||
- func: An asynchronous function to be wrapped with retry logic.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
--------
|
|
||||||
|
|
||||||
Returns the wrapper function that manages the retry behavior for the wrapped async
|
|
||||||
function.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@functools.wraps(func)
|
|
||||||
async def wrapper(*args, **kwargs):
|
|
||||||
"""
|
|
||||||
Handle retries for an async function with exponential backoff and jitter.
|
|
||||||
|
|
||||||
If the environment variable DISABLE_RETRIES is set to true, 1, or yes, the function will
|
|
||||||
not retry on errors.
|
|
||||||
It attempts to call the wrapped function until it succeeds or the maximum number of
|
|
||||||
retries is reached. If an exception occurs, it checks if it's a rate limit error to
|
|
||||||
determine if a retry is needed.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
-----------
|
|
||||||
|
|
||||||
- *args: Positional arguments passed to the wrapped function.
|
|
||||||
- **kwargs: Keyword arguments passed to the wrapped function.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
--------
|
|
||||||
|
|
||||||
Returns the result of the wrapped async function if successful; raises the last
|
|
||||||
encountered error if all retries fail.
|
|
||||||
"""
|
|
||||||
# If DISABLE_RETRIES is set, don't retry for testing purposes
|
|
||||||
disable_retries = os.environ.get("DISABLE_RETRIES", "false").lower() in (
|
|
||||||
"true",
|
|
||||||
"1",
|
|
||||||
"yes",
|
|
||||||
)
|
|
||||||
|
|
||||||
retries = 0
|
|
||||||
last_error = None
|
|
||||||
|
|
||||||
while retries <= max_retries:
|
|
||||||
try:
|
|
||||||
return await func(*args, **kwargs)
|
|
||||||
except Exception as e:
|
|
||||||
# Check if this is a rate limit error
|
|
||||||
error_str = str(e).lower()
|
|
||||||
error_type = type(e).__name__
|
|
||||||
is_rate_limit = any(
|
|
||||||
pattern in error_str.lower() for pattern in RATE_LIMIT_ERROR_PATTERNS
|
|
||||||
)
|
|
||||||
|
|
||||||
if disable_retries:
|
|
||||||
# For testing, propagate the exception immediately
|
|
||||||
raise
|
|
||||||
|
|
||||||
if is_rate_limit and retries < max_retries:
|
|
||||||
# Calculate backoff with jitter
|
|
||||||
backoff = (
|
|
||||||
base_backoff * (2**retries) * (1 + random.uniform(-jitter, jitter))
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.warning(
|
|
||||||
f"Embedding rate limit hit, retrying in {backoff:.2f}s "
|
|
||||||
f"(attempt {retries + 1}/{max_retries}): "
|
|
||||||
f"({error_str!r}, {error_type!r})"
|
|
||||||
)
|
|
||||||
|
|
||||||
await asyncio.sleep(backoff)
|
|
||||||
retries += 1
|
|
||||||
last_error = e
|
|
||||||
else:
|
|
||||||
# Not a rate limit error or max retries reached, raise
|
|
||||||
raise
|
|
||||||
|
|
||||||
# If we exit the loop due to max retries, raise the last error
|
|
||||||
if last_error:
|
|
||||||
raise last_error
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
return decorator
|
|
||||||
|
|
@ -15,6 +15,7 @@ from tenacity import (
|
||||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
|
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
|
||||||
LLMInterface,
|
LLMInterface,
|
||||||
)
|
)
|
||||||
|
from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
|
||||||
from cognee.infrastructure.llm.config import get_llm_config
|
from cognee.infrastructure.llm.config import get_llm_config
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
@ -45,7 +46,7 @@ class AnthropicAdapter(LLMInterface):
|
||||||
|
|
||||||
@retry(
|
@retry(
|
||||||
stop=stop_after_delay(128),
|
stop=stop_after_delay(128),
|
||||||
wait=wait_exponential_jitter(2, 128),
|
wait=wait_exponential_jitter(8, 128),
|
||||||
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
||||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||||
reraise=True,
|
reraise=True,
|
||||||
|
|
@ -69,17 +70,17 @@ class AnthropicAdapter(LLMInterface):
|
||||||
|
|
||||||
- BaseModel: An instance of BaseModel containing the structured response.
|
- BaseModel: An instance of BaseModel containing the structured response.
|
||||||
"""
|
"""
|
||||||
|
async with llm_rate_limiter_context_manager():
|
||||||
return await self.aclient(
|
return await self.aclient(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
max_tokens=4096,
|
max_tokens=4096,
|
||||||
max_retries=5,
|
max_retries=2,
|
||||||
messages=[
|
messages=[
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": f"""Use the given format to extract information
|
"content": f"""Use the given format to extract information
|
||||||
from the following input: {text_input}. {system_prompt}""",
|
from the following input: {text_input}. {system_prompt}""",
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
response_model=response_model,
|
response_model=response_model,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,7 @@ from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.ll
|
||||||
LLMInterface,
|
LLMInterface,
|
||||||
)
|
)
|
||||||
import logging
|
import logging
|
||||||
|
from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
from tenacity import (
|
from tenacity import (
|
||||||
retry,
|
retry,
|
||||||
|
|
@ -73,7 +74,7 @@ class GeminiAdapter(LLMInterface):
|
||||||
|
|
||||||
@retry(
|
@retry(
|
||||||
stop=stop_after_delay(128),
|
stop=stop_after_delay(128),
|
||||||
wait=wait_exponential_jitter(2, 128),
|
wait=wait_exponential_jitter(8, 128),
|
||||||
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
||||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||||
reraise=True,
|
reraise=True,
|
||||||
|
|
@ -105,24 +106,25 @@ class GeminiAdapter(LLMInterface):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return await self.aclient.chat.completions.create(
|
async with llm_rate_limiter_context_manager():
|
||||||
model=self.model,
|
return await self.aclient.chat.completions.create(
|
||||||
messages=[
|
model=self.model,
|
||||||
{
|
messages=[
|
||||||
"role": "user",
|
{
|
||||||
"content": f"""{text_input}""",
|
"role": "user",
|
||||||
},
|
"content": f"""{text_input}""",
|
||||||
{
|
},
|
||||||
"role": "system",
|
{
|
||||||
"content": system_prompt,
|
"role": "system",
|
||||||
},
|
"content": system_prompt,
|
||||||
],
|
},
|
||||||
api_key=self.api_key,
|
],
|
||||||
max_retries=5,
|
api_key=self.api_key,
|
||||||
api_base=self.endpoint,
|
max_retries=2,
|
||||||
api_version=self.api_version,
|
api_base=self.endpoint,
|
||||||
response_model=response_model,
|
api_version=self.api_version,
|
||||||
)
|
response_model=response_model,
|
||||||
|
)
|
||||||
except (
|
except (
|
||||||
ContentFilterFinishReasonError,
|
ContentFilterFinishReasonError,
|
||||||
ContentPolicyViolationError,
|
ContentPolicyViolationError,
|
||||||
|
|
@ -140,23 +142,24 @@ class GeminiAdapter(LLMInterface):
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return await self.aclient.chat.completions.create(
|
async with llm_rate_limiter_context_manager():
|
||||||
model=self.fallback_model,
|
return await self.aclient.chat.completions.create(
|
||||||
messages=[
|
model=self.fallback_model,
|
||||||
{
|
messages=[
|
||||||
"role": "user",
|
{
|
||||||
"content": f"""{text_input}""",
|
"role": "user",
|
||||||
},
|
"content": f"""{text_input}""",
|
||||||
{
|
},
|
||||||
"role": "system",
|
{
|
||||||
"content": system_prompt,
|
"role": "system",
|
||||||
},
|
"content": system_prompt,
|
||||||
],
|
},
|
||||||
max_retries=5,
|
],
|
||||||
api_key=self.fallback_api_key,
|
max_retries=2,
|
||||||
api_base=self.fallback_endpoint,
|
api_key=self.fallback_api_key,
|
||||||
response_model=response_model,
|
api_base=self.fallback_endpoint,
|
||||||
)
|
response_model=response_model,
|
||||||
|
)
|
||||||
except (
|
except (
|
||||||
ContentFilterFinishReasonError,
|
ContentFilterFinishReasonError,
|
||||||
ContentPolicyViolationError,
|
ContentPolicyViolationError,
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,7 @@ from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.ll
|
||||||
LLMInterface,
|
LLMInterface,
|
||||||
)
|
)
|
||||||
import logging
|
import logging
|
||||||
|
from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
from tenacity import (
|
from tenacity import (
|
||||||
retry,
|
retry,
|
||||||
|
|
@ -73,7 +74,7 @@ class GenericAPIAdapter(LLMInterface):
|
||||||
|
|
||||||
@retry(
|
@retry(
|
||||||
stop=stop_after_delay(128),
|
stop=stop_after_delay(128),
|
||||||
wait=wait_exponential_jitter(2, 128),
|
wait=wait_exponential_jitter(8, 128),
|
||||||
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
||||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||||
reraise=True,
|
reraise=True,
|
||||||
|
|
@ -105,23 +106,24 @@ class GenericAPIAdapter(LLMInterface):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return await self.aclient.chat.completions.create(
|
async with llm_rate_limiter_context_manager():
|
||||||
model=self.model,
|
return await self.aclient.chat.completions.create(
|
||||||
messages=[
|
model=self.model,
|
||||||
{
|
messages=[
|
||||||
"role": "user",
|
{
|
||||||
"content": f"""{text_input}""",
|
"role": "user",
|
||||||
},
|
"content": f"""{text_input}""",
|
||||||
{
|
},
|
||||||
"role": "system",
|
{
|
||||||
"content": system_prompt,
|
"role": "system",
|
||||||
},
|
"content": system_prompt,
|
||||||
],
|
},
|
||||||
max_retries=5,
|
],
|
||||||
api_key=self.api_key,
|
max_retries=2,
|
||||||
api_base=self.endpoint,
|
api_key=self.api_key,
|
||||||
response_model=response_model,
|
api_base=self.endpoint,
|
||||||
)
|
response_model=response_model,
|
||||||
|
)
|
||||||
except (
|
except (
|
||||||
ContentFilterFinishReasonError,
|
ContentFilterFinishReasonError,
|
||||||
ContentPolicyViolationError,
|
ContentPolicyViolationError,
|
||||||
|
|
@ -139,23 +141,24 @@ class GenericAPIAdapter(LLMInterface):
|
||||||
) from error
|
) from error
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return await self.aclient.chat.completions.create(
|
async with llm_rate_limiter_context_manager():
|
||||||
model=self.fallback_model,
|
return await self.aclient.chat.completions.create(
|
||||||
messages=[
|
model=self.fallback_model,
|
||||||
{
|
messages=[
|
||||||
"role": "user",
|
{
|
||||||
"content": f"""{text_input}""",
|
"role": "user",
|
||||||
},
|
"content": f"""{text_input}""",
|
||||||
{
|
},
|
||||||
"role": "system",
|
{
|
||||||
"content": system_prompt,
|
"role": "system",
|
||||||
},
|
"content": system_prompt,
|
||||||
],
|
},
|
||||||
max_retries=5,
|
],
|
||||||
api_key=self.fallback_api_key,
|
max_retries=2,
|
||||||
api_base=self.fallback_endpoint,
|
api_key=self.fallback_api_key,
|
||||||
response_model=response_model,
|
api_base=self.fallback_endpoint,
|
||||||
)
|
response_model=response_model,
|
||||||
|
)
|
||||||
except (
|
except (
|
||||||
ContentFilterFinishReasonError,
|
ContentFilterFinishReasonError,
|
||||||
ContentPolicyViolationError,
|
ContentPolicyViolationError,
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,7 @@ from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.ll
|
||||||
LLMInterface,
|
LLMInterface,
|
||||||
)
|
)
|
||||||
from cognee.infrastructure.llm.config import get_llm_config
|
from cognee.infrastructure.llm.config import get_llm_config
|
||||||
|
from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from tenacity import (
|
from tenacity import (
|
||||||
|
|
@ -62,7 +63,7 @@ class MistralAdapter(LLMInterface):
|
||||||
|
|
||||||
@retry(
|
@retry(
|
||||||
stop=stop_after_delay(128),
|
stop=stop_after_delay(128),
|
||||||
wait=wait_exponential_jitter(2, 128),
|
wait=wait_exponential_jitter(8, 128),
|
||||||
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
||||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||||
reraise=True,
|
reraise=True,
|
||||||
|
|
@ -97,13 +98,14 @@ class MistralAdapter(LLMInterface):
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
try:
|
try:
|
||||||
response = await self.aclient.chat.completions.create(
|
async with llm_rate_limiter_context_manager():
|
||||||
model=self.model,
|
response = await self.aclient.chat.completions.create(
|
||||||
max_tokens=self.max_completion_tokens,
|
model=self.model,
|
||||||
max_retries=5,
|
max_tokens=self.max_completion_tokens,
|
||||||
messages=messages,
|
max_retries=2,
|
||||||
response_model=response_model,
|
messages=messages,
|
||||||
)
|
response_model=response_model,
|
||||||
|
)
|
||||||
if response.choices and response.choices[0].message.content:
|
if response.choices and response.choices[0].message.content:
|
||||||
content = response.choices[0].message.content
|
content = response.choices[0].message.content
|
||||||
return response_model.model_validate_json(content)
|
return response_model.model_validate_json(content)
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,8 @@ from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.ll
|
||||||
)
|
)
|
||||||
from cognee.infrastructure.files.utils.open_data_file import open_data_file
|
from cognee.infrastructure.files.utils.open_data_file import open_data_file
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
|
||||||
|
|
||||||
from tenacity import (
|
from tenacity import (
|
||||||
retry,
|
retry,
|
||||||
stop_after_delay,
|
stop_after_delay,
|
||||||
|
|
@ -68,7 +70,7 @@ class OllamaAPIAdapter(LLMInterface):
|
||||||
|
|
||||||
@retry(
|
@retry(
|
||||||
stop=stop_after_delay(128),
|
stop=stop_after_delay(128),
|
||||||
wait=wait_exponential_jitter(2, 128),
|
wait=wait_exponential_jitter(8, 128),
|
||||||
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
||||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||||
reraise=True,
|
reraise=True,
|
||||||
|
|
@ -95,28 +97,28 @@ class OllamaAPIAdapter(LLMInterface):
|
||||||
|
|
||||||
- BaseModel: A structured output that conforms to the specified response model.
|
- BaseModel: A structured output that conforms to the specified response model.
|
||||||
"""
|
"""
|
||||||
|
async with llm_rate_limiter_context_manager():
|
||||||
response = self.aclient.chat.completions.create(
|
response = self.aclient.chat.completions.create(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
messages=[
|
messages=[
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": f"{text_input}",
|
"content": f"{text_input}",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
"content": system_prompt,
|
"content": system_prompt,
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
max_retries=5,
|
max_retries=2,
|
||||||
response_model=response_model,
|
response_model=response_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
@retry(
|
@retry(
|
||||||
stop=stop_after_delay(128),
|
stop=stop_after_delay(128),
|
||||||
wait=wait_exponential_jitter(2, 128),
|
wait=wait_exponential_jitter(8, 128),
|
||||||
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
||||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||||
reraise=True,
|
reraise=True,
|
||||||
|
|
|
||||||
|
|
@ -106,7 +106,7 @@ class OpenAIAdapter(LLMInterface):
|
||||||
@observe(as_type="generation")
|
@observe(as_type="generation")
|
||||||
@retry(
|
@retry(
|
||||||
stop=stop_after_delay(128),
|
stop=stop_after_delay(128),
|
||||||
wait=wait_exponential_jitter(2, 128),
|
wait=wait_exponential_jitter(8, 128),
|
||||||
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
||||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||||
reraise=True,
|
reraise=True,
|
||||||
|
|
|
||||||
|
|
@ -4,10 +4,7 @@ from typing import List
|
||||||
from cognee.infrastructure.databases.vector.embeddings.LiteLLMEmbeddingEngine import (
|
from cognee.infrastructure.databases.vector.embeddings.LiteLLMEmbeddingEngine import (
|
||||||
LiteLLMEmbeddingEngine,
|
LiteLLMEmbeddingEngine,
|
||||||
)
|
)
|
||||||
from cognee.infrastructure.databases.vector.embeddings.embedding_rate_limiter import (
|
from cognee.shared.rate_limiting import embedding_rate_limiter_context_manager
|
||||||
embedding_rate_limit_async,
|
|
||||||
embedding_sleep_and_retry_async,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class MockEmbeddingEngine(LiteLLMEmbeddingEngine):
|
class MockEmbeddingEngine(LiteLLMEmbeddingEngine):
|
||||||
|
|
@ -34,8 +31,6 @@ class MockEmbeddingEngine(LiteLLMEmbeddingEngine):
|
||||||
self.fail_every_n_requests = fail_every_n_requests
|
self.fail_every_n_requests = fail_every_n_requests
|
||||||
self.add_delay = add_delay
|
self.add_delay = add_delay
|
||||||
|
|
||||||
@embedding_sleep_and_retry_async()
|
|
||||||
@embedding_rate_limit_async
|
|
||||||
async def embed_text(self, text: List[str]) -> List[List[float]]:
|
async def embed_text(self, text: List[str]) -> List[List[float]]:
|
||||||
"""
|
"""
|
||||||
Mock implementation that returns fixed embeddings and can
|
Mock implementation that returns fixed embeddings and can
|
||||||
|
|
@ -52,4 +47,5 @@ class MockEmbeddingEngine(LiteLLMEmbeddingEngine):
|
||||||
raise Exception(f"Mock failure on request #{self.request_count}")
|
raise Exception(f"Mock failure on request #{self.request_count}")
|
||||||
|
|
||||||
# Return mock embeddings of the correct dimension
|
# Return mock embeddings of the correct dimension
|
||||||
return [[0.1] * self.dimensions for _ in text]
|
async with embedding_rate_limiter_context_manager():
|
||||||
|
return [[0.1] * self.dimensions for _ in text]
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue