feat: Add RPM limiting to Cognee
This commit is contained in:
parent
0c97a400b0
commit
7deaa6e8e9
10 changed files with 146 additions and 683 deletions
|
|
@ -17,6 +17,7 @@ from cognee.infrastructure.databases.exceptions import EmbeddingException
|
|||
from cognee.infrastructure.llm.tokenizer.TikToken import (
|
||||
TikTokenTokenizer,
|
||||
)
|
||||
from cognee.shared.rate_limiting import embedding_rate_limiter_context_manager
|
||||
|
||||
litellm.set_verbose = False
|
||||
logger = get_logger("FastembedEmbeddingEngine")
|
||||
|
|
@ -68,7 +69,7 @@ class FastembedEmbeddingEngine(EmbeddingEngine):
|
|||
|
||||
@retry(
|
||||
stop=stop_after_delay(128),
|
||||
wait=wait_exponential_jitter(2, 128),
|
||||
wait=wait_exponential_jitter(8, 128),
|
||||
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||
reraise=True,
|
||||
|
|
@ -96,11 +97,12 @@ class FastembedEmbeddingEngine(EmbeddingEngine):
|
|||
if self.mock:
|
||||
return [[0.0] * self.dimensions for _ in text]
|
||||
else:
|
||||
embeddings = self.embedding_model.embed(
|
||||
text,
|
||||
batch_size=len(text),
|
||||
parallel=None,
|
||||
)
|
||||
async with embedding_rate_limiter_context_manager():
|
||||
embeddings = self.embedding_model.embed(
|
||||
text,
|
||||
batch_size=len(text),
|
||||
parallel=None,
|
||||
)
|
||||
|
||||
return list(embeddings)
|
||||
|
||||
|
|
|
|||
|
|
@ -18,10 +18,7 @@ from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import Em
|
|||
from cognee.infrastructure.llm.tokenizer.HuggingFace import (
|
||||
HuggingFaceTokenizer,
|
||||
)
|
||||
from cognee.infrastructure.databases.vector.embeddings.embedding_rate_limiter import (
|
||||
embedding_rate_limit_async,
|
||||
embedding_sleep_and_retry_async,
|
||||
)
|
||||
from cognee.shared.rate_limiting import embedding_rate_limiter_context_manager
|
||||
from cognee.shared.utils import create_secure_ssl_context
|
||||
|
||||
logger = get_logger("OllamaEmbeddingEngine")
|
||||
|
|
@ -101,7 +98,7 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
|
|||
|
||||
@retry(
|
||||
stop=stop_after_delay(128),
|
||||
wait=wait_exponential_jitter(2, 128),
|
||||
wait=wait_exponential_jitter(8, 128),
|
||||
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||
reraise=True,
|
||||
|
|
@ -120,14 +117,15 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
|
|||
ssl_context = create_secure_ssl_context()
|
||||
connector = aiohttp.TCPConnector(ssl=ssl_context)
|
||||
async with aiohttp.ClientSession(connector=connector) as session:
|
||||
async with session.post(
|
||||
self.endpoint, json=payload, headers=headers, timeout=60.0
|
||||
) as response:
|
||||
data = await response.json()
|
||||
if "embeddings" in data:
|
||||
return data["embeddings"][0]
|
||||
else:
|
||||
return data["data"][0]["embedding"]
|
||||
async with embedding_rate_limiter_context_manager():
|
||||
async with session.post(
|
||||
self.endpoint, json=payload, headers=headers, timeout=60.0
|
||||
) as response:
|
||||
data = await response.json()
|
||||
if "embeddings" in data:
|
||||
return data["embeddings"][0]
|
||||
else:
|
||||
return data["data"][0]["embedding"]
|
||||
|
||||
def get_vector_size(self) -> int:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,544 +0,0 @@
|
|||
import threading
|
||||
import logging
|
||||
import functools
|
||||
import os
|
||||
import time
|
||||
import asyncio
|
||||
import random
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.infrastructure.llm.config import get_llm_config
|
||||
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
# Common error patterns that indicate rate limiting
|
||||
RATE_LIMIT_ERROR_PATTERNS = [
|
||||
"rate limit",
|
||||
"rate_limit",
|
||||
"ratelimit",
|
||||
"too many requests",
|
||||
"retry after",
|
||||
"capacity",
|
||||
"quota",
|
||||
"limit exceeded",
|
||||
"tps limit exceeded",
|
||||
"request limit exceeded",
|
||||
"maximum requests",
|
||||
"exceeded your current quota",
|
||||
"throttled",
|
||||
"throttling",
|
||||
]
|
||||
|
||||
# Default retry settings
|
||||
DEFAULT_MAX_RETRIES = 5
|
||||
DEFAULT_INITIAL_BACKOFF = 1.0 # seconds
|
||||
DEFAULT_BACKOFF_FACTOR = 2.0 # exponential backoff multiplier
|
||||
DEFAULT_JITTER = 0.1 # 10% jitter to avoid thundering herd
|
||||
|
||||
|
||||
class EmbeddingRateLimiter:
|
||||
"""
|
||||
Rate limiter for embedding API calls.
|
||||
|
||||
This class implements a singleton pattern to ensure that rate limiting
|
||||
is consistent across all embedding requests. It uses the limits
|
||||
library with a moving window strategy to control request rates.
|
||||
|
||||
The rate limiter uses the same configuration as the LLM API rate limiter
|
||||
but uses a separate key to track embedding API calls independently.
|
||||
|
||||
Public Methods:
|
||||
- get_instance
|
||||
- reset_instance
|
||||
- hit_limit
|
||||
- wait_if_needed
|
||||
- async_wait_if_needed
|
||||
|
||||
Instance Variables:
|
||||
- enabled
|
||||
- requests_limit
|
||||
- interval_seconds
|
||||
- request_times
|
||||
- lock
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
lock = threading.Lock()
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls):
|
||||
"""
|
||||
Retrieve the singleton instance of the EmbeddingRateLimiter.
|
||||
|
||||
This method ensures that only one instance of the class exists and
|
||||
is thread-safe. It lazily initializes the instance if it doesn't
|
||||
already exist.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
The singleton instance of the EmbeddingRateLimiter class.
|
||||
"""
|
||||
if cls._instance is None:
|
||||
with cls.lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
|
||||
@classmethod
|
||||
def reset_instance(cls):
|
||||
"""
|
||||
Reset the singleton instance of the EmbeddingRateLimiter.
|
||||
|
||||
This method is thread-safe and sets the instance to None, allowing
|
||||
for a new instance to be created when requested again.
|
||||
"""
|
||||
with cls.lock:
|
||||
cls._instance = None
|
||||
|
||||
def __init__(self):
|
||||
config = get_llm_config()
|
||||
self.enabled = config.embedding_rate_limit_enabled
|
||||
self.requests_limit = config.embedding_rate_limit_requests
|
||||
self.interval_seconds = config.embedding_rate_limit_interval
|
||||
self.request_times = []
|
||||
self.lock = threading.Lock()
|
||||
|
||||
logging.info(
|
||||
f"EmbeddingRateLimiter initialized: enabled={self.enabled}, "
|
||||
f"requests_limit={self.requests_limit}, interval_seconds={self.interval_seconds}"
|
||||
)
|
||||
|
||||
def hit_limit(self) -> bool:
|
||||
"""
|
||||
Check if the current request would exceed the rate limit.
|
||||
|
||||
This method checks if the rate limiter is enabled and evaluates
|
||||
the number of requests made in the elapsed interval.
|
||||
|
||||
Returns:
|
||||
- bool: True if the rate limit would be exceeded, False otherwise.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- bool: True if the rate limit would be exceeded, otherwise False.
|
||||
"""
|
||||
if not self.enabled:
|
||||
return False
|
||||
|
||||
current_time = time.time()
|
||||
|
||||
with self.lock:
|
||||
# Remove expired request times
|
||||
cutoff_time = current_time - self.interval_seconds
|
||||
self.request_times = [t for t in self.request_times if t > cutoff_time]
|
||||
|
||||
# Check if adding a new request would exceed the limit
|
||||
if len(self.request_times) >= self.requests_limit:
|
||||
logger.info(
|
||||
f"Rate limit hit: {len(self.request_times)} requests in the last {self.interval_seconds} seconds"
|
||||
)
|
||||
return True
|
||||
|
||||
# Otherwise, we're under the limit
|
||||
return False
|
||||
|
||||
def wait_if_needed(self) -> float:
|
||||
"""
|
||||
Block until a request can be made without exceeding the rate limit.
|
||||
|
||||
This method will wait if the current request would exceed the
|
||||
rate limit and returns the time waited in seconds.
|
||||
|
||||
Returns:
|
||||
- float: Time waited in seconds before a request is allowed.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- float: Time waited in seconds before proceeding.
|
||||
"""
|
||||
if not self.enabled:
|
||||
return 0
|
||||
|
||||
wait_time = 0
|
||||
start_time = time.time()
|
||||
|
||||
while self.hit_limit():
|
||||
time.sleep(0.5) # Poll every 0.5 seconds
|
||||
wait_time = time.time() - start_time
|
||||
|
||||
# Record this request
|
||||
with self.lock:
|
||||
self.request_times.append(time.time())
|
||||
|
||||
return wait_time
|
||||
|
||||
async def async_wait_if_needed(self) -> float:
|
||||
"""
|
||||
Asynchronously wait until a request can be made without exceeding the rate limit.
|
||||
|
||||
This method will wait if the current request would exceed the
|
||||
rate limit and returns the time waited in seconds.
|
||||
|
||||
Returns:
|
||||
- float: Time waited in seconds before a request is allowed.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- float: Time waited in seconds before proceeding.
|
||||
"""
|
||||
if not self.enabled:
|
||||
return 0
|
||||
|
||||
wait_time = 0
|
||||
start_time = time.time()
|
||||
|
||||
while self.hit_limit():
|
||||
await asyncio.sleep(0.5) # Poll every 0.5 seconds
|
||||
wait_time = time.time() - start_time
|
||||
|
||||
# Record this request
|
||||
with self.lock:
|
||||
self.request_times.append(time.time())
|
||||
|
||||
return wait_time
|
||||
|
||||
|
||||
def embedding_rate_limit_sync(func):
|
||||
"""
|
||||
Apply rate limiting to a synchronous embedding function.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- func: Function to decorate with rate limiting logic.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
Returns the decorated function that applies rate limiting.
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
"""
|
||||
Wrap the given function with rate limiting logic to control the embedding API usage.
|
||||
|
||||
Checks if the rate limit has been exceeded before allowing the function to execute. If
|
||||
the limit is hit, it logs a warning and raises an EmbeddingException. Otherwise, it
|
||||
updates the request count and proceeds to call the original function.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- *args: Variable length argument list for the wrapped function.
|
||||
- **kwargs: Keyword arguments for the wrapped function.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
Returns the result of the wrapped function if rate limiting conditions are met.
|
||||
"""
|
||||
limiter = EmbeddingRateLimiter.get_instance()
|
||||
|
||||
# Check if rate limiting is enabled and if we're at the limit
|
||||
if limiter.hit_limit():
|
||||
error_msg = "Embedding API rate limit exceeded"
|
||||
logger.warning(error_msg)
|
||||
|
||||
# Create a custom embedding rate limit exception
|
||||
from cognee.infrastructure.databases.exceptions import EmbeddingException
|
||||
|
||||
raise EmbeddingException(error_msg)
|
||||
|
||||
# Add this request to the counter and proceed
|
||||
limiter.wait_if_needed()
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def embedding_rate_limit_async(func):
|
||||
"""
|
||||
Decorator that applies rate limiting to an asynchronous embedding function.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- func: Async function to decorate.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
Returns the decorated async function that applies rate limiting.
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
"""
|
||||
Handle function calls with embedding rate limiting.
|
||||
|
||||
This asynchronous wrapper checks if the embedding API rate limit is exceeded before
|
||||
allowing the function to execute. If the limit is exceeded, it logs a warning and raises
|
||||
an EmbeddingException. If not, it waits as necessary and proceeds with the function
|
||||
call.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- *args: Positional arguments passed to the wrapped function.
|
||||
- **kwargs: Keyword arguments passed to the wrapped function.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
Returns the result of the wrapped function after handling rate limiting.
|
||||
"""
|
||||
limiter = EmbeddingRateLimiter.get_instance()
|
||||
|
||||
# Check if rate limiting is enabled and if we're at the limit
|
||||
if limiter.hit_limit():
|
||||
error_msg = "Embedding API rate limit exceeded"
|
||||
logger.warning(error_msg)
|
||||
|
||||
# Create a custom embedding rate limit exception
|
||||
from cognee.infrastructure.databases.exceptions import EmbeddingException
|
||||
|
||||
raise EmbeddingException(error_msg)
|
||||
|
||||
# Add this request to the counter and proceed
|
||||
await limiter.async_wait_if_needed()
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def embedding_sleep_and_retry_sync(max_retries=5, base_backoff=1.0, jitter=0.5):
|
||||
"""
|
||||
Add retry with exponential backoff for synchronous embedding functions.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- max_retries: Maximum number of retries before giving up. (default 5)
|
||||
- base_backoff: Base backoff time in seconds for retry intervals. (default 1.0)
|
||||
- jitter: Jitter factor to randomize the backoff time to avoid collision. (default
|
||||
0.5)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
A decorator that retries the wrapped function on rate limit errors, applying
|
||||
exponential backoff with jitter.
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
"""
|
||||
Wraps a function to apply retry logic on rate limit errors.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- func: The function to be wrapped with retry logic.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
Returns the wrapped function with retry logic applied.
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
"""
|
||||
Retry the execution of a function with backoff on failure due to rate limit errors.
|
||||
|
||||
This wrapper function will call the specified function and if it raises an exception, it
|
||||
will handle retries according to defined conditions. It will check the environment for a
|
||||
DISABLE_RETRIES flag to determine whether to retry or propagate errors immediately
|
||||
during tests. If the error is identified as a rate limit error, it will apply an
|
||||
exponential backoff strategy with jitter before retrying, up to a maximum number of
|
||||
retries. If the retries are exhausted, it raises the last encountered error.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- *args: Positional arguments passed to the wrapped function.
|
||||
- **kwargs: Keyword arguments passed to the wrapped function.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
Returns the result of the wrapped function if successful; otherwise, raises the last
|
||||
error encountered after maximum retries are exhausted.
|
||||
"""
|
||||
# If DISABLE_RETRIES is set, don't retry for testing purposes
|
||||
disable_retries = os.environ.get("DISABLE_RETRIES", "false").lower() in (
|
||||
"true",
|
||||
"1",
|
||||
"yes",
|
||||
)
|
||||
|
||||
retries = 0
|
||||
last_error = None
|
||||
|
||||
while retries <= max_retries:
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
# Check if this is a rate limit error
|
||||
error_str = str(e).lower()
|
||||
error_type = type(e).__name__
|
||||
is_rate_limit = any(
|
||||
pattern in error_str.lower() for pattern in RATE_LIMIT_ERROR_PATTERNS
|
||||
)
|
||||
|
||||
if disable_retries:
|
||||
# For testing, propagate the exception immediately
|
||||
raise
|
||||
|
||||
if is_rate_limit and retries < max_retries:
|
||||
# Calculate backoff with jitter
|
||||
backoff = (
|
||||
base_backoff * (2**retries) * (1 + random.uniform(-jitter, jitter))
|
||||
)
|
||||
|
||||
logger.warning(
|
||||
f"Embedding rate limit hit, retrying in {backoff:.2f}s "
|
||||
f"(attempt {retries + 1}/{max_retries}): "
|
||||
f"({error_str!r}, {error_type!r})"
|
||||
)
|
||||
|
||||
time.sleep(backoff)
|
||||
retries += 1
|
||||
last_error = e
|
||||
else:
|
||||
# Not a rate limit error or max retries reached, raise
|
||||
raise
|
||||
|
||||
# If we exit the loop due to max retries, raise the last error
|
||||
if last_error:
|
||||
raise last_error
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def embedding_sleep_and_retry_async(max_retries=5, base_backoff=1.0, jitter=0.5):
|
||||
"""
|
||||
Add retry logic with exponential backoff for asynchronous embedding functions.
|
||||
|
||||
This decorator retries the wrapped asynchronous function upon encountering rate limit
|
||||
errors, utilizing exponential backoff with optional jitter to space out retry attempts.
|
||||
It allows for a maximum number of retries before giving up and raising the last error
|
||||
encountered.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- max_retries: Maximum number of retries allowed before giving up. (default 5)
|
||||
- base_backoff: Base amount of time in seconds to wait before retrying after a rate
|
||||
limit error. (default 1.0)
|
||||
- jitter: Amount of randomness to add to the backoff duration to help mitigate burst
|
||||
issues on retries. (default 0.5)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
Returns a decorated asynchronous function that implements the retry logic on rate
|
||||
limit errors.
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
"""
|
||||
Handle retries for an async function with exponential backoff and jitter.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- func: An asynchronous function to be wrapped with retry logic.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
Returns the wrapper function that manages the retry behavior for the wrapped async
|
||||
function.
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
"""
|
||||
Handle retries for an async function with exponential backoff and jitter.
|
||||
|
||||
If the environment variable DISABLE_RETRIES is set to true, 1, or yes, the function will
|
||||
not retry on errors.
|
||||
It attempts to call the wrapped function until it succeeds or the maximum number of
|
||||
retries is reached. If an exception occurs, it checks if it's a rate limit error to
|
||||
determine if a retry is needed.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- *args: Positional arguments passed to the wrapped function.
|
||||
- **kwargs: Keyword arguments passed to the wrapped function.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
Returns the result of the wrapped async function if successful; raises the last
|
||||
encountered error if all retries fail.
|
||||
"""
|
||||
# If DISABLE_RETRIES is set, don't retry for testing purposes
|
||||
disable_retries = os.environ.get("DISABLE_RETRIES", "false").lower() in (
|
||||
"true",
|
||||
"1",
|
||||
"yes",
|
||||
)
|
||||
|
||||
retries = 0
|
||||
last_error = None
|
||||
|
||||
while retries <= max_retries:
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
# Check if this is a rate limit error
|
||||
error_str = str(e).lower()
|
||||
error_type = type(e).__name__
|
||||
is_rate_limit = any(
|
||||
pattern in error_str.lower() for pattern in RATE_LIMIT_ERROR_PATTERNS
|
||||
)
|
||||
|
||||
if disable_retries:
|
||||
# For testing, propagate the exception immediately
|
||||
raise
|
||||
|
||||
if is_rate_limit and retries < max_retries:
|
||||
# Calculate backoff with jitter
|
||||
backoff = (
|
||||
base_backoff * (2**retries) * (1 + random.uniform(-jitter, jitter))
|
||||
)
|
||||
|
||||
logger.warning(
|
||||
f"Embedding rate limit hit, retrying in {backoff:.2f}s "
|
||||
f"(attempt {retries + 1}/{max_retries}): "
|
||||
f"({error_str!r}, {error_type!r})"
|
||||
)
|
||||
|
||||
await asyncio.sleep(backoff)
|
||||
retries += 1
|
||||
last_error = e
|
||||
else:
|
||||
# Not a rate limit error or max retries reached, raise
|
||||
raise
|
||||
|
||||
# If we exit the loop due to max retries, raise the last error
|
||||
if last_error:
|
||||
raise last_error
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
|
@ -15,6 +15,7 @@ from tenacity import (
|
|||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
|
||||
LLMInterface,
|
||||
)
|
||||
from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
|
||||
from cognee.infrastructure.llm.config import get_llm_config
|
||||
|
||||
logger = get_logger()
|
||||
|
|
@ -45,7 +46,7 @@ class AnthropicAdapter(LLMInterface):
|
|||
|
||||
@retry(
|
||||
stop=stop_after_delay(128),
|
||||
wait=wait_exponential_jitter(2, 128),
|
||||
wait=wait_exponential_jitter(8, 128),
|
||||
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||
reraise=True,
|
||||
|
|
@ -69,17 +70,17 @@ class AnthropicAdapter(LLMInterface):
|
|||
|
||||
- BaseModel: An instance of BaseModel containing the structured response.
|
||||
"""
|
||||
|
||||
return await self.aclient(
|
||||
model=self.model,
|
||||
max_tokens=4096,
|
||||
max_retries=5,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""Use the given format to extract information
|
||||
from the following input: {text_input}. {system_prompt}""",
|
||||
}
|
||||
],
|
||||
response_model=response_model,
|
||||
)
|
||||
async with llm_rate_limiter_context_manager():
|
||||
return await self.aclient(
|
||||
model=self.model,
|
||||
max_tokens=4096,
|
||||
max_retries=2,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""Use the given format to extract information
|
||||
from the following input: {text_input}. {system_prompt}""",
|
||||
}
|
||||
],
|
||||
response_model=response_model,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.ll
|
|||
LLMInterface,
|
||||
)
|
||||
import logging
|
||||
from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from tenacity import (
|
||||
retry,
|
||||
|
|
@ -73,7 +74,7 @@ class GeminiAdapter(LLMInterface):
|
|||
|
||||
@retry(
|
||||
stop=stop_after_delay(128),
|
||||
wait=wait_exponential_jitter(2, 128),
|
||||
wait=wait_exponential_jitter(8, 128),
|
||||
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||
reraise=True,
|
||||
|
|
@ -105,24 +106,25 @@ class GeminiAdapter(LLMInterface):
|
|||
"""
|
||||
|
||||
try:
|
||||
return await self.aclient.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""{text_input}""",
|
||||
},
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
],
|
||||
api_key=self.api_key,
|
||||
max_retries=5,
|
||||
api_base=self.endpoint,
|
||||
api_version=self.api_version,
|
||||
response_model=response_model,
|
||||
)
|
||||
async with llm_rate_limiter_context_manager():
|
||||
return await self.aclient.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""{text_input}""",
|
||||
},
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
],
|
||||
api_key=self.api_key,
|
||||
max_retries=2,
|
||||
api_base=self.endpoint,
|
||||
api_version=self.api_version,
|
||||
response_model=response_model,
|
||||
)
|
||||
except (
|
||||
ContentFilterFinishReasonError,
|
||||
ContentPolicyViolationError,
|
||||
|
|
@ -140,23 +142,24 @@ class GeminiAdapter(LLMInterface):
|
|||
)
|
||||
|
||||
try:
|
||||
return await self.aclient.chat.completions.create(
|
||||
model=self.fallback_model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""{text_input}""",
|
||||
},
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
],
|
||||
max_retries=5,
|
||||
api_key=self.fallback_api_key,
|
||||
api_base=self.fallback_endpoint,
|
||||
response_model=response_model,
|
||||
)
|
||||
async with llm_rate_limiter_context_manager():
|
||||
return await self.aclient.chat.completions.create(
|
||||
model=self.fallback_model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""{text_input}""",
|
||||
},
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
],
|
||||
max_retries=2,
|
||||
api_key=self.fallback_api_key,
|
||||
api_base=self.fallback_endpoint,
|
||||
response_model=response_model,
|
||||
)
|
||||
except (
|
||||
ContentFilterFinishReasonError,
|
||||
ContentPolicyViolationError,
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.ll
|
|||
LLMInterface,
|
||||
)
|
||||
import logging
|
||||
from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from tenacity import (
|
||||
retry,
|
||||
|
|
@ -73,7 +74,7 @@ class GenericAPIAdapter(LLMInterface):
|
|||
|
||||
@retry(
|
||||
stop=stop_after_delay(128),
|
||||
wait=wait_exponential_jitter(2, 128),
|
||||
wait=wait_exponential_jitter(8, 128),
|
||||
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||
reraise=True,
|
||||
|
|
@ -105,23 +106,24 @@ class GenericAPIAdapter(LLMInterface):
|
|||
"""
|
||||
|
||||
try:
|
||||
return await self.aclient.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""{text_input}""",
|
||||
},
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
],
|
||||
max_retries=5,
|
||||
api_key=self.api_key,
|
||||
api_base=self.endpoint,
|
||||
response_model=response_model,
|
||||
)
|
||||
async with llm_rate_limiter_context_manager():
|
||||
return await self.aclient.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""{text_input}""",
|
||||
},
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
],
|
||||
max_retries=2,
|
||||
api_key=self.api_key,
|
||||
api_base=self.endpoint,
|
||||
response_model=response_model,
|
||||
)
|
||||
except (
|
||||
ContentFilterFinishReasonError,
|
||||
ContentPolicyViolationError,
|
||||
|
|
@ -139,23 +141,24 @@ class GenericAPIAdapter(LLMInterface):
|
|||
) from error
|
||||
|
||||
try:
|
||||
return await self.aclient.chat.completions.create(
|
||||
model=self.fallback_model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""{text_input}""",
|
||||
},
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
],
|
||||
max_retries=5,
|
||||
api_key=self.fallback_api_key,
|
||||
api_base=self.fallback_endpoint,
|
||||
response_model=response_model,
|
||||
)
|
||||
async with llm_rate_limiter_context_manager():
|
||||
return await self.aclient.chat.completions.create(
|
||||
model=self.fallback_model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""{text_input}""",
|
||||
},
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
],
|
||||
max_retries=2,
|
||||
api_key=self.fallback_api_key,
|
||||
api_base=self.fallback_endpoint,
|
||||
response_model=response_model,
|
||||
)
|
||||
except (
|
||||
ContentFilterFinishReasonError,
|
||||
ContentPolicyViolationError,
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.ll
|
|||
LLMInterface,
|
||||
)
|
||||
from cognee.infrastructure.llm.config import get_llm_config
|
||||
from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
|
||||
|
||||
import logging
|
||||
from tenacity import (
|
||||
|
|
@ -62,7 +63,7 @@ class MistralAdapter(LLMInterface):
|
|||
|
||||
@retry(
|
||||
stop=stop_after_delay(128),
|
||||
wait=wait_exponential_jitter(2, 128),
|
||||
wait=wait_exponential_jitter(8, 128),
|
||||
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||
reraise=True,
|
||||
|
|
@ -97,13 +98,14 @@ class MistralAdapter(LLMInterface):
|
|||
},
|
||||
]
|
||||
try:
|
||||
response = await self.aclient.chat.completions.create(
|
||||
model=self.model,
|
||||
max_tokens=self.max_completion_tokens,
|
||||
max_retries=5,
|
||||
messages=messages,
|
||||
response_model=response_model,
|
||||
)
|
||||
async with llm_rate_limiter_context_manager():
|
||||
response = await self.aclient.chat.completions.create(
|
||||
model=self.model,
|
||||
max_tokens=self.max_completion_tokens,
|
||||
max_retries=2,
|
||||
messages=messages,
|
||||
response_model=response_model,
|
||||
)
|
||||
if response.choices and response.choices[0].message.content:
|
||||
content = response.choices[0].message.content
|
||||
return response_model.model_validate_json(content)
|
||||
|
|
|
|||
|
|
@ -11,6 +11,8 @@ from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.ll
|
|||
)
|
||||
from cognee.infrastructure.files.utils.open_data_file import open_data_file
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
|
||||
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_delay,
|
||||
|
|
@ -68,7 +70,7 @@ class OllamaAPIAdapter(LLMInterface):
|
|||
|
||||
@retry(
|
||||
stop=stop_after_delay(128),
|
||||
wait=wait_exponential_jitter(2, 128),
|
||||
wait=wait_exponential_jitter(8, 128),
|
||||
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||
reraise=True,
|
||||
|
|
@ -95,28 +97,28 @@ class OllamaAPIAdapter(LLMInterface):
|
|||
|
||||
- BaseModel: A structured output that conforms to the specified response model.
|
||||
"""
|
||||
|
||||
response = self.aclient.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"{text_input}",
|
||||
},
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
],
|
||||
max_retries=5,
|
||||
response_model=response_model,
|
||||
)
|
||||
async with llm_rate_limiter_context_manager():
|
||||
response = self.aclient.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"{text_input}",
|
||||
},
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
],
|
||||
max_retries=2,
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@retry(
|
||||
stop=stop_after_delay(128),
|
||||
wait=wait_exponential_jitter(2, 128),
|
||||
wait=wait_exponential_jitter(8, 128),
|
||||
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||
reraise=True,
|
||||
|
|
|
|||
|
|
@ -106,7 +106,7 @@ class OpenAIAdapter(LLMInterface):
|
|||
@observe(as_type="generation")
|
||||
@retry(
|
||||
stop=stop_after_delay(128),
|
||||
wait=wait_exponential_jitter(2, 128),
|
||||
wait=wait_exponential_jitter(8, 128),
|
||||
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||
reraise=True,
|
||||
|
|
|
|||
|
|
@ -4,10 +4,7 @@ from typing import List
|
|||
from cognee.infrastructure.databases.vector.embeddings.LiteLLMEmbeddingEngine import (
|
||||
LiteLLMEmbeddingEngine,
|
||||
)
|
||||
from cognee.infrastructure.databases.vector.embeddings.embedding_rate_limiter import (
|
||||
embedding_rate_limit_async,
|
||||
embedding_sleep_and_retry_async,
|
||||
)
|
||||
from cognee.shared.rate_limiting import embedding_rate_limiter_context_manager
|
||||
|
||||
|
||||
class MockEmbeddingEngine(LiteLLMEmbeddingEngine):
|
||||
|
|
@ -34,8 +31,6 @@ class MockEmbeddingEngine(LiteLLMEmbeddingEngine):
|
|||
self.fail_every_n_requests = fail_every_n_requests
|
||||
self.add_delay = add_delay
|
||||
|
||||
@embedding_sleep_and_retry_async()
|
||||
@embedding_rate_limit_async
|
||||
async def embed_text(self, text: List[str]) -> List[List[float]]:
|
||||
"""
|
||||
Mock implementation that returns fixed embeddings and can
|
||||
|
|
@ -52,4 +47,5 @@ class MockEmbeddingEngine(LiteLLMEmbeddingEngine):
|
|||
raise Exception(f"Mock failure on request #{self.request_count}")
|
||||
|
||||
# Return mock embeddings of the correct dimension
|
||||
return [[0.1] * self.dimensions for _ in text]
|
||||
async with embedding_rate_limiter_context_manager():
|
||||
return [[0.1] * self.dimensions for _ in text]
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue