refactor: Switch to using tenacity for rate limiting

This commit is contained in:
Igor Ilic 2025-10-15 18:08:18 +02:00
parent d01523e6fc
commit 96496f38ed
6 changed files with 142 additions and 78 deletions

View file

@ -1,19 +1,24 @@
import logging
from typing import Type from typing import Type
from pydantic import BaseModel from pydantic import BaseModel
import litellm
import instructor import instructor
from cognee.shared.logging_utils import get_logger
from tenacity import (
retry,
stop_after_delay,
wait_exponential_jitter,
retry_if_not_exception_type,
before_sleep_log,
)
from cognee.infrastructure.llm.exceptions import MissingSystemPromptPathError
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import ( from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
LLMInterface, LLMInterface,
) )
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.rate_limiter import (
rate_limit_async,
sleep_and_retry_async,
)
from cognee.infrastructure.llm.LLMGateway import LLMGateway
from cognee.infrastructure.llm.config import get_llm_config from cognee.infrastructure.llm.config import get_llm_config
logger = get_logger()
class AnthropicAdapter(LLMInterface): class AnthropicAdapter(LLMInterface):
""" """
@ -35,8 +40,13 @@ class AnthropicAdapter(LLMInterface):
self.model = model self.model = model
self.max_completion_tokens = max_completion_tokens self.max_completion_tokens = max_completion_tokens
@sleep_and_retry_async() @retry(
@rate_limit_async stop=stop_after_delay(128),
wait=wait_exponential_jitter(2, 128),
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
async def acreate_structured_output( async def acreate_structured_output(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel] self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
) -> BaseModel: ) -> BaseModel:

View file

@ -12,11 +12,18 @@ from cognee.infrastructure.llm.exceptions import ContentPolicyFilterError
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import ( from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
LLMInterface, LLMInterface,
) )
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.rate_limiter import ( import logging
rate_limit_async, from cognee.shared.logging_utils import get_logger
sleep_and_retry_async, from tenacity import (
retry,
stop_after_delay,
wait_exponential_jitter,
retry_if_not_exception_type,
before_sleep_log,
) )
logger = get_logger()
class GeminiAdapter(LLMInterface): class GeminiAdapter(LLMInterface):
""" """
@ -58,8 +65,13 @@ class GeminiAdapter(LLMInterface):
self.aclient = instructor.from_litellm(litellm.acompletion, mode=instructor.Mode.JSON) self.aclient = instructor.from_litellm(litellm.acompletion, mode=instructor.Mode.JSON)
@sleep_and_retry_async() @retry(
@rate_limit_async stop=stop_after_delay(128),
wait=wait_exponential_jitter(2, 128),
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
async def acreate_structured_output( async def acreate_structured_output(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel] self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
) -> BaseModel: ) -> BaseModel:

View file

@ -12,11 +12,18 @@ from cognee.infrastructure.llm.exceptions import ContentPolicyFilterError
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import ( from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
LLMInterface, LLMInterface,
) )
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.rate_limiter import ( import logging
rate_limit_async, from cognee.shared.logging_utils import get_logger
sleep_and_retry_async, from tenacity import (
retry,
stop_after_delay,
wait_exponential_jitter,
retry_if_not_exception_type,
before_sleep_log,
) )
logger = get_logger()
class GenericAPIAdapter(LLMInterface): class GenericAPIAdapter(LLMInterface):
""" """
@ -58,8 +65,13 @@ class GenericAPIAdapter(LLMInterface):
self.aclient = instructor.from_litellm(litellm.acompletion, mode=instructor.Mode.JSON) self.aclient = instructor.from_litellm(litellm.acompletion, mode=instructor.Mode.JSON)
@sleep_and_retry_async() @retry(
@rate_limit_async stop=stop_after_delay(128),
wait=wait_exponential_jitter(2, 128),
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
async def acreate_structured_output( async def acreate_structured_output(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel] self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
) -> BaseModel: ) -> BaseModel:

View file

@ -1,20 +1,23 @@
import litellm import litellm
import instructor import instructor
from pydantic import BaseModel from pydantic import BaseModel
from typing import Type, Optional from typing import Type
from litellm import acompletion, JSONSchemaValidationError from litellm import JSONSchemaValidationError
from cognee.shared.logging_utils import get_logger from cognee.shared.logging_utils import get_logger
from cognee.modules.observability.get_observe import get_observe from cognee.modules.observability.get_observe import get_observe
from cognee.infrastructure.llm.exceptions import MissingSystemPromptPathError
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import ( from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
LLMInterface, LLMInterface,
) )
from cognee.infrastructure.llm.LLMGateway import LLMGateway
from cognee.infrastructure.llm.config import get_llm_config from cognee.infrastructure.llm.config import get_llm_config
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.rate_limiter import (
rate_limit_async, import logging
sleep_and_retry_async, from tenacity import (
retry,
stop_after_delay,
wait_exponential_jitter,
retry_if_not_exception_type,
before_sleep_log,
) )
logger = get_logger() logger = get_logger()
@ -47,8 +50,13 @@ class MistralAdapter(LLMInterface):
api_key=get_llm_config().llm_api_key, api_key=get_llm_config().llm_api_key,
) )
@sleep_and_retry_async() @retry(
@rate_limit_async stop=stop_after_delay(128),
wait=wait_exponential_jitter(2, 128),
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
async def acreate_structured_output( async def acreate_structured_output(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel] self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
) -> BaseModel: ) -> BaseModel:
@ -99,31 +107,3 @@ class MistralAdapter(LLMInterface):
logger.error(f"Schema validation failed: {str(e)}") logger.error(f"Schema validation failed: {str(e)}")
logger.debug(f"Raw response: {e.raw_response}") logger.debug(f"Raw response: {e.raw_response}")
raise ValueError(f"Response failed schema validation: {str(e)}") raise ValueError(f"Response failed schema validation: {str(e)}")
def show_prompt(self, text_input: str, system_prompt: str) -> str:
"""
Format and display the prompt for a user query.
Parameters:
-----------
- text_input (str): Input text from the user to be included in the prompt.
- system_prompt (str): The system prompt that will be shown alongside the user input.
Returns:
--------
- str: The formatted prompt string combining system prompt and user input.
"""
if not text_input:
text_input = "No user input provided."
if not system_prompt:
raise MissingSystemPromptPathError()
system_prompt = LLMGateway.read_query_prompt(system_prompt)
formatted_prompt = (
f"""System Prompt:\n{system_prompt}\n\nUser Input:\n{text_input}\n"""
if system_prompt
else None
)
return formatted_prompt

View file

@ -1,4 +1,6 @@
import base64 import base64
import litellm
import logging
import instructor import instructor
from typing import Type from typing import Type
from openai import OpenAI from openai import OpenAI
@ -7,11 +9,17 @@ from pydantic import BaseModel
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import ( from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
LLMInterface, LLMInterface,
) )
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.rate_limiter import (
rate_limit_async,
sleep_and_retry_async,
)
from cognee.infrastructure.files.utils.open_data_file import open_data_file from cognee.infrastructure.files.utils.open_data_file import open_data_file
from cognee.shared.logging_utils import get_logger
from tenacity import (
retry,
stop_after_delay,
wait_exponential_jitter,
retry_if_not_exception_type,
before_sleep_log,
)
logger = get_logger()
class OllamaAPIAdapter(LLMInterface): class OllamaAPIAdapter(LLMInterface):
@ -47,8 +55,13 @@ class OllamaAPIAdapter(LLMInterface):
OpenAI(base_url=self.endpoint, api_key=self.api_key), mode=instructor.Mode.JSON OpenAI(base_url=self.endpoint, api_key=self.api_key), mode=instructor.Mode.JSON
) )
@sleep_and_retry_async() @retry(
@rate_limit_async stop=stop_after_delay(128),
wait=wait_exponential_jitter(2, 128),
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
async def acreate_structured_output( async def acreate_structured_output(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel] self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
) -> BaseModel: ) -> BaseModel:
@ -90,7 +103,13 @@ class OllamaAPIAdapter(LLMInterface):
return response return response
@rate_limit_async @retry(
stop=stop_after_delay(128),
wait=wait_exponential_jitter(2, 128),
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
async def create_transcript(self, input_file: str) -> str: async def create_transcript(self, input_file: str) -> str:
""" """
Generate an audio transcript from a user query. Generate an audio transcript from a user query.
@ -123,7 +142,13 @@ class OllamaAPIAdapter(LLMInterface):
return transcription.text return transcription.text
@rate_limit_async @retry(
stop=stop_after_delay(128),
wait=wait_exponential_jitter(2, 128),
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
async def transcribe_image(self, input_file: str) -> str: async def transcribe_image(self, input_file: str) -> str:
""" """
Transcribe content from an image using base64 encoding. Transcribe content from an image using base64 encoding.

View file

@ -7,6 +7,15 @@ from openai import ContentFilterFinishReasonError
from litellm.exceptions import ContentPolicyViolationError from litellm.exceptions import ContentPolicyViolationError
from instructor.core import InstructorRetryException from instructor.core import InstructorRetryException
import logging
from tenacity import (
retry,
stop_after_delay,
wait_exponential_jitter,
retry_if_not_exception_type,
before_sleep_log,
)
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import ( from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
LLMInterface, LLMInterface,
) )
@ -14,19 +23,13 @@ from cognee.infrastructure.llm.exceptions import (
ContentPolicyFilterError, ContentPolicyFilterError,
) )
from cognee.infrastructure.files.utils.open_data_file import open_data_file from cognee.infrastructure.files.utils.open_data_file import open_data_file
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.rate_limiter import (
rate_limit_async,
rate_limit_sync,
sleep_and_retry_async,
sleep_and_retry_sync,
)
from cognee.modules.observability.get_observe import get_observe from cognee.modules.observability.get_observe import get_observe
from cognee.shared.logging_utils import get_logger from cognee.shared.logging_utils import get_logger
observe = get_observe()
logger = get_logger() logger = get_logger()
observe = get_observe()
class OpenAIAdapter(LLMInterface): class OpenAIAdapter(LLMInterface):
""" """
@ -97,8 +100,13 @@ class OpenAIAdapter(LLMInterface):
self.fallback_endpoint = fallback_endpoint self.fallback_endpoint = fallback_endpoint
@observe(as_type="generation") @observe(as_type="generation")
@sleep_and_retry_async() @retry(
@rate_limit_async stop=stop_after_delay(128),
wait=wait_exponential_jitter(2, 128),
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
async def acreate_structured_output( async def acreate_structured_output(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel] self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
) -> BaseModel: ) -> BaseModel:
@ -186,8 +194,13 @@ class OpenAIAdapter(LLMInterface):
) from error ) from error
@observe @observe
@sleep_and_retry_sync() @retry(
@rate_limit_sync stop=stop_after_delay(128),
wait=wait_exponential_jitter(2, 128),
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
def create_structured_output( def create_structured_output(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel] self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
) -> BaseModel: ) -> BaseModel:
@ -231,7 +244,13 @@ class OpenAIAdapter(LLMInterface):
max_retries=self.MAX_RETRIES, max_retries=self.MAX_RETRIES,
) )
@rate_limit_async @retry(
stop=stop_after_delay(128),
wait=wait_exponential_jitter(2, 128),
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
async def create_transcript(self, input): async def create_transcript(self, input):
""" """
Generate an audio transcript from a user query. Generate an audio transcript from a user query.
@ -263,7 +282,13 @@ class OpenAIAdapter(LLMInterface):
return transcription return transcription
@rate_limit_async @retry(
stop=stop_after_delay(128),
wait=wait_exponential_jitter(2, 128),
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
async def transcribe_image(self, input) -> BaseModel: async def transcribe_image(self, input) -> BaseModel:
""" """
Generate a transcription of an image from a user query. Generate a transcription of an image from a user query.