diff --git a/cognee/infrastructure/llm/LLMGateway.py b/cognee/infrastructure/llm/LLMGateway.py index 7bec9ca01..2d1d3f97e 100644 --- a/cognee/infrastructure/llm/LLMGateway.py +++ b/cognee/infrastructure/llm/LLMGateway.py @@ -37,19 +37,6 @@ class LLMGateway: **kwargs, ) - @staticmethod - def create_structured_output( - text_input: str, system_prompt: str, response_model: Type[BaseModel] - ) -> BaseModel: - from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.get_llm_client import ( - get_llm_client, - ) - - llm_client = get_llm_client() - return llm_client.create_structured_output( - text_input=text_input, system_prompt=system_prompt, response_model=response_model - ) - @staticmethod def create_transcript(input) -> Coroutine: from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.get_llm_client import ( diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py index 58b68436c..26cb471cd 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py @@ -3,7 +3,9 @@ from typing import Type from pydantic import BaseModel import litellm import instructor +import anthropic from cognee.shared.logging_utils import get_logger +from cognee.modules.observability.get_observe import get_observe from tenacity import ( retry, stop_after_delay, @@ -12,38 +14,41 @@ from tenacity import ( before_sleep_log, ) -from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import ( - LLMInterface, +from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.generic_llm_api.adapter import ( + GenericAPIAdapter, ) from cognee.shared.rate_limiting import llm_rate_limiter_context_manager from cognee.infrastructure.llm.config import get_llm_config logger = get_logger() +observe = get_observe() -class AnthropicAdapter(LLMInterface): +class AnthropicAdapter(GenericAPIAdapter): """ Adapter for interfacing with the Anthropic API, enabling structured output generation and prompt display. """ - name = "Anthropic" - model: str default_instructor_mode = "anthropic_tools" - def __init__(self, max_completion_tokens: int, model: str = None, instructor_mode: str = None): - import anthropic - + def __init__( + self, api_key: str, model: str, max_completion_tokens: int, instructor_mode: str = None + ): + super().__init__( + api_key=api_key, + model=model, + max_completion_tokens=max_completion_tokens, + name="Anthropic", + ) self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode self.aclient = instructor.patch( - create=anthropic.AsyncAnthropic(api_key=get_llm_config().llm_api_key).messages.create, + create=anthropic.AsyncAnthropic(api_key=self.api_key).messages.create, mode=instructor.Mode(self.instructor_mode), ) - self.model = model - self.max_completion_tokens = max_completion_tokens - + @observe(as_type="generation") @retry( stop=stop_after_delay(128), wait=wait_exponential_jitter(8, 128), diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py index 208c3729d..fb7ea8a6c 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py @@ -1,4 +1,4 @@ -"""Adapter for Generic API LLM provider API""" +"""Adapter for Gemini API LLM provider""" import litellm import instructor @@ -8,13 +8,9 @@ from openai import ContentFilterFinishReasonError from litellm.exceptions import ContentPolicyViolationError from instructor.core import InstructorRetryException -from cognee.infrastructure.llm.exceptions import ContentPolicyFilterError -from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import ( - LLMInterface, -) import logging from cognee.shared.rate_limiting import llm_rate_limiter_context_manager -from cognee.shared.logging_utils import get_logger + from tenacity import ( retry, stop_after_delay, @@ -23,55 +19,65 @@ from tenacity import ( before_sleep_log, ) +from cognee.infrastructure.llm.exceptions import ContentPolicyFilterError +from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.generic_llm_api.adapter import ( + GenericAPIAdapter, +) +from cognee.shared.logging_utils import get_logger +from cognee.modules.observability.get_observe import get_observe + logger = get_logger() +observe = get_observe() -class GeminiAdapter(LLMInterface): +class GeminiAdapter(GenericAPIAdapter): """ Adapter for Gemini API LLM provider. This class initializes the API adapter with necessary credentials and configurations for interacting with the gemini LLM models. It provides methods for creating structured outputs - based on user input and system prompts. + based on user input and system prompts, as well as multimodal processing capabilities. Public methods: - - acreate_structured_output(text_input: str, system_prompt: str, response_model: - Type[BaseModel]) -> BaseModel + - acreate_structured_output(text_input: str, system_prompt: str, response_model: Type[BaseModel]) -> BaseModel + - create_transcript(input) -> BaseModel: Transcribe audio files to text + - transcribe_image(input) -> BaseModel: Inherited from GenericAPIAdapter """ - name: str - model: str - api_key: str default_instructor_mode = "json_mode" def __init__( self, - endpoint, api_key: str, model: str, - api_version: str, max_completion_tokens: int, + endpoint: str = None, + api_version: str = None, + transcription_model: str = None, instructor_mode: str = None, fallback_model: str = None, fallback_api_key: str = None, fallback_endpoint: str = None, ): - self.model = model - self.api_key = api_key - self.endpoint = endpoint - self.api_version = api_version - self.max_completion_tokens = max_completion_tokens - - self.fallback_model = fallback_model - self.fallback_api_key = fallback_api_key - self.fallback_endpoint = fallback_endpoint - + super().__init__( + api_key=api_key, + model=model, + max_completion_tokens=max_completion_tokens, + name="Gemini", + endpoint=endpoint, + api_version=api_version, + transcription_model=transcription_model, + fallback_model=fallback_model, + fallback_api_key=fallback_api_key, + fallback_endpoint=fallback_endpoint, + ) self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode self.aclient = instructor.from_litellm( litellm.acompletion, mode=instructor.Mode(self.instructor_mode) ) + @observe(as_type="generation") @retry( stop=stop_after_delay(128), wait=wait_exponential_jitter(8, 128), diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py index d6e00d40a..4fd7b45c1 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py @@ -1,8 +1,10 @@ """Adapter for Generic API LLM provider API""" +import base64 +import mimetypes import litellm import instructor -from typing import Type +from typing import Type, Optional from pydantic import BaseModel from openai import ContentFilterFinishReasonError from litellm.exceptions import ContentPolicyViolationError @@ -12,6 +14,8 @@ from cognee.infrastructure.llm.exceptions import ContentPolicyFilterError from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import ( LLMInterface, ) +from cognee.infrastructure.files.utils.open_data_file import open_data_file +from cognee.modules.observability.get_observe import get_observe import logging from cognee.shared.rate_limiting import llm_rate_limiter_context_manager from cognee.shared.logging_utils import get_logger @@ -23,7 +27,12 @@ from tenacity import ( before_sleep_log, ) +from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.types import ( + TranscriptionReturnType, +) + logger = get_logger() +observe = get_observe() class GenericAPIAdapter(LLMInterface): @@ -39,18 +48,19 @@ class GenericAPIAdapter(LLMInterface): Type[BaseModel]) -> BaseModel """ - name: str - model: str - api_key: str + MAX_RETRIES = 5 default_instructor_mode = "json_mode" def __init__( self, - endpoint, api_key: str, model: str, - name: str, max_completion_tokens: int, + name: str, + endpoint: str = None, + api_version: str = None, + transcription_model: str = None, + image_transcribe_model: str = None, instructor_mode: str = None, fallback_model: str = None, fallback_api_key: str = None, @@ -59,9 +69,11 @@ class GenericAPIAdapter(LLMInterface): self.name = name self.model = model self.api_key = api_key + self.api_version = api_version self.endpoint = endpoint self.max_completion_tokens = max_completion_tokens - + self.transcription_model = transcription_model or model + self.image_transcribe_model = image_transcribe_model or model self.fallback_model = fallback_model self.fallback_api_key = fallback_api_key self.fallback_endpoint = fallback_endpoint @@ -72,6 +84,7 @@ class GenericAPIAdapter(LLMInterface): litellm.acompletion, mode=instructor.Mode(self.instructor_mode) ) + @observe(as_type="generation") @retry( stop=stop_after_delay(128), wait=wait_exponential_jitter(8, 128), @@ -173,3 +186,115 @@ class GenericAPIAdapter(LLMInterface): raise ContentPolicyFilterError( f"The provided input contains content that is not aligned with our content policy: {text_input}" ) from error + + @observe(as_type="transcription") + @retry( + stop=stop_after_delay(128), + wait=wait_exponential_jitter(2, 128), + retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError), + before_sleep=before_sleep_log(logger, logging.DEBUG), + reraise=True, + ) + async def create_transcript(self, input) -> TranscriptionReturnType: + """ + Generate an audio transcript from a user query. + + This method creates a transcript from the specified audio file, raising a + FileNotFoundError if the file does not exist. The audio file is processed and the + transcription is retrieved from the API. + + Parameters: + ----------- + - input: The path to the audio file that needs to be transcribed. + + Returns: + -------- + The generated transcription of the audio file. + """ + async with open_data_file(input, mode="rb") as audio_file: + encoded_string = base64.b64encode(audio_file.read()).decode("utf-8") + mime_type, _ = mimetypes.guess_type(input) + if not mime_type or not mime_type.startswith("audio/"): + raise ValueError( + f"Could not determine MIME type for audio file: {input}. Is the extension correct?" + ) + response = await litellm.acompletion( + model=self.transcription_model, + messages=[ + { + "role": "user", + "content": [ + { + "type": "file", + "file": {"file_data": f"data:{mime_type};base64,{encoded_string}"}, + }, + {"type": "text", "text": "Transcribe the following audio precisely."}, + ], + } + ], + api_key=self.api_key, + api_version=self.api_version, + max_completion_tokens=self.max_completion_tokens, + api_base=self.endpoint, + max_retries=self.MAX_RETRIES, + ) + + return TranscriptionReturnType(response.choices[0].message.content, response) + + @observe(as_type="transcribe_image") + @retry( + stop=stop_after_delay(128), + wait=wait_exponential_jitter(2, 128), + retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError), + before_sleep=before_sleep_log(logger, logging.DEBUG), + reraise=True, + ) + async def transcribe_image(self, input) -> BaseModel: + """ + Generate a transcription of an image from a user query. + + This method encodes the image and sends a request to the API to obtain a + description of the contents of the image. + + Parameters: + ----------- + - input: The path to the image file that needs to be transcribed. + + Returns: + -------- + - BaseModel: A structured output generated by the model, returned as an instance of + BaseModel. + """ + async with open_data_file(input, mode="rb") as image_file: + encoded_image = base64.b64encode(image_file.read()).decode("utf-8") + mime_type, _ = mimetypes.guess_type(input) + if not mime_type or not mime_type.startswith("image/"): + raise ValueError( + f"Could not determine MIME type for image file: {input}. Is the extension correct?" + ) + response = await litellm.acompletion( + model=self.image_transcribe_model, + messages=[ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "What's in this image?", + }, + { + "type": "image_url", + "image_url": { + "url": f"data:{mime_type};base64,{encoded_image}", + }, + }, + ], + } + ], + api_key=self.api_key, + api_base=self.endpoint, + api_version=self.api_version, + max_completion_tokens=300, + max_retries=self.MAX_RETRIES, + ) + return response diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py index 954d85c1d..b5a255559 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py @@ -103,7 +103,7 @@ def get_llm_client(raise_api_key_error: bool = True): llm_config.llm_api_key, llm_config.llm_model, "Ollama", - max_completion_tokens=max_completion_tokens, + max_completion_tokens, instructor_mode=llm_config.llm_instructor_mode.lower(), ) @@ -113,8 +113,9 @@ def get_llm_client(raise_api_key_error: bool = True): ) return AnthropicAdapter( - max_completion_tokens=max_completion_tokens, - model=llm_config.llm_model, + llm_config.llm_api_key, + llm_config.llm_model, + max_completion_tokens, instructor_mode=llm_config.llm_instructor_mode.lower(), ) @@ -127,11 +128,10 @@ def get_llm_client(raise_api_key_error: bool = True): ) return GenericAPIAdapter( - llm_config.llm_endpoint, llm_config.llm_api_key, llm_config.llm_model, + max_completion_tokens, "Custom", - max_completion_tokens=max_completion_tokens, instructor_mode=llm_config.llm_instructor_mode.lower(), fallback_api_key=llm_config.fallback_api_key, fallback_endpoint=llm_config.fallback_endpoint, diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/llm_interface.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/llm_interface.py index b02105484..6afd4138c 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/llm_interface.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/llm_interface.py @@ -3,18 +3,14 @@ from typing import Type, Protocol from abc import abstractmethod from pydantic import BaseModel -from cognee.infrastructure.llm.LLMGateway import LLMGateway class LLMInterface(Protocol): """ - Define an interface for LLM models with methods for structured output and prompt - display. + Define an interface for LLM models with methods for structured output, multimodal processing, and prompt display. Methods: - - acreate_structured_output(text_input: str, system_prompt: str, response_model: - Type[BaseModel]) - - show_prompt(text_input: str, system_prompt: str) + - acreate_structured_output(text_input: str, system_prompt: str, response_model: Type[BaseModel]) """ @abstractmethod diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py index e1131524d..a52bbe281 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py @@ -1,13 +1,13 @@ import litellm import instructor from pydantic import BaseModel -from typing import Type +from typing import Type, Optional from litellm import JSONSchemaValidationError - +from cognee.infrastructure.files.utils.open_data_file import open_data_file from cognee.shared.logging_utils import get_logger from cognee.modules.observability.get_observe import get_observe -from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import ( - LLMInterface, +from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.generic_llm_api.adapter import ( + GenericAPIAdapter, ) from cognee.infrastructure.llm.config import get_llm_config from cognee.shared.rate_limiting import llm_rate_limiter_context_manager @@ -20,12 +20,14 @@ from tenacity import ( retry_if_not_exception_type, before_sleep_log, ) +from ..types import TranscriptionReturnType +from mistralai import Mistral logger = get_logger() observe = get_observe() -class MistralAdapter(LLMInterface): +class MistralAdapter(GenericAPIAdapter): """ Adapter for Mistral AI API, for structured output generation and prompt display. @@ -34,10 +36,6 @@ class MistralAdapter(LLMInterface): - show_prompt """ - name = "Mistral" - model: str - api_key: str - max_completion_tokens: int default_instructor_mode = "mistral_tools" def __init__( @@ -46,12 +44,19 @@ class MistralAdapter(LLMInterface): model: str, max_completion_tokens: int, endpoint: str = None, + transcription_model: str = None, + image_transcribe_model: str = None, instructor_mode: str = None, ): - from mistralai import Mistral - - self.model = model - self.max_completion_tokens = max_completion_tokens + super().__init__( + api_key=api_key, + model=model, + max_completion_tokens=max_completion_tokens, + name="Mistral", + endpoint=endpoint, + transcription_model=transcription_model, + image_transcribe_model=image_transcribe_model, + ) self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode @@ -60,7 +65,9 @@ class MistralAdapter(LLMInterface): mode=instructor.Mode(self.instructor_mode), api_key=get_llm_config().llm_api_key, ) + self.mistral_client = Mistral(api_key=self.api_key) + @observe(as_type="generation") @retry( stop=stop_after_delay(128), wait=wait_exponential_jitter(8, 128), @@ -119,3 +126,41 @@ class MistralAdapter(LLMInterface): logger.error(f"Schema validation failed: {str(e)}") logger.debug(f"Raw response: {e.raw_response}") raise ValueError(f"Response failed schema validation: {str(e)}") + + @observe(as_type="transcription") + @retry( + stop=stop_after_delay(128), + wait=wait_exponential_jitter(2, 128), + retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError), + before_sleep=before_sleep_log(logger, logging.DEBUG), + reraise=True, + ) + async def create_transcript(self, input) -> Optional[TranscriptionReturnType]: + """ + Generate an audio transcript from a user query. + + This method creates a transcript from the specified audio file. + The audio file is processed and the transcription is retrieved from the API. + + Parameters: + ----------- + - input: The path to the audio file that needs to be transcribed. + + Returns: + -------- + The generated transcription of the audio file. + """ + transcription_model = self.transcription_model + if self.transcription_model.startswith("mistral"): + transcription_model = self.transcription_model.split("/")[-1] + file_name = input.split("/")[-1] + async with open_data_file(input, mode="rb") as f: + transcription_response = self.mistral_client.audio.transcriptions.complete( + model=transcription_model, + file={ + "content": f, + "file_name": file_name, + }, + ) + + return TranscriptionReturnType(transcription_response.text, transcription_response) diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py index 211e49694..c702ce8bf 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py @@ -12,7 +12,6 @@ from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.ll from cognee.infrastructure.files.utils.open_data_file import open_data_file from cognee.shared.logging_utils import get_logger from cognee.shared.rate_limiting import llm_rate_limiter_context_manager - from tenacity import ( retry, stop_after_delay, diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py index ca9b583b7..582c3a08f 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py @@ -1,4 +1,3 @@ -import base64 import litellm import instructor from typing import Type @@ -16,8 +15,8 @@ from tenacity import ( before_sleep_log, ) -from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import ( - LLMInterface, +from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.generic_llm_api.adapter import ( + GenericAPIAdapter, ) from cognee.infrastructure.llm.exceptions import ( ContentPolicyFilterError, @@ -26,13 +25,16 @@ from cognee.shared.rate_limiting import llm_rate_limiter_context_manager from cognee.infrastructure.files.utils.open_data_file import open_data_file from cognee.modules.observability.get_observe import get_observe from cognee.shared.logging_utils import get_logger +from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.types import ( + TranscriptionReturnType, +) logger = get_logger() observe = get_observe() -class OpenAIAdapter(LLMInterface): +class OpenAIAdapter(GenericAPIAdapter): """ Adapter for OpenAI's GPT-3, GPT-4 API. @@ -53,12 +55,7 @@ class OpenAIAdapter(LLMInterface): - MAX_RETRIES """ - name = "OpenAI" - model: str - api_key: str - api_version: str default_instructor_mode = "json_schema_mode" - MAX_RETRIES = 5 """Adapter for OpenAI's GPT-3, GPT=4 API""" @@ -66,17 +63,29 @@ class OpenAIAdapter(LLMInterface): def __init__( self, api_key: str, - endpoint: str, - api_version: str, model: str, - transcription_model: str, max_completion_tokens: int, + endpoint: str = None, + api_version: str = None, + transcription_model: str = None, instructor_mode: str = None, streaming: bool = False, fallback_model: str = None, fallback_api_key: str = None, fallback_endpoint: str = None, ): + super().__init__( + api_key=api_key, + model=model, + max_completion_tokens=max_completion_tokens, + name="OpenAI", + endpoint=endpoint, + api_version=api_version, + transcription_model=transcription_model, + fallback_model=fallback_model, + fallback_api_key=fallback_api_key, + fallback_endpoint=fallback_endpoint, + ) self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode # TODO: With gpt5 series models OpenAI expects JSON_SCHEMA as a mode for structured outputs. # Make sure all new gpt models will work with this mode as well. @@ -91,18 +100,8 @@ class OpenAIAdapter(LLMInterface): self.aclient = instructor.from_litellm(litellm.acompletion) self.client = instructor.from_litellm(litellm.completion) - self.transcription_model = transcription_model - self.model = model - self.api_key = api_key - self.endpoint = endpoint - self.api_version = api_version - self.max_completion_tokens = max_completion_tokens self.streaming = streaming - self.fallback_model = fallback_model - self.fallback_api_key = fallback_api_key - self.fallback_endpoint = fallback_endpoint - @observe(as_type="generation") @retry( stop=stop_after_delay(128), @@ -198,7 +197,7 @@ class OpenAIAdapter(LLMInterface): f"The provided input contains content that is not aligned with our content policy: {text_input}" ) from error - @observe + @observe(as_type="transcription") @retry( stop=stop_after_delay(128), wait=wait_exponential_jitter(2, 128), @@ -206,58 +205,7 @@ class OpenAIAdapter(LLMInterface): before_sleep=before_sleep_log(logger, logging.DEBUG), reraise=True, ) - def create_structured_output( - self, text_input: str, system_prompt: str, response_model: Type[BaseModel], **kwargs - ) -> BaseModel: - """ - Generate a response from a user query. - - This method creates structured output by sending a synchronous request to the OpenAI API - using the provided parameters to generate a completion based on the user input and - system prompt. - - Parameters: - ----------- - - - text_input (str): The input text provided by the user for generating a response. - - system_prompt (str): The system's prompt to guide the model's response. - - response_model (Type[BaseModel]): The expected model type for the response. - - Returns: - -------- - - - BaseModel: A structured output generated by the model, returned as an instance of - BaseModel. - """ - - return self.client.chat.completions.create( - model=self.model, - messages=[ - { - "role": "user", - "content": f"""{text_input}""", - }, - { - "role": "system", - "content": system_prompt, - }, - ], - api_key=self.api_key, - api_base=self.endpoint, - api_version=self.api_version, - response_model=response_model, - max_retries=self.MAX_RETRIES, - **kwargs, - ) - - @retry( - stop=stop_after_delay(128), - wait=wait_exponential_jitter(2, 128), - retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError), - before_sleep=before_sleep_log(logger, logging.DEBUG), - reraise=True, - ) - async def create_transcript(self, input, **kwargs): + async def create_transcript(self, input, **kwargs) -> TranscriptionReturnType: """ Generate an audio transcript from a user query. @@ -286,60 +234,6 @@ class OpenAIAdapter(LLMInterface): max_retries=self.MAX_RETRIES, **kwargs, ) + return TranscriptionReturnType(transcription.text, transcription) - return transcription - - @retry( - stop=stop_after_delay(128), - wait=wait_exponential_jitter(2, 128), - retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError), - before_sleep=before_sleep_log(logger, logging.DEBUG), - reraise=True, - ) - async def transcribe_image(self, input, **kwargs) -> BaseModel: - """ - Generate a transcription of an image from a user query. - - This method encodes the image and sends a request to the OpenAI API to obtain a - description of the contents of the image. - - Parameters: - ----------- - - - input: The path to the image file that needs to be transcribed. - - Returns: - -------- - - - BaseModel: A structured output generated by the model, returned as an instance of - BaseModel. - """ - async with open_data_file(input, mode="rb") as image_file: - encoded_image = base64.b64encode(image_file.read()).decode("utf-8") - - return litellm.completion( - model=self.model, - messages=[ - { - "role": "user", - "content": [ - { - "type": "text", - "text": "What's in this image?", - }, - { - "type": "image_url", - "image_url": { - "url": f"data:image/jpeg;base64,{encoded_image}", - }, - }, - ], - } - ], - api_key=self.api_key, - api_base=self.endpoint, - api_version=self.api_version, - max_completion_tokens=300, - max_retries=self.MAX_RETRIES, - **kwargs, - ) + # transcribe_image is inherited from GenericAPIAdapter diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/types.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/types.py new file mode 100644 index 000000000..fc850830d --- /dev/null +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/types.py @@ -0,0 +1,10 @@ +from pydantic import BaseModel + + +class TranscriptionReturnType: + text: str + payload: BaseModel + + def __init__(self, text: str, payload: BaseModel): + self.text = text + self.payload = payload