strandardizing return type for transcription and some CR changes
This commit is contained in:
parent
d57d188459
commit
6260f9eb82
7 changed files with 39 additions and 20 deletions
|
|
@ -44,7 +44,7 @@ class AnthropicAdapter(GenericAPIAdapter):
|
||||||
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
|
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
|
||||||
|
|
||||||
self.aclient = instructor.patch(
|
self.aclient = instructor.patch(
|
||||||
create=anthropic.AsyncAnthropic(api_key=get_llm_config().llm_api_key).messages.create,
|
create=anthropic.AsyncAnthropic(api_key=self.api_key).messages.create,
|
||||||
mode=instructor.Mode(self.instructor_mode),
|
mode=instructor.Mode(self.instructor_mode),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,6 @@ from instructor.core import InstructorRetryException
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
|
from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
|
||||||
from cognee.shared.logging_utils import get_logger
|
|
||||||
|
|
||||||
from tenacity import (
|
from tenacity import (
|
||||||
retry,
|
retry,
|
||||||
|
|
|
||||||
|
|
@ -27,6 +27,8 @@ from tenacity import (
|
||||||
before_sleep_log,
|
before_sleep_log,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from ..types import TranscriptionReturnType
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
observe = get_observe()
|
observe = get_observe()
|
||||||
|
|
||||||
|
|
@ -191,7 +193,7 @@ class GenericAPIAdapter(LLMInterface):
|
||||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||||
reraise=True,
|
reraise=True,
|
||||||
)
|
)
|
||||||
async def create_transcript(self, input) -> Optional[BaseModel]:
|
async def create_transcript(self, input) -> Optional[TranscriptionReturnType]:
|
||||||
"""
|
"""
|
||||||
Generate an audio transcript from a user query.
|
Generate an audio transcript from a user query.
|
||||||
|
|
||||||
|
|
@ -214,7 +216,7 @@ class GenericAPIAdapter(LLMInterface):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Could not determine MIME type for audio file: {input}. Is the extension correct?"
|
f"Could not determine MIME type for audio file: {input}. Is the extension correct?"
|
||||||
)
|
)
|
||||||
return litellm.completion(
|
response = litellm.completion(
|
||||||
model=self.transcription_model,
|
model=self.transcription_model,
|
||||||
messages=[
|
messages=[
|
||||||
{
|
{
|
||||||
|
|
@ -234,6 +236,11 @@ class GenericAPIAdapter(LLMInterface):
|
||||||
api_base=self.endpoint,
|
api_base=self.endpoint,
|
||||||
max_retries=self.MAX_RETRIES,
|
max_retries=self.MAX_RETRIES,
|
||||||
)
|
)
|
||||||
|
if response and response.choices and len(response.choices) > 0:
|
||||||
|
return TranscriptionReturnType(response.choices[0].message.content,response)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
@observe(as_type="transcribe_image")
|
@observe(as_type="transcribe_image")
|
||||||
@retry(
|
@retry(
|
||||||
|
|
|
||||||
|
|
@ -97,10 +97,11 @@ def get_llm_client(raise_api_key_error: bool = True):
|
||||||
)
|
)
|
||||||
|
|
||||||
return OllamaAPIAdapter(
|
return OllamaAPIAdapter(
|
||||||
|
llm_config.llm_endpoint,
|
||||||
llm_config.llm_api_key,
|
llm_config.llm_api_key,
|
||||||
llm_config.llm_model,
|
llm_config.llm_model,
|
||||||
|
"Ollama",
|
||||||
max_completion_tokens,
|
max_completion_tokens,
|
||||||
llm_config.llm_endpoint,
|
|
||||||
instructor_mode=llm_config.llm_instructor_mode.lower(),
|
instructor_mode=llm_config.llm_instructor_mode.lower(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,9 @@
|
||||||
import litellm
|
import litellm
|
||||||
import instructor
|
import instructor
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import Type
|
from typing import Type, Optional
|
||||||
from litellm import JSONSchemaValidationError, transcription
|
from litellm import JSONSchemaValidationError
|
||||||
|
from cognee.infrastructure.files.utils.open_data_file import open_data_file
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
from cognee.modules.observability.get_observe import get_observe
|
from cognee.modules.observability.get_observe import get_observe
|
||||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.generic_llm_api.adapter import (
|
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.generic_llm_api.adapter import (
|
||||||
|
|
@ -20,6 +20,7 @@ from tenacity import (
|
||||||
retry_if_not_exception_type,
|
retry_if_not_exception_type,
|
||||||
before_sleep_log,
|
before_sleep_log,
|
||||||
)
|
)
|
||||||
|
from ..types import TranscriptionReturnType
|
||||||
from mistralai import Mistral
|
from mistralai import Mistral
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
@ -47,8 +48,6 @@ class MistralAdapter(GenericAPIAdapter):
|
||||||
image_transcribe_model: str = None,
|
image_transcribe_model: str = None,
|
||||||
instructor_mode: str = None,
|
instructor_mode: str = None,
|
||||||
):
|
):
|
||||||
from mistralai import Mistral
|
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
model=model,
|
model=model,
|
||||||
|
|
@ -66,6 +65,7 @@ class MistralAdapter(GenericAPIAdapter):
|
||||||
mode=instructor.Mode(self.instructor_mode),
|
mode=instructor.Mode(self.instructor_mode),
|
||||||
api_key=get_llm_config().llm_api_key,
|
api_key=get_llm_config().llm_api_key,
|
||||||
)
|
)
|
||||||
|
self.mistral_client = Mistral(api_key=self.api_key)
|
||||||
|
|
||||||
@observe(as_type="generation")
|
@observe(as_type="generation")
|
||||||
@retry(
|
@retry(
|
||||||
|
|
@ -135,7 +135,7 @@ class MistralAdapter(GenericAPIAdapter):
|
||||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||||
reraise=True,
|
reraise=True,
|
||||||
)
|
)
|
||||||
async def create_transcript(self, input):
|
async def create_transcript(self, input) -> Optional[TranscriptionReturnType]:
|
||||||
"""
|
"""
|
||||||
Generate an audio transcript from a user query.
|
Generate an audio transcript from a user query.
|
||||||
|
|
||||||
|
|
@ -154,14 +154,14 @@ class MistralAdapter(GenericAPIAdapter):
|
||||||
if self.transcription_model.startswith("mistral"):
|
if self.transcription_model.startswith("mistral"):
|
||||||
transcription_model = self.transcription_model.split("/")[-1]
|
transcription_model = self.transcription_model.split("/")[-1]
|
||||||
file_name = input.split("/")[-1]
|
file_name = input.split("/")[-1]
|
||||||
client = Mistral(api_key=self.api_key)
|
async with open_data_file(input, mode="rb") as f:
|
||||||
with open(input, "rb") as f:
|
transcription_response = self.mistral_client.audio.transcriptions.complete(
|
||||||
transcription_response = client.audio.transcriptions.complete(
|
|
||||||
model=transcription_model,
|
model=transcription_model,
|
||||||
file={
|
file={
|
||||||
"content": f,
|
"content": f,
|
||||||
"file_name": file_name,
|
"file_name": file_name,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
# TODO: We need to standardize return type of create_transcript across different models.
|
if transcription_response:
|
||||||
return transcription_response
|
return TranscriptionReturnType(transcription_response.text, transcription_response)
|
||||||
|
return None
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
import litellm
|
import litellm
|
||||||
import instructor
|
import instructor
|
||||||
from typing import Type
|
from typing import Type, Optional
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from openai import ContentFilterFinishReasonError
|
from openai import ContentFilterFinishReasonError
|
||||||
from litellm.exceptions import ContentPolicyViolationError
|
from litellm.exceptions import ContentPolicyViolationError
|
||||||
|
|
@ -25,6 +25,7 @@ from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
|
||||||
from cognee.infrastructure.files.utils.open_data_file import open_data_file
|
from cognee.infrastructure.files.utils.open_data_file import open_data_file
|
||||||
from cognee.modules.observability.get_observe import get_observe
|
from cognee.modules.observability.get_observe import get_observe
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
from ..types import TranscriptionReturnType
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
@ -200,7 +201,7 @@ class OpenAIAdapter(GenericAPIAdapter):
|
||||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||||
reraise=True,
|
reraise=True,
|
||||||
)
|
)
|
||||||
async def create_transcript(self, input):
|
async def create_transcript(self, input) -> Optional[TranscriptionReturnType]:
|
||||||
"""
|
"""
|
||||||
Generate an audio transcript from a user query.
|
Generate an audio transcript from a user query.
|
||||||
|
|
||||||
|
|
@ -228,7 +229,9 @@ class OpenAIAdapter(GenericAPIAdapter):
|
||||||
api_version=self.api_version,
|
api_version=self.api_version,
|
||||||
max_retries=self.MAX_RETRIES,
|
max_retries=self.MAX_RETRIES,
|
||||||
)
|
)
|
||||||
|
if transcription:
|
||||||
|
return TranscriptionReturnType(transcription.text, transcription)
|
||||||
|
|
||||||
return transcription
|
return None
|
||||||
|
|
||||||
# transcribe image inherited from GenericAdapter
|
# transcribe_image is inherited from GenericAPIAdapter
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,9 @@
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
class TranscriptionReturnType:
|
||||||
|
text: str
|
||||||
|
payload: BaseModel
|
||||||
|
|
||||||
|
def __init__(self, text:str, payload: BaseModel):
|
||||||
|
self.text = text
|
||||||
|
self.payload = payload
|
||||||
Loading…
Add table
Reference in a new issue