feat: Adding rate limiting (#709)
<!-- .github/pull_request_template.md --> ## Description <!-- Provide a clear description of the changes in this PR --> ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin. --------- Co-authored-by: Igor Ilic <30923996+dexters1@users.noreply.github.com>
This commit is contained in:
parent
d1eab97102
commit
4e9ca94e78
20 changed files with 3191 additions and 1482 deletions
24
.github/workflows/notebooks_tests.yml
vendored
24
.github/workflows/notebooks_tests.yml
vendored
|
|
@ -4,12 +4,12 @@ on:
|
|||
workflow_call:
|
||||
|
||||
jobs:
|
||||
run-main-notebook:
|
||||
name: Main Notebook Test
|
||||
uses: ./.github/workflows/reusable_notebook.yml
|
||||
with:
|
||||
notebook-location: notebooks/cognee_demo.ipynb
|
||||
secrets: inherit
|
||||
# run-main-notebook:
|
||||
# name: Main Notebook Test
|
||||
# uses: ./.github/workflows/reusable_notebook.yml
|
||||
# with:
|
||||
# notebook-location: notebooks/cognee_demo.ipynb
|
||||
# secrets: inherit
|
||||
|
||||
run-llama-index-integration:
|
||||
name: LlamaIndex Integration Notebook
|
||||
|
|
@ -32,9 +32,9 @@ jobs:
|
|||
notebook-location: notebooks/cognee_multimedia_demo.ipynb
|
||||
secrets: inherit
|
||||
|
||||
run-graphrag-vs-rag:
|
||||
name: Graphrag vs Rag notebook
|
||||
uses: ./.github/workflows/reusable_notebook.yml
|
||||
with:
|
||||
notebook-location: notebooks/graphrag_vs_rag.ipynb
|
||||
secrets: inherit
|
||||
# run-graphrag-vs-rag:
|
||||
# name: Graphrag vs Rag notebook
|
||||
# uses: ./.github/workflows/reusable_notebook.yml
|
||||
# with:
|
||||
# notebook-location: notebooks/graphrag_vs_rag.ipynb
|
||||
# secrets: inherit
|
||||
|
|
|
|||
|
|
@ -10,6 +10,10 @@ from cognee.infrastructure.llm.tokenizer.Gemini import GeminiTokenizer
|
|||
from cognee.infrastructure.llm.tokenizer.HuggingFace import HuggingFaceTokenizer
|
||||
from cognee.infrastructure.llm.tokenizer.Mistral import MistralTokenizer
|
||||
from cognee.infrastructure.llm.tokenizer.TikToken import TikTokenTokenizer
|
||||
from cognee.infrastructure.llm.embedding_rate_limiter import (
|
||||
embedding_rate_limit_async,
|
||||
embedding_sleep_and_retry_async,
|
||||
)
|
||||
|
||||
litellm.set_verbose = False
|
||||
logger = get_logger("LiteLLMEmbeddingEngine")
|
||||
|
|
@ -51,17 +55,12 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
|||
enable_mocking = str(enable_mocking).lower()
|
||||
self.mock = enable_mocking in ("true", "1", "yes")
|
||||
|
||||
@embedding_sleep_and_retry_async()
|
||||
@embedding_rate_limit_async
|
||||
async def embed_text(self, text: List[str]) -> List[List[float]]:
|
||||
async def exponential_backoff(attempt):
|
||||
wait_time = min(10 * (2**attempt), 60) # Max 60 seconds
|
||||
await asyncio.sleep(wait_time)
|
||||
|
||||
try:
|
||||
if self.mock:
|
||||
response = {"data": [{"embedding": [0.0] * self.dimensions} for _ in text]}
|
||||
|
||||
self.retry_count = 0
|
||||
|
||||
return [data["embedding"] for data in response["data"]]
|
||||
else:
|
||||
response = await litellm.aembedding(
|
||||
|
|
@ -72,8 +71,6 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
|||
api_version=self.api_version,
|
||||
)
|
||||
|
||||
self.retry_count = 0 # Reset retry count on successful call
|
||||
|
||||
return [data["embedding"] for data in response.data]
|
||||
|
||||
except litellm.exceptions.ContextWindowExceededError as error:
|
||||
|
|
@ -95,15 +92,6 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
|||
logger.error("Context window exceeded for embedding text: %s", str(error))
|
||||
raise error
|
||||
|
||||
except litellm.exceptions.RateLimitError:
|
||||
if self.retry_count >= self.MAX_RETRIES:
|
||||
raise Exception("Rate limit exceeded and no more retries left.")
|
||||
|
||||
await exponential_backoff(self.retry_count)
|
||||
self.retry_count += 1
|
||||
|
||||
return await self.embed_text(text)
|
||||
|
||||
except (
|
||||
litellm.exceptions.BadRequestError,
|
||||
litellm.exceptions.NotFoundError,
|
||||
|
|
|
|||
|
|
@ -9,6 +9,10 @@ import aiohttp.http_exceptions
|
|||
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
|
||||
from cognee.infrastructure.databases.exceptions.EmbeddingException import EmbeddingException
|
||||
from cognee.infrastructure.llm.tokenizer.HuggingFace import HuggingFaceTokenizer
|
||||
from cognee.infrastructure.llm.rate_limiter import (
|
||||
embedding_rate_limit_async,
|
||||
embedding_sleep_and_retry_async,
|
||||
)
|
||||
|
||||
logger = get_logger("OllamaEmbeddingEngine")
|
||||
|
||||
|
|
@ -43,6 +47,7 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
|
|||
enable_mocking = str(enable_mocking).lower()
|
||||
self.mock = enable_mocking in ("true", "1", "yes")
|
||||
|
||||
@embedding_rate_limit_async
|
||||
async def embed_text(self, text: List[str]) -> List[List[float]]:
|
||||
"""
|
||||
Given a list of text prompts, returns a list of embedding vectors.
|
||||
|
|
@ -53,6 +58,7 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
|
|||
embeddings = await asyncio.gather(*[self._get_embedding(prompt) for prompt in text])
|
||||
return embeddings
|
||||
|
||||
@embedding_sleep_and_retry_async()
|
||||
async def _get_embedding(self, prompt: str) -> List[float]:
|
||||
"""
|
||||
Internal method to call the Ollama embeddings endpoint for a single prompt.
|
||||
|
|
@ -66,26 +72,12 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
|
|||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
retries = 0
|
||||
while retries < self.MAX_RETRIES:
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
self.endpoint, json=payload, headers=headers, timeout=60.0
|
||||
) as response:
|
||||
data = await response.json()
|
||||
return data["embedding"]
|
||||
except aiohttp.http_exceptions.HttpBadRequest as e:
|
||||
logger.error(f"HTTP error on attempt {retries + 1}: {e}")
|
||||
retries += 1
|
||||
await asyncio.sleep(min(2**retries, 60))
|
||||
except Exception as e:
|
||||
logger.error(f"Error on attempt {retries + 1}: {e}")
|
||||
retries += 1
|
||||
await asyncio.sleep(min(2**retries, 60))
|
||||
raise EmbeddingException(
|
||||
f"Failed to embed text using model {self.model} after {self.MAX_RETRIES} retries"
|
||||
)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
self.endpoint, json=payload, headers=headers, timeout=60.0
|
||||
) as response:
|
||||
data = await response.json()
|
||||
return data["embedding"]
|
||||
|
||||
def get_vector_size(self) -> int:
|
||||
return self.dimensions
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import instructor
|
|||
from cognee.exceptions import InvalidValueError
|
||||
from cognee.infrastructure.llm.llm_interface import LLMInterface
|
||||
from cognee.infrastructure.llm.prompts import read_query_prompt
|
||||
from cognee.infrastructure.llm.rate_limiter import rate_limit_async, sleep_and_retry_async
|
||||
|
||||
|
||||
class AnthropicAdapter(LLMInterface):
|
||||
|
|
@ -23,6 +24,8 @@ class AnthropicAdapter(LLMInterface):
|
|||
self.model = model
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
@sleep_and_retry_async()
|
||||
@rate_limit_async
|
||||
async def acreate_structured_output(
|
||||
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
||||
) -> BaseModel:
|
||||
|
|
|
|||
|
|
@ -16,6 +16,12 @@ class LLMConfig(BaseSettings):
|
|||
llm_max_tokens: int = 16384
|
||||
transcription_model: str = "whisper-1"
|
||||
graph_prompt_path: str = "generate_graph_prompt.txt"
|
||||
llm_rate_limit_enabled: bool = False
|
||||
llm_rate_limit_requests: int = 60
|
||||
llm_rate_limit_interval: int = 60 # in seconds (default is 60 requests per minute)
|
||||
embedding_rate_limit_enabled: bool = False
|
||||
embedding_rate_limit_requests: int = 60
|
||||
embedding_rate_limit_interval: int = 60 # in seconds (default is 60 requests per minute)
|
||||
|
||||
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
||||
|
||||
|
|
@ -85,6 +91,12 @@ class LLMConfig(BaseSettings):
|
|||
"max_tokens": self.llm_max_tokens,
|
||||
"transcription_model": self.transcription_model,
|
||||
"graph_prompt_path": self.graph_prompt_path,
|
||||
"rate_limit_enabled": self.llm_rate_limit_enabled,
|
||||
"rate_limit_requests": self.llm_rate_limit_requests,
|
||||
"rate_limit_interval": self.llm_rate_limit_interval,
|
||||
"embedding_rate_limit_enabled": self.embedding_rate_limit_enabled,
|
||||
"embedding_rate_limit_requests": self.embedding_rate_limit_requests,
|
||||
"embedding_rate_limit_interval": self.embedding_rate_limit_interval,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
369
cognee/infrastructure/llm/embedding_rate_limiter.py
Normal file
369
cognee/infrastructure/llm/embedding_rate_limiter.py
Normal file
|
|
@ -0,0 +1,369 @@
|
|||
import threading
|
||||
import logging
|
||||
import functools
|
||||
import os
|
||||
import time
|
||||
import asyncio
|
||||
import random
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.infrastructure.llm.config import get_llm_config
|
||||
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
# Common error patterns that indicate rate limiting
|
||||
RATE_LIMIT_ERROR_PATTERNS = [
|
||||
"rate limit",
|
||||
"rate_limit",
|
||||
"ratelimit",
|
||||
"too many requests",
|
||||
"retry after",
|
||||
"capacity",
|
||||
"quota",
|
||||
"limit exceeded",
|
||||
"tps limit exceeded",
|
||||
"request limit exceeded",
|
||||
"maximum requests",
|
||||
"exceeded your current quota",
|
||||
"throttled",
|
||||
"throttling",
|
||||
]
|
||||
|
||||
# Default retry settings
|
||||
DEFAULT_MAX_RETRIES = 5
|
||||
DEFAULT_INITIAL_BACKOFF = 1.0 # seconds
|
||||
DEFAULT_BACKOFF_FACTOR = 2.0 # exponential backoff multiplier
|
||||
DEFAULT_JITTER = 0.1 # 10% jitter to avoid thundering herd
|
||||
|
||||
|
||||
class EmbeddingRateLimiter:
|
||||
"""
|
||||
Rate limiter for embedding API calls.
|
||||
|
||||
This class implements a singleton pattern to ensure that rate limiting
|
||||
is consistent across all embedding requests. It uses the limits
|
||||
library with a moving window strategy to control request rates.
|
||||
|
||||
The rate limiter uses the same configuration as the LLM API rate limiter
|
||||
but uses a separate key to track embedding API calls independently.
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
lock = threading.Lock()
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls):
|
||||
if cls._instance is None:
|
||||
with cls.lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
|
||||
@classmethod
|
||||
def reset_instance(cls):
|
||||
with cls.lock:
|
||||
cls._instance = None
|
||||
|
||||
def __init__(self):
|
||||
config = get_llm_config()
|
||||
self.enabled = config.embedding_rate_limit_enabled
|
||||
self.requests_limit = config.embedding_rate_limit_requests
|
||||
self.interval_seconds = config.embedding_rate_limit_interval
|
||||
self.request_times = []
|
||||
self.lock = threading.Lock()
|
||||
|
||||
logging.info(
|
||||
f"EmbeddingRateLimiter initialized: enabled={self.enabled}, "
|
||||
f"requests_limit={self.requests_limit}, interval_seconds={self.interval_seconds}"
|
||||
)
|
||||
|
||||
def hit_limit(self) -> bool:
|
||||
"""
|
||||
Check if the current request would exceed the rate limit.
|
||||
|
||||
Returns:
|
||||
bool: True if the rate limit would be exceeded, False otherwise.
|
||||
"""
|
||||
if not self.enabled:
|
||||
return False
|
||||
|
||||
current_time = time.time()
|
||||
|
||||
with self.lock:
|
||||
# Remove expired request times
|
||||
cutoff_time = current_time - self.interval_seconds
|
||||
self.request_times = [t for t in self.request_times if t > cutoff_time]
|
||||
|
||||
# Check if adding a new request would exceed the limit
|
||||
if len(self.request_times) >= self.requests_limit:
|
||||
logger.info(
|
||||
f"Rate limit hit: {len(self.request_times)} requests in the last {self.interval_seconds} seconds"
|
||||
)
|
||||
return True
|
||||
|
||||
# Otherwise, we're under the limit
|
||||
return False
|
||||
|
||||
def wait_if_needed(self) -> float:
|
||||
"""
|
||||
Block until a request can be made without exceeding the rate limit.
|
||||
|
||||
Returns:
|
||||
float: Time waited in seconds.
|
||||
"""
|
||||
if not self.enabled:
|
||||
return 0
|
||||
|
||||
wait_time = 0
|
||||
start_time = time.time()
|
||||
|
||||
while self.hit_limit():
|
||||
time.sleep(0.5) # Poll every 0.5 seconds
|
||||
wait_time = time.time() - start_time
|
||||
|
||||
# Record this request
|
||||
with self.lock:
|
||||
self.request_times.append(time.time())
|
||||
|
||||
return wait_time
|
||||
|
||||
async def async_wait_if_needed(self) -> float:
|
||||
"""
|
||||
Asynchronously wait until a request can be made without exceeding the rate limit.
|
||||
|
||||
Returns:
|
||||
float: Time waited in seconds.
|
||||
"""
|
||||
if not self.enabled:
|
||||
return 0
|
||||
|
||||
wait_time = 0
|
||||
start_time = time.time()
|
||||
|
||||
while self.hit_limit():
|
||||
await asyncio.sleep(0.5) # Poll every 0.5 seconds
|
||||
wait_time = time.time() - start_time
|
||||
|
||||
# Record this request
|
||||
with self.lock:
|
||||
self.request_times.append(time.time())
|
||||
|
||||
return wait_time
|
||||
|
||||
|
||||
def embedding_rate_limit_sync(func):
|
||||
"""
|
||||
Decorator that applies rate limiting to a synchronous embedding function.
|
||||
|
||||
This decorator checks if the request would exceed the rate limit,
|
||||
and blocks if necessary.
|
||||
|
||||
Args:
|
||||
func: Function to decorate.
|
||||
|
||||
Returns:
|
||||
Decorated function that applies rate limiting.
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
limiter = EmbeddingRateLimiter.get_instance()
|
||||
|
||||
# Check if rate limiting is enabled and if we're at the limit
|
||||
if limiter.hit_limit():
|
||||
error_msg = "Embedding API rate limit exceeded"
|
||||
logger.warning(error_msg)
|
||||
|
||||
# Create a custom embedding rate limit exception
|
||||
from cognee.infrastructure.databases.exceptions.EmbeddingException import (
|
||||
EmbeddingException,
|
||||
)
|
||||
|
||||
raise EmbeddingException(error_msg)
|
||||
|
||||
# Add this request to the counter and proceed
|
||||
limiter.wait_if_needed()
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def embedding_rate_limit_async(func):
|
||||
"""
|
||||
Decorator that applies rate limiting to an asynchronous embedding function.
|
||||
|
||||
This decorator checks if the request would exceed the rate limit,
|
||||
and waits asynchronously if necessary.
|
||||
|
||||
Args:
|
||||
func: Async function to decorate.
|
||||
|
||||
Returns:
|
||||
Decorated async function that applies rate limiting.
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
limiter = EmbeddingRateLimiter.get_instance()
|
||||
|
||||
# Check if rate limiting is enabled and if we're at the limit
|
||||
if limiter.hit_limit():
|
||||
error_msg = "Embedding API rate limit exceeded"
|
||||
logger.warning(error_msg)
|
||||
|
||||
# Create a custom embedding rate limit exception
|
||||
from cognee.infrastructure.databases.exceptions.EmbeddingException import (
|
||||
EmbeddingException,
|
||||
)
|
||||
|
||||
raise EmbeddingException(error_msg)
|
||||
|
||||
# Add this request to the counter and proceed
|
||||
await limiter.async_wait_if_needed()
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def embedding_sleep_and_retry_sync(max_retries=5, base_backoff=1.0, jitter=0.5):
|
||||
"""
|
||||
Decorator that adds retry with exponential backoff for synchronous embedding functions.
|
||||
|
||||
The decorator will retry the function with exponential backoff if it
|
||||
fails due to a rate limit error.
|
||||
|
||||
Args:
|
||||
max_retries: Maximum number of retries.
|
||||
base_backoff: Base backoff time in seconds.
|
||||
jitter: Jitter factor to randomize backoff time.
|
||||
|
||||
Returns:
|
||||
Decorated function that retries on rate limit errors.
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
# If DISABLE_RETRIES is set, don't retry for testing purposes
|
||||
disable_retries = os.environ.get("DISABLE_RETRIES", "false").lower() in (
|
||||
"true",
|
||||
"1",
|
||||
"yes",
|
||||
)
|
||||
|
||||
retries = 0
|
||||
last_error = None
|
||||
|
||||
while retries <= max_retries:
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
# Check if this is a rate limit error
|
||||
error_str = str(e).lower()
|
||||
error_type = type(e).__name__
|
||||
is_rate_limit = any(
|
||||
pattern in error_str.lower() for pattern in RATE_LIMIT_ERROR_PATTERNS
|
||||
)
|
||||
|
||||
if disable_retries:
|
||||
# For testing, propagate the exception immediately
|
||||
raise
|
||||
|
||||
if is_rate_limit and retries < max_retries:
|
||||
# Calculate backoff with jitter
|
||||
backoff = (
|
||||
base_backoff * (2**retries) * (1 + random.uniform(-jitter, jitter))
|
||||
)
|
||||
|
||||
logger.warning(
|
||||
f"Embedding rate limit hit, retrying in {backoff:.2f}s "
|
||||
f"(attempt {retries + 1}/{max_retries}): "
|
||||
f"({error_str!r}, {error_type!r})"
|
||||
)
|
||||
|
||||
time.sleep(backoff)
|
||||
retries += 1
|
||||
last_error = e
|
||||
else:
|
||||
# Not a rate limit error or max retries reached, raise
|
||||
raise
|
||||
|
||||
# If we exit the loop due to max retries, raise the last error
|
||||
if last_error:
|
||||
raise last_error
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def embedding_sleep_and_retry_async(max_retries=5, base_backoff=1.0, jitter=0.5):
|
||||
"""
|
||||
Decorator that adds retry with exponential backoff for asynchronous embedding functions.
|
||||
|
||||
The decorator will retry the function with exponential backoff if it
|
||||
fails due to a rate limit error.
|
||||
|
||||
Args:
|
||||
max_retries: Maximum number of retries.
|
||||
base_backoff: Base backoff time in seconds.
|
||||
jitter: Jitter factor to randomize backoff time.
|
||||
|
||||
Returns:
|
||||
Decorated async function that retries on rate limit errors.
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
# If DISABLE_RETRIES is set, don't retry for testing purposes
|
||||
disable_retries = os.environ.get("DISABLE_RETRIES", "false").lower() in (
|
||||
"true",
|
||||
"1",
|
||||
"yes",
|
||||
)
|
||||
|
||||
retries = 0
|
||||
last_error = None
|
||||
|
||||
while retries <= max_retries:
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
# Check if this is a rate limit error
|
||||
error_str = str(e).lower()
|
||||
error_type = type(e).__name__
|
||||
is_rate_limit = any(
|
||||
pattern in error_str.lower() for pattern in RATE_LIMIT_ERROR_PATTERNS
|
||||
)
|
||||
|
||||
if disable_retries:
|
||||
# For testing, propagate the exception immediately
|
||||
raise
|
||||
|
||||
if is_rate_limit and retries < max_retries:
|
||||
# Calculate backoff with jitter
|
||||
backoff = (
|
||||
base_backoff * (2**retries) * (1 + random.uniform(-jitter, jitter))
|
||||
)
|
||||
|
||||
logger.warning(
|
||||
f"Embedding rate limit hit, retrying in {backoff:.2f}s "
|
||||
f"(attempt {retries + 1}/{max_retries}): "
|
||||
f"({error_str!r}, {error_type!r})"
|
||||
)
|
||||
|
||||
await asyncio.sleep(backoff)
|
||||
retries += 1
|
||||
last_error = e
|
||||
else:
|
||||
# Not a rate limit error or max retries reached, raise
|
||||
raise
|
||||
|
||||
# If we exit the loop due to max retries, raise the last error
|
||||
if last_error:
|
||||
raise last_error
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
|
@ -7,6 +7,10 @@ from cognee.shared.data_models import MonitoringTool
|
|||
from cognee.exceptions import InvalidValueError
|
||||
from cognee.infrastructure.llm.llm_interface import LLMInterface
|
||||
from cognee.infrastructure.llm.prompts import read_query_prompt
|
||||
from cognee.infrastructure.llm.rate_limiter import (
|
||||
rate_limit_async,
|
||||
sleep_and_retry_async,
|
||||
)
|
||||
from cognee.base_config import get_base_config
|
||||
|
||||
logger = get_logger()
|
||||
|
|
@ -37,6 +41,8 @@ class GeminiAdapter(LLMInterface):
|
|||
self.max_tokens = max_tokens
|
||||
|
||||
@observe(as_type="generation")
|
||||
@sleep_and_retry_async()
|
||||
@rate_limit_async
|
||||
async def acreate_structured_output(
|
||||
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
||||
) -> BaseModel:
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from pydantic import BaseModel
|
|||
import instructor
|
||||
from cognee.infrastructure.llm.llm_interface import LLMInterface
|
||||
from cognee.infrastructure.llm.config import get_llm_config
|
||||
from cognee.infrastructure.llm.rate_limiter import rate_limit_async, sleep_and_retry_async
|
||||
import litellm
|
||||
|
||||
|
||||
|
|
@ -27,6 +28,8 @@ class GenericAPIAdapter(LLMInterface):
|
|||
litellm.acompletion, mode=instructor.Mode.JSON, api_key=api_key
|
||||
)
|
||||
|
||||
@sleep_and_retry_async()
|
||||
@rate_limit_async
|
||||
async def acreate_structured_output(
|
||||
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
||||
) -> BaseModel:
|
||||
|
|
|
|||
|
|
@ -3,6 +3,11 @@ from pydantic import BaseModel
|
|||
import instructor
|
||||
from cognee.infrastructure.llm.llm_interface import LLMInterface
|
||||
from cognee.infrastructure.llm.config import get_llm_config
|
||||
from cognee.infrastructure.llm.rate_limiter import (
|
||||
rate_limit_async,
|
||||
rate_limit_sync,
|
||||
sleep_and_retry_async,
|
||||
)
|
||||
from openai import OpenAI
|
||||
import base64
|
||||
import os
|
||||
|
|
@ -22,6 +27,8 @@ class OllamaAPIAdapter(LLMInterface):
|
|||
OpenAI(base_url=self.endpoint, api_key=self.api_key), mode=instructor.Mode.JSON
|
||||
)
|
||||
|
||||
@sleep_and_retry_async()
|
||||
@rate_limit_async
|
||||
async def acreate_structured_output(
|
||||
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
||||
) -> BaseModel:
|
||||
|
|
@ -45,6 +52,7 @@ class OllamaAPIAdapter(LLMInterface):
|
|||
|
||||
return response
|
||||
|
||||
@rate_limit_sync
|
||||
def create_transcript(self, input_file: str) -> str:
|
||||
"""Generate an audio transcript from a user query."""
|
||||
|
||||
|
|
@ -64,6 +72,7 @@ class OllamaAPIAdapter(LLMInterface):
|
|||
|
||||
return transcription.text
|
||||
|
||||
@rate_limit_sync
|
||||
def transcribe_image(self, input_file: str) -> str:
|
||||
"""Transcribe content from an image using base64 encoding."""
|
||||
|
||||
|
|
@ -79,7 +88,7 @@ class OllamaAPIAdapter(LLMInterface):
|
|||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What’s in this image?"},
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"},
|
||||
|
|
|
|||
|
|
@ -10,6 +10,12 @@ from cognee.shared.data_models import MonitoringTool
|
|||
from cognee.exceptions import InvalidValueError
|
||||
from cognee.infrastructure.llm.llm_interface import LLMInterface
|
||||
from cognee.infrastructure.llm.prompts import read_query_prompt
|
||||
from cognee.infrastructure.llm.rate_limiter import (
|
||||
rate_limit_async,
|
||||
rate_limit_sync,
|
||||
sleep_and_retry_async,
|
||||
sleep_and_retry_sync,
|
||||
)
|
||||
from cognee.base_config import get_base_config
|
||||
|
||||
monitoring = get_base_config().monitoring_tool
|
||||
|
|
@ -49,6 +55,8 @@ class OpenAIAdapter(LLMInterface):
|
|||
self.streaming = streaming
|
||||
|
||||
@observe(as_type="generation")
|
||||
@sleep_and_retry_async()
|
||||
@rate_limit_async
|
||||
async def acreate_structured_output(
|
||||
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
||||
) -> BaseModel:
|
||||
|
|
@ -75,6 +83,8 @@ class OpenAIAdapter(LLMInterface):
|
|||
)
|
||||
|
||||
@observe
|
||||
@sleep_and_retry_sync()
|
||||
@rate_limit_sync
|
||||
def create_structured_output(
|
||||
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
||||
) -> BaseModel:
|
||||
|
|
@ -100,6 +110,7 @@ class OpenAIAdapter(LLMInterface):
|
|||
max_retries=self.MAX_RETRIES,
|
||||
)
|
||||
|
||||
@rate_limit_sync
|
||||
def create_transcript(self, input):
|
||||
"""Generate a audio transcript from a user query."""
|
||||
|
||||
|
|
@ -120,6 +131,7 @@ class OpenAIAdapter(LLMInterface):
|
|||
|
||||
return transcription
|
||||
|
||||
@rate_limit_sync
|
||||
def transcribe_image(self, input) -> BaseModel:
|
||||
with open(input, "rb") as image_file:
|
||||
encoded_image = base64.b64encode(image_file.read()).decode("utf-8")
|
||||
|
|
@ -132,7 +144,7 @@ class OpenAIAdapter(LLMInterface):
|
|||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What’s in this image?",
|
||||
"text": "What's in this image?",
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
|
|
|
|||
376
cognee/infrastructure/llm/rate_limiter.py
Normal file
376
cognee/infrastructure/llm/rate_limiter.py
Normal file
|
|
@ -0,0 +1,376 @@
|
|||
"""
|
||||
Rate limiter for LLM API calls.
|
||||
|
||||
This module provides rate limiting functionality for LLM API calls to prevent exceeding
|
||||
API provider rate limits. The implementation uses the `limits` library with a moving window
|
||||
strategy to limit requests.
|
||||
|
||||
Configuration is done through the LLMConfig with these settings:
|
||||
- llm_rate_limit_enabled: Whether rate limiting is enabled (default: False)
|
||||
- llm_rate_limit_requests: Maximum number of requests allowed per interval (default: 60)
|
||||
- llm_rate_limit_interval: Interval in seconds for the rate limiting window (default: 60)
|
||||
|
||||
Usage:
|
||||
1. Add the decorator to any function that makes API calls:
|
||||
@rate_limit_sync
|
||||
def my_function():
|
||||
# Function that makes API calls
|
||||
|
||||
2. For async functions, use the async decorator:
|
||||
@rate_limit_async
|
||||
async def my_async_function():
|
||||
# Async function that makes API calls
|
||||
|
||||
3. For automatic retrying on rate limit errors:
|
||||
@sleep_and_retry_sync
|
||||
def my_function():
|
||||
# Function that may experience rate limit errors
|
||||
|
||||
4. For async functions with automatic retrying:
|
||||
@sleep_and_retry_async
|
||||
async def my_async_function():
|
||||
# Async function that may experience rate limit errors
|
||||
|
||||
5. For embedding rate limiting (uses the same configuration but separate limiter):
|
||||
@embedding_rate_limit_async
|
||||
async def my_embedding_function():
|
||||
# Async function for embedding API calls
|
||||
|
||||
6. For embedding with auto-retry:
|
||||
@embedding_sleep_and_retry_async
|
||||
async def my_embedding_function():
|
||||
# Async function for embedding with auto-retry
|
||||
"""
|
||||
|
||||
import time
|
||||
import asyncio
|
||||
import random
|
||||
from functools import wraps
|
||||
from limits import RateLimitItemPerMinute, storage
|
||||
from limits.strategies import MovingWindowRateLimiter
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.infrastructure.llm.config import get_llm_config
|
||||
import threading
|
||||
import logging
|
||||
import functools
|
||||
import openai
|
||||
import os
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
# Common error patterns that indicate rate limiting
|
||||
RATE_LIMIT_ERROR_PATTERNS = [
|
||||
"rate limit",
|
||||
"rate_limit",
|
||||
"ratelimit",
|
||||
"too many requests",
|
||||
"retry after",
|
||||
"capacity",
|
||||
"quota",
|
||||
"limit exceeded",
|
||||
"tps limit exceeded",
|
||||
"request limit exceeded",
|
||||
"maximum requests",
|
||||
"exceeded your current quota",
|
||||
"throttled",
|
||||
"throttling",
|
||||
]
|
||||
|
||||
# Default retry settings
|
||||
DEFAULT_MAX_RETRIES = 5
|
||||
DEFAULT_INITIAL_BACKOFF = 1.0 # seconds
|
||||
DEFAULT_BACKOFF_FACTOR = 2.0 # exponential backoff multiplier
|
||||
DEFAULT_JITTER = 0.1 # 10% jitter to avoid thundering herd
|
||||
|
||||
|
||||
class llm_rate_limiter:
|
||||
"""
|
||||
Rate limiter for LLM API calls.
|
||||
|
||||
This class implements a singleton pattern to ensure that rate limiting
|
||||
is consistent across all parts of the application. It uses the limits
|
||||
library with a moving window strategy to control request rates.
|
||||
|
||||
The rate limiter converts the configured requests/interval to a per-minute
|
||||
rate for compatibility with the limits library's built-in rate limit items.
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super(llm_rate_limiter, cls).__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
config = get_llm_config()
|
||||
self._enabled = config.llm_rate_limit_enabled
|
||||
self._requests = config.llm_rate_limit_requests
|
||||
self._interval = config.llm_rate_limit_interval
|
||||
|
||||
# Using in-memory storage by default
|
||||
self._storage = storage.MemoryStorage()
|
||||
self._limiter = MovingWindowRateLimiter(self._storage)
|
||||
|
||||
# Use the built-in per-minute rate limit item
|
||||
# We need to adjust the number of requests if interval isn't exactly 60s
|
||||
if self._interval == 60:
|
||||
self._rate_per_minute = self._requests
|
||||
else:
|
||||
self._rate_per_minute = int(self._requests * (60 / self._interval))
|
||||
|
||||
self._initialized = True
|
||||
|
||||
if self._enabled:
|
||||
logger.info(
|
||||
f"LLM Rate Limiter initialized: {self._requests} requests per {self._interval}s"
|
||||
)
|
||||
|
||||
def hit_limit(self) -> bool:
|
||||
"""
|
||||
Record a hit and check if limit is exceeded.
|
||||
|
||||
This method checks whether making a request now would exceed the
|
||||
configured rate limit. If rate limiting is disabled, it always
|
||||
returns True.
|
||||
|
||||
Returns:
|
||||
bool: True if the request is allowed, False otherwise.
|
||||
"""
|
||||
if not self._enabled:
|
||||
return True
|
||||
|
||||
# Create a fresh rate limit item for each check
|
||||
rate_limit = RateLimitItemPerMinute(self._rate_per_minute)
|
||||
|
||||
# Use a consistent key for the API to ensure proper rate limiting
|
||||
return self._limiter.hit(rate_limit, "llm_api")
|
||||
|
||||
def wait_if_needed(self) -> float:
|
||||
"""
|
||||
Wait if rate limit is reached.
|
||||
|
||||
This method blocks until the request can be made without exceeding
|
||||
the rate limit. It polls every 0.5 seconds.
|
||||
|
||||
Returns:
|
||||
float: Time waited in seconds.
|
||||
"""
|
||||
if not self._enabled:
|
||||
return 0
|
||||
|
||||
waited = 0
|
||||
while not self.hit_limit():
|
||||
time.sleep(0.5)
|
||||
waited += 0.5
|
||||
|
||||
return waited
|
||||
|
||||
async def async_wait_if_needed(self) -> float:
|
||||
"""
|
||||
Async wait if rate limit is reached.
|
||||
|
||||
This method asynchronously waits until the request can be made without
|
||||
exceeding the rate limit. It polls every 0.5 seconds.
|
||||
|
||||
Returns:
|
||||
float: Time waited in seconds.
|
||||
"""
|
||||
if not self._enabled:
|
||||
return 0
|
||||
|
||||
waited = 0
|
||||
while not self.hit_limit():
|
||||
await asyncio.sleep(0.5)
|
||||
waited += 0.5
|
||||
|
||||
return waited
|
||||
|
||||
|
||||
def rate_limit_sync(func):
|
||||
"""
|
||||
Decorator for rate limiting synchronous functions.
|
||||
|
||||
This decorator ensures that the decorated function respects the
|
||||
configured rate limits. If the rate limit would be exceeded,
|
||||
the decorator blocks until the request can be made.
|
||||
|
||||
Args:
|
||||
func: The synchronous function to decorate.
|
||||
|
||||
Returns:
|
||||
The decorated function.
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
limiter = llm_rate_limiter()
|
||||
waited = limiter.wait_if_needed()
|
||||
if waited > 0:
|
||||
logger.debug(f"Rate limited LLM API call, waited for {waited}s")
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def rate_limit_async(func):
|
||||
"""
|
||||
Decorator for rate limiting asynchronous functions.
|
||||
|
||||
This decorator ensures that the decorated async function respects the
|
||||
configured rate limits. If the rate limit would be exceeded,
|
||||
the decorator asynchronously waits until the request can be made.
|
||||
|
||||
Args:
|
||||
func: The asynchronous function to decorate.
|
||||
|
||||
Returns:
|
||||
The decorated async function.
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
limiter = llm_rate_limiter()
|
||||
waited = await limiter.async_wait_if_needed()
|
||||
if waited > 0:
|
||||
logger.debug(f"Rate limited LLM API call, waited for {waited}s")
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def is_rate_limit_error(error):
|
||||
"""
|
||||
Check if an error is related to rate limiting.
|
||||
|
||||
Args:
|
||||
error: The exception to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the error is rate-limit related, False otherwise.
|
||||
"""
|
||||
error_str = str(error).lower()
|
||||
return any(pattern.lower() in error_str for pattern in RATE_LIMIT_ERROR_PATTERNS)
|
||||
|
||||
|
||||
def calculate_backoff(
|
||||
attempt,
|
||||
initial_backoff=DEFAULT_INITIAL_BACKOFF,
|
||||
backoff_factor=DEFAULT_BACKOFF_FACTOR,
|
||||
jitter=DEFAULT_JITTER,
|
||||
):
|
||||
"""
|
||||
Calculate the backoff time for a retry attempt with jitter.
|
||||
|
||||
Args:
|
||||
attempt: The current retry attempt (0-based).
|
||||
initial_backoff: The initial backoff time in seconds.
|
||||
backoff_factor: The multiplier for exponential backoff.
|
||||
jitter: The jitter factor to avoid thundering herd.
|
||||
|
||||
Returns:
|
||||
float: The backoff time in seconds.
|
||||
"""
|
||||
backoff = initial_backoff * (backoff_factor**attempt)
|
||||
jitter_amount = backoff * jitter
|
||||
return backoff + random.uniform(-jitter_amount, jitter_amount)
|
||||
|
||||
|
||||
def sleep_and_retry_sync(
|
||||
max_retries=DEFAULT_MAX_RETRIES,
|
||||
initial_backoff=DEFAULT_INITIAL_BACKOFF,
|
||||
backoff_factor=DEFAULT_BACKOFF_FACTOR,
|
||||
jitter=DEFAULT_JITTER,
|
||||
):
|
||||
"""
|
||||
Decorator that automatically retries a synchronous function when rate limit errors occur.
|
||||
|
||||
This decorator implements an exponential backoff strategy with jitter
|
||||
to handle rate limit errors efficiently.
|
||||
|
||||
Args:
|
||||
max_retries: Maximum number of retry attempts.
|
||||
initial_backoff: Initial backoff time in seconds.
|
||||
backoff_factor: Multiplier for exponential backoff.
|
||||
jitter: Jitter factor to avoid the thundering herd problem.
|
||||
|
||||
Returns:
|
||||
The decorated function.
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
attempt = 0
|
||||
while True:
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
attempt += 1
|
||||
if not is_rate_limit_error(e) or attempt > max_retries:
|
||||
raise
|
||||
|
||||
backoff_time = calculate_backoff(
|
||||
attempt, initial_backoff, backoff_factor, jitter
|
||||
)
|
||||
logger.warning(
|
||||
f"Rate limit hit, retrying in {backoff_time:.2f}s "
|
||||
f"(attempt {attempt}/{max_retries}): {str(e)}"
|
||||
)
|
||||
time.sleep(backoff_time)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def sleep_and_retry_async(
|
||||
max_retries=DEFAULT_MAX_RETRIES,
|
||||
initial_backoff=DEFAULT_INITIAL_BACKOFF,
|
||||
backoff_factor=DEFAULT_BACKOFF_FACTOR,
|
||||
jitter=DEFAULT_JITTER,
|
||||
):
|
||||
"""
|
||||
Decorator that automatically retries an asynchronous function when rate limit errors occur.
|
||||
|
||||
This decorator implements an exponential backoff strategy with jitter
|
||||
to handle rate limit errors efficiently.
|
||||
|
||||
Args:
|
||||
max_retries: Maximum number of retry attempts.
|
||||
initial_backoff: Initial backoff time in seconds.
|
||||
backoff_factor: Multiplier for exponential backoff.
|
||||
jitter: Jitter factor to avoid the thundering herd problem.
|
||||
|
||||
Returns:
|
||||
The decorated async function.
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
attempt = 0
|
||||
while True:
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
attempt += 1
|
||||
if not is_rate_limit_error(e) or attempt > max_retries:
|
||||
raise
|
||||
|
||||
backoff_time = calculate_backoff(
|
||||
attempt, initial_backoff, backoff_factor, jitter
|
||||
)
|
||||
logger.warning(
|
||||
f"Rate limit hit, retrying in {backoff_time:.2f}s "
|
||||
f"(attempt {attempt}/{max_retries}): {str(e)}"
|
||||
)
|
||||
await asyncio.sleep(backoff_time)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
171
cognee/tests/unit/infrastructure/databases/test_rate_limiter.py
Normal file
171
cognee/tests/unit/infrastructure/databases/test_rate_limiter.py
Normal file
|
|
@ -0,0 +1,171 @@
|
|||
"""Tests for the LLM rate limiter."""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
import time
|
||||
from unittest.mock import patch
|
||||
from cognee.infrastructure.llm.rate_limiter import (
|
||||
llm_rate_limiter,
|
||||
rate_limit_async,
|
||||
rate_limit_sync,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_limiter_singleton():
|
||||
"""Reset the singleton instance between tests."""
|
||||
llm_rate_limiter._instance = None
|
||||
yield
|
||||
|
||||
|
||||
def test_rate_limiter_initialization():
|
||||
"""Test that the rate limiter can be initialized properly."""
|
||||
with patch("cognee.infrastructure.llm.rate_limiter.get_llm_config") as mock_config:
|
||||
mock_config.return_value.llm_rate_limit_enabled = True
|
||||
mock_config.return_value.llm_rate_limit_requests = 10
|
||||
mock_config.return_value.llm_rate_limit_interval = 60 # 1 minute
|
||||
|
||||
limiter = llm_rate_limiter()
|
||||
|
||||
assert limiter._enabled is True
|
||||
assert limiter._requests == 10
|
||||
assert limiter._interval == 60
|
||||
|
||||
|
||||
def test_rate_limiter_disabled():
|
||||
"""Test that the rate limiter is disabled by default."""
|
||||
with patch("cognee.infrastructure.llm.rate_limiter.get_llm_config") as mock_config:
|
||||
mock_config.return_value.llm_rate_limit_enabled = False
|
||||
|
||||
limiter = llm_rate_limiter()
|
||||
|
||||
assert limiter._enabled is False
|
||||
assert limiter.hit_limit() is True # Should always return True when disabled
|
||||
|
||||
|
||||
def test_rate_limiter_singleton():
|
||||
"""Test that the rate limiter is a singleton."""
|
||||
with patch("cognee.infrastructure.llm.rate_limiter.get_llm_config") as mock_config:
|
||||
mock_config.return_value.llm_rate_limit_enabled = True
|
||||
mock_config.return_value.llm_rate_limit_requests = 5
|
||||
mock_config.return_value.llm_rate_limit_interval = 60
|
||||
|
||||
limiter1 = llm_rate_limiter()
|
||||
limiter2 = llm_rate_limiter()
|
||||
|
||||
assert limiter1 is limiter2
|
||||
|
||||
|
||||
def test_sync_decorator():
|
||||
"""Test the sync decorator."""
|
||||
with patch("cognee.infrastructure.llm.rate_limiter.llm_rate_limiter") as mock_limiter_class:
|
||||
mock_limiter = mock_limiter_class.return_value
|
||||
mock_limiter.wait_if_needed.return_value = 0
|
||||
|
||||
@rate_limit_sync
|
||||
def test_func():
|
||||
return "success"
|
||||
|
||||
result = test_func()
|
||||
|
||||
assert result == "success"
|
||||
mock_limiter.wait_if_needed.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_decorator():
|
||||
"""Test the async decorator."""
|
||||
with patch("cognee.infrastructure.llm.rate_limiter.llm_rate_limiter") as mock_limiter_class:
|
||||
mock_limiter = mock_limiter_class.return_value
|
||||
|
||||
# Mock an async method with a coroutine
|
||||
async def mock_wait():
|
||||
return 0
|
||||
|
||||
mock_limiter.async_wait_if_needed.return_value = mock_wait()
|
||||
|
||||
@rate_limit_async
|
||||
async def test_func():
|
||||
return "success"
|
||||
|
||||
result = await test_func()
|
||||
|
||||
assert result == "success"
|
||||
mock_limiter.async_wait_if_needed.assert_called_once()
|
||||
|
||||
|
||||
def test_rate_limiting_actual():
|
||||
"""Test actual rate limiting behavior with a small window."""
|
||||
with patch("cognee.infrastructure.llm.rate_limiter.get_llm_config") as mock_config:
|
||||
# Configure for 3 requests per minute
|
||||
mock_config.return_value.llm_rate_limit_enabled = True
|
||||
mock_config.return_value.llm_rate_limit_requests = 3
|
||||
mock_config.return_value.llm_rate_limit_interval = 60
|
||||
|
||||
# Create a fresh instance
|
||||
llm_rate_limiter._instance = None
|
||||
limiter = llm_rate_limiter()
|
||||
|
||||
# First 3 requests should succeed
|
||||
assert limiter.hit_limit() is True
|
||||
assert limiter.hit_limit() is True
|
||||
assert limiter.hit_limit() is True
|
||||
|
||||
# Fourth request should fail (exceed limit)
|
||||
assert limiter.hit_limit() is False
|
||||
|
||||
|
||||
def test_rate_limit_60_per_minute():
|
||||
"""Test rate limiting with the default 60 requests per minute limit."""
|
||||
with patch("cognee.infrastructure.llm.rate_limiter.get_llm_config") as mock_config:
|
||||
# Configure for default values: 60 requests per 60 seconds
|
||||
mock_config.return_value.llm_rate_limit_enabled = True
|
||||
mock_config.return_value.llm_rate_limit_requests = 60 # 60 requests
|
||||
mock_config.return_value.llm_rate_limit_interval = 60 # per minute
|
||||
|
||||
# Create a fresh instance
|
||||
llm_rate_limiter._instance = None
|
||||
limiter = llm_rate_limiter()
|
||||
|
||||
# Track successful and failed requests
|
||||
successes = []
|
||||
failures = []
|
||||
|
||||
# Send requests in batches until we see some failures
|
||||
# This simulates reaching the rate limit
|
||||
num_test_requests = 70 # Try more than our limit of 60
|
||||
|
||||
for i in range(num_test_requests):
|
||||
if limiter.hit_limit():
|
||||
successes.append(f"Request {i}")
|
||||
else:
|
||||
failures.append(f"Request {i}")
|
||||
|
||||
# Print the results
|
||||
print(f"Total successful requests: {len(successes)}")
|
||||
print(f"Total failed requests: {len(failures)}")
|
||||
|
||||
if len(failures) > 0:
|
||||
print(f"First failed request: {failures[0]}")
|
||||
|
||||
# Verify we got the expected behavior (close to 60 requests allowed)
|
||||
# Allow small variations due to timing
|
||||
assert 58 <= len(successes) <= 62, f"Expected ~60 successful requests, got {len(successes)}"
|
||||
assert len(failures) > 0, "Expected at least some rate-limited requests"
|
||||
|
||||
# Verify that roughly the first 60 requests succeeded
|
||||
if len(failures) > 0:
|
||||
first_failure_idx = int(failures[0].split()[1])
|
||||
print(f"First failure occurred at request index: {first_failure_idx}")
|
||||
assert 58 <= first_failure_idx <= 62, (
|
||||
f"Expected first failure around request #60, got #{first_failure_idx}"
|
||||
)
|
||||
|
||||
# Calculate requests per minute
|
||||
rate_per_minute = len(successes)
|
||||
print(f"Rate: {rate_per_minute} requests per minute")
|
||||
|
||||
# Verify the rate is close to our target of 60 requests per minute
|
||||
assert 58 <= rate_per_minute <= 62, (
|
||||
f"Expected rate of ~60 requests per minute, got {rate_per_minute}"
|
||||
)
|
||||
55
cognee/tests/unit/infrastructure/mock_embedding_engine.py
Normal file
55
cognee/tests/unit/infrastructure/mock_embedding_engine.py
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
import asyncio
|
||||
from typing import List
|
||||
|
||||
from cognee.infrastructure.databases.vector.embeddings.LiteLLMEmbeddingEngine import (
|
||||
LiteLLMEmbeddingEngine,
|
||||
)
|
||||
from cognee.infrastructure.llm.embedding_rate_limiter import (
|
||||
embedding_rate_limit_async,
|
||||
embedding_sleep_and_retry_async,
|
||||
)
|
||||
|
||||
|
||||
class MockEmbeddingEngine(LiteLLMEmbeddingEngine):
|
||||
"""
|
||||
Mock version of LiteLLMEmbeddingEngine that returns fixed embeddings
|
||||
and can be configured to simulate rate limiting and failures.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.mock = True
|
||||
self.fail_every_n_requests = 0
|
||||
self.request_count = 0
|
||||
self.add_delay = 0
|
||||
|
||||
def configure_mock(self, fail_every_n_requests=0, add_delay=0):
|
||||
"""
|
||||
Configure the mock's behavior
|
||||
|
||||
Args:
|
||||
fail_every_n_requests: Raise an exception every n requests (0 = never fail)
|
||||
add_delay: Add artificial delay in seconds to each request
|
||||
"""
|
||||
self.fail_every_n_requests = fail_every_n_requests
|
||||
self.add_delay = add_delay
|
||||
|
||||
@embedding_sleep_and_retry_async()
|
||||
@embedding_rate_limit_async
|
||||
async def embed_text(self, text: List[str]) -> List[List[float]]:
|
||||
"""
|
||||
Mock implementation that returns fixed embeddings and can
|
||||
simulate failures and delays based on configuration.
|
||||
"""
|
||||
self.request_count += 1
|
||||
|
||||
# Simulate processing delay if configured
|
||||
if self.add_delay > 0:
|
||||
await asyncio.sleep(self.add_delay)
|
||||
|
||||
# Simulate failures if configured
|
||||
if self.fail_every_n_requests > 0 and self.request_count % self.fail_every_n_requests == 0:
|
||||
raise Exception(f"Mock failure on request #{self.request_count}")
|
||||
|
||||
# Return mock embeddings of the correct dimension
|
||||
return [[0.1] * self.dimensions for _ in text]
|
||||
|
|
@ -0,0 +1,201 @@
|
|||
import os
|
||||
import time
|
||||
import asyncio
|
||||
from functools import lru_cache
|
||||
import logging
|
||||
|
||||
from cognee.infrastructure.llm.config import LLMConfig, get_llm_config
|
||||
from cognee.infrastructure.llm.embedding_rate_limiter import EmbeddingRateLimiter
|
||||
from cognee.infrastructure.databases.vector.embeddings.LiteLLMEmbeddingEngine import (
|
||||
LiteLLMEmbeddingEngine,
|
||||
)
|
||||
from cognee.tests.unit.infrastructure.mock_embedding_engine import MockEmbeddingEngine
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def test_embedding_rate_limiting_realistic():
|
||||
"""
|
||||
Test the embedding rate limiting feature with a realistic scenario:
|
||||
- Set limit to 3 requests per 5 seconds
|
||||
- Send requests in bursts with waiting periods
|
||||
- Track successful and rate-limited requests
|
||||
- Verify the rate limiter's behavior
|
||||
"""
|
||||
# Set up environment variables for rate limiting
|
||||
os.environ["EMBEDDING_RATE_LIMIT_ENABLED"] = "true"
|
||||
os.environ["EMBEDDING_RATE_LIMIT_REQUESTS"] = "3" # Only 3 requests per interval
|
||||
os.environ["EMBEDDING_RATE_LIMIT_INTERVAL"] = "5"
|
||||
os.environ["MOCK_EMBEDDING"] = "true" # Use mock embeddings for testing
|
||||
os.environ["DISABLE_RETRIES"] = "true" # Disable automatic retries for testing
|
||||
|
||||
# Clear the config and rate limiter caches to ensure our settings are applied
|
||||
get_llm_config.cache_clear()
|
||||
EmbeddingRateLimiter.reset_instance()
|
||||
|
||||
# Create a fresh config instance and verify settings
|
||||
config = get_llm_config()
|
||||
logger.info(f"Embedding Rate Limiting Enabled: {config.embedding_rate_limit_enabled}")
|
||||
logger.info(
|
||||
f"Embedding Rate Limit: {config.embedding_rate_limit_requests} requests per {config.embedding_rate_limit_interval} seconds"
|
||||
)
|
||||
|
||||
# Create a mock embedding engine
|
||||
engine = MockEmbeddingEngine()
|
||||
# Configure some delay to simulate realistic API calls but not too long
|
||||
engine.configure_mock(add_delay=0.1)
|
||||
|
||||
# Track overall statistics
|
||||
total_requests = 0
|
||||
total_successes = 0
|
||||
total_rate_limited = 0
|
||||
|
||||
# Create a list of tasks to simulate concurrent requests
|
||||
async def make_request(i):
|
||||
nonlocal total_successes, total_rate_limited
|
||||
try:
|
||||
logger.info(f"Making request #{i + 1}")
|
||||
text = f"Concurrent - Text {i}"
|
||||
embedding = await engine.embed_text([text])
|
||||
logger.info(f"Request #{i + 1} succeeded with embedding size: {len(embedding[0])}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.info(f"Request #{i + 1} rate limited: {e}")
|
||||
return False
|
||||
|
||||
# Batch 1: Send 10 concurrent requests (expect 3 to succeed, 7 to be rate limited)
|
||||
batch_size = 10
|
||||
logger.info(f"\n--- Batch 1: Sending {batch_size} concurrent requests ---")
|
||||
|
||||
batch_start = time.time()
|
||||
tasks = [make_request(i) for i in range(batch_size)]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=False)
|
||||
|
||||
batch_successes = results.count(True)
|
||||
batch_rate_limited = results.count(False)
|
||||
|
||||
batch_end = time.time()
|
||||
logger.info(f"Batch 1 completed in {batch_end - batch_start:.2f} seconds")
|
||||
logger.info(f"Successes: {batch_successes}, Rate limited: {batch_rate_limited}")
|
||||
|
||||
total_requests += batch_size
|
||||
total_successes += batch_successes
|
||||
total_rate_limited += batch_rate_limited
|
||||
|
||||
# Wait 2 seconds (should recover some capacity but not all)
|
||||
wait_time = 2
|
||||
logger.info(f"\nWaiting {wait_time} seconds to allow partial capacity recovery...")
|
||||
await asyncio.sleep(wait_time)
|
||||
|
||||
# Batch 2: Send 5 more requests (expect some to succeed, some to be rate limited)
|
||||
batch_size = 5
|
||||
logger.info(f"\n--- Batch 2: Sending {batch_size} concurrent requests ---")
|
||||
|
||||
batch_start = time.time()
|
||||
tasks = [make_request(i) for i in range(batch_size)]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=False)
|
||||
|
||||
batch_successes = results.count(True)
|
||||
batch_rate_limited = results.count(False)
|
||||
|
||||
batch_end = time.time()
|
||||
logger.info(f"Batch 2 completed in {batch_end - batch_start:.2f} seconds")
|
||||
logger.info(f"Successes: {batch_successes}, Rate limited: {batch_rate_limited}")
|
||||
|
||||
total_requests += batch_size
|
||||
total_successes += batch_successes
|
||||
total_rate_limited += batch_rate_limited
|
||||
|
||||
# Wait 5 seconds (should recover full capacity)
|
||||
wait_time = 5
|
||||
logger.info(f"\nWaiting {wait_time} seconds to allow full capacity recovery...")
|
||||
await asyncio.sleep(wait_time)
|
||||
|
||||
# Batch 3: Send 3 more requests sequentially (all should succeed)
|
||||
batch_size = 3
|
||||
logger.info(f"\n--- Batch 3: Sending {batch_size} sequential requests ---")
|
||||
|
||||
batch_start = time.time()
|
||||
batch_successes = 0
|
||||
batch_rate_limited = 0
|
||||
|
||||
for i in range(batch_size):
|
||||
try:
|
||||
logger.info(f"Making request #{i + 1}")
|
||||
text = f"Sequential - Text {i}"
|
||||
embedding = await engine.embed_text([text])
|
||||
logger.info(f"Request #{i + 1} succeeded with embedding size: {len(embedding[0])}")
|
||||
batch_successes += 1
|
||||
except Exception as e:
|
||||
logger.info(f"Request #{i + 1} rate limited: {e}")
|
||||
batch_rate_limited += 1
|
||||
|
||||
batch_end = time.time()
|
||||
logger.info(f"Batch 3 completed in {batch_end - batch_start:.2f} seconds")
|
||||
logger.info(f"Successes: {batch_successes}, Rate limited: {batch_rate_limited}")
|
||||
|
||||
total_requests += batch_size
|
||||
total_successes += batch_successes
|
||||
total_rate_limited += batch_rate_limited
|
||||
|
||||
# Log overall results
|
||||
logger.info("\n--- Test Summary ---")
|
||||
logger.info(f"Total requests: {total_requests}")
|
||||
logger.info(f"Total successes: {total_successes}")
|
||||
logger.info(f"Total rate limited: {total_rate_limited}")
|
||||
|
||||
# Verify the behavior
|
||||
assert total_successes > 0, "Expected some successful requests"
|
||||
assert total_rate_limited > 0, "Expected some rate limited requests"
|
||||
|
||||
# Reset environment variables
|
||||
os.environ.pop("EMBEDDING_RATE_LIMIT_ENABLED", None)
|
||||
os.environ.pop("EMBEDDING_RATE_LIMIT_REQUESTS", None)
|
||||
os.environ.pop("EMBEDDING_RATE_LIMIT_INTERVAL", None)
|
||||
os.environ.pop("MOCK_EMBEDDING", None)
|
||||
os.environ.pop("DISABLE_RETRIES", None)
|
||||
|
||||
|
||||
async def test_with_mock_failures():
|
||||
"""
|
||||
Test with the mock engine's ability to generate controlled failures.
|
||||
"""
|
||||
# Setup rate limiting (more permissive settings)
|
||||
os.environ["EMBEDDING_RATE_LIMIT_ENABLED"] = "true"
|
||||
os.environ["EMBEDDING_RATE_LIMIT_REQUESTS"] = "10"
|
||||
os.environ["EMBEDDING_RATE_LIMIT_INTERVAL"] = "5"
|
||||
os.environ["DISABLE_RETRIES"] = "true"
|
||||
|
||||
# Clear caches
|
||||
get_llm_config.cache_clear()
|
||||
EmbeddingRateLimiter.reset_instance()
|
||||
|
||||
# Create a mock engine configured to fail every 3rd request
|
||||
engine = MockEmbeddingEngine()
|
||||
engine.configure_mock(fail_every_n_requests=3, add_delay=0.1)
|
||||
|
||||
logger.info("\n--- Testing controlled failures with mock ---")
|
||||
|
||||
# Send 10 requests, expecting every 3rd to fail
|
||||
for i in range(10):
|
||||
try:
|
||||
logger.info(f"Making request #{i + 1}")
|
||||
text = f"Test text {i}"
|
||||
embedding = await engine.embed_text([text])
|
||||
|
||||
logger.info(f"Request #{i + 1} succeeded for {str(embedding)}")
|
||||
except Exception as e:
|
||||
logger.info(f"Request #{i + 1} failed as expected: {e}")
|
||||
|
||||
# Reset environment variables
|
||||
os.environ.pop("EMBEDDING_RATE_LIMIT_ENABLED", None)
|
||||
os.environ.pop("EMBEDDING_RATE_LIMIT_REQUESTS", None)
|
||||
os.environ.pop("EMBEDDING_RATE_LIMIT_INTERVAL", None)
|
||||
os.environ.pop("DISABLE_RETRIES", None)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test_embedding_rate_limiting_realistic())
|
||||
asyncio.run(test_with_mock_failures())
|
||||
166
cognee/tests/unit/infrastructure/test_rate_limiting_realistic.py
Normal file
166
cognee/tests/unit/infrastructure/test_rate_limiting_realistic.py
Normal file
|
|
@ -0,0 +1,166 @@
|
|||
import asyncio
|
||||
import time
|
||||
import os
|
||||
from functools import lru_cache
|
||||
from unittest.mock import patch
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.infrastructure.llm.rate_limiter import llm_rate_limiter
|
||||
from cognee.infrastructure.llm.config import get_llm_config
|
||||
|
||||
|
||||
async def test_rate_limiting_realistic():
|
||||
"""
|
||||
Test the rate limiting feature with a smaller limit to demonstrate
|
||||
how rate limiting works in practice.
|
||||
"""
|
||||
print("\n=== Testing Rate Limiting Feature (Realistic Test) ===")
|
||||
|
||||
# Configure a lower rate limit for faster testing: 5 requests per 10 seconds
|
||||
os.environ["LLM_RATE_LIMIT_ENABLED"] = "true"
|
||||
os.environ["LLM_RATE_LIMIT_REQUESTS"] = "5"
|
||||
os.environ["LLM_RATE_LIMIT_INTERVAL"] = "10"
|
||||
|
||||
# Clear the cached config and limiter
|
||||
get_llm_config.cache_clear()
|
||||
llm_rate_limiter._instance = None
|
||||
|
||||
# Create fresh instances
|
||||
config = get_llm_config()
|
||||
print(
|
||||
f"Rate limit settings: {config.llm_rate_limit_enabled=}, {config.llm_rate_limit_requests=}, {config.llm_rate_limit_interval=}"
|
||||
)
|
||||
|
||||
# We'll use monkey patching to guarantee rate limiting for test purposes
|
||||
with patch.object(llm_rate_limiter, "hit_limit") as mock_hit_limit:
|
||||
# Setup mock behavior: first 5 calls succeed, then one fails, then all succeed
|
||||
# This simulates the window moving after waiting
|
||||
mock_hit_limit.side_effect = [
|
||||
# First batch: 5 allowed, then 5 limited
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
# Second batch after waiting: 2 allowed (partial window)
|
||||
True,
|
||||
True,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
# Third batch after full window reset: all 5 allowed
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
]
|
||||
|
||||
limiter = llm_rate_limiter()
|
||||
print(f"Rate limiter initialized with {limiter._rate_per_minute} requests per minute")
|
||||
|
||||
# First batch - should allow 5 and limit the rest
|
||||
print("\nBatch 1: Sending 10 requests (expecting only 5 to succeed)...")
|
||||
batch1_success = []
|
||||
batch1_failure = []
|
||||
|
||||
for i in range(10):
|
||||
result = limiter.hit_limit()
|
||||
if result:
|
||||
batch1_success.append(i)
|
||||
print(f"✓ Request {i}: Success")
|
||||
else:
|
||||
batch1_failure.append(i)
|
||||
print(f"✗ Request {i}: Rate limited")
|
||||
|
||||
print(f"Batch 1 results: {len(batch1_success)} successes, {len(batch1_failure)} failures")
|
||||
|
||||
if len(batch1_failure) > 0:
|
||||
print(f"First rate-limited request: #{batch1_failure[0]}")
|
||||
|
||||
# Wait for window to partially reset
|
||||
wait_time = 5 # seconds - half the rate limit interval
|
||||
print(f"\nWaiting for {wait_time} seconds to allow capacity to partially regenerate...")
|
||||
await asyncio.sleep(wait_time)
|
||||
|
||||
# Second batch - should get some capacity back
|
||||
print("\nBatch 2: Sending 5 more requests (expecting 2 to succeed)...")
|
||||
batch2_success = []
|
||||
batch2_failure = []
|
||||
|
||||
for i in range(5):
|
||||
result = limiter.hit_limit()
|
||||
if result:
|
||||
batch2_success.append(i)
|
||||
print(f"✓ Request {i}: Success")
|
||||
else:
|
||||
batch2_failure.append(i)
|
||||
print(f"✗ Request {i}: Rate limited")
|
||||
|
||||
print(f"Batch 2 results: {len(batch2_success)} successes, {len(batch2_failure)} failures")
|
||||
|
||||
# Wait for full window to reset
|
||||
full_wait = 10 # seconds - full rate limit interval
|
||||
print(f"\nWaiting for {full_wait} seconds for full capacity to regenerate...")
|
||||
await asyncio.sleep(full_wait)
|
||||
|
||||
# Third batch - should have full capacity again
|
||||
print("\nBatch 3: Sending 5 requests (expecting all to succeed)...")
|
||||
batch3_success = []
|
||||
batch3_failure = []
|
||||
|
||||
for i in range(5):
|
||||
result = limiter.hit_limit()
|
||||
if result:
|
||||
batch3_success.append(i)
|
||||
print(f"✓ Request {i}: Success")
|
||||
else:
|
||||
batch3_failure.append(i)
|
||||
print(f"✗ Request {i}: Rate limited")
|
||||
|
||||
print(f"Batch 3 results: {len(batch3_success)} successes, {len(batch3_failure)} failures")
|
||||
|
||||
# Calculate total successes and failures
|
||||
total_success = len(batch1_success) + len(batch2_success) + len(batch3_success)
|
||||
total_failure = len(batch1_failure) + len(batch2_failure) + len(batch3_failure)
|
||||
|
||||
print(f"\nTotal requests: {total_success + total_failure}")
|
||||
print(f"Total successful: {total_success}")
|
||||
print(f"Total rate limited: {total_failure}")
|
||||
|
||||
# Verify the rate limiting behavior
|
||||
if len(batch1_success) == 5 and len(batch1_failure) == 5:
|
||||
print("\n✅ PASS: Rate limiter correctly limited first batch to 5 requests")
|
||||
else:
|
||||
print(f"\n❌ FAIL: First batch should allow 5 requests, got {len(batch1_success)}")
|
||||
|
||||
if len(batch2_success) == 2 and len(batch2_failure) == 3:
|
||||
print("✅ PASS: Rate limiter correctly allowed 2 requests after partial wait")
|
||||
else:
|
||||
print(f"❌ FAIL: Second batch should allow 2 requests, got {len(batch2_success)}")
|
||||
|
||||
if len(batch3_success) == 5 and len(batch3_failure) == 0:
|
||||
print("✅ PASS: Rate limiter correctly allowed all requests after full window expired")
|
||||
else:
|
||||
print(f"❌ FAIL: Third batch should allow all 5 requests, got {len(batch3_success)}")
|
||||
|
||||
print("=== Rate Limiting Test Complete ===\n")
|
||||
|
||||
|
||||
async def main():
|
||||
"""Run the realistic rate limiting test."""
|
||||
await test_rate_limiting_realistic()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger = get_logger()
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(main())
|
||||
finally:
|
||||
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||
196
cognee/tests/unit/infrastructure/test_rate_limiting_retry.py
Normal file
196
cognee/tests/unit/infrastructure/test_rate_limiting_retry.py
Normal file
|
|
@ -0,0 +1,196 @@
|
|||
import asyncio
|
||||
import time
|
||||
import os
|
||||
from unittest.mock import patch, MagicMock
|
||||
from functools import lru_cache
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.infrastructure.llm.rate_limiter import (
|
||||
sleep_and_retry_sync,
|
||||
sleep_and_retry_async,
|
||||
is_rate_limit_error,
|
||||
)
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
# Test function to be decorated
|
||||
@sleep_and_retry_sync(max_retries=3, initial_backoff=0.1, backoff_factor=2.0)
|
||||
def test_function_sync():
|
||||
"""A test function that raises rate limit errors a few times, then succeeds."""
|
||||
if hasattr(test_function_sync, "counter"):
|
||||
test_function_sync.counter += 1
|
||||
else:
|
||||
test_function_sync.counter = 1
|
||||
|
||||
if test_function_sync.counter <= 2:
|
||||
error_msg = "429 Too Many Requests: Rate limit exceeded"
|
||||
logger.info(f"Attempt {test_function_sync.counter}: Raising rate limit error")
|
||||
raise Exception(error_msg)
|
||||
|
||||
logger.info(f"Attempt {test_function_sync.counter}: Success!")
|
||||
return f"Success on attempt {test_function_sync.counter}"
|
||||
|
||||
|
||||
# Test async function to be decorated
|
||||
@sleep_and_retry_async(max_retries=3, initial_backoff=0.1, backoff_factor=2.0)
|
||||
async def test_function_async():
|
||||
"""An async test function that raises rate limit errors a few times, then succeeds."""
|
||||
if hasattr(test_function_async, "counter"):
|
||||
test_function_async.counter += 1
|
||||
else:
|
||||
test_function_async.counter = 1
|
||||
|
||||
if test_function_async.counter <= 2:
|
||||
error_msg = "429 Too Many Requests: Rate limit exceeded"
|
||||
logger.info(f"Attempt {test_function_async.counter}: Raising rate limit error")
|
||||
raise Exception(error_msg)
|
||||
|
||||
logger.info(f"Attempt {test_function_async.counter}: Success!")
|
||||
return f"Success on attempt {test_function_async.counter}"
|
||||
|
||||
|
||||
def test_is_rate_limit_error():
|
||||
"""Test the rate limit error detection function."""
|
||||
print("\n=== Testing Rate Limit Error Detection ===")
|
||||
|
||||
# Test various error messages that should be detected as rate limit errors
|
||||
rate_limit_errors = [
|
||||
"429 Rate limit exceeded",
|
||||
"Too many requests",
|
||||
"rate_limit_exceeded",
|
||||
"ratelimit error",
|
||||
"You have exceeded your quota",
|
||||
"capacity has been exceeded",
|
||||
"Service throttled",
|
||||
]
|
||||
|
||||
# Test error messages that should not be detected as rate limit errors
|
||||
non_rate_limit_errors = [
|
||||
"404 Not Found",
|
||||
"500 Internal Server Error",
|
||||
"Invalid API Key",
|
||||
"Bad Request",
|
||||
]
|
||||
|
||||
# Check that rate limit errors are correctly identified
|
||||
for error in rate_limit_errors:
|
||||
error_obj = Exception(error)
|
||||
result = is_rate_limit_error(error_obj)
|
||||
print(f"Error '{error}': {'✓' if result else '✗'} {result}")
|
||||
assert result, f"Failed to identify rate limit error: {error}"
|
||||
print(f"✓ Correctly identified as rate limit error: {error}")
|
||||
|
||||
# Check that non-rate limit errors are not misidentified
|
||||
for error in non_rate_limit_errors:
|
||||
error_obj = Exception(error)
|
||||
result = is_rate_limit_error(error_obj)
|
||||
print(f"Error '{error}': {'✓' if not result else '✗'} {not result}")
|
||||
assert not result, f"Incorrectly identified as rate limit error: {error}"
|
||||
print(f"✓ Correctly identified as non-rate limit error: {error}")
|
||||
|
||||
print("✅ PASS: Rate limit error detection is working correctly")
|
||||
|
||||
|
||||
def test_sync_retry():
|
||||
"""Test the synchronous retry decorator."""
|
||||
print("\n=== Testing Synchronous Sleep and Retry ===")
|
||||
|
||||
# Reset counter for the test function
|
||||
if hasattr(test_function_sync, "counter"):
|
||||
del test_function_sync.counter
|
||||
|
||||
# Time the execution to verify backoff is working
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
result = test_function_sync()
|
||||
end_time = time.time()
|
||||
elapsed = end_time - start_time
|
||||
|
||||
# Verify results
|
||||
print(f"Result: {result}")
|
||||
print(f"Test completed in {elapsed:.2f} seconds")
|
||||
print(f"Number of attempts: {test_function_sync.counter}")
|
||||
|
||||
# The function should succeed on the 3rd attempt (after 2 failures)
|
||||
assert test_function_sync.counter == 3, (
|
||||
f"Expected 3 attempts, got {test_function_sync.counter}"
|
||||
)
|
||||
assert elapsed >= 0.3, f"Expected at least 0.3 seconds of backoff, got {elapsed:.2f}"
|
||||
|
||||
print("✅ PASS: Synchronous retry mechanism is working correctly")
|
||||
except Exception as e:
|
||||
print(f"❌ FAIL: Test encountered an unexpected error: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
async def test_async_retry():
|
||||
"""Test the asynchronous retry decorator."""
|
||||
print("\n=== Testing Asynchronous Sleep and Retry ===")
|
||||
|
||||
# Reset counter for the test function
|
||||
if hasattr(test_function_async, "counter"):
|
||||
del test_function_async.counter
|
||||
|
||||
# Time the execution to verify backoff is working
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
result = await test_function_async()
|
||||
end_time = time.time()
|
||||
elapsed = end_time - start_time
|
||||
|
||||
# Verify results
|
||||
print(f"Result: {result}")
|
||||
print(f"Test completed in {elapsed:.2f} seconds")
|
||||
print(f"Number of attempts: {test_function_async.counter}")
|
||||
|
||||
# The function should succeed on the 3rd attempt (after 2 failures)
|
||||
assert test_function_async.counter == 3, (
|
||||
f"Expected 3 attempts, got {test_function_async.counter}"
|
||||
)
|
||||
assert elapsed >= 0.3, f"Expected at least 0.3 seconds of backoff, got {elapsed:.2f}"
|
||||
|
||||
print("✅ PASS: Asynchronous retry mechanism is working correctly")
|
||||
except Exception as e:
|
||||
print(f"❌ FAIL: Test encountered an unexpected error: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
async def test_retry_max_exceeded():
|
||||
"""Test what happens when max retries is exceeded."""
|
||||
print("\n=== Testing Max Retries Exceeded ===")
|
||||
|
||||
@sleep_and_retry_async(max_retries=2, initial_backoff=0.1)
|
||||
async def always_fails():
|
||||
"""A function that always raises a rate limit error."""
|
||||
error_msg = "429 Too Many Requests: Rate limit always exceeded"
|
||||
logger.info(f"Always fails with: {error_msg}")
|
||||
raise Exception(error_msg)
|
||||
|
||||
try:
|
||||
# This should fail after 2 retries (3 attempts total)
|
||||
await always_fails()
|
||||
print("❌ FAIL: Function should have failed but succeeded")
|
||||
except Exception as e:
|
||||
print(f"Expected error after max retries: {str(e)}")
|
||||
print("✅ PASS: Function correctly failed after max retries exceeded")
|
||||
|
||||
|
||||
async def main():
|
||||
"""Run all the retry tests."""
|
||||
test_is_rate_limit_error()
|
||||
test_sync_retry()
|
||||
await test_async_retry()
|
||||
await test_retry_max_exceeded()
|
||||
|
||||
print("\n=== All Rate Limiting Retry Tests Complete ===")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(main())
|
||||
finally:
|
||||
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||
3002
poetry.lock
generated
3002
poetry.lock
generated
File diff suppressed because it is too large
Load diff
|
|
@ -48,6 +48,7 @@ lancedb = "0.16.0"
|
|||
alembic = "^1.13.3"
|
||||
pre-commit = "^4.0.1"
|
||||
scikit-learn = "^1.6.1"
|
||||
limits = "^4.4.1"
|
||||
fastapi = {version = "0.115.7"}
|
||||
fastapi-users = {version = "14.0.0", extras = ["sqlalchemy"]}
|
||||
uvicorn = {version = "0.34.0", optional = true}
|
||||
|
|
|
|||
|
|
@ -1,7 +0,0 @@
|
|||
def test_import_cognee():
|
||||
try:
|
||||
import cognee
|
||||
|
||||
assert True # Pass the test if no error occurs
|
||||
except ImportError as e:
|
||||
assert False, f"Failed to import cognee: {e}"
|
||||
Loading…
Add table
Reference in a new issue