refactor: Switch to using tenacity for rate limiting

This commit is contained in:
Igor Ilic 2025-10-15 18:08:18 +02:00
parent d01523e6fc
commit 96496f38ed
6 changed files with 142 additions and 78 deletions

View file

@ -1,19 +1,24 @@
import logging
from typing import Type
from pydantic import BaseModel
import litellm
import instructor
from cognee.shared.logging_utils import get_logger
from tenacity import (
retry,
stop_after_delay,
wait_exponential_jitter,
retry_if_not_exception_type,
before_sleep_log,
)
from cognee.infrastructure.llm.exceptions import MissingSystemPromptPathError
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
LLMInterface,
)
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.rate_limiter import (
rate_limit_async,
sleep_and_retry_async,
)
from cognee.infrastructure.llm.LLMGateway import LLMGateway
from cognee.infrastructure.llm.config import get_llm_config
logger = get_logger()
class AnthropicAdapter(LLMInterface):
"""
@ -35,8 +40,13 @@ class AnthropicAdapter(LLMInterface):
self.model = model
self.max_completion_tokens = max_completion_tokens
@sleep_and_retry_async()
@rate_limit_async
@retry(
stop=stop_after_delay(128),
wait=wait_exponential_jitter(2, 128),
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
async def acreate_structured_output(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
) -> BaseModel:

View file

@ -12,11 +12,18 @@ from cognee.infrastructure.llm.exceptions import ContentPolicyFilterError
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
LLMInterface,
)
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.rate_limiter import (
rate_limit_async,
sleep_and_retry_async,
import logging
from cognee.shared.logging_utils import get_logger
from tenacity import (
retry,
stop_after_delay,
wait_exponential_jitter,
retry_if_not_exception_type,
before_sleep_log,
)
logger = get_logger()
class GeminiAdapter(LLMInterface):
"""
@ -58,8 +65,13 @@ class GeminiAdapter(LLMInterface):
self.aclient = instructor.from_litellm(litellm.acompletion, mode=instructor.Mode.JSON)
@sleep_and_retry_async()
@rate_limit_async
@retry(
stop=stop_after_delay(128),
wait=wait_exponential_jitter(2, 128),
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
async def acreate_structured_output(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
) -> BaseModel:

View file

@ -12,11 +12,18 @@ from cognee.infrastructure.llm.exceptions import ContentPolicyFilterError
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
LLMInterface,
)
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.rate_limiter import (
rate_limit_async,
sleep_and_retry_async,
import logging
from cognee.shared.logging_utils import get_logger
from tenacity import (
retry,
stop_after_delay,
wait_exponential_jitter,
retry_if_not_exception_type,
before_sleep_log,
)
logger = get_logger()
class GenericAPIAdapter(LLMInterface):
"""
@ -58,8 +65,13 @@ class GenericAPIAdapter(LLMInterface):
self.aclient = instructor.from_litellm(litellm.acompletion, mode=instructor.Mode.JSON)
@sleep_and_retry_async()
@rate_limit_async
@retry(
stop=stop_after_delay(128),
wait=wait_exponential_jitter(2, 128),
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
async def acreate_structured_output(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
) -> BaseModel:

View file

@ -1,20 +1,23 @@
import litellm
import instructor
from pydantic import BaseModel
from typing import Type, Optional
from litellm import acompletion, JSONSchemaValidationError
from typing import Type
from litellm import JSONSchemaValidationError
from cognee.shared.logging_utils import get_logger
from cognee.modules.observability.get_observe import get_observe
from cognee.infrastructure.llm.exceptions import MissingSystemPromptPathError
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
LLMInterface,
)
from cognee.infrastructure.llm.LLMGateway import LLMGateway
from cognee.infrastructure.llm.config import get_llm_config
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.rate_limiter import (
rate_limit_async,
sleep_and_retry_async,
import logging
from tenacity import (
retry,
stop_after_delay,
wait_exponential_jitter,
retry_if_not_exception_type,
before_sleep_log,
)
logger = get_logger()
@ -47,8 +50,13 @@ class MistralAdapter(LLMInterface):
api_key=get_llm_config().llm_api_key,
)
@sleep_and_retry_async()
@rate_limit_async
@retry(
stop=stop_after_delay(128),
wait=wait_exponential_jitter(2, 128),
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
async def acreate_structured_output(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
) -> BaseModel:
@ -99,31 +107,3 @@ class MistralAdapter(LLMInterface):
logger.error(f"Schema validation failed: {str(e)}")
logger.debug(f"Raw response: {e.raw_response}")
raise ValueError(f"Response failed schema validation: {str(e)}")
def show_prompt(self, text_input: str, system_prompt: str) -> str:
"""
Format and display the prompt for a user query.
Parameters:
-----------
- text_input (str): Input text from the user to be included in the prompt.
- system_prompt (str): The system prompt that will be shown alongside the user input.
Returns:
--------
- str: The formatted prompt string combining system prompt and user input.
"""
if not text_input:
text_input = "No user input provided."
if not system_prompt:
raise MissingSystemPromptPathError()
system_prompt = LLMGateway.read_query_prompt(system_prompt)
formatted_prompt = (
f"""System Prompt:\n{system_prompt}\n\nUser Input:\n{text_input}\n"""
if system_prompt
else None
)
return formatted_prompt

View file

@ -1,4 +1,6 @@
import base64
import litellm
import logging
import instructor
from typing import Type
from openai import OpenAI
@ -7,11 +9,17 @@ from pydantic import BaseModel
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
LLMInterface,
)
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.rate_limiter import (
rate_limit_async,
sleep_and_retry_async,
)
from cognee.infrastructure.files.utils.open_data_file import open_data_file
from cognee.shared.logging_utils import get_logger
from tenacity import (
retry,
stop_after_delay,
wait_exponential_jitter,
retry_if_not_exception_type,
before_sleep_log,
)
logger = get_logger()
class OllamaAPIAdapter(LLMInterface):
@ -47,8 +55,13 @@ class OllamaAPIAdapter(LLMInterface):
OpenAI(base_url=self.endpoint, api_key=self.api_key), mode=instructor.Mode.JSON
)
@sleep_and_retry_async()
@rate_limit_async
@retry(
stop=stop_after_delay(128),
wait=wait_exponential_jitter(2, 128),
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
async def acreate_structured_output(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
) -> BaseModel:
@ -90,7 +103,13 @@ class OllamaAPIAdapter(LLMInterface):
return response
@rate_limit_async
@retry(
stop=stop_after_delay(128),
wait=wait_exponential_jitter(2, 128),
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
async def create_transcript(self, input_file: str) -> str:
"""
Generate an audio transcript from a user query.
@ -123,7 +142,13 @@ class OllamaAPIAdapter(LLMInterface):
return transcription.text
@rate_limit_async
@retry(
stop=stop_after_delay(128),
wait=wait_exponential_jitter(2, 128),
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
async def transcribe_image(self, input_file: str) -> str:
"""
Transcribe content from an image using base64 encoding.

View file

@ -7,6 +7,15 @@ from openai import ContentFilterFinishReasonError
from litellm.exceptions import ContentPolicyViolationError
from instructor.core import InstructorRetryException
import logging
from tenacity import (
retry,
stop_after_delay,
wait_exponential_jitter,
retry_if_not_exception_type,
before_sleep_log,
)
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
LLMInterface,
)
@ -14,19 +23,13 @@ from cognee.infrastructure.llm.exceptions import (
ContentPolicyFilterError,
)
from cognee.infrastructure.files.utils.open_data_file import open_data_file
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.rate_limiter import (
rate_limit_async,
rate_limit_sync,
sleep_and_retry_async,
sleep_and_retry_sync,
)
from cognee.modules.observability.get_observe import get_observe
from cognee.shared.logging_utils import get_logger
observe = get_observe()
logger = get_logger()
observe = get_observe()
class OpenAIAdapter(LLMInterface):
"""
@ -97,8 +100,13 @@ class OpenAIAdapter(LLMInterface):
self.fallback_endpoint = fallback_endpoint
@observe(as_type="generation")
@sleep_and_retry_async()
@rate_limit_async
@retry(
stop=stop_after_delay(128),
wait=wait_exponential_jitter(2, 128),
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
async def acreate_structured_output(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
) -> BaseModel:
@ -186,8 +194,13 @@ class OpenAIAdapter(LLMInterface):
) from error
@observe
@sleep_and_retry_sync()
@rate_limit_sync
@retry(
stop=stop_after_delay(128),
wait=wait_exponential_jitter(2, 128),
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
def create_structured_output(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
) -> BaseModel:
@ -231,7 +244,13 @@ class OpenAIAdapter(LLMInterface):
max_retries=self.MAX_RETRIES,
)
@rate_limit_async
@retry(
stop=stop_after_delay(128),
wait=wait_exponential_jitter(2, 128),
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
async def create_transcript(self, input):
"""
Generate an audio transcript from a user query.
@ -263,7 +282,13 @@ class OpenAIAdapter(LLMInterface):
return transcription
@rate_limit_async
@retry(
stop=stop_after_delay(128),
wait=wait_exponential_jitter(2, 128),
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
async def transcribe_image(self, input) -> BaseModel:
"""
Generate a transcription of an image from a user query.