Adding support for audio/image transcription for all other providers

This commit is contained in:
rajeevrajeshuni 2025-11-25 12:22:15 +05:30
parent c2c64a417c
commit 02b1778658
10 changed files with 313 additions and 312 deletions

View file

@ -34,19 +34,6 @@ class LLMGateway:
text_input=text_input, system_prompt=system_prompt, response_model=response_model
)
@staticmethod
def create_structured_output(
text_input: str, system_prompt: str, response_model: Type[BaseModel]
) -> BaseModel:
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.get_llm_client import (
get_llm_client,
)
llm_client = get_llm_client()
return llm_client.create_structured_output(
text_input=text_input, system_prompt=system_prompt, response_model=response_model
)
@staticmethod
def create_transcript(input) -> Coroutine:
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.get_llm_client import (

View file

@ -3,7 +3,9 @@ from typing import Type
from pydantic import BaseModel
import litellm
import instructor
import anthropic
from cognee.shared.logging_utils import get_logger
from cognee.modules.observability.get_observe import get_observe
from tenacity import (
retry,
stop_after_delay,
@ -12,27 +14,32 @@ from tenacity import (
before_sleep_log,
)
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
LLMInterface,
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.generic_llm_api.adapter import (
GenericAPIAdapter,
)
from cognee.infrastructure.llm.config import get_llm_config
logger = get_logger()
observe = get_observe()
class AnthropicAdapter(LLMInterface):
class AnthropicAdapter(GenericAPIAdapter):
"""
Adapter for interfacing with the Anthropic API, enabling structured output generation
and prompt display.
"""
name = "Anthropic"
model: str
default_instructor_mode = "anthropic_tools"
def __init__(self, max_completion_tokens: int, model: str = None, instructor_mode: str = None):
import anthropic
def __init__(
self, api_key: str, model: str, max_completion_tokens: int, instructor_mode: str = None
):
super().__init__(
api_key=api_key,
model=model,
max_completion_tokens=max_completion_tokens,
name="Anthropic",
)
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
self.aclient = instructor.patch(
@ -40,9 +47,7 @@ class AnthropicAdapter(LLMInterface):
mode=instructor.Mode(self.instructor_mode),
)
self.model = model
self.max_completion_tokens = max_completion_tokens
@observe(as_type="generation")
@retry(
stop=stop_after_delay(128),
wait=wait_exponential_jitter(2, 128),

View file

@ -1,4 +1,4 @@
"""Adapter for Generic API LLM provider API"""
"""Adapter for Gemini API LLM provider"""
import litellm
import instructor
@ -8,12 +8,7 @@ from openai import ContentFilterFinishReasonError
from litellm.exceptions import ContentPolicyViolationError
from instructor.core import InstructorRetryException
from cognee.infrastructure.llm.exceptions import ContentPolicyFilterError
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
LLMInterface,
)
import logging
from cognee.shared.logging_utils import get_logger
from tenacity import (
retry,
stop_after_delay,
@ -22,55 +17,65 @@ from tenacity import (
before_sleep_log,
)
from cognee.infrastructure.llm.exceptions import ContentPolicyFilterError
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.generic_llm_api.adapter import (
GenericAPIAdapter,
)
from cognee.shared.logging_utils import get_logger
from cognee.modules.observability.get_observe import get_observe
logger = get_logger()
observe = get_observe()
class GeminiAdapter(LLMInterface):
class GeminiAdapter(GenericAPIAdapter):
"""
Adapter for Gemini API LLM provider.
This class initializes the API adapter with necessary credentials and configurations for
interacting with the gemini LLM models. It provides methods for creating structured outputs
based on user input and system prompts.
based on user input and system prompts, as well as multimodal processing capabilities.
Public methods:
- acreate_structured_output(text_input: str, system_prompt: str, response_model:
Type[BaseModel]) -> BaseModel
- acreate_structured_output(text_input: str, system_prompt: str, response_model: Type[BaseModel]) -> BaseModel
- create_transcript(input) -> BaseModel: Transcribe audio files to text
- transcribe_image(input) -> BaseModel: Inherited from GenericAPIAdapter
"""
name: str
model: str
api_key: str
default_instructor_mode = "json_mode"
def __init__(
self,
endpoint,
api_key: str,
model: str,
api_version: str,
max_completion_tokens: int,
endpoint: str = None,
api_version: str = None,
transcription_model: str = None,
instructor_mode: str = None,
fallback_model: str = None,
fallback_api_key: str = None,
fallback_endpoint: str = None,
):
self.model = model
self.api_key = api_key
self.endpoint = endpoint
self.api_version = api_version
self.max_completion_tokens = max_completion_tokens
self.fallback_model = fallback_model
self.fallback_api_key = fallback_api_key
self.fallback_endpoint = fallback_endpoint
super().__init__(
api_key=api_key,
model=model,
max_completion_tokens=max_completion_tokens,
name="Gemini",
endpoint=endpoint,
api_version=api_version,
transcription_model=transcription_model,
fallback_model=fallback_model,
fallback_api_key=fallback_api_key,
fallback_endpoint=fallback_endpoint,
)
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
self.aclient = instructor.from_litellm(
litellm.acompletion, mode=instructor.Mode(self.instructor_mode)
)
@observe(as_type="generation")
@retry(
stop=stop_after_delay(128),
wait=wait_exponential_jitter(2, 128),
@ -118,7 +123,7 @@ class GeminiAdapter(LLMInterface):
},
],
api_key=self.api_key,
max_retries=5,
max_retries=self.MAX_RETRIES,
api_base=self.endpoint,
api_version=self.api_version,
response_model=response_model,
@ -152,7 +157,7 @@ class GeminiAdapter(LLMInterface):
"content": system_prompt,
},
],
max_retries=5,
max_retries=self.MAX_RETRIES,
api_key=self.fallback_api_key,
api_base=self.fallback_endpoint,
response_model=response_model,

View file

@ -1,8 +1,10 @@
"""Adapter for Generic API LLM provider API"""
import base64
import mimetypes
import litellm
import instructor
from typing import Type
from typing import Type, Optional
from pydantic import BaseModel
from openai import ContentFilterFinishReasonError
from litellm.exceptions import ContentPolicyViolationError
@ -12,6 +14,8 @@ from cognee.infrastructure.llm.exceptions import ContentPolicyFilterError
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
LLMInterface,
)
from cognee.infrastructure.files.utils.open_data_file import open_data_file
from cognee.modules.observability.get_observe import get_observe
import logging
from cognee.shared.logging_utils import get_logger
from tenacity import (
@ -23,6 +27,7 @@ from tenacity import (
)
logger = get_logger()
observe = get_observe()
class GenericAPIAdapter(LLMInterface):
@ -38,18 +43,19 @@ class GenericAPIAdapter(LLMInterface):
Type[BaseModel]) -> BaseModel
"""
name: str
model: str
api_key: str
MAX_RETRIES = 5
default_instructor_mode = "json_mode"
def __init__(
self,
endpoint,
api_key: str,
model: str,
name: str,
max_completion_tokens: int,
name: str,
endpoint: str = None,
api_version: str = None,
transcription_model: str = None,
image_transcribe_model: str = None,
instructor_mode: str = None,
fallback_model: str = None,
fallback_api_key: str = None,
@ -58,9 +64,11 @@ class GenericAPIAdapter(LLMInterface):
self.name = name
self.model = model
self.api_key = api_key
self.api_version = api_version
self.endpoint = endpoint
self.max_completion_tokens = max_completion_tokens
self.transcription_model = transcription_model or model
self.image_transcribe_model = image_transcribe_model or model
self.fallback_model = fallback_model
self.fallback_api_key = fallback_api_key
self.fallback_endpoint = fallback_endpoint
@ -71,6 +79,7 @@ class GenericAPIAdapter(LLMInterface):
litellm.acompletion, mode=instructor.Mode(self.instructor_mode)
)
@observe(as_type="generation")
@retry(
stop=stop_after_delay(128),
wait=wait_exponential_jitter(2, 128),
@ -170,3 +179,112 @@ class GenericAPIAdapter(LLMInterface):
raise ContentPolicyFilterError(
f"The provided input contains content that is not aligned with our content policy: {text_input}"
) from error
@observe(as_type="transcription")
@retry(
stop=stop_after_delay(128),
wait=wait_exponential_jitter(2, 128),
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
async def create_transcript(self, input) -> Optional[BaseModel]:
"""
Generate an audio transcript from a user query.
This method creates a transcript from the specified audio file, raising a
FileNotFoundError if the file does not exist. The audio file is processed and the
transcription is retrieved from the API.
Parameters:
-----------
- input: The path to the audio file that needs to be transcribed.
Returns:
--------
The generated transcription of the audio file.
"""
async with open_data_file(input, mode="rb") as audio_file:
encoded_string = base64.b64encode(audio_file.read()).decode("utf-8")
mime_type, _ = mimetypes.guess_type(input)
if not mime_type or not mime_type.startswith("audio/"):
raise ValueError(
f"Could not determine MIME type for audio file: {input}. Is the extension correct?"
)
return litellm.completion(
model=self.transcription_model,
messages=[
{
"role": "user",
"content": [
{
"type": "file",
"file": {"file_data": f"data:{mime_type};base64,{encoded_string}"},
},
{"type": "text", "text": "Transcribe the following audio precisely."},
],
}
],
api_key=self.api_key,
api_version=self.api_version,
max_completion_tokens=self.max_completion_tokens,
api_base=self.endpoint,
max_retries=self.MAX_RETRIES,
)
@observe(as_type="transcribe_image")
@retry(
stop=stop_after_delay(128),
wait=wait_exponential_jitter(2, 128),
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
async def transcribe_image(self, input) -> Optional[BaseModel]:
"""
Generate a transcription of an image from a user query.
This method encodes the image and sends a request to the API to obtain a
description of the contents of the image.
Parameters:
-----------
- input: The path to the image file that needs to be transcribed.
Returns:
--------
- BaseModel: A structured output generated by the model, returned as an instance of
BaseModel.
"""
async with open_data_file(input, mode="rb") as image_file:
encoded_image = base64.b64encode(image_file.read()).decode("utf-8")
mime_type, _ = mimetypes.guess_type(input)
if not mime_type or not mime_type.startswith("image/"):
raise ValueError(
f"Could not determine MIME type for image file: {input}. Is the extension correct?"
)
return litellm.completion(
model=self.image_transcribe_model,
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": "What's in this image?",
},
{
"type": "image_url",
"image_url": {
"url": f"data:{mime_type};base64,{encoded_image}",
},
},
],
}
],
api_key=self.api_key,
api_base=self.endpoint,
api_version=self.api_version,
max_completion_tokens=300,
max_retries=self.MAX_RETRIES,
)

View file

@ -97,11 +97,10 @@ def get_llm_client(raise_api_key_error: bool = True):
)
return OllamaAPIAdapter(
llm_config.llm_endpoint,
llm_config.llm_api_key,
llm_config.llm_model,
"Ollama",
max_completion_tokens=max_completion_tokens,
max_completion_tokens,
llm_config.llm_endpoint,
instructor_mode=llm_config.llm_instructor_mode.lower(),
)
@ -111,8 +110,9 @@ def get_llm_client(raise_api_key_error: bool = True):
)
return AnthropicAdapter(
max_completion_tokens=max_completion_tokens,
model=llm_config.llm_model,
llm_config.llm_api_key,
llm_config.llm_model,
max_completion_tokens,
instructor_mode=llm_config.llm_instructor_mode.lower(),
)
@ -125,11 +125,10 @@ def get_llm_client(raise_api_key_error: bool = True):
)
return GenericAPIAdapter(
llm_config.llm_endpoint,
llm_config.llm_api_key,
llm_config.llm_model,
max_completion_tokens,
"Custom",
max_completion_tokens=max_completion_tokens,
instructor_mode=llm_config.llm_instructor_mode.lower(),
fallback_api_key=llm_config.fallback_api_key,
fallback_endpoint=llm_config.fallback_endpoint,

View file

@ -1,6 +1,6 @@
"""LLM Interface"""
from typing import Type, Protocol
from typing import Type, Protocol, Optional
from abc import abstractmethod
from pydantic import BaseModel
from cognee.infrastructure.llm.LLMGateway import LLMGateway
@ -8,13 +8,12 @@ from cognee.infrastructure.llm.LLMGateway import LLMGateway
class LLMInterface(Protocol):
"""
Define an interface for LLM models with methods for structured output and prompt
display.
Define an interface for LLM models with methods for structured output, multimodal processing, and prompt display.
Methods:
- acreate_structured_output(text_input: str, system_prompt: str, response_model:
Type[BaseModel])
- show_prompt(text_input: str, system_prompt: str)
- acreate_structured_output(text_input: str, system_prompt: str, response_model: Type[BaseModel])
- create_transcript(input): Transcribe audio files to text
- transcribe_image(input): Analyze image files and return text description
"""
@abstractmethod
@ -36,3 +35,39 @@ class LLMInterface(Protocol):
output.
"""
raise NotImplementedError
@abstractmethod
async def create_transcript(self, input) -> Optional[BaseModel]:
"""
Transcribe audio content to text.
This method should be implemented by subclasses that support audio transcription.
If not implemented, returns None and should be handled gracefully by callers.
Parameters:
-----------
- input: The path to the audio file that needs to be transcribed.
Returns:
--------
- BaseModel: A structured output containing the transcription, or None if not supported.
"""
raise NotImplementedError
@abstractmethod
async def transcribe_image(self, input) -> Optional[BaseModel]:
"""
Analyze image content and return text description.
This method should be implemented by subclasses that support image analysis.
If not implemented, returns None and should be handled gracefully by callers.
Parameters:
-----------
- input: The path to the image file that needs to be analyzed.
Returns:
--------
- BaseModel: A structured output containing the image description, or None if not supported.
"""
raise NotImplementedError

View file

@ -2,12 +2,12 @@ import litellm
import instructor
from pydantic import BaseModel
from typing import Type
from litellm import JSONSchemaValidationError
from litellm import JSONSchemaValidationError, transcription
from cognee.shared.logging_utils import get_logger
from cognee.modules.observability.get_observe import get_observe
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
LLMInterface,
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.generic_llm_api.adapter import (
GenericAPIAdapter,
)
from cognee.infrastructure.llm.config import get_llm_config
@ -19,12 +19,13 @@ from tenacity import (
retry_if_not_exception_type,
before_sleep_log,
)
from mistralai import Mistral
logger = get_logger()
observe = get_observe()
class MistralAdapter(LLMInterface):
class MistralAdapter(GenericAPIAdapter):
"""
Adapter for Mistral AI API, for structured output generation and prompt display.
@ -33,10 +34,6 @@ class MistralAdapter(LLMInterface):
- show_prompt
"""
name = "Mistral"
model: str
api_key: str
max_completion_tokens: int
default_instructor_mode = "mistral_tools"
def __init__(
@ -45,12 +42,21 @@ class MistralAdapter(LLMInterface):
model: str,
max_completion_tokens: int,
endpoint: str = None,
transcription_model: str = None,
image_transcribe_model: str = None,
instructor_mode: str = None,
):
from mistralai import Mistral
self.model = model
self.max_completion_tokens = max_completion_tokens
super().__init__(
api_key=api_key,
model=model,
max_completion_tokens=max_completion_tokens,
name="Mistral",
endpoint=endpoint,
transcription_model=transcription_model,
image_transcribe_model=image_transcribe_model,
)
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
@ -60,6 +66,7 @@ class MistralAdapter(LLMInterface):
api_key=get_llm_config().llm_api_key,
)
@observe(as_type="generation")
@retry(
stop=stop_after_delay(128),
wait=wait_exponential_jitter(2, 128),
@ -117,3 +124,42 @@ class MistralAdapter(LLMInterface):
logger.error(f"Schema validation failed: {str(e)}")
logger.debug(f"Raw response: {e.raw_response}")
raise ValueError(f"Response failed schema validation: {str(e)}")
@observe(as_type="transcription")
@retry(
stop=stop_after_delay(128),
wait=wait_exponential_jitter(2, 128),
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
async def create_transcript(self, input):
"""
Generate an audio transcript from a user query.
This method creates a transcript from the specified audio file.
The audio file is processed and the transcription is retrieved from the API.
Parameters:
-----------
- input: The path to the audio file that needs to be transcribed.
Returns:
--------
The generated transcription of the audio file.
"""
transcription_model = self.transcription_model
if self.transcription_model.startswith("mistral"):
transcription_model = self.transcription_model.split("/")[-1]
file_name = input.split("/")[-1]
client = Mistral(api_key=self.api_key)
with open(input, "rb") as f:
transcription_response = client.audio.transcriptions.complete(
model=transcription_model,
file={
"content": f,
"file_name": file_name,
},
)
# TODO: We need to standardize return type of create_transcript across different models.
return transcription_response

View file

@ -5,12 +5,12 @@ import instructor
from typing import Type
from openai import OpenAI
from pydantic import BaseModel
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
LLMInterface,
)
from cognee.infrastructure.files.utils.open_data_file import open_data_file
from cognee.shared.logging_utils import get_logger
from cognee.modules.observability.get_observe import get_observe
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.generic_llm_api.adapter import (
GenericAPIAdapter,
)
from tenacity import (
retry,
stop_after_delay,
@ -20,9 +20,10 @@ from tenacity import (
)
logger = get_logger()
observe = get_observe()
class OllamaAPIAdapter(LLMInterface):
class OllamaAPIAdapter(GenericAPIAdapter):
"""
Adapter for a Generic API LLM provider using instructor with an OpenAI backend.
@ -46,18 +47,20 @@ class OllamaAPIAdapter(LLMInterface):
def __init__(
self,
endpoint: str,
api_key: str,
model: str,
name: str,
max_completion_tokens: int,
endpoint: str,
instructor_mode: str = None,
):
self.name = name
self.model = model
self.api_key = api_key
self.endpoint = endpoint
self.max_completion_tokens = max_completion_tokens
super().__init__(
api_key=api_key,
model=model,
max_completion_tokens=max_completion_tokens,
name="Ollama",
endpoint=endpoint,
)
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
@ -66,6 +69,7 @@ class OllamaAPIAdapter(LLMInterface):
mode=instructor.Mode(self.instructor_mode),
)
@observe(as_type="generation")
@retry(
stop=stop_after_delay(128),
wait=wait_exponential_jitter(2, 128),
@ -113,95 +117,3 @@ class OllamaAPIAdapter(LLMInterface):
)
return response
@retry(
stop=stop_after_delay(128),
wait=wait_exponential_jitter(2, 128),
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
async def create_transcript(self, input_file: str) -> str:
"""
Generate an audio transcript from a user query.
This synchronous method takes an input audio file and returns its transcription. Raises
a FileNotFoundError if the input file does not exist, and raises a ValueError if
transcription fails or returns no text.
Parameters:
-----------
- input_file (str): The path to the audio file to be transcribed.
Returns:
--------
- str: The transcription of the audio as a string.
"""
async with open_data_file(input_file, mode="rb") as audio_file:
transcription = self.aclient.audio.transcriptions.create(
model="whisper-1", # Ensure the correct model for transcription
file=audio_file,
language="en",
)
# Ensure the response contains a valid transcript
if not hasattr(transcription, "text"):
raise ValueError("Transcription failed. No text returned.")
return transcription.text
@retry(
stop=stop_after_delay(128),
wait=wait_exponential_jitter(2, 128),
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
async def transcribe_image(self, input_file: str) -> str:
"""
Transcribe content from an image using base64 encoding.
This synchronous method takes an input image file, encodes it as base64, and returns the
transcription of its content. Raises a FileNotFoundError if the input file does not
exist, and raises a ValueError if the transcription fails or no valid response is
received.
Parameters:
-----------
- input_file (str): The path to the image file to be transcribed.
Returns:
--------
- str: The transcription of the image's content as a string.
"""
async with open_data_file(input_file, mode="rb") as image_file:
encoded_image = base64.b64encode(image_file.read()).decode("utf-8")
response = self.aclient.chat.completions.create(
model=self.model,
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "What's in this image?"},
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"},
},
],
}
],
max_completion_tokens=300,
)
# Ensure response is valid before accessing .choices[0].message.content
if not hasattr(response, "choices") or not response.choices:
raise ValueError("Image transcription failed. No response received.")
return response.choices[0].message.content

View file

@ -1,4 +1,3 @@
import base64
import litellm
import instructor
from typing import Type
@ -16,8 +15,8 @@ from tenacity import (
before_sleep_log,
)
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
LLMInterface,
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.generic_llm_api.adapter import (
GenericAPIAdapter,
)
from cognee.infrastructure.llm.exceptions import (
ContentPolicyFilterError,
@ -31,7 +30,7 @@ logger = get_logger()
observe = get_observe()
class OpenAIAdapter(LLMInterface):
class OpenAIAdapter(GenericAPIAdapter):
"""
Adapter for OpenAI's GPT-3, GPT-4 API.
@ -52,12 +51,7 @@ class OpenAIAdapter(LLMInterface):
- MAX_RETRIES
"""
name = "OpenAI"
model: str
api_key: str
api_version: str
default_instructor_mode = "json_schema_mode"
MAX_RETRIES = 5
"""Adapter for OpenAI's GPT-3, GPT=4 API"""
@ -65,17 +59,29 @@ class OpenAIAdapter(LLMInterface):
def __init__(
self,
api_key: str,
endpoint: str,
api_version: str,
model: str,
transcription_model: str,
max_completion_tokens: int,
endpoint: str = None,
api_version: str = None,
transcription_model: str = None,
instructor_mode: str = None,
streaming: bool = False,
fallback_model: str = None,
fallback_api_key: str = None,
fallback_endpoint: str = None,
):
super().__init__(
api_key=api_key,
model=model,
max_completion_tokens=max_completion_tokens,
name="OpenAI",
endpoint=endpoint,
api_version=api_version,
transcription_model=transcription_model,
fallback_model=fallback_model,
fallback_api_key=fallback_api_key,
fallback_endpoint=fallback_endpoint,
)
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
# TODO: With gpt5 series models OpenAI expects JSON_SCHEMA as a mode for structured outputs.
# Make sure all new gpt models will work with this mode as well.
@ -90,18 +96,8 @@ class OpenAIAdapter(LLMInterface):
self.aclient = instructor.from_litellm(litellm.acompletion)
self.client = instructor.from_litellm(litellm.completion)
self.transcription_model = transcription_model
self.model = model
self.api_key = api_key
self.endpoint = endpoint
self.api_version = api_version
self.max_completion_tokens = max_completion_tokens
self.streaming = streaming
self.fallback_model = fallback_model
self.fallback_api_key = fallback_api_key
self.fallback_endpoint = fallback_endpoint
@observe(as_type="generation")
@retry(
stop=stop_after_delay(128),
@ -174,7 +170,7 @@ class OpenAIAdapter(LLMInterface):
},
],
api_key=self.fallback_api_key,
# api_base=self.fallback_endpoint,
api_base=self.fallback_endpoint,
response_model=response_model,
max_retries=self.MAX_RETRIES,
)
@ -193,57 +189,7 @@ class OpenAIAdapter(LLMInterface):
f"The provided input contains content that is not aligned with our content policy: {text_input}"
) from error
@observe
@retry(
stop=stop_after_delay(128),
wait=wait_exponential_jitter(2, 128),
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
def create_structured_output(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
) -> BaseModel:
"""
Generate a response from a user query.
This method creates structured output by sending a synchronous request to the OpenAI API
using the provided parameters to generate a completion based on the user input and
system prompt.
Parameters:
-----------
- text_input (str): The input text provided by the user for generating a response.
- system_prompt (str): The system's prompt to guide the model's response.
- response_model (Type[BaseModel]): The expected model type for the response.
Returns:
--------
- BaseModel: A structured output generated by the model, returned as an instance of
BaseModel.
"""
return self.client.chat.completions.create(
model=self.model,
messages=[
{
"role": "user",
"content": f"""{text_input}""",
},
{
"role": "system",
"content": system_prompt,
},
],
api_key=self.api_key,
api_base=self.endpoint,
api_version=self.api_version,
response_model=response_model,
max_retries=self.MAX_RETRIES,
)
@observe(as_type="transcription")
@retry(
stop=stop_after_delay(128),
wait=wait_exponential_jitter(2, 128),
@ -282,56 +228,4 @@ class OpenAIAdapter(LLMInterface):
return transcription
@retry(
stop=stop_after_delay(128),
wait=wait_exponential_jitter(2, 128),
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
async def transcribe_image(self, input) -> BaseModel:
"""
Generate a transcription of an image from a user query.
This method encodes the image and sends a request to the OpenAI API to obtain a
description of the contents of the image.
Parameters:
-----------
- input: The path to the image file that needs to be transcribed.
Returns:
--------
- BaseModel: A structured output generated by the model, returned as an instance of
BaseModel.
"""
async with open_data_file(input, mode="rb") as image_file:
encoded_image = base64.b64encode(image_file.read()).decode("utf-8")
return litellm.completion(
model=self.model,
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": "What's in this image?",
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{encoded_image}",
},
},
],
}
],
api_key=self.api_key,
api_base=self.endpoint,
api_version=self.api_version,
max_completion_tokens=300,
max_retries=self.MAX_RETRIES,
)
# transcribe image inherited from GenericAdapter

2
uv.lock generated
View file

@ -1,5 +1,5 @@
version = 1
revision = 2
revision = 3
requires-python = ">=3.10, <3.14"
resolution-markers = [
"python_full_version >= '3.13' and platform_python_implementation != 'PyPy' and sys_platform == 'darwin'",