refactor: make return type mandatory for transcription
This commit is contained in:
parent
13c034e2e4
commit
a52873a71f
3 changed files with 15 additions and 17 deletions
|
|
@ -193,7 +193,7 @@ class GenericAPIAdapter(LLMInterface):
|
|||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||
reraise=True,
|
||||
)
|
||||
async def create_transcript(self, input) -> Optional[TranscriptionReturnType]:
|
||||
async def create_transcript(self, input) -> TranscriptionReturnType:
|
||||
"""
|
||||
Generate an audio transcript from a user query.
|
||||
|
||||
|
|
@ -216,7 +216,7 @@ class GenericAPIAdapter(LLMInterface):
|
|||
raise ValueError(
|
||||
f"Could not determine MIME type for audio file: {input}. Is the extension correct?"
|
||||
)
|
||||
response = litellm.completion(
|
||||
response = litellm.completion(
|
||||
model=self.transcription_model,
|
||||
messages=[
|
||||
{
|
||||
|
|
@ -236,11 +236,8 @@ class GenericAPIAdapter(LLMInterface):
|
|||
api_base=self.endpoint,
|
||||
max_retries=self.MAX_RETRIES,
|
||||
)
|
||||
if response and response.choices and len(response.choices) > 0:
|
||||
return TranscriptionReturnType(response.choices[0].message.content,response)
|
||||
else:
|
||||
return None
|
||||
|
||||
return TranscriptionReturnType(response.choices[0].message.content, response)
|
||||
|
||||
@observe(as_type="transcribe_image")
|
||||
@retry(
|
||||
|
|
@ -250,7 +247,7 @@ class GenericAPIAdapter(LLMInterface):
|
|||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||
reraise=True,
|
||||
)
|
||||
async def transcribe_image(self, input) -> Optional[BaseModel]:
|
||||
async def transcribe_image(self, input) -> BaseModel:
|
||||
"""
|
||||
Generate a transcription of an image from a user query.
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,11 @@
|
|||
"""LLM Interface"""
|
||||
|
||||
from typing import Type, Protocol, Optional
|
||||
from typing import Type, Protocol
|
||||
from abc import abstractmethod
|
||||
from pydantic import BaseModel
|
||||
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.types import (
|
||||
TranscriptionReturnType,
|
||||
)
|
||||
|
||||
|
||||
class LLMInterface(Protocol):
|
||||
|
|
@ -37,7 +39,7 @@ class LLMInterface(Protocol):
|
|||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def create_transcript(self, input) -> Optional[BaseModel]:
|
||||
async def create_transcript(self, input) -> TranscriptionReturnType:
|
||||
"""
|
||||
Transcribe audio content to text.
|
||||
|
||||
|
|
@ -55,7 +57,7 @@ class LLMInterface(Protocol):
|
|||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def transcribe_image(self, input) -> Optional[BaseModel]:
|
||||
async def transcribe_image(self, input) -> BaseModel:
|
||||
"""
|
||||
Analyze image content and return text description.
|
||||
|
||||
|
|
|
|||
|
|
@ -25,7 +25,9 @@ from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
|
|||
from cognee.infrastructure.files.utils.open_data_file import open_data_file
|
||||
from cognee.modules.observability.get_observe import get_observe
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from ..types import TranscriptionReturnType
|
||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.types import (
|
||||
TranscriptionReturnType,
|
||||
)
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
|
@ -203,7 +205,7 @@ class OpenAIAdapter(GenericAPIAdapter):
|
|||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||
reraise=True,
|
||||
)
|
||||
async def create_transcript(self, input, **kwargs) -> Optional[TranscriptionReturnType]:
|
||||
async def create_transcript(self, input, **kwargs) -> TranscriptionReturnType:
|
||||
"""
|
||||
Generate an audio transcript from a user query.
|
||||
|
||||
|
|
@ -232,9 +234,6 @@ class OpenAIAdapter(GenericAPIAdapter):
|
|||
max_retries=self.MAX_RETRIES,
|
||||
**kwargs,
|
||||
)
|
||||
if transcription:
|
||||
return TranscriptionReturnType(transcription.text, transcription)
|
||||
|
||||
return None
|
||||
return TranscriptionReturnType(transcription.text, transcription)
|
||||
|
||||
# transcribe_image is inherited from GenericAPIAdapter
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue