From 02b17786588b7be4c582a2fdfe93a5412c074cda Mon Sep 17 00:00:00 2001 From: rajeevrajeshuni Date: Tue, 25 Nov 2025 12:22:15 +0530 Subject: [PATCH] Adding support for audio/image transcription for all other providers --- cognee/infrastructure/llm/LLMGateway.py | 13 -- .../llm/anthropic/adapter.py | 27 ++-- .../litellm_instructor/llm/gemini/adapter.py | 59 +++---- .../llm/generic_llm_api/adapter.py | 132 +++++++++++++++- .../litellm_instructor/llm/get_llm_client.py | 13 +- .../litellm_instructor/llm/llm_interface.py | 47 +++++- .../litellm_instructor/llm/mistral/adapter.py | 66 ++++++-- .../litellm_instructor/llm/ollama/adapter.py | 118 ++------------ .../litellm_instructor/llm/openai/adapter.py | 148 +++--------------- uv.lock | 2 +- 10 files changed, 313 insertions(+), 312 deletions(-) diff --git a/cognee/infrastructure/llm/LLMGateway.py b/cognee/infrastructure/llm/LLMGateway.py index ab5bb35d7..66a364110 100644 --- a/cognee/infrastructure/llm/LLMGateway.py +++ b/cognee/infrastructure/llm/LLMGateway.py @@ -34,19 +34,6 @@ class LLMGateway: text_input=text_input, system_prompt=system_prompt, response_model=response_model ) - @staticmethod - def create_structured_output( - text_input: str, system_prompt: str, response_model: Type[BaseModel] - ) -> BaseModel: - from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.get_llm_client import ( - get_llm_client, - ) - - llm_client = get_llm_client() - return llm_client.create_structured_output( - text_input=text_input, system_prompt=system_prompt, response_model=response_model - ) - @staticmethod def create_transcript(input) -> Coroutine: from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.get_llm_client import ( diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py index dbf0dfbea..818d3adb7 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py @@ -3,7 +3,9 @@ from typing import Type from pydantic import BaseModel import litellm import instructor +import anthropic from cognee.shared.logging_utils import get_logger +from cognee.modules.observability.get_observe import get_observe from tenacity import ( retry, stop_after_delay, @@ -12,27 +14,32 @@ from tenacity import ( before_sleep_log, ) -from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import ( - LLMInterface, +from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.generic_llm_api.adapter import ( + GenericAPIAdapter, ) from cognee.infrastructure.llm.config import get_llm_config logger = get_logger() +observe = get_observe() -class AnthropicAdapter(LLMInterface): +class AnthropicAdapter(GenericAPIAdapter): """ Adapter for interfacing with the Anthropic API, enabling structured output generation and prompt display. """ - name = "Anthropic" - model: str default_instructor_mode = "anthropic_tools" - def __init__(self, max_completion_tokens: int, model: str = None, instructor_mode: str = None): - import anthropic - + def __init__( + self, api_key: str, model: str, max_completion_tokens: int, instructor_mode: str = None + ): + super().__init__( + api_key=api_key, + model=model, + max_completion_tokens=max_completion_tokens, + name="Anthropic", + ) self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode self.aclient = instructor.patch( @@ -40,9 +47,7 @@ class AnthropicAdapter(LLMInterface): mode=instructor.Mode(self.instructor_mode), ) - self.model = model - self.max_completion_tokens = max_completion_tokens - + @observe(as_type="generation") @retry( stop=stop_after_delay(128), wait=wait_exponential_jitter(2, 128), diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py index 226f291d7..bae665052 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py @@ -1,4 +1,4 @@ -"""Adapter for Generic API LLM provider API""" +"""Adapter for Gemini API LLM provider""" import litellm import instructor @@ -8,12 +8,7 @@ from openai import ContentFilterFinishReasonError from litellm.exceptions import ContentPolicyViolationError from instructor.core import InstructorRetryException -from cognee.infrastructure.llm.exceptions import ContentPolicyFilterError -from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import ( - LLMInterface, -) import logging -from cognee.shared.logging_utils import get_logger from tenacity import ( retry, stop_after_delay, @@ -22,55 +17,65 @@ from tenacity import ( before_sleep_log, ) +from cognee.infrastructure.llm.exceptions import ContentPolicyFilterError +from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.generic_llm_api.adapter import ( + GenericAPIAdapter, +) +from cognee.shared.logging_utils import get_logger +from cognee.modules.observability.get_observe import get_observe + logger = get_logger() +observe = get_observe() -class GeminiAdapter(LLMInterface): +class GeminiAdapter(GenericAPIAdapter): """ Adapter for Gemini API LLM provider. This class initializes the API adapter with necessary credentials and configurations for interacting with the gemini LLM models. It provides methods for creating structured outputs - based on user input and system prompts. + based on user input and system prompts, as well as multimodal processing capabilities. Public methods: - - acreate_structured_output(text_input: str, system_prompt: str, response_model: - Type[BaseModel]) -> BaseModel + - acreate_structured_output(text_input: str, system_prompt: str, response_model: Type[BaseModel]) -> BaseModel + - create_transcript(input) -> BaseModel: Transcribe audio files to text + - transcribe_image(input) -> BaseModel: Inherited from GenericAPIAdapter """ - name: str - model: str - api_key: str default_instructor_mode = "json_mode" def __init__( self, - endpoint, api_key: str, model: str, - api_version: str, max_completion_tokens: int, + endpoint: str = None, + api_version: str = None, + transcription_model: str = None, instructor_mode: str = None, fallback_model: str = None, fallback_api_key: str = None, fallback_endpoint: str = None, ): - self.model = model - self.api_key = api_key - self.endpoint = endpoint - self.api_version = api_version - self.max_completion_tokens = max_completion_tokens - - self.fallback_model = fallback_model - self.fallback_api_key = fallback_api_key - self.fallback_endpoint = fallback_endpoint - + super().__init__( + api_key=api_key, + model=model, + max_completion_tokens=max_completion_tokens, + name="Gemini", + endpoint=endpoint, + api_version=api_version, + transcription_model=transcription_model, + fallback_model=fallback_model, + fallback_api_key=fallback_api_key, + fallback_endpoint=fallback_endpoint, + ) self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode self.aclient = instructor.from_litellm( litellm.acompletion, mode=instructor.Mode(self.instructor_mode) ) + @observe(as_type="generation") @retry( stop=stop_after_delay(128), wait=wait_exponential_jitter(2, 128), @@ -118,7 +123,7 @@ class GeminiAdapter(LLMInterface): }, ], api_key=self.api_key, - max_retries=5, + max_retries=self.MAX_RETRIES, api_base=self.endpoint, api_version=self.api_version, response_model=response_model, @@ -152,7 +157,7 @@ class GeminiAdapter(LLMInterface): "content": system_prompt, }, ], - max_retries=5, + max_retries=self.MAX_RETRIES, api_key=self.fallback_api_key, api_base=self.fallback_endpoint, response_model=response_model, diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py index 9d7f25fc5..9987711b9 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py @@ -1,8 +1,10 @@ """Adapter for Generic API LLM provider API""" +import base64 +import mimetypes import litellm import instructor -from typing import Type +from typing import Type, Optional from pydantic import BaseModel from openai import ContentFilterFinishReasonError from litellm.exceptions import ContentPolicyViolationError @@ -12,6 +14,8 @@ from cognee.infrastructure.llm.exceptions import ContentPolicyFilterError from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import ( LLMInterface, ) +from cognee.infrastructure.files.utils.open_data_file import open_data_file +from cognee.modules.observability.get_observe import get_observe import logging from cognee.shared.logging_utils import get_logger from tenacity import ( @@ -23,6 +27,7 @@ from tenacity import ( ) logger = get_logger() +observe = get_observe() class GenericAPIAdapter(LLMInterface): @@ -38,18 +43,19 @@ class GenericAPIAdapter(LLMInterface): Type[BaseModel]) -> BaseModel """ - name: str - model: str - api_key: str + MAX_RETRIES = 5 default_instructor_mode = "json_mode" def __init__( self, - endpoint, api_key: str, model: str, - name: str, max_completion_tokens: int, + name: str, + endpoint: str = None, + api_version: str = None, + transcription_model: str = None, + image_transcribe_model: str = None, instructor_mode: str = None, fallback_model: str = None, fallback_api_key: str = None, @@ -58,9 +64,11 @@ class GenericAPIAdapter(LLMInterface): self.name = name self.model = model self.api_key = api_key + self.api_version = api_version self.endpoint = endpoint self.max_completion_tokens = max_completion_tokens - + self.transcription_model = transcription_model or model + self.image_transcribe_model = image_transcribe_model or model self.fallback_model = fallback_model self.fallback_api_key = fallback_api_key self.fallback_endpoint = fallback_endpoint @@ -71,6 +79,7 @@ class GenericAPIAdapter(LLMInterface): litellm.acompletion, mode=instructor.Mode(self.instructor_mode) ) + @observe(as_type="generation") @retry( stop=stop_after_delay(128), wait=wait_exponential_jitter(2, 128), @@ -170,3 +179,112 @@ class GenericAPIAdapter(LLMInterface): raise ContentPolicyFilterError( f"The provided input contains content that is not aligned with our content policy: {text_input}" ) from error + + @observe(as_type="transcription") + @retry( + stop=stop_after_delay(128), + wait=wait_exponential_jitter(2, 128), + retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError), + before_sleep=before_sleep_log(logger, logging.DEBUG), + reraise=True, + ) + async def create_transcript(self, input) -> Optional[BaseModel]: + """ + Generate an audio transcript from a user query. + + This method creates a transcript from the specified audio file, raising a + FileNotFoundError if the file does not exist. The audio file is processed and the + transcription is retrieved from the API. + + Parameters: + ----------- + - input: The path to the audio file that needs to be transcribed. + + Returns: + -------- + The generated transcription of the audio file. + """ + async with open_data_file(input, mode="rb") as audio_file: + encoded_string = base64.b64encode(audio_file.read()).decode("utf-8") + mime_type, _ = mimetypes.guess_type(input) + if not mime_type or not mime_type.startswith("audio/"): + raise ValueError( + f"Could not determine MIME type for audio file: {input}. Is the extension correct?" + ) + return litellm.completion( + model=self.transcription_model, + messages=[ + { + "role": "user", + "content": [ + { + "type": "file", + "file": {"file_data": f"data:{mime_type};base64,{encoded_string}"}, + }, + {"type": "text", "text": "Transcribe the following audio precisely."}, + ], + } + ], + api_key=self.api_key, + api_version=self.api_version, + max_completion_tokens=self.max_completion_tokens, + api_base=self.endpoint, + max_retries=self.MAX_RETRIES, + ) + + @observe(as_type="transcribe_image") + @retry( + stop=stop_after_delay(128), + wait=wait_exponential_jitter(2, 128), + retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError), + before_sleep=before_sleep_log(logger, logging.DEBUG), + reraise=True, + ) + async def transcribe_image(self, input) -> Optional[BaseModel]: + """ + Generate a transcription of an image from a user query. + + This method encodes the image and sends a request to the API to obtain a + description of the contents of the image. + + Parameters: + ----------- + - input: The path to the image file that needs to be transcribed. + + Returns: + -------- + - BaseModel: A structured output generated by the model, returned as an instance of + BaseModel. + """ + async with open_data_file(input, mode="rb") as image_file: + encoded_image = base64.b64encode(image_file.read()).decode("utf-8") + mime_type, _ = mimetypes.guess_type(input) + if not mime_type or not mime_type.startswith("image/"): + raise ValueError( + f"Could not determine MIME type for image file: {input}. Is the extension correct?" + ) + return litellm.completion( + model=self.image_transcribe_model, + messages=[ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "What's in this image?", + }, + { + "type": "image_url", + "image_url": { + "url": f"data:{mime_type};base64,{encoded_image}", + }, + }, + ], + } + ], + api_key=self.api_key, + api_base=self.endpoint, + api_version=self.api_version, + max_completion_tokens=300, + max_retries=self.MAX_RETRIES, + ) diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py index 39558f36d..de6cfaf19 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py @@ -97,11 +97,10 @@ def get_llm_client(raise_api_key_error: bool = True): ) return OllamaAPIAdapter( - llm_config.llm_endpoint, llm_config.llm_api_key, llm_config.llm_model, - "Ollama", - max_completion_tokens=max_completion_tokens, + max_completion_tokens, + llm_config.llm_endpoint, instructor_mode=llm_config.llm_instructor_mode.lower(), ) @@ -111,8 +110,9 @@ def get_llm_client(raise_api_key_error: bool = True): ) return AnthropicAdapter( - max_completion_tokens=max_completion_tokens, - model=llm_config.llm_model, + llm_config.llm_api_key, + llm_config.llm_model, + max_completion_tokens, instructor_mode=llm_config.llm_instructor_mode.lower(), ) @@ -125,11 +125,10 @@ def get_llm_client(raise_api_key_error: bool = True): ) return GenericAPIAdapter( - llm_config.llm_endpoint, llm_config.llm_api_key, llm_config.llm_model, + max_completion_tokens, "Custom", - max_completion_tokens=max_completion_tokens, instructor_mode=llm_config.llm_instructor_mode.lower(), fallback_api_key=llm_config.fallback_api_key, fallback_endpoint=llm_config.fallback_endpoint, diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/llm_interface.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/llm_interface.py index b02105484..f8352737d 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/llm_interface.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/llm_interface.py @@ -1,6 +1,6 @@ """LLM Interface""" -from typing import Type, Protocol +from typing import Type, Protocol, Optional from abc import abstractmethod from pydantic import BaseModel from cognee.infrastructure.llm.LLMGateway import LLMGateway @@ -8,13 +8,12 @@ from cognee.infrastructure.llm.LLMGateway import LLMGateway class LLMInterface(Protocol): """ - Define an interface for LLM models with methods for structured output and prompt - display. + Define an interface for LLM models with methods for structured output, multimodal processing, and prompt display. Methods: - - acreate_structured_output(text_input: str, system_prompt: str, response_model: - Type[BaseModel]) - - show_prompt(text_input: str, system_prompt: str) + - acreate_structured_output(text_input: str, system_prompt: str, response_model: Type[BaseModel]) + - create_transcript(input): Transcribe audio files to text + - transcribe_image(input): Analyze image files and return text description """ @abstractmethod @@ -36,3 +35,39 @@ class LLMInterface(Protocol): output. """ raise NotImplementedError + + @abstractmethod + async def create_transcript(self, input) -> Optional[BaseModel]: + """ + Transcribe audio content to text. + + This method should be implemented by subclasses that support audio transcription. + If not implemented, returns None and should be handled gracefully by callers. + + Parameters: + ----------- + - input: The path to the audio file that needs to be transcribed. + + Returns: + -------- + - BaseModel: A structured output containing the transcription, or None if not supported. + """ + raise NotImplementedError + + @abstractmethod + async def transcribe_image(self, input) -> Optional[BaseModel]: + """ + Analyze image content and return text description. + + This method should be implemented by subclasses that support image analysis. + If not implemented, returns None and should be handled gracefully by callers. + + Parameters: + ----------- + - input: The path to the image file that needs to be analyzed. + + Returns: + -------- + - BaseModel: A structured output containing the image description, or None if not supported. + """ + raise NotImplementedError diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py index 355cdae0b..0fa35923f 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py @@ -2,12 +2,12 @@ import litellm import instructor from pydantic import BaseModel from typing import Type -from litellm import JSONSchemaValidationError +from litellm import JSONSchemaValidationError, transcription from cognee.shared.logging_utils import get_logger from cognee.modules.observability.get_observe import get_observe -from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import ( - LLMInterface, +from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.generic_llm_api.adapter import ( + GenericAPIAdapter, ) from cognee.infrastructure.llm.config import get_llm_config @@ -19,12 +19,13 @@ from tenacity import ( retry_if_not_exception_type, before_sleep_log, ) +from mistralai import Mistral logger = get_logger() observe = get_observe() -class MistralAdapter(LLMInterface): +class MistralAdapter(GenericAPIAdapter): """ Adapter for Mistral AI API, for structured output generation and prompt display. @@ -33,10 +34,6 @@ class MistralAdapter(LLMInterface): - show_prompt """ - name = "Mistral" - model: str - api_key: str - max_completion_tokens: int default_instructor_mode = "mistral_tools" def __init__( @@ -45,12 +42,21 @@ class MistralAdapter(LLMInterface): model: str, max_completion_tokens: int, endpoint: str = None, + transcription_model: str = None, + image_transcribe_model: str = None, instructor_mode: str = None, ): from mistralai import Mistral - self.model = model - self.max_completion_tokens = max_completion_tokens + super().__init__( + api_key=api_key, + model=model, + max_completion_tokens=max_completion_tokens, + name="Mistral", + endpoint=endpoint, + transcription_model=transcription_model, + image_transcribe_model=image_transcribe_model, + ) self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode @@ -60,6 +66,7 @@ class MistralAdapter(LLMInterface): api_key=get_llm_config().llm_api_key, ) + @observe(as_type="generation") @retry( stop=stop_after_delay(128), wait=wait_exponential_jitter(2, 128), @@ -117,3 +124,42 @@ class MistralAdapter(LLMInterface): logger.error(f"Schema validation failed: {str(e)}") logger.debug(f"Raw response: {e.raw_response}") raise ValueError(f"Response failed schema validation: {str(e)}") + + @observe(as_type="transcription") + @retry( + stop=stop_after_delay(128), + wait=wait_exponential_jitter(2, 128), + retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError), + before_sleep=before_sleep_log(logger, logging.DEBUG), + reraise=True, + ) + async def create_transcript(self, input): + """ + Generate an audio transcript from a user query. + + This method creates a transcript from the specified audio file. + The audio file is processed and the transcription is retrieved from the API. + + Parameters: + ----------- + - input: The path to the audio file that needs to be transcribed. + + Returns: + -------- + The generated transcription of the audio file. + """ + transcription_model = self.transcription_model + if self.transcription_model.startswith("mistral"): + transcription_model = self.transcription_model.split("/")[-1] + file_name = input.split("/")[-1] + client = Mistral(api_key=self.api_key) + with open(input, "rb") as f: + transcription_response = client.audio.transcriptions.complete( + model=transcription_model, + file={ + "content": f, + "file_name": file_name, + }, + ) + # TODO: We need to standardize return type of create_transcript across different models. + return transcription_response diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py index aabd19867..163637a95 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py @@ -5,12 +5,12 @@ import instructor from typing import Type from openai import OpenAI from pydantic import BaseModel - -from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import ( - LLMInterface, -) from cognee.infrastructure.files.utils.open_data_file import open_data_file from cognee.shared.logging_utils import get_logger +from cognee.modules.observability.get_observe import get_observe +from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.generic_llm_api.adapter import ( + GenericAPIAdapter, +) from tenacity import ( retry, stop_after_delay, @@ -20,9 +20,10 @@ from tenacity import ( ) logger = get_logger() +observe = get_observe() -class OllamaAPIAdapter(LLMInterface): +class OllamaAPIAdapter(GenericAPIAdapter): """ Adapter for a Generic API LLM provider using instructor with an OpenAI backend. @@ -46,18 +47,20 @@ class OllamaAPIAdapter(LLMInterface): def __init__( self, - endpoint: str, api_key: str, model: str, name: str, max_completion_tokens: int, + endpoint: str, instructor_mode: str = None, ): - self.name = name - self.model = model - self.api_key = api_key - self.endpoint = endpoint - self.max_completion_tokens = max_completion_tokens + super().__init__( + api_key=api_key, + model=model, + max_completion_tokens=max_completion_tokens, + name="Ollama", + endpoint=endpoint, + ) self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode @@ -66,6 +69,7 @@ class OllamaAPIAdapter(LLMInterface): mode=instructor.Mode(self.instructor_mode), ) + @observe(as_type="generation") @retry( stop=stop_after_delay(128), wait=wait_exponential_jitter(2, 128), @@ -113,95 +117,3 @@ class OllamaAPIAdapter(LLMInterface): ) return response - - @retry( - stop=stop_after_delay(128), - wait=wait_exponential_jitter(2, 128), - retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError), - before_sleep=before_sleep_log(logger, logging.DEBUG), - reraise=True, - ) - async def create_transcript(self, input_file: str) -> str: - """ - Generate an audio transcript from a user query. - - This synchronous method takes an input audio file and returns its transcription. Raises - a FileNotFoundError if the input file does not exist, and raises a ValueError if - transcription fails or returns no text. - - Parameters: - ----------- - - - input_file (str): The path to the audio file to be transcribed. - - Returns: - -------- - - - str: The transcription of the audio as a string. - """ - - async with open_data_file(input_file, mode="rb") as audio_file: - transcription = self.aclient.audio.transcriptions.create( - model="whisper-1", # Ensure the correct model for transcription - file=audio_file, - language="en", - ) - - # Ensure the response contains a valid transcript - if not hasattr(transcription, "text"): - raise ValueError("Transcription failed. No text returned.") - - return transcription.text - - @retry( - stop=stop_after_delay(128), - wait=wait_exponential_jitter(2, 128), - retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError), - before_sleep=before_sleep_log(logger, logging.DEBUG), - reraise=True, - ) - async def transcribe_image(self, input_file: str) -> str: - """ - Transcribe content from an image using base64 encoding. - - This synchronous method takes an input image file, encodes it as base64, and returns the - transcription of its content. Raises a FileNotFoundError if the input file does not - exist, and raises a ValueError if the transcription fails or no valid response is - received. - - Parameters: - ----------- - - - input_file (str): The path to the image file to be transcribed. - - Returns: - -------- - - - str: The transcription of the image's content as a string. - """ - - async with open_data_file(input_file, mode="rb") as image_file: - encoded_image = base64.b64encode(image_file.read()).decode("utf-8") - - response = self.aclient.chat.completions.create( - model=self.model, - messages=[ - { - "role": "user", - "content": [ - {"type": "text", "text": "What's in this image?"}, - { - "type": "image_url", - "image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}, - }, - ], - } - ], - max_completion_tokens=300, - ) - - # Ensure response is valid before accessing .choices[0].message.content - if not hasattr(response, "choices") or not response.choices: - raise ValueError("Image transcription failed. No response received.") - - return response.choices[0].message.content diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py index 778c8eec7..e9943c335 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py @@ -1,4 +1,3 @@ -import base64 import litellm import instructor from typing import Type @@ -16,8 +15,8 @@ from tenacity import ( before_sleep_log, ) -from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import ( - LLMInterface, +from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.generic_llm_api.adapter import ( + GenericAPIAdapter, ) from cognee.infrastructure.llm.exceptions import ( ContentPolicyFilterError, @@ -31,7 +30,7 @@ logger = get_logger() observe = get_observe() -class OpenAIAdapter(LLMInterface): +class OpenAIAdapter(GenericAPIAdapter): """ Adapter for OpenAI's GPT-3, GPT-4 API. @@ -52,12 +51,7 @@ class OpenAIAdapter(LLMInterface): - MAX_RETRIES """ - name = "OpenAI" - model: str - api_key: str - api_version: str default_instructor_mode = "json_schema_mode" - MAX_RETRIES = 5 """Adapter for OpenAI's GPT-3, GPT=4 API""" @@ -65,17 +59,29 @@ class OpenAIAdapter(LLMInterface): def __init__( self, api_key: str, - endpoint: str, - api_version: str, model: str, - transcription_model: str, max_completion_tokens: int, + endpoint: str = None, + api_version: str = None, + transcription_model: str = None, instructor_mode: str = None, streaming: bool = False, fallback_model: str = None, fallback_api_key: str = None, fallback_endpoint: str = None, ): + super().__init__( + api_key=api_key, + model=model, + max_completion_tokens=max_completion_tokens, + name="OpenAI", + endpoint=endpoint, + api_version=api_version, + transcription_model=transcription_model, + fallback_model=fallback_model, + fallback_api_key=fallback_api_key, + fallback_endpoint=fallback_endpoint, + ) self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode # TODO: With gpt5 series models OpenAI expects JSON_SCHEMA as a mode for structured outputs. # Make sure all new gpt models will work with this mode as well. @@ -90,18 +96,8 @@ class OpenAIAdapter(LLMInterface): self.aclient = instructor.from_litellm(litellm.acompletion) self.client = instructor.from_litellm(litellm.completion) - self.transcription_model = transcription_model - self.model = model - self.api_key = api_key - self.endpoint = endpoint - self.api_version = api_version - self.max_completion_tokens = max_completion_tokens self.streaming = streaming - self.fallback_model = fallback_model - self.fallback_api_key = fallback_api_key - self.fallback_endpoint = fallback_endpoint - @observe(as_type="generation") @retry( stop=stop_after_delay(128), @@ -174,7 +170,7 @@ class OpenAIAdapter(LLMInterface): }, ], api_key=self.fallback_api_key, - # api_base=self.fallback_endpoint, + api_base=self.fallback_endpoint, response_model=response_model, max_retries=self.MAX_RETRIES, ) @@ -193,57 +189,7 @@ class OpenAIAdapter(LLMInterface): f"The provided input contains content that is not aligned with our content policy: {text_input}" ) from error - @observe - @retry( - stop=stop_after_delay(128), - wait=wait_exponential_jitter(2, 128), - retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError), - before_sleep=before_sleep_log(logger, logging.DEBUG), - reraise=True, - ) - def create_structured_output( - self, text_input: str, system_prompt: str, response_model: Type[BaseModel] - ) -> BaseModel: - """ - Generate a response from a user query. - - This method creates structured output by sending a synchronous request to the OpenAI API - using the provided parameters to generate a completion based on the user input and - system prompt. - - Parameters: - ----------- - - - text_input (str): The input text provided by the user for generating a response. - - system_prompt (str): The system's prompt to guide the model's response. - - response_model (Type[BaseModel]): The expected model type for the response. - - Returns: - -------- - - - BaseModel: A structured output generated by the model, returned as an instance of - BaseModel. - """ - - return self.client.chat.completions.create( - model=self.model, - messages=[ - { - "role": "user", - "content": f"""{text_input}""", - }, - { - "role": "system", - "content": system_prompt, - }, - ], - api_key=self.api_key, - api_base=self.endpoint, - api_version=self.api_version, - response_model=response_model, - max_retries=self.MAX_RETRIES, - ) - + @observe(as_type="transcription") @retry( stop=stop_after_delay(128), wait=wait_exponential_jitter(2, 128), @@ -282,56 +228,4 @@ class OpenAIAdapter(LLMInterface): return transcription - @retry( - stop=stop_after_delay(128), - wait=wait_exponential_jitter(2, 128), - retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError), - before_sleep=before_sleep_log(logger, logging.DEBUG), - reraise=True, - ) - async def transcribe_image(self, input) -> BaseModel: - """ - Generate a transcription of an image from a user query. - - This method encodes the image and sends a request to the OpenAI API to obtain a - description of the contents of the image. - - Parameters: - ----------- - - - input: The path to the image file that needs to be transcribed. - - Returns: - -------- - - - BaseModel: A structured output generated by the model, returned as an instance of - BaseModel. - """ - async with open_data_file(input, mode="rb") as image_file: - encoded_image = base64.b64encode(image_file.read()).decode("utf-8") - - return litellm.completion( - model=self.model, - messages=[ - { - "role": "user", - "content": [ - { - "type": "text", - "text": "What's in this image?", - }, - { - "type": "image_url", - "image_url": { - "url": f"data:image/jpeg;base64,{encoded_image}", - }, - }, - ], - } - ], - api_key=self.api_key, - api_base=self.endpoint, - api_version=self.api_version, - max_completion_tokens=300, - max_retries=self.MAX_RETRIES, - ) + # transcribe image inherited from GenericAdapter diff --git a/uv.lock b/uv.lock index cc66c3d7e..d8fb3805b 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.10, <3.14" resolution-markers = [ "python_full_version >= '3.13' and platform_python_implementation != 'PyPy' and sys_platform == 'darwin'",