strandardizing return type for transcription and some CR changes
This commit is contained in:
parent
d57d188459
commit
6260f9eb82
7 changed files with 39 additions and 20 deletions
|
|
@ -44,7 +44,7 @@ class AnthropicAdapter(GenericAPIAdapter):
|
|||
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
|
||||
|
||||
self.aclient = instructor.patch(
|
||||
create=anthropic.AsyncAnthropic(api_key=get_llm_config().llm_api_key).messages.create,
|
||||
create=anthropic.AsyncAnthropic(api_key=self.api_key).messages.create,
|
||||
mode=instructor.Mode(self.instructor_mode),
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -10,7 +10,6 @@ from instructor.core import InstructorRetryException
|
|||
|
||||
import logging
|
||||
from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
from tenacity import (
|
||||
retry,
|
||||
|
|
|
|||
|
|
@ -27,6 +27,8 @@ from tenacity import (
|
|||
before_sleep_log,
|
||||
)
|
||||
|
||||
from ..types import TranscriptionReturnType
|
||||
|
||||
logger = get_logger()
|
||||
observe = get_observe()
|
||||
|
||||
|
|
@ -191,7 +193,7 @@ class GenericAPIAdapter(LLMInterface):
|
|||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||
reraise=True,
|
||||
)
|
||||
async def create_transcript(self, input) -> Optional[BaseModel]:
|
||||
async def create_transcript(self, input) -> Optional[TranscriptionReturnType]:
|
||||
"""
|
||||
Generate an audio transcript from a user query.
|
||||
|
||||
|
|
@ -214,7 +216,7 @@ class GenericAPIAdapter(LLMInterface):
|
|||
raise ValueError(
|
||||
f"Could not determine MIME type for audio file: {input}. Is the extension correct?"
|
||||
)
|
||||
return litellm.completion(
|
||||
response = litellm.completion(
|
||||
model=self.transcription_model,
|
||||
messages=[
|
||||
{
|
||||
|
|
@ -234,6 +236,11 @@ class GenericAPIAdapter(LLMInterface):
|
|||
api_base=self.endpoint,
|
||||
max_retries=self.MAX_RETRIES,
|
||||
)
|
||||
if response and response.choices and len(response.choices) > 0:
|
||||
return TranscriptionReturnType(response.choices[0].message.content,response)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
@observe(as_type="transcribe_image")
|
||||
@retry(
|
||||
|
|
|
|||
|
|
@ -97,10 +97,11 @@ def get_llm_client(raise_api_key_error: bool = True):
|
|||
)
|
||||
|
||||
return OllamaAPIAdapter(
|
||||
llm_config.llm_endpoint,
|
||||
llm_config.llm_api_key,
|
||||
llm_config.llm_model,
|
||||
"Ollama",
|
||||
max_completion_tokens,
|
||||
llm_config.llm_endpoint,
|
||||
instructor_mode=llm_config.llm_instructor_mode.lower(),
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
import litellm
|
||||
import instructor
|
||||
from pydantic import BaseModel
|
||||
from typing import Type
|
||||
from litellm import JSONSchemaValidationError, transcription
|
||||
|
||||
from typing import Type, Optional
|
||||
from litellm import JSONSchemaValidationError
|
||||
from cognee.infrastructure.files.utils.open_data_file import open_data_file
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.observability.get_observe import get_observe
|
||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.generic_llm_api.adapter import (
|
||||
|
|
@ -20,6 +20,7 @@ from tenacity import (
|
|||
retry_if_not_exception_type,
|
||||
before_sleep_log,
|
||||
)
|
||||
from ..types import TranscriptionReturnType
|
||||
from mistralai import Mistral
|
||||
|
||||
logger = get_logger()
|
||||
|
|
@ -47,8 +48,6 @@ class MistralAdapter(GenericAPIAdapter):
|
|||
image_transcribe_model: str = None,
|
||||
instructor_mode: str = None,
|
||||
):
|
||||
from mistralai import Mistral
|
||||
|
||||
super().__init__(
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
|
|
@ -66,6 +65,7 @@ class MistralAdapter(GenericAPIAdapter):
|
|||
mode=instructor.Mode(self.instructor_mode),
|
||||
api_key=get_llm_config().llm_api_key,
|
||||
)
|
||||
self.mistral_client = Mistral(api_key=self.api_key)
|
||||
|
||||
@observe(as_type="generation")
|
||||
@retry(
|
||||
|
|
@ -135,7 +135,7 @@ class MistralAdapter(GenericAPIAdapter):
|
|||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||
reraise=True,
|
||||
)
|
||||
async def create_transcript(self, input):
|
||||
async def create_transcript(self, input) -> Optional[TranscriptionReturnType]:
|
||||
"""
|
||||
Generate an audio transcript from a user query.
|
||||
|
||||
|
|
@ -154,14 +154,14 @@ class MistralAdapter(GenericAPIAdapter):
|
|||
if self.transcription_model.startswith("mistral"):
|
||||
transcription_model = self.transcription_model.split("/")[-1]
|
||||
file_name = input.split("/")[-1]
|
||||
client = Mistral(api_key=self.api_key)
|
||||
with open(input, "rb") as f:
|
||||
transcription_response = client.audio.transcriptions.complete(
|
||||
async with open_data_file(input, mode="rb") as f:
|
||||
transcription_response = self.mistral_client.audio.transcriptions.complete(
|
||||
model=transcription_model,
|
||||
file={
|
||||
"content": f,
|
||||
"file_name": file_name,
|
||||
},
|
||||
)
|
||||
# TODO: We need to standardize return type of create_transcript across different models.
|
||||
return transcription_response
|
||||
if transcription_response:
|
||||
return TranscriptionReturnType(transcription_response.text, transcription_response)
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import litellm
|
||||
import instructor
|
||||
from typing import Type
|
||||
from typing import Type, Optional
|
||||
from pydantic import BaseModel
|
||||
from openai import ContentFilterFinishReasonError
|
||||
from litellm.exceptions import ContentPolicyViolationError
|
||||
|
|
@ -25,6 +25,7 @@ from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
|
|||
from cognee.infrastructure.files.utils.open_data_file import open_data_file
|
||||
from cognee.modules.observability.get_observe import get_observe
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from ..types import TranscriptionReturnType
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
|
@ -200,7 +201,7 @@ class OpenAIAdapter(GenericAPIAdapter):
|
|||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||
reraise=True,
|
||||
)
|
||||
async def create_transcript(self, input):
|
||||
async def create_transcript(self, input) -> Optional[TranscriptionReturnType]:
|
||||
"""
|
||||
Generate an audio transcript from a user query.
|
||||
|
||||
|
|
@ -228,7 +229,9 @@ class OpenAIAdapter(GenericAPIAdapter):
|
|||
api_version=self.api_version,
|
||||
max_retries=self.MAX_RETRIES,
|
||||
)
|
||||
if transcription:
|
||||
return TranscriptionReturnType(transcription.text, transcription)
|
||||
|
||||
return transcription
|
||||
return None
|
||||
|
||||
# transcribe image inherited from GenericAdapter
|
||||
# transcribe_image is inherited from GenericAPIAdapter
|
||||
|
|
|
|||
|
|
@ -0,0 +1,9 @@
|
|||
from pydantic import BaseModel
|
||||
|
||||
class TranscriptionReturnType:
|
||||
text: str
|
||||
payload: BaseModel
|
||||
|
||||
def __init__(self, text:str, payload: BaseModel):
|
||||
self.text = text
|
||||
self.payload = payload
|
||||
Loading…
Add table
Reference in a new issue