feat: Adding rate limiting (#709)

<!-- .github/pull_request_template.md -->

## Description
<!-- Provide a clear description of the changes in this PR -->

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.

---------

Co-authored-by: Igor Ilic <30923996+dexters1@users.noreply.github.com>
This commit is contained in:
Vasilije 2025-04-16 12:03:46 +02:00 committed by GitHub
parent d1eab97102
commit 4e9ca94e78
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 3191 additions and 1482 deletions

View file

@ -4,12 +4,12 @@ on:
workflow_call:
jobs:
run-main-notebook:
name: Main Notebook Test
uses: ./.github/workflows/reusable_notebook.yml
with:
notebook-location: notebooks/cognee_demo.ipynb
secrets: inherit
# run-main-notebook:
# name: Main Notebook Test
# uses: ./.github/workflows/reusable_notebook.yml
# with:
# notebook-location: notebooks/cognee_demo.ipynb
# secrets: inherit
run-llama-index-integration:
name: LlamaIndex Integration Notebook
@ -32,9 +32,9 @@ jobs:
notebook-location: notebooks/cognee_multimedia_demo.ipynb
secrets: inherit
run-graphrag-vs-rag:
name: Graphrag vs Rag notebook
uses: ./.github/workflows/reusable_notebook.yml
with:
notebook-location: notebooks/graphrag_vs_rag.ipynb
secrets: inherit
# run-graphrag-vs-rag:
# name: Graphrag vs Rag notebook
# uses: ./.github/workflows/reusable_notebook.yml
# with:
# notebook-location: notebooks/graphrag_vs_rag.ipynb
# secrets: inherit

View file

@ -10,6 +10,10 @@ from cognee.infrastructure.llm.tokenizer.Gemini import GeminiTokenizer
from cognee.infrastructure.llm.tokenizer.HuggingFace import HuggingFaceTokenizer
from cognee.infrastructure.llm.tokenizer.Mistral import MistralTokenizer
from cognee.infrastructure.llm.tokenizer.TikToken import TikTokenTokenizer
from cognee.infrastructure.llm.embedding_rate_limiter import (
embedding_rate_limit_async,
embedding_sleep_and_retry_async,
)
litellm.set_verbose = False
logger = get_logger("LiteLLMEmbeddingEngine")
@ -51,17 +55,12 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
enable_mocking = str(enable_mocking).lower()
self.mock = enable_mocking in ("true", "1", "yes")
@embedding_sleep_and_retry_async()
@embedding_rate_limit_async
async def embed_text(self, text: List[str]) -> List[List[float]]:
async def exponential_backoff(attempt):
wait_time = min(10 * (2**attempt), 60) # Max 60 seconds
await asyncio.sleep(wait_time)
try:
if self.mock:
response = {"data": [{"embedding": [0.0] * self.dimensions} for _ in text]}
self.retry_count = 0
return [data["embedding"] for data in response["data"]]
else:
response = await litellm.aembedding(
@ -72,8 +71,6 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
api_version=self.api_version,
)
self.retry_count = 0 # Reset retry count on successful call
return [data["embedding"] for data in response.data]
except litellm.exceptions.ContextWindowExceededError as error:
@ -95,15 +92,6 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
logger.error("Context window exceeded for embedding text: %s", str(error))
raise error
except litellm.exceptions.RateLimitError:
if self.retry_count >= self.MAX_RETRIES:
raise Exception("Rate limit exceeded and no more retries left.")
await exponential_backoff(self.retry_count)
self.retry_count += 1
return await self.embed_text(text)
except (
litellm.exceptions.BadRequestError,
litellm.exceptions.NotFoundError,

View file

@ -9,6 +9,10 @@ import aiohttp.http_exceptions
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
from cognee.infrastructure.databases.exceptions.EmbeddingException import EmbeddingException
from cognee.infrastructure.llm.tokenizer.HuggingFace import HuggingFaceTokenizer
from cognee.infrastructure.llm.rate_limiter import (
embedding_rate_limit_async,
embedding_sleep_and_retry_async,
)
logger = get_logger("OllamaEmbeddingEngine")
@ -43,6 +47,7 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
enable_mocking = str(enable_mocking).lower()
self.mock = enable_mocking in ("true", "1", "yes")
@embedding_rate_limit_async
async def embed_text(self, text: List[str]) -> List[List[float]]:
"""
Given a list of text prompts, returns a list of embedding vectors.
@ -53,6 +58,7 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
embeddings = await asyncio.gather(*[self._get_embedding(prompt) for prompt in text])
return embeddings
@embedding_sleep_and_retry_async()
async def _get_embedding(self, prompt: str) -> List[float]:
"""
Internal method to call the Ollama embeddings endpoint for a single prompt.
@ -66,26 +72,12 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
retries = 0
while retries < self.MAX_RETRIES:
try:
async with aiohttp.ClientSession() as session:
async with session.post(
self.endpoint, json=payload, headers=headers, timeout=60.0
) as response:
data = await response.json()
return data["embedding"]
except aiohttp.http_exceptions.HttpBadRequest as e:
logger.error(f"HTTP error on attempt {retries + 1}: {e}")
retries += 1
await asyncio.sleep(min(2**retries, 60))
except Exception as e:
logger.error(f"Error on attempt {retries + 1}: {e}")
retries += 1
await asyncio.sleep(min(2**retries, 60))
raise EmbeddingException(
f"Failed to embed text using model {self.model} after {self.MAX_RETRIES} retries"
)
async with aiohttp.ClientSession() as session:
async with session.post(
self.endpoint, json=payload, headers=headers, timeout=60.0
) as response:
data = await response.json()
return data["embedding"]
def get_vector_size(self) -> int:
return self.dimensions

View file

@ -5,6 +5,7 @@ import instructor
from cognee.exceptions import InvalidValueError
from cognee.infrastructure.llm.llm_interface import LLMInterface
from cognee.infrastructure.llm.prompts import read_query_prompt
from cognee.infrastructure.llm.rate_limiter import rate_limit_async, sleep_and_retry_async
class AnthropicAdapter(LLMInterface):
@ -23,6 +24,8 @@ class AnthropicAdapter(LLMInterface):
self.model = model
self.max_tokens = max_tokens
@sleep_and_retry_async()
@rate_limit_async
async def acreate_structured_output(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
) -> BaseModel:

View file

@ -16,6 +16,12 @@ class LLMConfig(BaseSettings):
llm_max_tokens: int = 16384
transcription_model: str = "whisper-1"
graph_prompt_path: str = "generate_graph_prompt.txt"
llm_rate_limit_enabled: bool = False
llm_rate_limit_requests: int = 60
llm_rate_limit_interval: int = 60 # in seconds (default is 60 requests per minute)
embedding_rate_limit_enabled: bool = False
embedding_rate_limit_requests: int = 60
embedding_rate_limit_interval: int = 60 # in seconds (default is 60 requests per minute)
model_config = SettingsConfigDict(env_file=".env", extra="allow")
@ -85,6 +91,12 @@ class LLMConfig(BaseSettings):
"max_tokens": self.llm_max_tokens,
"transcription_model": self.transcription_model,
"graph_prompt_path": self.graph_prompt_path,
"rate_limit_enabled": self.llm_rate_limit_enabled,
"rate_limit_requests": self.llm_rate_limit_requests,
"rate_limit_interval": self.llm_rate_limit_interval,
"embedding_rate_limit_enabled": self.embedding_rate_limit_enabled,
"embedding_rate_limit_requests": self.embedding_rate_limit_requests,
"embedding_rate_limit_interval": self.embedding_rate_limit_interval,
}

View file

@ -0,0 +1,369 @@
import threading
import logging
import functools
import os
import time
import asyncio
import random
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.llm.config import get_llm_config
logger = get_logger()
# Common error patterns that indicate rate limiting
RATE_LIMIT_ERROR_PATTERNS = [
"rate limit",
"rate_limit",
"ratelimit",
"too many requests",
"retry after",
"capacity",
"quota",
"limit exceeded",
"tps limit exceeded",
"request limit exceeded",
"maximum requests",
"exceeded your current quota",
"throttled",
"throttling",
]
# Default retry settings
DEFAULT_MAX_RETRIES = 5
DEFAULT_INITIAL_BACKOFF = 1.0 # seconds
DEFAULT_BACKOFF_FACTOR = 2.0 # exponential backoff multiplier
DEFAULT_JITTER = 0.1 # 10% jitter to avoid thundering herd
class EmbeddingRateLimiter:
"""
Rate limiter for embedding API calls.
This class implements a singleton pattern to ensure that rate limiting
is consistent across all embedding requests. It uses the limits
library with a moving window strategy to control request rates.
The rate limiter uses the same configuration as the LLM API rate limiter
but uses a separate key to track embedding API calls independently.
"""
_instance = None
lock = threading.Lock()
@classmethod
def get_instance(cls):
if cls._instance is None:
with cls.lock:
if cls._instance is None:
cls._instance = cls()
return cls._instance
@classmethod
def reset_instance(cls):
with cls.lock:
cls._instance = None
def __init__(self):
config = get_llm_config()
self.enabled = config.embedding_rate_limit_enabled
self.requests_limit = config.embedding_rate_limit_requests
self.interval_seconds = config.embedding_rate_limit_interval
self.request_times = []
self.lock = threading.Lock()
logging.info(
f"EmbeddingRateLimiter initialized: enabled={self.enabled}, "
f"requests_limit={self.requests_limit}, interval_seconds={self.interval_seconds}"
)
def hit_limit(self) -> bool:
"""
Check if the current request would exceed the rate limit.
Returns:
bool: True if the rate limit would be exceeded, False otherwise.
"""
if not self.enabled:
return False
current_time = time.time()
with self.lock:
# Remove expired request times
cutoff_time = current_time - self.interval_seconds
self.request_times = [t for t in self.request_times if t > cutoff_time]
# Check if adding a new request would exceed the limit
if len(self.request_times) >= self.requests_limit:
logger.info(
f"Rate limit hit: {len(self.request_times)} requests in the last {self.interval_seconds} seconds"
)
return True
# Otherwise, we're under the limit
return False
def wait_if_needed(self) -> float:
"""
Block until a request can be made without exceeding the rate limit.
Returns:
float: Time waited in seconds.
"""
if not self.enabled:
return 0
wait_time = 0
start_time = time.time()
while self.hit_limit():
time.sleep(0.5) # Poll every 0.5 seconds
wait_time = time.time() - start_time
# Record this request
with self.lock:
self.request_times.append(time.time())
return wait_time
async def async_wait_if_needed(self) -> float:
"""
Asynchronously wait until a request can be made without exceeding the rate limit.
Returns:
float: Time waited in seconds.
"""
if not self.enabled:
return 0
wait_time = 0
start_time = time.time()
while self.hit_limit():
await asyncio.sleep(0.5) # Poll every 0.5 seconds
wait_time = time.time() - start_time
# Record this request
with self.lock:
self.request_times.append(time.time())
return wait_time
def embedding_rate_limit_sync(func):
"""
Decorator that applies rate limiting to a synchronous embedding function.
This decorator checks if the request would exceed the rate limit,
and blocks if necessary.
Args:
func: Function to decorate.
Returns:
Decorated function that applies rate limiting.
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
limiter = EmbeddingRateLimiter.get_instance()
# Check if rate limiting is enabled and if we're at the limit
if limiter.hit_limit():
error_msg = "Embedding API rate limit exceeded"
logger.warning(error_msg)
# Create a custom embedding rate limit exception
from cognee.infrastructure.databases.exceptions.EmbeddingException import (
EmbeddingException,
)
raise EmbeddingException(error_msg)
# Add this request to the counter and proceed
limiter.wait_if_needed()
return func(*args, **kwargs)
return wrapper
def embedding_rate_limit_async(func):
"""
Decorator that applies rate limiting to an asynchronous embedding function.
This decorator checks if the request would exceed the rate limit,
and waits asynchronously if necessary.
Args:
func: Async function to decorate.
Returns:
Decorated async function that applies rate limiting.
"""
@functools.wraps(func)
async def wrapper(*args, **kwargs):
limiter = EmbeddingRateLimiter.get_instance()
# Check if rate limiting is enabled and if we're at the limit
if limiter.hit_limit():
error_msg = "Embedding API rate limit exceeded"
logger.warning(error_msg)
# Create a custom embedding rate limit exception
from cognee.infrastructure.databases.exceptions.EmbeddingException import (
EmbeddingException,
)
raise EmbeddingException(error_msg)
# Add this request to the counter and proceed
await limiter.async_wait_if_needed()
return await func(*args, **kwargs)
return wrapper
def embedding_sleep_and_retry_sync(max_retries=5, base_backoff=1.0, jitter=0.5):
"""
Decorator that adds retry with exponential backoff for synchronous embedding functions.
The decorator will retry the function with exponential backoff if it
fails due to a rate limit error.
Args:
max_retries: Maximum number of retries.
base_backoff: Base backoff time in seconds.
jitter: Jitter factor to randomize backoff time.
Returns:
Decorated function that retries on rate limit errors.
"""
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
# If DISABLE_RETRIES is set, don't retry for testing purposes
disable_retries = os.environ.get("DISABLE_RETRIES", "false").lower() in (
"true",
"1",
"yes",
)
retries = 0
last_error = None
while retries <= max_retries:
try:
return func(*args, **kwargs)
except Exception as e:
# Check if this is a rate limit error
error_str = str(e).lower()
error_type = type(e).__name__
is_rate_limit = any(
pattern in error_str.lower() for pattern in RATE_LIMIT_ERROR_PATTERNS
)
if disable_retries:
# For testing, propagate the exception immediately
raise
if is_rate_limit and retries < max_retries:
# Calculate backoff with jitter
backoff = (
base_backoff * (2**retries) * (1 + random.uniform(-jitter, jitter))
)
logger.warning(
f"Embedding rate limit hit, retrying in {backoff:.2f}s "
f"(attempt {retries + 1}/{max_retries}): "
f"({error_str!r}, {error_type!r})"
)
time.sleep(backoff)
retries += 1
last_error = e
else:
# Not a rate limit error or max retries reached, raise
raise
# If we exit the loop due to max retries, raise the last error
if last_error:
raise last_error
return wrapper
return decorator
def embedding_sleep_and_retry_async(max_retries=5, base_backoff=1.0, jitter=0.5):
"""
Decorator that adds retry with exponential backoff for asynchronous embedding functions.
The decorator will retry the function with exponential backoff if it
fails due to a rate limit error.
Args:
max_retries: Maximum number of retries.
base_backoff: Base backoff time in seconds.
jitter: Jitter factor to randomize backoff time.
Returns:
Decorated async function that retries on rate limit errors.
"""
def decorator(func):
@functools.wraps(func)
async def wrapper(*args, **kwargs):
# If DISABLE_RETRIES is set, don't retry for testing purposes
disable_retries = os.environ.get("DISABLE_RETRIES", "false").lower() in (
"true",
"1",
"yes",
)
retries = 0
last_error = None
while retries <= max_retries:
try:
return await func(*args, **kwargs)
except Exception as e:
# Check if this is a rate limit error
error_str = str(e).lower()
error_type = type(e).__name__
is_rate_limit = any(
pattern in error_str.lower() for pattern in RATE_LIMIT_ERROR_PATTERNS
)
if disable_retries:
# For testing, propagate the exception immediately
raise
if is_rate_limit and retries < max_retries:
# Calculate backoff with jitter
backoff = (
base_backoff * (2**retries) * (1 + random.uniform(-jitter, jitter))
)
logger.warning(
f"Embedding rate limit hit, retrying in {backoff:.2f}s "
f"(attempt {retries + 1}/{max_retries}): "
f"({error_str!r}, {error_type!r})"
)
await asyncio.sleep(backoff)
retries += 1
last_error = e
else:
# Not a rate limit error or max retries reached, raise
raise
# If we exit the loop due to max retries, raise the last error
if last_error:
raise last_error
return wrapper
return decorator

View file

@ -7,6 +7,10 @@ from cognee.shared.data_models import MonitoringTool
from cognee.exceptions import InvalidValueError
from cognee.infrastructure.llm.llm_interface import LLMInterface
from cognee.infrastructure.llm.prompts import read_query_prompt
from cognee.infrastructure.llm.rate_limiter import (
rate_limit_async,
sleep_and_retry_async,
)
from cognee.base_config import get_base_config
logger = get_logger()
@ -37,6 +41,8 @@ class GeminiAdapter(LLMInterface):
self.max_tokens = max_tokens
@observe(as_type="generation")
@sleep_and_retry_async()
@rate_limit_async
async def acreate_structured_output(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
) -> BaseModel:

View file

@ -6,6 +6,7 @@ from pydantic import BaseModel
import instructor
from cognee.infrastructure.llm.llm_interface import LLMInterface
from cognee.infrastructure.llm.config import get_llm_config
from cognee.infrastructure.llm.rate_limiter import rate_limit_async, sleep_and_retry_async
import litellm
@ -27,6 +28,8 @@ class GenericAPIAdapter(LLMInterface):
litellm.acompletion, mode=instructor.Mode.JSON, api_key=api_key
)
@sleep_and_retry_async()
@rate_limit_async
async def acreate_structured_output(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
) -> BaseModel:

View file

@ -3,6 +3,11 @@ from pydantic import BaseModel
import instructor
from cognee.infrastructure.llm.llm_interface import LLMInterface
from cognee.infrastructure.llm.config import get_llm_config
from cognee.infrastructure.llm.rate_limiter import (
rate_limit_async,
rate_limit_sync,
sleep_and_retry_async,
)
from openai import OpenAI
import base64
import os
@ -22,6 +27,8 @@ class OllamaAPIAdapter(LLMInterface):
OpenAI(base_url=self.endpoint, api_key=self.api_key), mode=instructor.Mode.JSON
)
@sleep_and_retry_async()
@rate_limit_async
async def acreate_structured_output(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
) -> BaseModel:
@ -45,6 +52,7 @@ class OllamaAPIAdapter(LLMInterface):
return response
@rate_limit_sync
def create_transcript(self, input_file: str) -> str:
"""Generate an audio transcript from a user query."""
@ -64,6 +72,7 @@ class OllamaAPIAdapter(LLMInterface):
return transcription.text
@rate_limit_sync
def transcribe_image(self, input_file: str) -> str:
"""Transcribe content from an image using base64 encoding."""
@ -79,7 +88,7 @@ class OllamaAPIAdapter(LLMInterface):
{
"role": "user",
"content": [
{"type": "text", "text": "Whats in this image?"},
{"type": "text", "text": "What's in this image?"},
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"},

View file

@ -10,6 +10,12 @@ from cognee.shared.data_models import MonitoringTool
from cognee.exceptions import InvalidValueError
from cognee.infrastructure.llm.llm_interface import LLMInterface
from cognee.infrastructure.llm.prompts import read_query_prompt
from cognee.infrastructure.llm.rate_limiter import (
rate_limit_async,
rate_limit_sync,
sleep_and_retry_async,
sleep_and_retry_sync,
)
from cognee.base_config import get_base_config
monitoring = get_base_config().monitoring_tool
@ -49,6 +55,8 @@ class OpenAIAdapter(LLMInterface):
self.streaming = streaming
@observe(as_type="generation")
@sleep_and_retry_async()
@rate_limit_async
async def acreate_structured_output(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
) -> BaseModel:
@ -75,6 +83,8 @@ class OpenAIAdapter(LLMInterface):
)
@observe
@sleep_and_retry_sync()
@rate_limit_sync
def create_structured_output(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
) -> BaseModel:
@ -100,6 +110,7 @@ class OpenAIAdapter(LLMInterface):
max_retries=self.MAX_RETRIES,
)
@rate_limit_sync
def create_transcript(self, input):
"""Generate a audio transcript from a user query."""
@ -120,6 +131,7 @@ class OpenAIAdapter(LLMInterface):
return transcription
@rate_limit_sync
def transcribe_image(self, input) -> BaseModel:
with open(input, "rb") as image_file:
encoded_image = base64.b64encode(image_file.read()).decode("utf-8")
@ -132,7 +144,7 @@ class OpenAIAdapter(LLMInterface):
"content": [
{
"type": "text",
"text": "Whats in this image?",
"text": "What's in this image?",
},
{
"type": "image_url",

View file

@ -0,0 +1,376 @@
"""
Rate limiter for LLM API calls.
This module provides rate limiting functionality for LLM API calls to prevent exceeding
API provider rate limits. The implementation uses the `limits` library with a moving window
strategy to limit requests.
Configuration is done through the LLMConfig with these settings:
- llm_rate_limit_enabled: Whether rate limiting is enabled (default: False)
- llm_rate_limit_requests: Maximum number of requests allowed per interval (default: 60)
- llm_rate_limit_interval: Interval in seconds for the rate limiting window (default: 60)
Usage:
1. Add the decorator to any function that makes API calls:
@rate_limit_sync
def my_function():
# Function that makes API calls
2. For async functions, use the async decorator:
@rate_limit_async
async def my_async_function():
# Async function that makes API calls
3. For automatic retrying on rate limit errors:
@sleep_and_retry_sync
def my_function():
# Function that may experience rate limit errors
4. For async functions with automatic retrying:
@sleep_and_retry_async
async def my_async_function():
# Async function that may experience rate limit errors
5. For embedding rate limiting (uses the same configuration but separate limiter):
@embedding_rate_limit_async
async def my_embedding_function():
# Async function for embedding API calls
6. For embedding with auto-retry:
@embedding_sleep_and_retry_async
async def my_embedding_function():
# Async function for embedding with auto-retry
"""
import time
import asyncio
import random
from functools import wraps
from limits import RateLimitItemPerMinute, storage
from limits.strategies import MovingWindowRateLimiter
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.llm.config import get_llm_config
import threading
import logging
import functools
import openai
import os
logger = get_logger()
# Common error patterns that indicate rate limiting
RATE_LIMIT_ERROR_PATTERNS = [
"rate limit",
"rate_limit",
"ratelimit",
"too many requests",
"retry after",
"capacity",
"quota",
"limit exceeded",
"tps limit exceeded",
"request limit exceeded",
"maximum requests",
"exceeded your current quota",
"throttled",
"throttling",
]
# Default retry settings
DEFAULT_MAX_RETRIES = 5
DEFAULT_INITIAL_BACKOFF = 1.0 # seconds
DEFAULT_BACKOFF_FACTOR = 2.0 # exponential backoff multiplier
DEFAULT_JITTER = 0.1 # 10% jitter to avoid thundering herd
class llm_rate_limiter:
"""
Rate limiter for LLM API calls.
This class implements a singleton pattern to ensure that rate limiting
is consistent across all parts of the application. It uses the limits
library with a moving window strategy to control request rates.
The rate limiter converts the configured requests/interval to a per-minute
rate for compatibility with the limits library's built-in rate limit items.
"""
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super(llm_rate_limiter, cls).__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
if self._initialized:
return
config = get_llm_config()
self._enabled = config.llm_rate_limit_enabled
self._requests = config.llm_rate_limit_requests
self._interval = config.llm_rate_limit_interval
# Using in-memory storage by default
self._storage = storage.MemoryStorage()
self._limiter = MovingWindowRateLimiter(self._storage)
# Use the built-in per-minute rate limit item
# We need to adjust the number of requests if interval isn't exactly 60s
if self._interval == 60:
self._rate_per_minute = self._requests
else:
self._rate_per_minute = int(self._requests * (60 / self._interval))
self._initialized = True
if self._enabled:
logger.info(
f"LLM Rate Limiter initialized: {self._requests} requests per {self._interval}s"
)
def hit_limit(self) -> bool:
"""
Record a hit and check if limit is exceeded.
This method checks whether making a request now would exceed the
configured rate limit. If rate limiting is disabled, it always
returns True.
Returns:
bool: True if the request is allowed, False otherwise.
"""
if not self._enabled:
return True
# Create a fresh rate limit item for each check
rate_limit = RateLimitItemPerMinute(self._rate_per_minute)
# Use a consistent key for the API to ensure proper rate limiting
return self._limiter.hit(rate_limit, "llm_api")
def wait_if_needed(self) -> float:
"""
Wait if rate limit is reached.
This method blocks until the request can be made without exceeding
the rate limit. It polls every 0.5 seconds.
Returns:
float: Time waited in seconds.
"""
if not self._enabled:
return 0
waited = 0
while not self.hit_limit():
time.sleep(0.5)
waited += 0.5
return waited
async def async_wait_if_needed(self) -> float:
"""
Async wait if rate limit is reached.
This method asynchronously waits until the request can be made without
exceeding the rate limit. It polls every 0.5 seconds.
Returns:
float: Time waited in seconds.
"""
if not self._enabled:
return 0
waited = 0
while not self.hit_limit():
await asyncio.sleep(0.5)
waited += 0.5
return waited
def rate_limit_sync(func):
"""
Decorator for rate limiting synchronous functions.
This decorator ensures that the decorated function respects the
configured rate limits. If the rate limit would be exceeded,
the decorator blocks until the request can be made.
Args:
func: The synchronous function to decorate.
Returns:
The decorated function.
"""
@wraps(func)
def wrapper(*args, **kwargs):
limiter = llm_rate_limiter()
waited = limiter.wait_if_needed()
if waited > 0:
logger.debug(f"Rate limited LLM API call, waited for {waited}s")
return func(*args, **kwargs)
return wrapper
def rate_limit_async(func):
"""
Decorator for rate limiting asynchronous functions.
This decorator ensures that the decorated async function respects the
configured rate limits. If the rate limit would be exceeded,
the decorator asynchronously waits until the request can be made.
Args:
func: The asynchronous function to decorate.
Returns:
The decorated async function.
"""
@wraps(func)
async def wrapper(*args, **kwargs):
limiter = llm_rate_limiter()
waited = await limiter.async_wait_if_needed()
if waited > 0:
logger.debug(f"Rate limited LLM API call, waited for {waited}s")
return await func(*args, **kwargs)
return wrapper
def is_rate_limit_error(error):
"""
Check if an error is related to rate limiting.
Args:
error: The exception to check.
Returns:
bool: True if the error is rate-limit related, False otherwise.
"""
error_str = str(error).lower()
return any(pattern.lower() in error_str for pattern in RATE_LIMIT_ERROR_PATTERNS)
def calculate_backoff(
attempt,
initial_backoff=DEFAULT_INITIAL_BACKOFF,
backoff_factor=DEFAULT_BACKOFF_FACTOR,
jitter=DEFAULT_JITTER,
):
"""
Calculate the backoff time for a retry attempt with jitter.
Args:
attempt: The current retry attempt (0-based).
initial_backoff: The initial backoff time in seconds.
backoff_factor: The multiplier for exponential backoff.
jitter: The jitter factor to avoid thundering herd.
Returns:
float: The backoff time in seconds.
"""
backoff = initial_backoff * (backoff_factor**attempt)
jitter_amount = backoff * jitter
return backoff + random.uniform(-jitter_amount, jitter_amount)
def sleep_and_retry_sync(
max_retries=DEFAULT_MAX_RETRIES,
initial_backoff=DEFAULT_INITIAL_BACKOFF,
backoff_factor=DEFAULT_BACKOFF_FACTOR,
jitter=DEFAULT_JITTER,
):
"""
Decorator that automatically retries a synchronous function when rate limit errors occur.
This decorator implements an exponential backoff strategy with jitter
to handle rate limit errors efficiently.
Args:
max_retries: Maximum number of retry attempts.
initial_backoff: Initial backoff time in seconds.
backoff_factor: Multiplier for exponential backoff.
jitter: Jitter factor to avoid the thundering herd problem.
Returns:
The decorated function.
"""
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
attempt = 0
while True:
try:
return func(*args, **kwargs)
except Exception as e:
attempt += 1
if not is_rate_limit_error(e) or attempt > max_retries:
raise
backoff_time = calculate_backoff(
attempt, initial_backoff, backoff_factor, jitter
)
logger.warning(
f"Rate limit hit, retrying in {backoff_time:.2f}s "
f"(attempt {attempt}/{max_retries}): {str(e)}"
)
time.sleep(backoff_time)
return wrapper
return decorator
def sleep_and_retry_async(
max_retries=DEFAULT_MAX_RETRIES,
initial_backoff=DEFAULT_INITIAL_BACKOFF,
backoff_factor=DEFAULT_BACKOFF_FACTOR,
jitter=DEFAULT_JITTER,
):
"""
Decorator that automatically retries an asynchronous function when rate limit errors occur.
This decorator implements an exponential backoff strategy with jitter
to handle rate limit errors efficiently.
Args:
max_retries: Maximum number of retry attempts.
initial_backoff: Initial backoff time in seconds.
backoff_factor: Multiplier for exponential backoff.
jitter: Jitter factor to avoid the thundering herd problem.
Returns:
The decorated async function.
"""
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
attempt = 0
while True:
try:
return await func(*args, **kwargs)
except Exception as e:
attempt += 1
if not is_rate_limit_error(e) or attempt > max_retries:
raise
backoff_time = calculate_backoff(
attempt, initial_backoff, backoff_factor, jitter
)
logger.warning(
f"Rate limit hit, retrying in {backoff_time:.2f}s "
f"(attempt {attempt}/{max_retries}): {str(e)}"
)
await asyncio.sleep(backoff_time)
return wrapper
return decorator

View file

@ -0,0 +1,171 @@
"""Tests for the LLM rate limiter."""
import pytest
import asyncio
import time
from unittest.mock import patch
from cognee.infrastructure.llm.rate_limiter import (
llm_rate_limiter,
rate_limit_async,
rate_limit_sync,
)
@pytest.fixture(autouse=True)
def reset_limiter_singleton():
"""Reset the singleton instance between tests."""
llm_rate_limiter._instance = None
yield
def test_rate_limiter_initialization():
"""Test that the rate limiter can be initialized properly."""
with patch("cognee.infrastructure.llm.rate_limiter.get_llm_config") as mock_config:
mock_config.return_value.llm_rate_limit_enabled = True
mock_config.return_value.llm_rate_limit_requests = 10
mock_config.return_value.llm_rate_limit_interval = 60 # 1 minute
limiter = llm_rate_limiter()
assert limiter._enabled is True
assert limiter._requests == 10
assert limiter._interval == 60
def test_rate_limiter_disabled():
"""Test that the rate limiter is disabled by default."""
with patch("cognee.infrastructure.llm.rate_limiter.get_llm_config") as mock_config:
mock_config.return_value.llm_rate_limit_enabled = False
limiter = llm_rate_limiter()
assert limiter._enabled is False
assert limiter.hit_limit() is True # Should always return True when disabled
def test_rate_limiter_singleton():
"""Test that the rate limiter is a singleton."""
with patch("cognee.infrastructure.llm.rate_limiter.get_llm_config") as mock_config:
mock_config.return_value.llm_rate_limit_enabled = True
mock_config.return_value.llm_rate_limit_requests = 5
mock_config.return_value.llm_rate_limit_interval = 60
limiter1 = llm_rate_limiter()
limiter2 = llm_rate_limiter()
assert limiter1 is limiter2
def test_sync_decorator():
"""Test the sync decorator."""
with patch("cognee.infrastructure.llm.rate_limiter.llm_rate_limiter") as mock_limiter_class:
mock_limiter = mock_limiter_class.return_value
mock_limiter.wait_if_needed.return_value = 0
@rate_limit_sync
def test_func():
return "success"
result = test_func()
assert result == "success"
mock_limiter.wait_if_needed.assert_called_once()
@pytest.mark.asyncio
async def test_async_decorator():
"""Test the async decorator."""
with patch("cognee.infrastructure.llm.rate_limiter.llm_rate_limiter") as mock_limiter_class:
mock_limiter = mock_limiter_class.return_value
# Mock an async method with a coroutine
async def mock_wait():
return 0
mock_limiter.async_wait_if_needed.return_value = mock_wait()
@rate_limit_async
async def test_func():
return "success"
result = await test_func()
assert result == "success"
mock_limiter.async_wait_if_needed.assert_called_once()
def test_rate_limiting_actual():
"""Test actual rate limiting behavior with a small window."""
with patch("cognee.infrastructure.llm.rate_limiter.get_llm_config") as mock_config:
# Configure for 3 requests per minute
mock_config.return_value.llm_rate_limit_enabled = True
mock_config.return_value.llm_rate_limit_requests = 3
mock_config.return_value.llm_rate_limit_interval = 60
# Create a fresh instance
llm_rate_limiter._instance = None
limiter = llm_rate_limiter()
# First 3 requests should succeed
assert limiter.hit_limit() is True
assert limiter.hit_limit() is True
assert limiter.hit_limit() is True
# Fourth request should fail (exceed limit)
assert limiter.hit_limit() is False
def test_rate_limit_60_per_minute():
"""Test rate limiting with the default 60 requests per minute limit."""
with patch("cognee.infrastructure.llm.rate_limiter.get_llm_config") as mock_config:
# Configure for default values: 60 requests per 60 seconds
mock_config.return_value.llm_rate_limit_enabled = True
mock_config.return_value.llm_rate_limit_requests = 60 # 60 requests
mock_config.return_value.llm_rate_limit_interval = 60 # per minute
# Create a fresh instance
llm_rate_limiter._instance = None
limiter = llm_rate_limiter()
# Track successful and failed requests
successes = []
failures = []
# Send requests in batches until we see some failures
# This simulates reaching the rate limit
num_test_requests = 70 # Try more than our limit of 60
for i in range(num_test_requests):
if limiter.hit_limit():
successes.append(f"Request {i}")
else:
failures.append(f"Request {i}")
# Print the results
print(f"Total successful requests: {len(successes)}")
print(f"Total failed requests: {len(failures)}")
if len(failures) > 0:
print(f"First failed request: {failures[0]}")
# Verify we got the expected behavior (close to 60 requests allowed)
# Allow small variations due to timing
assert 58 <= len(successes) <= 62, f"Expected ~60 successful requests, got {len(successes)}"
assert len(failures) > 0, "Expected at least some rate-limited requests"
# Verify that roughly the first 60 requests succeeded
if len(failures) > 0:
first_failure_idx = int(failures[0].split()[1])
print(f"First failure occurred at request index: {first_failure_idx}")
assert 58 <= first_failure_idx <= 62, (
f"Expected first failure around request #60, got #{first_failure_idx}"
)
# Calculate requests per minute
rate_per_minute = len(successes)
print(f"Rate: {rate_per_minute} requests per minute")
# Verify the rate is close to our target of 60 requests per minute
assert 58 <= rate_per_minute <= 62, (
f"Expected rate of ~60 requests per minute, got {rate_per_minute}"
)

View file

@ -0,0 +1,55 @@
import asyncio
from typing import List
from cognee.infrastructure.databases.vector.embeddings.LiteLLMEmbeddingEngine import (
LiteLLMEmbeddingEngine,
)
from cognee.infrastructure.llm.embedding_rate_limiter import (
embedding_rate_limit_async,
embedding_sleep_and_retry_async,
)
class MockEmbeddingEngine(LiteLLMEmbeddingEngine):
"""
Mock version of LiteLLMEmbeddingEngine that returns fixed embeddings
and can be configured to simulate rate limiting and failures.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.mock = True
self.fail_every_n_requests = 0
self.request_count = 0
self.add_delay = 0
def configure_mock(self, fail_every_n_requests=0, add_delay=0):
"""
Configure the mock's behavior
Args:
fail_every_n_requests: Raise an exception every n requests (0 = never fail)
add_delay: Add artificial delay in seconds to each request
"""
self.fail_every_n_requests = fail_every_n_requests
self.add_delay = add_delay
@embedding_sleep_and_retry_async()
@embedding_rate_limit_async
async def embed_text(self, text: List[str]) -> List[List[float]]:
"""
Mock implementation that returns fixed embeddings and can
simulate failures and delays based on configuration.
"""
self.request_count += 1
# Simulate processing delay if configured
if self.add_delay > 0:
await asyncio.sleep(self.add_delay)
# Simulate failures if configured
if self.fail_every_n_requests > 0 and self.request_count % self.fail_every_n_requests == 0:
raise Exception(f"Mock failure on request #{self.request_count}")
# Return mock embeddings of the correct dimension
return [[0.1] * self.dimensions for _ in text]

View file

@ -0,0 +1,201 @@
import os
import time
import asyncio
from functools import lru_cache
import logging
from cognee.infrastructure.llm.config import LLMConfig, get_llm_config
from cognee.infrastructure.llm.embedding_rate_limiter import EmbeddingRateLimiter
from cognee.infrastructure.databases.vector.embeddings.LiteLLMEmbeddingEngine import (
LiteLLMEmbeddingEngine,
)
from cognee.tests.unit.infrastructure.mock_embedding_engine import MockEmbeddingEngine
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
async def test_embedding_rate_limiting_realistic():
"""
Test the embedding rate limiting feature with a realistic scenario:
- Set limit to 3 requests per 5 seconds
- Send requests in bursts with waiting periods
- Track successful and rate-limited requests
- Verify the rate limiter's behavior
"""
# Set up environment variables for rate limiting
os.environ["EMBEDDING_RATE_LIMIT_ENABLED"] = "true"
os.environ["EMBEDDING_RATE_LIMIT_REQUESTS"] = "3" # Only 3 requests per interval
os.environ["EMBEDDING_RATE_LIMIT_INTERVAL"] = "5"
os.environ["MOCK_EMBEDDING"] = "true" # Use mock embeddings for testing
os.environ["DISABLE_RETRIES"] = "true" # Disable automatic retries for testing
# Clear the config and rate limiter caches to ensure our settings are applied
get_llm_config.cache_clear()
EmbeddingRateLimiter.reset_instance()
# Create a fresh config instance and verify settings
config = get_llm_config()
logger.info(f"Embedding Rate Limiting Enabled: {config.embedding_rate_limit_enabled}")
logger.info(
f"Embedding Rate Limit: {config.embedding_rate_limit_requests} requests per {config.embedding_rate_limit_interval} seconds"
)
# Create a mock embedding engine
engine = MockEmbeddingEngine()
# Configure some delay to simulate realistic API calls but not too long
engine.configure_mock(add_delay=0.1)
# Track overall statistics
total_requests = 0
total_successes = 0
total_rate_limited = 0
# Create a list of tasks to simulate concurrent requests
async def make_request(i):
nonlocal total_successes, total_rate_limited
try:
logger.info(f"Making request #{i + 1}")
text = f"Concurrent - Text {i}"
embedding = await engine.embed_text([text])
logger.info(f"Request #{i + 1} succeeded with embedding size: {len(embedding[0])}")
return True
except Exception as e:
logger.info(f"Request #{i + 1} rate limited: {e}")
return False
# Batch 1: Send 10 concurrent requests (expect 3 to succeed, 7 to be rate limited)
batch_size = 10
logger.info(f"\n--- Batch 1: Sending {batch_size} concurrent requests ---")
batch_start = time.time()
tasks = [make_request(i) for i in range(batch_size)]
results = await asyncio.gather(*tasks, return_exceptions=False)
batch_successes = results.count(True)
batch_rate_limited = results.count(False)
batch_end = time.time()
logger.info(f"Batch 1 completed in {batch_end - batch_start:.2f} seconds")
logger.info(f"Successes: {batch_successes}, Rate limited: {batch_rate_limited}")
total_requests += batch_size
total_successes += batch_successes
total_rate_limited += batch_rate_limited
# Wait 2 seconds (should recover some capacity but not all)
wait_time = 2
logger.info(f"\nWaiting {wait_time} seconds to allow partial capacity recovery...")
await asyncio.sleep(wait_time)
# Batch 2: Send 5 more requests (expect some to succeed, some to be rate limited)
batch_size = 5
logger.info(f"\n--- Batch 2: Sending {batch_size} concurrent requests ---")
batch_start = time.time()
tasks = [make_request(i) for i in range(batch_size)]
results = await asyncio.gather(*tasks, return_exceptions=False)
batch_successes = results.count(True)
batch_rate_limited = results.count(False)
batch_end = time.time()
logger.info(f"Batch 2 completed in {batch_end - batch_start:.2f} seconds")
logger.info(f"Successes: {batch_successes}, Rate limited: {batch_rate_limited}")
total_requests += batch_size
total_successes += batch_successes
total_rate_limited += batch_rate_limited
# Wait 5 seconds (should recover full capacity)
wait_time = 5
logger.info(f"\nWaiting {wait_time} seconds to allow full capacity recovery...")
await asyncio.sleep(wait_time)
# Batch 3: Send 3 more requests sequentially (all should succeed)
batch_size = 3
logger.info(f"\n--- Batch 3: Sending {batch_size} sequential requests ---")
batch_start = time.time()
batch_successes = 0
batch_rate_limited = 0
for i in range(batch_size):
try:
logger.info(f"Making request #{i + 1}")
text = f"Sequential - Text {i}"
embedding = await engine.embed_text([text])
logger.info(f"Request #{i + 1} succeeded with embedding size: {len(embedding[0])}")
batch_successes += 1
except Exception as e:
logger.info(f"Request #{i + 1} rate limited: {e}")
batch_rate_limited += 1
batch_end = time.time()
logger.info(f"Batch 3 completed in {batch_end - batch_start:.2f} seconds")
logger.info(f"Successes: {batch_successes}, Rate limited: {batch_rate_limited}")
total_requests += batch_size
total_successes += batch_successes
total_rate_limited += batch_rate_limited
# Log overall results
logger.info("\n--- Test Summary ---")
logger.info(f"Total requests: {total_requests}")
logger.info(f"Total successes: {total_successes}")
logger.info(f"Total rate limited: {total_rate_limited}")
# Verify the behavior
assert total_successes > 0, "Expected some successful requests"
assert total_rate_limited > 0, "Expected some rate limited requests"
# Reset environment variables
os.environ.pop("EMBEDDING_RATE_LIMIT_ENABLED", None)
os.environ.pop("EMBEDDING_RATE_LIMIT_REQUESTS", None)
os.environ.pop("EMBEDDING_RATE_LIMIT_INTERVAL", None)
os.environ.pop("MOCK_EMBEDDING", None)
os.environ.pop("DISABLE_RETRIES", None)
async def test_with_mock_failures():
"""
Test with the mock engine's ability to generate controlled failures.
"""
# Setup rate limiting (more permissive settings)
os.environ["EMBEDDING_RATE_LIMIT_ENABLED"] = "true"
os.environ["EMBEDDING_RATE_LIMIT_REQUESTS"] = "10"
os.environ["EMBEDDING_RATE_LIMIT_INTERVAL"] = "5"
os.environ["DISABLE_RETRIES"] = "true"
# Clear caches
get_llm_config.cache_clear()
EmbeddingRateLimiter.reset_instance()
# Create a mock engine configured to fail every 3rd request
engine = MockEmbeddingEngine()
engine.configure_mock(fail_every_n_requests=3, add_delay=0.1)
logger.info("\n--- Testing controlled failures with mock ---")
# Send 10 requests, expecting every 3rd to fail
for i in range(10):
try:
logger.info(f"Making request #{i + 1}")
text = f"Test text {i}"
embedding = await engine.embed_text([text])
logger.info(f"Request #{i + 1} succeeded for {str(embedding)}")
except Exception as e:
logger.info(f"Request #{i + 1} failed as expected: {e}")
# Reset environment variables
os.environ.pop("EMBEDDING_RATE_LIMIT_ENABLED", None)
os.environ.pop("EMBEDDING_RATE_LIMIT_REQUESTS", None)
os.environ.pop("EMBEDDING_RATE_LIMIT_INTERVAL", None)
os.environ.pop("DISABLE_RETRIES", None)
if __name__ == "__main__":
asyncio.run(test_embedding_rate_limiting_realistic())
asyncio.run(test_with_mock_failures())

View file

@ -0,0 +1,166 @@
import asyncio
import time
import os
from functools import lru_cache
from unittest.mock import patch
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.llm.rate_limiter import llm_rate_limiter
from cognee.infrastructure.llm.config import get_llm_config
async def test_rate_limiting_realistic():
"""
Test the rate limiting feature with a smaller limit to demonstrate
how rate limiting works in practice.
"""
print("\n=== Testing Rate Limiting Feature (Realistic Test) ===")
# Configure a lower rate limit for faster testing: 5 requests per 10 seconds
os.environ["LLM_RATE_LIMIT_ENABLED"] = "true"
os.environ["LLM_RATE_LIMIT_REQUESTS"] = "5"
os.environ["LLM_RATE_LIMIT_INTERVAL"] = "10"
# Clear the cached config and limiter
get_llm_config.cache_clear()
llm_rate_limiter._instance = None
# Create fresh instances
config = get_llm_config()
print(
f"Rate limit settings: {config.llm_rate_limit_enabled=}, {config.llm_rate_limit_requests=}, {config.llm_rate_limit_interval=}"
)
# We'll use monkey patching to guarantee rate limiting for test purposes
with patch.object(llm_rate_limiter, "hit_limit") as mock_hit_limit:
# Setup mock behavior: first 5 calls succeed, then one fails, then all succeed
# This simulates the window moving after waiting
mock_hit_limit.side_effect = [
# First batch: 5 allowed, then 5 limited
True,
True,
True,
True,
True,
False,
False,
False,
False,
False,
# Second batch after waiting: 2 allowed (partial window)
True,
True,
False,
False,
False,
# Third batch after full window reset: all 5 allowed
True,
True,
True,
True,
True,
]
limiter = llm_rate_limiter()
print(f"Rate limiter initialized with {limiter._rate_per_minute} requests per minute")
# First batch - should allow 5 and limit the rest
print("\nBatch 1: Sending 10 requests (expecting only 5 to succeed)...")
batch1_success = []
batch1_failure = []
for i in range(10):
result = limiter.hit_limit()
if result:
batch1_success.append(i)
print(f"✓ Request {i}: Success")
else:
batch1_failure.append(i)
print(f"✗ Request {i}: Rate limited")
print(f"Batch 1 results: {len(batch1_success)} successes, {len(batch1_failure)} failures")
if len(batch1_failure) > 0:
print(f"First rate-limited request: #{batch1_failure[0]}")
# Wait for window to partially reset
wait_time = 5 # seconds - half the rate limit interval
print(f"\nWaiting for {wait_time} seconds to allow capacity to partially regenerate...")
await asyncio.sleep(wait_time)
# Second batch - should get some capacity back
print("\nBatch 2: Sending 5 more requests (expecting 2 to succeed)...")
batch2_success = []
batch2_failure = []
for i in range(5):
result = limiter.hit_limit()
if result:
batch2_success.append(i)
print(f"✓ Request {i}: Success")
else:
batch2_failure.append(i)
print(f"✗ Request {i}: Rate limited")
print(f"Batch 2 results: {len(batch2_success)} successes, {len(batch2_failure)} failures")
# Wait for full window to reset
full_wait = 10 # seconds - full rate limit interval
print(f"\nWaiting for {full_wait} seconds for full capacity to regenerate...")
await asyncio.sleep(full_wait)
# Third batch - should have full capacity again
print("\nBatch 3: Sending 5 requests (expecting all to succeed)...")
batch3_success = []
batch3_failure = []
for i in range(5):
result = limiter.hit_limit()
if result:
batch3_success.append(i)
print(f"✓ Request {i}: Success")
else:
batch3_failure.append(i)
print(f"✗ Request {i}: Rate limited")
print(f"Batch 3 results: {len(batch3_success)} successes, {len(batch3_failure)} failures")
# Calculate total successes and failures
total_success = len(batch1_success) + len(batch2_success) + len(batch3_success)
total_failure = len(batch1_failure) + len(batch2_failure) + len(batch3_failure)
print(f"\nTotal requests: {total_success + total_failure}")
print(f"Total successful: {total_success}")
print(f"Total rate limited: {total_failure}")
# Verify the rate limiting behavior
if len(batch1_success) == 5 and len(batch1_failure) == 5:
print("\n✅ PASS: Rate limiter correctly limited first batch to 5 requests")
else:
print(f"\n❌ FAIL: First batch should allow 5 requests, got {len(batch1_success)}")
if len(batch2_success) == 2 and len(batch2_failure) == 3:
print("✅ PASS: Rate limiter correctly allowed 2 requests after partial wait")
else:
print(f"❌ FAIL: Second batch should allow 2 requests, got {len(batch2_success)}")
if len(batch3_success) == 5 and len(batch3_failure) == 0:
print("✅ PASS: Rate limiter correctly allowed all requests after full window expired")
else:
print(f"❌ FAIL: Third batch should allow all 5 requests, got {len(batch3_success)}")
print("=== Rate Limiting Test Complete ===\n")
async def main():
"""Run the realistic rate limiting test."""
await test_rate_limiting_realistic()
if __name__ == "__main__":
logger = get_logger()
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(main())
finally:
loop.run_until_complete(loop.shutdown_asyncgens())

View file

@ -0,0 +1,196 @@
import asyncio
import time
import os
from unittest.mock import patch, MagicMock
from functools import lru_cache
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.llm.rate_limiter import (
sleep_and_retry_sync,
sleep_and_retry_async,
is_rate_limit_error,
)
logger = get_logger()
# Test function to be decorated
@sleep_and_retry_sync(max_retries=3, initial_backoff=0.1, backoff_factor=2.0)
def test_function_sync():
"""A test function that raises rate limit errors a few times, then succeeds."""
if hasattr(test_function_sync, "counter"):
test_function_sync.counter += 1
else:
test_function_sync.counter = 1
if test_function_sync.counter <= 2:
error_msg = "429 Too Many Requests: Rate limit exceeded"
logger.info(f"Attempt {test_function_sync.counter}: Raising rate limit error")
raise Exception(error_msg)
logger.info(f"Attempt {test_function_sync.counter}: Success!")
return f"Success on attempt {test_function_sync.counter}"
# Test async function to be decorated
@sleep_and_retry_async(max_retries=3, initial_backoff=0.1, backoff_factor=2.0)
async def test_function_async():
"""An async test function that raises rate limit errors a few times, then succeeds."""
if hasattr(test_function_async, "counter"):
test_function_async.counter += 1
else:
test_function_async.counter = 1
if test_function_async.counter <= 2:
error_msg = "429 Too Many Requests: Rate limit exceeded"
logger.info(f"Attempt {test_function_async.counter}: Raising rate limit error")
raise Exception(error_msg)
logger.info(f"Attempt {test_function_async.counter}: Success!")
return f"Success on attempt {test_function_async.counter}"
def test_is_rate_limit_error():
"""Test the rate limit error detection function."""
print("\n=== Testing Rate Limit Error Detection ===")
# Test various error messages that should be detected as rate limit errors
rate_limit_errors = [
"429 Rate limit exceeded",
"Too many requests",
"rate_limit_exceeded",
"ratelimit error",
"You have exceeded your quota",
"capacity has been exceeded",
"Service throttled",
]
# Test error messages that should not be detected as rate limit errors
non_rate_limit_errors = [
"404 Not Found",
"500 Internal Server Error",
"Invalid API Key",
"Bad Request",
]
# Check that rate limit errors are correctly identified
for error in rate_limit_errors:
error_obj = Exception(error)
result = is_rate_limit_error(error_obj)
print(f"Error '{error}': {'' if result else ''} {result}")
assert result, f"Failed to identify rate limit error: {error}"
print(f"✓ Correctly identified as rate limit error: {error}")
# Check that non-rate limit errors are not misidentified
for error in non_rate_limit_errors:
error_obj = Exception(error)
result = is_rate_limit_error(error_obj)
print(f"Error '{error}': {'' if not result else ''} {not result}")
assert not result, f"Incorrectly identified as rate limit error: {error}"
print(f"✓ Correctly identified as non-rate limit error: {error}")
print("✅ PASS: Rate limit error detection is working correctly")
def test_sync_retry():
"""Test the synchronous retry decorator."""
print("\n=== Testing Synchronous Sleep and Retry ===")
# Reset counter for the test function
if hasattr(test_function_sync, "counter"):
del test_function_sync.counter
# Time the execution to verify backoff is working
start_time = time.time()
try:
result = test_function_sync()
end_time = time.time()
elapsed = end_time - start_time
# Verify results
print(f"Result: {result}")
print(f"Test completed in {elapsed:.2f} seconds")
print(f"Number of attempts: {test_function_sync.counter}")
# The function should succeed on the 3rd attempt (after 2 failures)
assert test_function_sync.counter == 3, (
f"Expected 3 attempts, got {test_function_sync.counter}"
)
assert elapsed >= 0.3, f"Expected at least 0.3 seconds of backoff, got {elapsed:.2f}"
print("✅ PASS: Synchronous retry mechanism is working correctly")
except Exception as e:
print(f"❌ FAIL: Test encountered an unexpected error: {str(e)}")
raise
async def test_async_retry():
"""Test the asynchronous retry decorator."""
print("\n=== Testing Asynchronous Sleep and Retry ===")
# Reset counter for the test function
if hasattr(test_function_async, "counter"):
del test_function_async.counter
# Time the execution to verify backoff is working
start_time = time.time()
try:
result = await test_function_async()
end_time = time.time()
elapsed = end_time - start_time
# Verify results
print(f"Result: {result}")
print(f"Test completed in {elapsed:.2f} seconds")
print(f"Number of attempts: {test_function_async.counter}")
# The function should succeed on the 3rd attempt (after 2 failures)
assert test_function_async.counter == 3, (
f"Expected 3 attempts, got {test_function_async.counter}"
)
assert elapsed >= 0.3, f"Expected at least 0.3 seconds of backoff, got {elapsed:.2f}"
print("✅ PASS: Asynchronous retry mechanism is working correctly")
except Exception as e:
print(f"❌ FAIL: Test encountered an unexpected error: {str(e)}")
raise
async def test_retry_max_exceeded():
"""Test what happens when max retries is exceeded."""
print("\n=== Testing Max Retries Exceeded ===")
@sleep_and_retry_async(max_retries=2, initial_backoff=0.1)
async def always_fails():
"""A function that always raises a rate limit error."""
error_msg = "429 Too Many Requests: Rate limit always exceeded"
logger.info(f"Always fails with: {error_msg}")
raise Exception(error_msg)
try:
# This should fail after 2 retries (3 attempts total)
await always_fails()
print("❌ FAIL: Function should have failed but succeeded")
except Exception as e:
print(f"Expected error after max retries: {str(e)}")
print("✅ PASS: Function correctly failed after max retries exceeded")
async def main():
"""Run all the retry tests."""
test_is_rate_limit_error()
test_sync_retry()
await test_async_retry()
await test_retry_max_exceeded()
print("\n=== All Rate Limiting Retry Tests Complete ===")
if __name__ == "__main__":
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(main())
finally:
loop.run_until_complete(loop.shutdown_asyncgens())

3002
poetry.lock generated

File diff suppressed because it is too large Load diff

View file

@ -48,6 +48,7 @@ lancedb = "0.16.0"
alembic = "^1.13.3"
pre-commit = "^4.0.1"
scikit-learn = "^1.6.1"
limits = "^4.4.1"
fastapi = {version = "0.115.7"}
fastapi-users = {version = "14.0.0", extras = ["sqlalchemy"]}
uvicorn = {version = "0.34.0", optional = true}

View file

View file

@ -1,7 +0,0 @@
def test_import_cognee():
try:
import cognee
assert True # Pass the test if no error occurs
except ImportError as e:
assert False, f"Failed to import cognee: {e}"