refactor: make return type mandatory for transcription

This commit is contained in:
Igor Ilic 2025-12-16 15:54:25 +01:00
parent 13c034e2e4
commit a52873a71f
3 changed files with 15 additions and 17 deletions

View file

@ -193,7 +193,7 @@ class GenericAPIAdapter(LLMInterface):
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
async def create_transcript(self, input) -> Optional[TranscriptionReturnType]:
async def create_transcript(self, input) -> TranscriptionReturnType:
"""
Generate an audio transcript from a user query.
@ -216,7 +216,7 @@ class GenericAPIAdapter(LLMInterface):
raise ValueError(
f"Could not determine MIME type for audio file: {input}. Is the extension correct?"
)
response = litellm.completion(
response = litellm.completion(
model=self.transcription_model,
messages=[
{
@ -236,11 +236,8 @@ class GenericAPIAdapter(LLMInterface):
api_base=self.endpoint,
max_retries=self.MAX_RETRIES,
)
if response and response.choices and len(response.choices) > 0:
return TranscriptionReturnType(response.choices[0].message.content,response)
else:
return None
return TranscriptionReturnType(response.choices[0].message.content, response)
@observe(as_type="transcribe_image")
@retry(
@ -250,7 +247,7 @@ class GenericAPIAdapter(LLMInterface):
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
async def transcribe_image(self, input) -> Optional[BaseModel]:
async def transcribe_image(self, input) -> BaseModel:
"""
Generate a transcription of an image from a user query.

View file

@ -1,9 +1,11 @@
"""LLM Interface"""
from typing import Type, Protocol, Optional
from typing import Type, Protocol
from abc import abstractmethod
from pydantic import BaseModel
from cognee.infrastructure.llm.LLMGateway import LLMGateway
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.types import (
TranscriptionReturnType,
)
class LLMInterface(Protocol):
@ -37,7 +39,7 @@ class LLMInterface(Protocol):
raise NotImplementedError
@abstractmethod
async def create_transcript(self, input) -> Optional[BaseModel]:
async def create_transcript(self, input) -> TranscriptionReturnType:
"""
Transcribe audio content to text.
@ -55,7 +57,7 @@ class LLMInterface(Protocol):
raise NotImplementedError
@abstractmethod
async def transcribe_image(self, input) -> Optional[BaseModel]:
async def transcribe_image(self, input) -> BaseModel:
"""
Analyze image content and return text description.

View file

@ -25,7 +25,9 @@ from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
from cognee.infrastructure.files.utils.open_data_file import open_data_file
from cognee.modules.observability.get_observe import get_observe
from cognee.shared.logging_utils import get_logger
from ..types import TranscriptionReturnType
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.types import (
TranscriptionReturnType,
)
logger = get_logger()
@ -203,7 +205,7 @@ class OpenAIAdapter(GenericAPIAdapter):
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
async def create_transcript(self, input, **kwargs) -> Optional[TranscriptionReturnType]:
async def create_transcript(self, input, **kwargs) -> TranscriptionReturnType:
"""
Generate an audio transcript from a user query.
@ -232,9 +234,6 @@ class OpenAIAdapter(GenericAPIAdapter):
max_retries=self.MAX_RETRIES,
**kwargs,
)
if transcription:
return TranscriptionReturnType(transcription.text, transcription)
return None
return TranscriptionReturnType(transcription.text, transcription)
# transcribe_image is inherited from GenericAPIAdapter