refactor: Switch to using tenacity for rate limiting
This commit is contained in:
parent
d01523e6fc
commit
96496f38ed
6 changed files with 142 additions and 78 deletions
|
|
@ -1,19 +1,24 @@
|
||||||
|
import logging
|
||||||
from typing import Type
|
from typing import Type
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
import litellm
|
||||||
import instructor
|
import instructor
|
||||||
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
from tenacity import (
|
||||||
|
retry,
|
||||||
|
stop_after_delay,
|
||||||
|
wait_exponential_jitter,
|
||||||
|
retry_if_not_exception_type,
|
||||||
|
before_sleep_log,
|
||||||
|
)
|
||||||
|
|
||||||
from cognee.infrastructure.llm.exceptions import MissingSystemPromptPathError
|
|
||||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
|
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
|
||||||
LLMInterface,
|
LLMInterface,
|
||||||
)
|
)
|
||||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.rate_limiter import (
|
|
||||||
rate_limit_async,
|
|
||||||
sleep_and_retry_async,
|
|
||||||
)
|
|
||||||
|
|
||||||
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
|
||||||
from cognee.infrastructure.llm.config import get_llm_config
|
from cognee.infrastructure.llm.config import get_llm_config
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
class AnthropicAdapter(LLMInterface):
|
class AnthropicAdapter(LLMInterface):
|
||||||
"""
|
"""
|
||||||
|
|
@ -35,8 +40,13 @@ class AnthropicAdapter(LLMInterface):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.max_completion_tokens = max_completion_tokens
|
self.max_completion_tokens = max_completion_tokens
|
||||||
|
|
||||||
@sleep_and_retry_async()
|
@retry(
|
||||||
@rate_limit_async
|
stop=stop_after_delay(128),
|
||||||
|
wait=wait_exponential_jitter(2, 128),
|
||||||
|
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
||||||
|
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||||
|
reraise=True,
|
||||||
|
)
|
||||||
async def acreate_structured_output(
|
async def acreate_structured_output(
|
||||||
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
||||||
) -> BaseModel:
|
) -> BaseModel:
|
||||||
|
|
|
||||||
|
|
@ -12,11 +12,18 @@ from cognee.infrastructure.llm.exceptions import ContentPolicyFilterError
|
||||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
|
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
|
||||||
LLMInterface,
|
LLMInterface,
|
||||||
)
|
)
|
||||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.rate_limiter import (
|
import logging
|
||||||
rate_limit_async,
|
from cognee.shared.logging_utils import get_logger
|
||||||
sleep_and_retry_async,
|
from tenacity import (
|
||||||
|
retry,
|
||||||
|
stop_after_delay,
|
||||||
|
wait_exponential_jitter,
|
||||||
|
retry_if_not_exception_type,
|
||||||
|
before_sleep_log,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
class GeminiAdapter(LLMInterface):
|
class GeminiAdapter(LLMInterface):
|
||||||
"""
|
"""
|
||||||
|
|
@ -58,8 +65,13 @@ class GeminiAdapter(LLMInterface):
|
||||||
|
|
||||||
self.aclient = instructor.from_litellm(litellm.acompletion, mode=instructor.Mode.JSON)
|
self.aclient = instructor.from_litellm(litellm.acompletion, mode=instructor.Mode.JSON)
|
||||||
|
|
||||||
@sleep_and_retry_async()
|
@retry(
|
||||||
@rate_limit_async
|
stop=stop_after_delay(128),
|
||||||
|
wait=wait_exponential_jitter(2, 128),
|
||||||
|
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
||||||
|
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||||
|
reraise=True,
|
||||||
|
)
|
||||||
async def acreate_structured_output(
|
async def acreate_structured_output(
|
||||||
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
||||||
) -> BaseModel:
|
) -> BaseModel:
|
||||||
|
|
|
||||||
|
|
@ -12,11 +12,18 @@ from cognee.infrastructure.llm.exceptions import ContentPolicyFilterError
|
||||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
|
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
|
||||||
LLMInterface,
|
LLMInterface,
|
||||||
)
|
)
|
||||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.rate_limiter import (
|
import logging
|
||||||
rate_limit_async,
|
from cognee.shared.logging_utils import get_logger
|
||||||
sleep_and_retry_async,
|
from tenacity import (
|
||||||
|
retry,
|
||||||
|
stop_after_delay,
|
||||||
|
wait_exponential_jitter,
|
||||||
|
retry_if_not_exception_type,
|
||||||
|
before_sleep_log,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
class GenericAPIAdapter(LLMInterface):
|
class GenericAPIAdapter(LLMInterface):
|
||||||
"""
|
"""
|
||||||
|
|
@ -58,8 +65,13 @@ class GenericAPIAdapter(LLMInterface):
|
||||||
|
|
||||||
self.aclient = instructor.from_litellm(litellm.acompletion, mode=instructor.Mode.JSON)
|
self.aclient = instructor.from_litellm(litellm.acompletion, mode=instructor.Mode.JSON)
|
||||||
|
|
||||||
@sleep_and_retry_async()
|
@retry(
|
||||||
@rate_limit_async
|
stop=stop_after_delay(128),
|
||||||
|
wait=wait_exponential_jitter(2, 128),
|
||||||
|
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
||||||
|
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||||
|
reraise=True,
|
||||||
|
)
|
||||||
async def acreate_structured_output(
|
async def acreate_structured_output(
|
||||||
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
||||||
) -> BaseModel:
|
) -> BaseModel:
|
||||||
|
|
|
||||||
|
|
@ -1,20 +1,23 @@
|
||||||
import litellm
|
import litellm
|
||||||
import instructor
|
import instructor
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import Type, Optional
|
from typing import Type
|
||||||
from litellm import acompletion, JSONSchemaValidationError
|
from litellm import JSONSchemaValidationError
|
||||||
|
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
from cognee.modules.observability.get_observe import get_observe
|
from cognee.modules.observability.get_observe import get_observe
|
||||||
from cognee.infrastructure.llm.exceptions import MissingSystemPromptPathError
|
|
||||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
|
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
|
||||||
LLMInterface,
|
LLMInterface,
|
||||||
)
|
)
|
||||||
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
|
||||||
from cognee.infrastructure.llm.config import get_llm_config
|
from cognee.infrastructure.llm.config import get_llm_config
|
||||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.rate_limiter import (
|
|
||||||
rate_limit_async,
|
import logging
|
||||||
sleep_and_retry_async,
|
from tenacity import (
|
||||||
|
retry,
|
||||||
|
stop_after_delay,
|
||||||
|
wait_exponential_jitter,
|
||||||
|
retry_if_not_exception_type,
|
||||||
|
before_sleep_log,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
@ -47,8 +50,13 @@ class MistralAdapter(LLMInterface):
|
||||||
api_key=get_llm_config().llm_api_key,
|
api_key=get_llm_config().llm_api_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
@sleep_and_retry_async()
|
@retry(
|
||||||
@rate_limit_async
|
stop=stop_after_delay(128),
|
||||||
|
wait=wait_exponential_jitter(2, 128),
|
||||||
|
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
||||||
|
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||||
|
reraise=True,
|
||||||
|
)
|
||||||
async def acreate_structured_output(
|
async def acreate_structured_output(
|
||||||
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
||||||
) -> BaseModel:
|
) -> BaseModel:
|
||||||
|
|
@ -99,31 +107,3 @@ class MistralAdapter(LLMInterface):
|
||||||
logger.error(f"Schema validation failed: {str(e)}")
|
logger.error(f"Schema validation failed: {str(e)}")
|
||||||
logger.debug(f"Raw response: {e.raw_response}")
|
logger.debug(f"Raw response: {e.raw_response}")
|
||||||
raise ValueError(f"Response failed schema validation: {str(e)}")
|
raise ValueError(f"Response failed schema validation: {str(e)}")
|
||||||
|
|
||||||
def show_prompt(self, text_input: str, system_prompt: str) -> str:
|
|
||||||
"""
|
|
||||||
Format and display the prompt for a user query.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
-----------
|
|
||||||
- text_input (str): Input text from the user to be included in the prompt.
|
|
||||||
- system_prompt (str): The system prompt that will be shown alongside the user input.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
--------
|
|
||||||
- str: The formatted prompt string combining system prompt and user input.
|
|
||||||
"""
|
|
||||||
if not text_input:
|
|
||||||
text_input = "No user input provided."
|
|
||||||
if not system_prompt:
|
|
||||||
raise MissingSystemPromptPathError()
|
|
||||||
|
|
||||||
system_prompt = LLMGateway.read_query_prompt(system_prompt)
|
|
||||||
|
|
||||||
formatted_prompt = (
|
|
||||||
f"""System Prompt:\n{system_prompt}\n\nUser Input:\n{text_input}\n"""
|
|
||||||
if system_prompt
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
return formatted_prompt
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,6 @@
|
||||||
import base64
|
import base64
|
||||||
|
import litellm
|
||||||
|
import logging
|
||||||
import instructor
|
import instructor
|
||||||
from typing import Type
|
from typing import Type
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
@ -7,11 +9,17 @@ from pydantic import BaseModel
|
||||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
|
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
|
||||||
LLMInterface,
|
LLMInterface,
|
||||||
)
|
)
|
||||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.rate_limiter import (
|
|
||||||
rate_limit_async,
|
|
||||||
sleep_and_retry_async,
|
|
||||||
)
|
|
||||||
from cognee.infrastructure.files.utils.open_data_file import open_data_file
|
from cognee.infrastructure.files.utils.open_data_file import open_data_file
|
||||||
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
from tenacity import (
|
||||||
|
retry,
|
||||||
|
stop_after_delay,
|
||||||
|
wait_exponential_jitter,
|
||||||
|
retry_if_not_exception_type,
|
||||||
|
before_sleep_log,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
class OllamaAPIAdapter(LLMInterface):
|
class OllamaAPIAdapter(LLMInterface):
|
||||||
|
|
@ -47,8 +55,13 @@ class OllamaAPIAdapter(LLMInterface):
|
||||||
OpenAI(base_url=self.endpoint, api_key=self.api_key), mode=instructor.Mode.JSON
|
OpenAI(base_url=self.endpoint, api_key=self.api_key), mode=instructor.Mode.JSON
|
||||||
)
|
)
|
||||||
|
|
||||||
@sleep_and_retry_async()
|
@retry(
|
||||||
@rate_limit_async
|
stop=stop_after_delay(128),
|
||||||
|
wait=wait_exponential_jitter(2, 128),
|
||||||
|
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
||||||
|
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||||
|
reraise=True,
|
||||||
|
)
|
||||||
async def acreate_structured_output(
|
async def acreate_structured_output(
|
||||||
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
||||||
) -> BaseModel:
|
) -> BaseModel:
|
||||||
|
|
@ -90,7 +103,13 @@ class OllamaAPIAdapter(LLMInterface):
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
@rate_limit_async
|
@retry(
|
||||||
|
stop=stop_after_delay(128),
|
||||||
|
wait=wait_exponential_jitter(2, 128),
|
||||||
|
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
||||||
|
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||||
|
reraise=True,
|
||||||
|
)
|
||||||
async def create_transcript(self, input_file: str) -> str:
|
async def create_transcript(self, input_file: str) -> str:
|
||||||
"""
|
"""
|
||||||
Generate an audio transcript from a user query.
|
Generate an audio transcript from a user query.
|
||||||
|
|
@ -123,7 +142,13 @@ class OllamaAPIAdapter(LLMInterface):
|
||||||
|
|
||||||
return transcription.text
|
return transcription.text
|
||||||
|
|
||||||
@rate_limit_async
|
@retry(
|
||||||
|
stop=stop_after_delay(128),
|
||||||
|
wait=wait_exponential_jitter(2, 128),
|
||||||
|
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
||||||
|
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||||
|
reraise=True,
|
||||||
|
)
|
||||||
async def transcribe_image(self, input_file: str) -> str:
|
async def transcribe_image(self, input_file: str) -> str:
|
||||||
"""
|
"""
|
||||||
Transcribe content from an image using base64 encoding.
|
Transcribe content from an image using base64 encoding.
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,15 @@ from openai import ContentFilterFinishReasonError
|
||||||
from litellm.exceptions import ContentPolicyViolationError
|
from litellm.exceptions import ContentPolicyViolationError
|
||||||
from instructor.core import InstructorRetryException
|
from instructor.core import InstructorRetryException
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from tenacity import (
|
||||||
|
retry,
|
||||||
|
stop_after_delay,
|
||||||
|
wait_exponential_jitter,
|
||||||
|
retry_if_not_exception_type,
|
||||||
|
before_sleep_log,
|
||||||
|
)
|
||||||
|
|
||||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
|
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
|
||||||
LLMInterface,
|
LLMInterface,
|
||||||
)
|
)
|
||||||
|
|
@ -14,19 +23,13 @@ from cognee.infrastructure.llm.exceptions import (
|
||||||
ContentPolicyFilterError,
|
ContentPolicyFilterError,
|
||||||
)
|
)
|
||||||
from cognee.infrastructure.files.utils.open_data_file import open_data_file
|
from cognee.infrastructure.files.utils.open_data_file import open_data_file
|
||||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.rate_limiter import (
|
|
||||||
rate_limit_async,
|
|
||||||
rate_limit_sync,
|
|
||||||
sleep_and_retry_async,
|
|
||||||
sleep_and_retry_sync,
|
|
||||||
)
|
|
||||||
from cognee.modules.observability.get_observe import get_observe
|
from cognee.modules.observability.get_observe import get_observe
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
|
||||||
observe = get_observe()
|
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
|
observe = get_observe()
|
||||||
|
|
||||||
|
|
||||||
class OpenAIAdapter(LLMInterface):
|
class OpenAIAdapter(LLMInterface):
|
||||||
"""
|
"""
|
||||||
|
|
@ -97,8 +100,13 @@ class OpenAIAdapter(LLMInterface):
|
||||||
self.fallback_endpoint = fallback_endpoint
|
self.fallback_endpoint = fallback_endpoint
|
||||||
|
|
||||||
@observe(as_type="generation")
|
@observe(as_type="generation")
|
||||||
@sleep_and_retry_async()
|
@retry(
|
||||||
@rate_limit_async
|
stop=stop_after_delay(128),
|
||||||
|
wait=wait_exponential_jitter(2, 128),
|
||||||
|
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
||||||
|
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||||
|
reraise=True,
|
||||||
|
)
|
||||||
async def acreate_structured_output(
|
async def acreate_structured_output(
|
||||||
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
||||||
) -> BaseModel:
|
) -> BaseModel:
|
||||||
|
|
@ -186,8 +194,13 @@ class OpenAIAdapter(LLMInterface):
|
||||||
) from error
|
) from error
|
||||||
|
|
||||||
@observe
|
@observe
|
||||||
@sleep_and_retry_sync()
|
@retry(
|
||||||
@rate_limit_sync
|
stop=stop_after_delay(128),
|
||||||
|
wait=wait_exponential_jitter(2, 128),
|
||||||
|
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
||||||
|
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||||
|
reraise=True,
|
||||||
|
)
|
||||||
def create_structured_output(
|
def create_structured_output(
|
||||||
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
||||||
) -> BaseModel:
|
) -> BaseModel:
|
||||||
|
|
@ -231,7 +244,13 @@ class OpenAIAdapter(LLMInterface):
|
||||||
max_retries=self.MAX_RETRIES,
|
max_retries=self.MAX_RETRIES,
|
||||||
)
|
)
|
||||||
|
|
||||||
@rate_limit_async
|
@retry(
|
||||||
|
stop=stop_after_delay(128),
|
||||||
|
wait=wait_exponential_jitter(2, 128),
|
||||||
|
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
||||||
|
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||||
|
reraise=True,
|
||||||
|
)
|
||||||
async def create_transcript(self, input):
|
async def create_transcript(self, input):
|
||||||
"""
|
"""
|
||||||
Generate an audio transcript from a user query.
|
Generate an audio transcript from a user query.
|
||||||
|
|
@ -263,7 +282,13 @@ class OpenAIAdapter(LLMInterface):
|
||||||
|
|
||||||
return transcription
|
return transcription
|
||||||
|
|
||||||
@rate_limit_async
|
@retry(
|
||||||
|
stop=stop_after_delay(128),
|
||||||
|
wait=wait_exponential_jitter(2, 128),
|
||||||
|
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
||||||
|
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||||
|
reraise=True,
|
||||||
|
)
|
||||||
async def transcribe_image(self, input) -> BaseModel:
|
async def transcribe_image(self, input) -> BaseModel:
|
||||||
"""
|
"""
|
||||||
Generate a transcription of an image from a user query.
|
Generate a transcription of an image from a user query.
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue