strandardizing return type for transcription and some CR changes

This commit is contained in:
rajeevrajeshuni 2025-12-11 06:53:36 +05:30
parent d57d188459
commit 6260f9eb82
7 changed files with 39 additions and 20 deletions

View file

@ -44,7 +44,7 @@ class AnthropicAdapter(GenericAPIAdapter):
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
self.aclient = instructor.patch( self.aclient = instructor.patch(
create=anthropic.AsyncAnthropic(api_key=get_llm_config().llm_api_key).messages.create, create=anthropic.AsyncAnthropic(api_key=self.api_key).messages.create,
mode=instructor.Mode(self.instructor_mode), mode=instructor.Mode(self.instructor_mode),
) )

View file

@ -10,7 +10,6 @@ from instructor.core import InstructorRetryException
import logging import logging
from cognee.shared.rate_limiting import llm_rate_limiter_context_manager from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
from cognee.shared.logging_utils import get_logger
from tenacity import ( from tenacity import (
retry, retry,

View file

@ -27,6 +27,8 @@ from tenacity import (
before_sleep_log, before_sleep_log,
) )
from ..types import TranscriptionReturnType
logger = get_logger() logger = get_logger()
observe = get_observe() observe = get_observe()
@ -191,7 +193,7 @@ class GenericAPIAdapter(LLMInterface):
before_sleep=before_sleep_log(logger, logging.DEBUG), before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True, reraise=True,
) )
async def create_transcript(self, input) -> Optional[BaseModel]: async def create_transcript(self, input) -> Optional[TranscriptionReturnType]:
""" """
Generate an audio transcript from a user query. Generate an audio transcript from a user query.
@ -214,7 +216,7 @@ class GenericAPIAdapter(LLMInterface):
raise ValueError( raise ValueError(
f"Could not determine MIME type for audio file: {input}. Is the extension correct?" f"Could not determine MIME type for audio file: {input}. Is the extension correct?"
) )
return litellm.completion( response = litellm.completion(
model=self.transcription_model, model=self.transcription_model,
messages=[ messages=[
{ {
@ -234,6 +236,11 @@ class GenericAPIAdapter(LLMInterface):
api_base=self.endpoint, api_base=self.endpoint,
max_retries=self.MAX_RETRIES, max_retries=self.MAX_RETRIES,
) )
if response and response.choices and len(response.choices) > 0:
return TranscriptionReturnType(response.choices[0].message.content,response)
else:
return None
@observe(as_type="transcribe_image") @observe(as_type="transcribe_image")
@retry( @retry(

View file

@ -97,10 +97,11 @@ def get_llm_client(raise_api_key_error: bool = True):
) )
return OllamaAPIAdapter( return OllamaAPIAdapter(
llm_config.llm_endpoint,
llm_config.llm_api_key, llm_config.llm_api_key,
llm_config.llm_model, llm_config.llm_model,
"Ollama",
max_completion_tokens, max_completion_tokens,
llm_config.llm_endpoint,
instructor_mode=llm_config.llm_instructor_mode.lower(), instructor_mode=llm_config.llm_instructor_mode.lower(),
) )

View file

@ -1,9 +1,9 @@
import litellm import litellm
import instructor import instructor
from pydantic import BaseModel from pydantic import BaseModel
from typing import Type from typing import Type, Optional
from litellm import JSONSchemaValidationError, transcription from litellm import JSONSchemaValidationError
from cognee.infrastructure.files.utils.open_data_file import open_data_file
from cognee.shared.logging_utils import get_logger from cognee.shared.logging_utils import get_logger
from cognee.modules.observability.get_observe import get_observe from cognee.modules.observability.get_observe import get_observe
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.generic_llm_api.adapter import ( from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.generic_llm_api.adapter import (
@ -20,6 +20,7 @@ from tenacity import (
retry_if_not_exception_type, retry_if_not_exception_type,
before_sleep_log, before_sleep_log,
) )
from ..types import TranscriptionReturnType
from mistralai import Mistral from mistralai import Mistral
logger = get_logger() logger = get_logger()
@ -47,8 +48,6 @@ class MistralAdapter(GenericAPIAdapter):
image_transcribe_model: str = None, image_transcribe_model: str = None,
instructor_mode: str = None, instructor_mode: str = None,
): ):
from mistralai import Mistral
super().__init__( super().__init__(
api_key=api_key, api_key=api_key,
model=model, model=model,
@ -66,6 +65,7 @@ class MistralAdapter(GenericAPIAdapter):
mode=instructor.Mode(self.instructor_mode), mode=instructor.Mode(self.instructor_mode),
api_key=get_llm_config().llm_api_key, api_key=get_llm_config().llm_api_key,
) )
self.mistral_client = Mistral(api_key=self.api_key)
@observe(as_type="generation") @observe(as_type="generation")
@retry( @retry(
@ -135,7 +135,7 @@ class MistralAdapter(GenericAPIAdapter):
before_sleep=before_sleep_log(logger, logging.DEBUG), before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True, reraise=True,
) )
async def create_transcript(self, input): async def create_transcript(self, input) -> Optional[TranscriptionReturnType]:
""" """
Generate an audio transcript from a user query. Generate an audio transcript from a user query.
@ -154,14 +154,14 @@ class MistralAdapter(GenericAPIAdapter):
if self.transcription_model.startswith("mistral"): if self.transcription_model.startswith("mistral"):
transcription_model = self.transcription_model.split("/")[-1] transcription_model = self.transcription_model.split("/")[-1]
file_name = input.split("/")[-1] file_name = input.split("/")[-1]
client = Mistral(api_key=self.api_key) async with open_data_file(input, mode="rb") as f:
with open(input, "rb") as f: transcription_response = self.mistral_client.audio.transcriptions.complete(
transcription_response = client.audio.transcriptions.complete(
model=transcription_model, model=transcription_model,
file={ file={
"content": f, "content": f,
"file_name": file_name, "file_name": file_name,
}, },
) )
# TODO: We need to standardize return type of create_transcript across different models. if transcription_response:
return transcription_response return TranscriptionReturnType(transcription_response.text, transcription_response)
return None

View file

@ -1,6 +1,6 @@
import litellm import litellm
import instructor import instructor
from typing import Type from typing import Type, Optional
from pydantic import BaseModel from pydantic import BaseModel
from openai import ContentFilterFinishReasonError from openai import ContentFilterFinishReasonError
from litellm.exceptions import ContentPolicyViolationError from litellm.exceptions import ContentPolicyViolationError
@ -25,6 +25,7 @@ from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
from cognee.infrastructure.files.utils.open_data_file import open_data_file from cognee.infrastructure.files.utils.open_data_file import open_data_file
from cognee.modules.observability.get_observe import get_observe from cognee.modules.observability.get_observe import get_observe
from cognee.shared.logging_utils import get_logger from cognee.shared.logging_utils import get_logger
from ..types import TranscriptionReturnType
logger = get_logger() logger = get_logger()
@ -200,7 +201,7 @@ class OpenAIAdapter(GenericAPIAdapter):
before_sleep=before_sleep_log(logger, logging.DEBUG), before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True, reraise=True,
) )
async def create_transcript(self, input): async def create_transcript(self, input) -> Optional[TranscriptionReturnType]:
""" """
Generate an audio transcript from a user query. Generate an audio transcript from a user query.
@ -228,7 +229,9 @@ class OpenAIAdapter(GenericAPIAdapter):
api_version=self.api_version, api_version=self.api_version,
max_retries=self.MAX_RETRIES, max_retries=self.MAX_RETRIES,
) )
if transcription:
return TranscriptionReturnType(transcription.text, transcription)
return transcription return None
# transcribe image inherited from GenericAdapter # transcribe_image is inherited from GenericAPIAdapter

View file

@ -0,0 +1,9 @@
from pydantic import BaseModel
class TranscriptionReturnType:
text: str
payload: BaseModel
def __init__(self, text:str, payload: BaseModel):
self.text = text
self.payload = payload