diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py index 4d75c886a..49b13fcaa 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py @@ -44,7 +44,7 @@ class AnthropicAdapter(GenericAPIAdapter): self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode self.aclient = instructor.patch( - create=anthropic.AsyncAnthropic(api_key=get_llm_config().llm_api_key).messages.create, + create=anthropic.AsyncAnthropic(api_key=self.api_key).messages.create, mode=instructor.Mode(self.instructor_mode), ) diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py index ffb7bf77b..99dfd6179 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py @@ -10,7 +10,6 @@ from instructor.core import InstructorRetryException import logging from cognee.shared.rate_limiting import llm_rate_limiter_context_manager -from cognee.shared.logging_utils import get_logger from tenacity import ( retry, diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py index 408058d3e..7905c25bf 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py @@ -27,6 +27,8 @@ from tenacity import ( before_sleep_log, ) +from ..types import TranscriptionReturnType + logger = get_logger() observe = get_observe() @@ -191,7 +193,7 @@ class GenericAPIAdapter(LLMInterface): before_sleep=before_sleep_log(logger, logging.DEBUG), reraise=True, ) - async def create_transcript(self, input) -> Optional[BaseModel]: + async def create_transcript(self, input) -> Optional[TranscriptionReturnType]: """ Generate an audio transcript from a user query. @@ -214,7 +216,7 @@ class GenericAPIAdapter(LLMInterface): raise ValueError( f"Could not determine MIME type for audio file: {input}. Is the extension correct?" ) - return litellm.completion( + response = litellm.completion( model=self.transcription_model, messages=[ { @@ -234,6 +236,11 @@ class GenericAPIAdapter(LLMInterface): api_base=self.endpoint, max_retries=self.MAX_RETRIES, ) + if response and response.choices and len(response.choices) > 0: + return TranscriptionReturnType(response.choices[0].message.content,response) + else: + return None + @observe(as_type="transcribe_image") @retry( diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py index de6cfaf19..e5f4bd1b1 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py @@ -97,10 +97,11 @@ def get_llm_client(raise_api_key_error: bool = True): ) return OllamaAPIAdapter( + llm_config.llm_endpoint, llm_config.llm_api_key, llm_config.llm_model, + "Ollama", max_completion_tokens, - llm_config.llm_endpoint, instructor_mode=llm_config.llm_instructor_mode.lower(), ) diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py index 954510a25..b141f7585 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py @@ -1,9 +1,9 @@ import litellm import instructor from pydantic import BaseModel -from typing import Type -from litellm import JSONSchemaValidationError, transcription - +from typing import Type, Optional +from litellm import JSONSchemaValidationError +from cognee.infrastructure.files.utils.open_data_file import open_data_file from cognee.shared.logging_utils import get_logger from cognee.modules.observability.get_observe import get_observe from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.generic_llm_api.adapter import ( @@ -20,6 +20,7 @@ from tenacity import ( retry_if_not_exception_type, before_sleep_log, ) +from ..types import TranscriptionReturnType from mistralai import Mistral logger = get_logger() @@ -47,8 +48,6 @@ class MistralAdapter(GenericAPIAdapter): image_transcribe_model: str = None, instructor_mode: str = None, ): - from mistralai import Mistral - super().__init__( api_key=api_key, model=model, @@ -66,6 +65,7 @@ class MistralAdapter(GenericAPIAdapter): mode=instructor.Mode(self.instructor_mode), api_key=get_llm_config().llm_api_key, ) + self.mistral_client = Mistral(api_key=self.api_key) @observe(as_type="generation") @retry( @@ -135,7 +135,7 @@ class MistralAdapter(GenericAPIAdapter): before_sleep=before_sleep_log(logger, logging.DEBUG), reraise=True, ) - async def create_transcript(self, input): + async def create_transcript(self, input) -> Optional[TranscriptionReturnType]: """ Generate an audio transcript from a user query. @@ -154,14 +154,14 @@ class MistralAdapter(GenericAPIAdapter): if self.transcription_model.startswith("mistral"): transcription_model = self.transcription_model.split("/")[-1] file_name = input.split("/")[-1] - client = Mistral(api_key=self.api_key) - with open(input, "rb") as f: - transcription_response = client.audio.transcriptions.complete( + async with open_data_file(input, mode="rb") as f: + transcription_response = self.mistral_client.audio.transcriptions.complete( model=transcription_model, file={ "content": f, "file_name": file_name, }, ) - # TODO: We need to standardize return type of create_transcript across different models. - return transcription_response + if transcription_response: + return TranscriptionReturnType(transcription_response.text, transcription_response) + return None diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py index 57b6d339a..94c6aed6d 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py @@ -1,6 +1,6 @@ import litellm import instructor -from typing import Type +from typing import Type, Optional from pydantic import BaseModel from openai import ContentFilterFinishReasonError from litellm.exceptions import ContentPolicyViolationError @@ -25,6 +25,7 @@ from cognee.shared.rate_limiting import llm_rate_limiter_context_manager from cognee.infrastructure.files.utils.open_data_file import open_data_file from cognee.modules.observability.get_observe import get_observe from cognee.shared.logging_utils import get_logger +from ..types import TranscriptionReturnType logger = get_logger() @@ -200,7 +201,7 @@ class OpenAIAdapter(GenericAPIAdapter): before_sleep=before_sleep_log(logger, logging.DEBUG), reraise=True, ) - async def create_transcript(self, input): + async def create_transcript(self, input) -> Optional[TranscriptionReturnType]: """ Generate an audio transcript from a user query. @@ -228,7 +229,9 @@ class OpenAIAdapter(GenericAPIAdapter): api_version=self.api_version, max_retries=self.MAX_RETRIES, ) + if transcription: + return TranscriptionReturnType(transcription.text, transcription) - return transcription + return None - # transcribe image inherited from GenericAdapter + # transcribe_image is inherited from GenericAPIAdapter diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/types.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/types.py new file mode 100644 index 000000000..887cdd88d --- /dev/null +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/types.py @@ -0,0 +1,9 @@ +from pydantic import BaseModel + +class TranscriptionReturnType: + text: str + payload: BaseModel + + def __init__(self, text:str, payload: BaseModel): + self.text = text + self.payload = payload \ No newline at end of file