strandardizing return type for transcription and some CR changes

This commit is contained in:
rajeevrajeshuni 2025-12-11 06:53:36 +05:30
parent d57d188459
commit 6260f9eb82
7 changed files with 39 additions and 20 deletions

View file

@ -44,7 +44,7 @@ class AnthropicAdapter(GenericAPIAdapter):
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
self.aclient = instructor.patch(
create=anthropic.AsyncAnthropic(api_key=get_llm_config().llm_api_key).messages.create,
create=anthropic.AsyncAnthropic(api_key=self.api_key).messages.create,
mode=instructor.Mode(self.instructor_mode),
)

View file

@ -10,7 +10,6 @@ from instructor.core import InstructorRetryException
import logging
from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
from cognee.shared.logging_utils import get_logger
from tenacity import (
retry,

View file

@ -27,6 +27,8 @@ from tenacity import (
before_sleep_log,
)
from ..types import TranscriptionReturnType
logger = get_logger()
observe = get_observe()
@ -191,7 +193,7 @@ class GenericAPIAdapter(LLMInterface):
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
async def create_transcript(self, input) -> Optional[BaseModel]:
async def create_transcript(self, input) -> Optional[TranscriptionReturnType]:
"""
Generate an audio transcript from a user query.
@ -214,7 +216,7 @@ class GenericAPIAdapter(LLMInterface):
raise ValueError(
f"Could not determine MIME type for audio file: {input}. Is the extension correct?"
)
return litellm.completion(
response = litellm.completion(
model=self.transcription_model,
messages=[
{
@ -234,6 +236,11 @@ class GenericAPIAdapter(LLMInterface):
api_base=self.endpoint,
max_retries=self.MAX_RETRIES,
)
if response and response.choices and len(response.choices) > 0:
return TranscriptionReturnType(response.choices[0].message.content,response)
else:
return None
@observe(as_type="transcribe_image")
@retry(

View file

@ -97,10 +97,11 @@ def get_llm_client(raise_api_key_error: bool = True):
)
return OllamaAPIAdapter(
llm_config.llm_endpoint,
llm_config.llm_api_key,
llm_config.llm_model,
"Ollama",
max_completion_tokens,
llm_config.llm_endpoint,
instructor_mode=llm_config.llm_instructor_mode.lower(),
)

View file

@ -1,9 +1,9 @@
import litellm
import instructor
from pydantic import BaseModel
from typing import Type
from litellm import JSONSchemaValidationError, transcription
from typing import Type, Optional
from litellm import JSONSchemaValidationError
from cognee.infrastructure.files.utils.open_data_file import open_data_file
from cognee.shared.logging_utils import get_logger
from cognee.modules.observability.get_observe import get_observe
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.generic_llm_api.adapter import (
@ -20,6 +20,7 @@ from tenacity import (
retry_if_not_exception_type,
before_sleep_log,
)
from ..types import TranscriptionReturnType
from mistralai import Mistral
logger = get_logger()
@ -47,8 +48,6 @@ class MistralAdapter(GenericAPIAdapter):
image_transcribe_model: str = None,
instructor_mode: str = None,
):
from mistralai import Mistral
super().__init__(
api_key=api_key,
model=model,
@ -66,6 +65,7 @@ class MistralAdapter(GenericAPIAdapter):
mode=instructor.Mode(self.instructor_mode),
api_key=get_llm_config().llm_api_key,
)
self.mistral_client = Mistral(api_key=self.api_key)
@observe(as_type="generation")
@retry(
@ -135,7 +135,7 @@ class MistralAdapter(GenericAPIAdapter):
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
async def create_transcript(self, input):
async def create_transcript(self, input) -> Optional[TranscriptionReturnType]:
"""
Generate an audio transcript from a user query.
@ -154,14 +154,14 @@ class MistralAdapter(GenericAPIAdapter):
if self.transcription_model.startswith("mistral"):
transcription_model = self.transcription_model.split("/")[-1]
file_name = input.split("/")[-1]
client = Mistral(api_key=self.api_key)
with open(input, "rb") as f:
transcription_response = client.audio.transcriptions.complete(
async with open_data_file(input, mode="rb") as f:
transcription_response = self.mistral_client.audio.transcriptions.complete(
model=transcription_model,
file={
"content": f,
"file_name": file_name,
},
)
# TODO: We need to standardize return type of create_transcript across different models.
return transcription_response
if transcription_response:
return TranscriptionReturnType(transcription_response.text, transcription_response)
return None

View file

@ -1,6 +1,6 @@
import litellm
import instructor
from typing import Type
from typing import Type, Optional
from pydantic import BaseModel
from openai import ContentFilterFinishReasonError
from litellm.exceptions import ContentPolicyViolationError
@ -25,6 +25,7 @@ from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
from cognee.infrastructure.files.utils.open_data_file import open_data_file
from cognee.modules.observability.get_observe import get_observe
from cognee.shared.logging_utils import get_logger
from ..types import TranscriptionReturnType
logger = get_logger()
@ -200,7 +201,7 @@ class OpenAIAdapter(GenericAPIAdapter):
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
async def create_transcript(self, input):
async def create_transcript(self, input) -> Optional[TranscriptionReturnType]:
"""
Generate an audio transcript from a user query.
@ -228,7 +229,9 @@ class OpenAIAdapter(GenericAPIAdapter):
api_version=self.api_version,
max_retries=self.MAX_RETRIES,
)
if transcription:
return TranscriptionReturnType(transcription.text, transcription)
return transcription
return None
# transcribe image inherited from GenericAdapter
# transcribe_image is inherited from GenericAPIAdapter

View file

@ -0,0 +1,9 @@
from pydantic import BaseModel
class TranscriptionReturnType:
text: str
payload: BaseModel
def __init__(self, text:str, payload: BaseModel):
self.text = text
self.payload = payload