Adding support for audio/image transcription for all other providers
This commit is contained in:
parent
c2c64a417c
commit
02b1778658
10 changed files with 313 additions and 312 deletions
|
|
@ -34,19 +34,6 @@ class LLMGateway:
|
||||||
text_input=text_input, system_prompt=system_prompt, response_model=response_model
|
text_input=text_input, system_prompt=system_prompt, response_model=response_model
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def create_structured_output(
|
|
||||||
text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
|
||||||
) -> BaseModel:
|
|
||||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.get_llm_client import (
|
|
||||||
get_llm_client,
|
|
||||||
)
|
|
||||||
|
|
||||||
llm_client = get_llm_client()
|
|
||||||
return llm_client.create_structured_output(
|
|
||||||
text_input=text_input, system_prompt=system_prompt, response_model=response_model
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_transcript(input) -> Coroutine:
|
def create_transcript(input) -> Coroutine:
|
||||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.get_llm_client import (
|
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.get_llm_client import (
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,9 @@ from typing import Type
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
import litellm
|
import litellm
|
||||||
import instructor
|
import instructor
|
||||||
|
import anthropic
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
from cognee.modules.observability.get_observe import get_observe
|
||||||
from tenacity import (
|
from tenacity import (
|
||||||
retry,
|
retry,
|
||||||
stop_after_delay,
|
stop_after_delay,
|
||||||
|
|
@ -12,27 +14,32 @@ from tenacity import (
|
||||||
before_sleep_log,
|
before_sleep_log,
|
||||||
)
|
)
|
||||||
|
|
||||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
|
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.generic_llm_api.adapter import (
|
||||||
LLMInterface,
|
GenericAPIAdapter,
|
||||||
)
|
)
|
||||||
from cognee.infrastructure.llm.config import get_llm_config
|
from cognee.infrastructure.llm.config import get_llm_config
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
observe = get_observe()
|
||||||
|
|
||||||
|
|
||||||
class AnthropicAdapter(LLMInterface):
|
class AnthropicAdapter(GenericAPIAdapter):
|
||||||
"""
|
"""
|
||||||
Adapter for interfacing with the Anthropic API, enabling structured output generation
|
Adapter for interfacing with the Anthropic API, enabling structured output generation
|
||||||
and prompt display.
|
and prompt display.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
name = "Anthropic"
|
|
||||||
model: str
|
|
||||||
default_instructor_mode = "anthropic_tools"
|
default_instructor_mode = "anthropic_tools"
|
||||||
|
|
||||||
def __init__(self, max_completion_tokens: int, model: str = None, instructor_mode: str = None):
|
def __init__(
|
||||||
import anthropic
|
self, api_key: str, model: str, max_completion_tokens: int, instructor_mode: str = None
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
api_key=api_key,
|
||||||
|
model=model,
|
||||||
|
max_completion_tokens=max_completion_tokens,
|
||||||
|
name="Anthropic",
|
||||||
|
)
|
||||||
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
|
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
|
||||||
|
|
||||||
self.aclient = instructor.patch(
|
self.aclient = instructor.patch(
|
||||||
|
|
@ -40,9 +47,7 @@ class AnthropicAdapter(LLMInterface):
|
||||||
mode=instructor.Mode(self.instructor_mode),
|
mode=instructor.Mode(self.instructor_mode),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.model = model
|
@observe(as_type="generation")
|
||||||
self.max_completion_tokens = max_completion_tokens
|
|
||||||
|
|
||||||
@retry(
|
@retry(
|
||||||
stop=stop_after_delay(128),
|
stop=stop_after_delay(128),
|
||||||
wait=wait_exponential_jitter(2, 128),
|
wait=wait_exponential_jitter(2, 128),
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
"""Adapter for Generic API LLM provider API"""
|
"""Adapter for Gemini API LLM provider"""
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
import instructor
|
import instructor
|
||||||
|
|
@ -8,12 +8,7 @@ from openai import ContentFilterFinishReasonError
|
||||||
from litellm.exceptions import ContentPolicyViolationError
|
from litellm.exceptions import ContentPolicyViolationError
|
||||||
from instructor.core import InstructorRetryException
|
from instructor.core import InstructorRetryException
|
||||||
|
|
||||||
from cognee.infrastructure.llm.exceptions import ContentPolicyFilterError
|
|
||||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
|
|
||||||
LLMInterface,
|
|
||||||
)
|
|
||||||
import logging
|
import logging
|
||||||
from cognee.shared.logging_utils import get_logger
|
|
||||||
from tenacity import (
|
from tenacity import (
|
||||||
retry,
|
retry,
|
||||||
stop_after_delay,
|
stop_after_delay,
|
||||||
|
|
@ -22,55 +17,65 @@ from tenacity import (
|
||||||
before_sleep_log,
|
before_sleep_log,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from cognee.infrastructure.llm.exceptions import ContentPolicyFilterError
|
||||||
|
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.generic_llm_api.adapter import (
|
||||||
|
GenericAPIAdapter,
|
||||||
|
)
|
||||||
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
from cognee.modules.observability.get_observe import get_observe
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
observe = get_observe()
|
||||||
|
|
||||||
|
|
||||||
class GeminiAdapter(LLMInterface):
|
class GeminiAdapter(GenericAPIAdapter):
|
||||||
"""
|
"""
|
||||||
Adapter for Gemini API LLM provider.
|
Adapter for Gemini API LLM provider.
|
||||||
|
|
||||||
This class initializes the API adapter with necessary credentials and configurations for
|
This class initializes the API adapter with necessary credentials and configurations for
|
||||||
interacting with the gemini LLM models. It provides methods for creating structured outputs
|
interacting with the gemini LLM models. It provides methods for creating structured outputs
|
||||||
based on user input and system prompts.
|
based on user input and system prompts, as well as multimodal processing capabilities.
|
||||||
|
|
||||||
Public methods:
|
Public methods:
|
||||||
- acreate_structured_output(text_input: str, system_prompt: str, response_model:
|
- acreate_structured_output(text_input: str, system_prompt: str, response_model: Type[BaseModel]) -> BaseModel
|
||||||
Type[BaseModel]) -> BaseModel
|
- create_transcript(input) -> BaseModel: Transcribe audio files to text
|
||||||
|
- transcribe_image(input) -> BaseModel: Inherited from GenericAPIAdapter
|
||||||
"""
|
"""
|
||||||
|
|
||||||
name: str
|
|
||||||
model: str
|
|
||||||
api_key: str
|
|
||||||
default_instructor_mode = "json_mode"
|
default_instructor_mode = "json_mode"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
endpoint,
|
|
||||||
api_key: str,
|
api_key: str,
|
||||||
model: str,
|
model: str,
|
||||||
api_version: str,
|
|
||||||
max_completion_tokens: int,
|
max_completion_tokens: int,
|
||||||
|
endpoint: str = None,
|
||||||
|
api_version: str = None,
|
||||||
|
transcription_model: str = None,
|
||||||
instructor_mode: str = None,
|
instructor_mode: str = None,
|
||||||
fallback_model: str = None,
|
fallback_model: str = None,
|
||||||
fallback_api_key: str = None,
|
fallback_api_key: str = None,
|
||||||
fallback_endpoint: str = None,
|
fallback_endpoint: str = None,
|
||||||
):
|
):
|
||||||
self.model = model
|
super().__init__(
|
||||||
self.api_key = api_key
|
api_key=api_key,
|
||||||
self.endpoint = endpoint
|
model=model,
|
||||||
self.api_version = api_version
|
max_completion_tokens=max_completion_tokens,
|
||||||
self.max_completion_tokens = max_completion_tokens
|
name="Gemini",
|
||||||
|
endpoint=endpoint,
|
||||||
self.fallback_model = fallback_model
|
api_version=api_version,
|
||||||
self.fallback_api_key = fallback_api_key
|
transcription_model=transcription_model,
|
||||||
self.fallback_endpoint = fallback_endpoint
|
fallback_model=fallback_model,
|
||||||
|
fallback_api_key=fallback_api_key,
|
||||||
|
fallback_endpoint=fallback_endpoint,
|
||||||
|
)
|
||||||
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
|
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
|
||||||
|
|
||||||
self.aclient = instructor.from_litellm(
|
self.aclient = instructor.from_litellm(
|
||||||
litellm.acompletion, mode=instructor.Mode(self.instructor_mode)
|
litellm.acompletion, mode=instructor.Mode(self.instructor_mode)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@observe(as_type="generation")
|
||||||
@retry(
|
@retry(
|
||||||
stop=stop_after_delay(128),
|
stop=stop_after_delay(128),
|
||||||
wait=wait_exponential_jitter(2, 128),
|
wait=wait_exponential_jitter(2, 128),
|
||||||
|
|
@ -118,7 +123,7 @@ class GeminiAdapter(LLMInterface):
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
api_key=self.api_key,
|
api_key=self.api_key,
|
||||||
max_retries=5,
|
max_retries=self.MAX_RETRIES,
|
||||||
api_base=self.endpoint,
|
api_base=self.endpoint,
|
||||||
api_version=self.api_version,
|
api_version=self.api_version,
|
||||||
response_model=response_model,
|
response_model=response_model,
|
||||||
|
|
@ -152,7 +157,7 @@ class GeminiAdapter(LLMInterface):
|
||||||
"content": system_prompt,
|
"content": system_prompt,
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
max_retries=5,
|
max_retries=self.MAX_RETRIES,
|
||||||
api_key=self.fallback_api_key,
|
api_key=self.fallback_api_key,
|
||||||
api_base=self.fallback_endpoint,
|
api_base=self.fallback_endpoint,
|
||||||
response_model=response_model,
|
response_model=response_model,
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,10 @@
|
||||||
"""Adapter for Generic API LLM provider API"""
|
"""Adapter for Generic API LLM provider API"""
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import mimetypes
|
||||||
import litellm
|
import litellm
|
||||||
import instructor
|
import instructor
|
||||||
from typing import Type
|
from typing import Type, Optional
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from openai import ContentFilterFinishReasonError
|
from openai import ContentFilterFinishReasonError
|
||||||
from litellm.exceptions import ContentPolicyViolationError
|
from litellm.exceptions import ContentPolicyViolationError
|
||||||
|
|
@ -12,6 +14,8 @@ from cognee.infrastructure.llm.exceptions import ContentPolicyFilterError
|
||||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
|
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
|
||||||
LLMInterface,
|
LLMInterface,
|
||||||
)
|
)
|
||||||
|
from cognee.infrastructure.files.utils.open_data_file import open_data_file
|
||||||
|
from cognee.modules.observability.get_observe import get_observe
|
||||||
import logging
|
import logging
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
from tenacity import (
|
from tenacity import (
|
||||||
|
|
@ -23,6 +27,7 @@ from tenacity import (
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
observe = get_observe()
|
||||||
|
|
||||||
|
|
||||||
class GenericAPIAdapter(LLMInterface):
|
class GenericAPIAdapter(LLMInterface):
|
||||||
|
|
@ -38,18 +43,19 @@ class GenericAPIAdapter(LLMInterface):
|
||||||
Type[BaseModel]) -> BaseModel
|
Type[BaseModel]) -> BaseModel
|
||||||
"""
|
"""
|
||||||
|
|
||||||
name: str
|
MAX_RETRIES = 5
|
||||||
model: str
|
|
||||||
api_key: str
|
|
||||||
default_instructor_mode = "json_mode"
|
default_instructor_mode = "json_mode"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
endpoint,
|
|
||||||
api_key: str,
|
api_key: str,
|
||||||
model: str,
|
model: str,
|
||||||
name: str,
|
|
||||||
max_completion_tokens: int,
|
max_completion_tokens: int,
|
||||||
|
name: str,
|
||||||
|
endpoint: str = None,
|
||||||
|
api_version: str = None,
|
||||||
|
transcription_model: str = None,
|
||||||
|
image_transcribe_model: str = None,
|
||||||
instructor_mode: str = None,
|
instructor_mode: str = None,
|
||||||
fallback_model: str = None,
|
fallback_model: str = None,
|
||||||
fallback_api_key: str = None,
|
fallback_api_key: str = None,
|
||||||
|
|
@ -58,9 +64,11 @@ class GenericAPIAdapter(LLMInterface):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.model = model
|
self.model = model
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
|
self.api_version = api_version
|
||||||
self.endpoint = endpoint
|
self.endpoint = endpoint
|
||||||
self.max_completion_tokens = max_completion_tokens
|
self.max_completion_tokens = max_completion_tokens
|
||||||
|
self.transcription_model = transcription_model or model
|
||||||
|
self.image_transcribe_model = image_transcribe_model or model
|
||||||
self.fallback_model = fallback_model
|
self.fallback_model = fallback_model
|
||||||
self.fallback_api_key = fallback_api_key
|
self.fallback_api_key = fallback_api_key
|
||||||
self.fallback_endpoint = fallback_endpoint
|
self.fallback_endpoint = fallback_endpoint
|
||||||
|
|
@ -71,6 +79,7 @@ class GenericAPIAdapter(LLMInterface):
|
||||||
litellm.acompletion, mode=instructor.Mode(self.instructor_mode)
|
litellm.acompletion, mode=instructor.Mode(self.instructor_mode)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@observe(as_type="generation")
|
||||||
@retry(
|
@retry(
|
||||||
stop=stop_after_delay(128),
|
stop=stop_after_delay(128),
|
||||||
wait=wait_exponential_jitter(2, 128),
|
wait=wait_exponential_jitter(2, 128),
|
||||||
|
|
@ -170,3 +179,112 @@ class GenericAPIAdapter(LLMInterface):
|
||||||
raise ContentPolicyFilterError(
|
raise ContentPolicyFilterError(
|
||||||
f"The provided input contains content that is not aligned with our content policy: {text_input}"
|
f"The provided input contains content that is not aligned with our content policy: {text_input}"
|
||||||
) from error
|
) from error
|
||||||
|
|
||||||
|
@observe(as_type="transcription")
|
||||||
|
@retry(
|
||||||
|
stop=stop_after_delay(128),
|
||||||
|
wait=wait_exponential_jitter(2, 128),
|
||||||
|
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
||||||
|
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||||
|
reraise=True,
|
||||||
|
)
|
||||||
|
async def create_transcript(self, input) -> Optional[BaseModel]:
|
||||||
|
"""
|
||||||
|
Generate an audio transcript from a user query.
|
||||||
|
|
||||||
|
This method creates a transcript from the specified audio file, raising a
|
||||||
|
FileNotFoundError if the file does not exist. The audio file is processed and the
|
||||||
|
transcription is retrieved from the API.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
-----------
|
||||||
|
- input: The path to the audio file that needs to be transcribed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
--------
|
||||||
|
The generated transcription of the audio file.
|
||||||
|
"""
|
||||||
|
async with open_data_file(input, mode="rb") as audio_file:
|
||||||
|
encoded_string = base64.b64encode(audio_file.read()).decode("utf-8")
|
||||||
|
mime_type, _ = mimetypes.guess_type(input)
|
||||||
|
if not mime_type or not mime_type.startswith("audio/"):
|
||||||
|
raise ValueError(
|
||||||
|
f"Could not determine MIME type for audio file: {input}. Is the extension correct?"
|
||||||
|
)
|
||||||
|
return litellm.completion(
|
||||||
|
model=self.transcription_model,
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "file",
|
||||||
|
"file": {"file_data": f"data:{mime_type};base64,{encoded_string}"},
|
||||||
|
},
|
||||||
|
{"type": "text", "text": "Transcribe the following audio precisely."},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
api_key=self.api_key,
|
||||||
|
api_version=self.api_version,
|
||||||
|
max_completion_tokens=self.max_completion_tokens,
|
||||||
|
api_base=self.endpoint,
|
||||||
|
max_retries=self.MAX_RETRIES,
|
||||||
|
)
|
||||||
|
|
||||||
|
@observe(as_type="transcribe_image")
|
||||||
|
@retry(
|
||||||
|
stop=stop_after_delay(128),
|
||||||
|
wait=wait_exponential_jitter(2, 128),
|
||||||
|
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
||||||
|
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||||
|
reraise=True,
|
||||||
|
)
|
||||||
|
async def transcribe_image(self, input) -> Optional[BaseModel]:
|
||||||
|
"""
|
||||||
|
Generate a transcription of an image from a user query.
|
||||||
|
|
||||||
|
This method encodes the image and sends a request to the API to obtain a
|
||||||
|
description of the contents of the image.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
-----------
|
||||||
|
- input: The path to the image file that needs to be transcribed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
--------
|
||||||
|
- BaseModel: A structured output generated by the model, returned as an instance of
|
||||||
|
BaseModel.
|
||||||
|
"""
|
||||||
|
async with open_data_file(input, mode="rb") as image_file:
|
||||||
|
encoded_image = base64.b64encode(image_file.read()).decode("utf-8")
|
||||||
|
mime_type, _ = mimetypes.guess_type(input)
|
||||||
|
if not mime_type or not mime_type.startswith("image/"):
|
||||||
|
raise ValueError(
|
||||||
|
f"Could not determine MIME type for image file: {input}. Is the extension correct?"
|
||||||
|
)
|
||||||
|
return litellm.completion(
|
||||||
|
model=self.image_transcribe_model,
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "What's in this image?",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": f"data:{mime_type};base64,{encoded_image}",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
api_key=self.api_key,
|
||||||
|
api_base=self.endpoint,
|
||||||
|
api_version=self.api_version,
|
||||||
|
max_completion_tokens=300,
|
||||||
|
max_retries=self.MAX_RETRIES,
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -97,11 +97,10 @@ def get_llm_client(raise_api_key_error: bool = True):
|
||||||
)
|
)
|
||||||
|
|
||||||
return OllamaAPIAdapter(
|
return OllamaAPIAdapter(
|
||||||
llm_config.llm_endpoint,
|
|
||||||
llm_config.llm_api_key,
|
llm_config.llm_api_key,
|
||||||
llm_config.llm_model,
|
llm_config.llm_model,
|
||||||
"Ollama",
|
max_completion_tokens,
|
||||||
max_completion_tokens=max_completion_tokens,
|
llm_config.llm_endpoint,
|
||||||
instructor_mode=llm_config.llm_instructor_mode.lower(),
|
instructor_mode=llm_config.llm_instructor_mode.lower(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -111,8 +110,9 @@ def get_llm_client(raise_api_key_error: bool = True):
|
||||||
)
|
)
|
||||||
|
|
||||||
return AnthropicAdapter(
|
return AnthropicAdapter(
|
||||||
max_completion_tokens=max_completion_tokens,
|
llm_config.llm_api_key,
|
||||||
model=llm_config.llm_model,
|
llm_config.llm_model,
|
||||||
|
max_completion_tokens,
|
||||||
instructor_mode=llm_config.llm_instructor_mode.lower(),
|
instructor_mode=llm_config.llm_instructor_mode.lower(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -125,11 +125,10 @@ def get_llm_client(raise_api_key_error: bool = True):
|
||||||
)
|
)
|
||||||
|
|
||||||
return GenericAPIAdapter(
|
return GenericAPIAdapter(
|
||||||
llm_config.llm_endpoint,
|
|
||||||
llm_config.llm_api_key,
|
llm_config.llm_api_key,
|
||||||
llm_config.llm_model,
|
llm_config.llm_model,
|
||||||
|
max_completion_tokens,
|
||||||
"Custom",
|
"Custom",
|
||||||
max_completion_tokens=max_completion_tokens,
|
|
||||||
instructor_mode=llm_config.llm_instructor_mode.lower(),
|
instructor_mode=llm_config.llm_instructor_mode.lower(),
|
||||||
fallback_api_key=llm_config.fallback_api_key,
|
fallback_api_key=llm_config.fallback_api_key,
|
||||||
fallback_endpoint=llm_config.fallback_endpoint,
|
fallback_endpoint=llm_config.fallback_endpoint,
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
"""LLM Interface"""
|
"""LLM Interface"""
|
||||||
|
|
||||||
from typing import Type, Protocol
|
from typing import Type, Protocol, Optional
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
||||||
|
|
@ -8,13 +8,12 @@ from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
||||||
|
|
||||||
class LLMInterface(Protocol):
|
class LLMInterface(Protocol):
|
||||||
"""
|
"""
|
||||||
Define an interface for LLM models with methods for structured output and prompt
|
Define an interface for LLM models with methods for structured output, multimodal processing, and prompt display.
|
||||||
display.
|
|
||||||
|
|
||||||
Methods:
|
Methods:
|
||||||
- acreate_structured_output(text_input: str, system_prompt: str, response_model:
|
- acreate_structured_output(text_input: str, system_prompt: str, response_model: Type[BaseModel])
|
||||||
Type[BaseModel])
|
- create_transcript(input): Transcribe audio files to text
|
||||||
- show_prompt(text_input: str, system_prompt: str)
|
- transcribe_image(input): Analyze image files and return text description
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|
@ -36,3 +35,39 @@ class LLMInterface(Protocol):
|
||||||
output.
|
output.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def create_transcript(self, input) -> Optional[BaseModel]:
|
||||||
|
"""
|
||||||
|
Transcribe audio content to text.
|
||||||
|
|
||||||
|
This method should be implemented by subclasses that support audio transcription.
|
||||||
|
If not implemented, returns None and should be handled gracefully by callers.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
-----------
|
||||||
|
- input: The path to the audio file that needs to be transcribed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
--------
|
||||||
|
- BaseModel: A structured output containing the transcription, or None if not supported.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def transcribe_image(self, input) -> Optional[BaseModel]:
|
||||||
|
"""
|
||||||
|
Analyze image content and return text description.
|
||||||
|
|
||||||
|
This method should be implemented by subclasses that support image analysis.
|
||||||
|
If not implemented, returns None and should be handled gracefully by callers.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
-----------
|
||||||
|
- input: The path to the image file that needs to be analyzed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
--------
|
||||||
|
- BaseModel: A structured output containing the image description, or None if not supported.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
|
||||||
|
|
@ -2,12 +2,12 @@ import litellm
|
||||||
import instructor
|
import instructor
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import Type
|
from typing import Type
|
||||||
from litellm import JSONSchemaValidationError
|
from litellm import JSONSchemaValidationError, transcription
|
||||||
|
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
from cognee.modules.observability.get_observe import get_observe
|
from cognee.modules.observability.get_observe import get_observe
|
||||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
|
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.generic_llm_api.adapter import (
|
||||||
LLMInterface,
|
GenericAPIAdapter,
|
||||||
)
|
)
|
||||||
from cognee.infrastructure.llm.config import get_llm_config
|
from cognee.infrastructure.llm.config import get_llm_config
|
||||||
|
|
||||||
|
|
@ -19,12 +19,13 @@ from tenacity import (
|
||||||
retry_if_not_exception_type,
|
retry_if_not_exception_type,
|
||||||
before_sleep_log,
|
before_sleep_log,
|
||||||
)
|
)
|
||||||
|
from mistralai import Mistral
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
observe = get_observe()
|
observe = get_observe()
|
||||||
|
|
||||||
|
|
||||||
class MistralAdapter(LLMInterface):
|
class MistralAdapter(GenericAPIAdapter):
|
||||||
"""
|
"""
|
||||||
Adapter for Mistral AI API, for structured output generation and prompt display.
|
Adapter for Mistral AI API, for structured output generation and prompt display.
|
||||||
|
|
||||||
|
|
@ -33,10 +34,6 @@ class MistralAdapter(LLMInterface):
|
||||||
- show_prompt
|
- show_prompt
|
||||||
"""
|
"""
|
||||||
|
|
||||||
name = "Mistral"
|
|
||||||
model: str
|
|
||||||
api_key: str
|
|
||||||
max_completion_tokens: int
|
|
||||||
default_instructor_mode = "mistral_tools"
|
default_instructor_mode = "mistral_tools"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
@ -45,12 +42,21 @@ class MistralAdapter(LLMInterface):
|
||||||
model: str,
|
model: str,
|
||||||
max_completion_tokens: int,
|
max_completion_tokens: int,
|
||||||
endpoint: str = None,
|
endpoint: str = None,
|
||||||
|
transcription_model: str = None,
|
||||||
|
image_transcribe_model: str = None,
|
||||||
instructor_mode: str = None,
|
instructor_mode: str = None,
|
||||||
):
|
):
|
||||||
from mistralai import Mistral
|
from mistralai import Mistral
|
||||||
|
|
||||||
self.model = model
|
super().__init__(
|
||||||
self.max_completion_tokens = max_completion_tokens
|
api_key=api_key,
|
||||||
|
model=model,
|
||||||
|
max_completion_tokens=max_completion_tokens,
|
||||||
|
name="Mistral",
|
||||||
|
endpoint=endpoint,
|
||||||
|
transcription_model=transcription_model,
|
||||||
|
image_transcribe_model=image_transcribe_model,
|
||||||
|
)
|
||||||
|
|
||||||
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
|
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
|
||||||
|
|
||||||
|
|
@ -60,6 +66,7 @@ class MistralAdapter(LLMInterface):
|
||||||
api_key=get_llm_config().llm_api_key,
|
api_key=get_llm_config().llm_api_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@observe(as_type="generation")
|
||||||
@retry(
|
@retry(
|
||||||
stop=stop_after_delay(128),
|
stop=stop_after_delay(128),
|
||||||
wait=wait_exponential_jitter(2, 128),
|
wait=wait_exponential_jitter(2, 128),
|
||||||
|
|
@ -117,3 +124,42 @@ class MistralAdapter(LLMInterface):
|
||||||
logger.error(f"Schema validation failed: {str(e)}")
|
logger.error(f"Schema validation failed: {str(e)}")
|
||||||
logger.debug(f"Raw response: {e.raw_response}")
|
logger.debug(f"Raw response: {e.raw_response}")
|
||||||
raise ValueError(f"Response failed schema validation: {str(e)}")
|
raise ValueError(f"Response failed schema validation: {str(e)}")
|
||||||
|
|
||||||
|
@observe(as_type="transcription")
|
||||||
|
@retry(
|
||||||
|
stop=stop_after_delay(128),
|
||||||
|
wait=wait_exponential_jitter(2, 128),
|
||||||
|
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
||||||
|
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||||
|
reraise=True,
|
||||||
|
)
|
||||||
|
async def create_transcript(self, input):
|
||||||
|
"""
|
||||||
|
Generate an audio transcript from a user query.
|
||||||
|
|
||||||
|
This method creates a transcript from the specified audio file.
|
||||||
|
The audio file is processed and the transcription is retrieved from the API.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
-----------
|
||||||
|
- input: The path to the audio file that needs to be transcribed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
--------
|
||||||
|
The generated transcription of the audio file.
|
||||||
|
"""
|
||||||
|
transcription_model = self.transcription_model
|
||||||
|
if self.transcription_model.startswith("mistral"):
|
||||||
|
transcription_model = self.transcription_model.split("/")[-1]
|
||||||
|
file_name = input.split("/")[-1]
|
||||||
|
client = Mistral(api_key=self.api_key)
|
||||||
|
with open(input, "rb") as f:
|
||||||
|
transcription_response = client.audio.transcriptions.complete(
|
||||||
|
model=transcription_model,
|
||||||
|
file={
|
||||||
|
"content": f,
|
||||||
|
"file_name": file_name,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# TODO: We need to standardize return type of create_transcript across different models.
|
||||||
|
return transcription_response
|
||||||
|
|
|
||||||
|
|
@ -5,12 +5,12 @@ import instructor
|
||||||
from typing import Type
|
from typing import Type
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
|
|
||||||
LLMInterface,
|
|
||||||
)
|
|
||||||
from cognee.infrastructure.files.utils.open_data_file import open_data_file
|
from cognee.infrastructure.files.utils.open_data_file import open_data_file
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
from cognee.modules.observability.get_observe import get_observe
|
||||||
|
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.generic_llm_api.adapter import (
|
||||||
|
GenericAPIAdapter,
|
||||||
|
)
|
||||||
from tenacity import (
|
from tenacity import (
|
||||||
retry,
|
retry,
|
||||||
stop_after_delay,
|
stop_after_delay,
|
||||||
|
|
@ -20,9 +20,10 @@ from tenacity import (
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
observe = get_observe()
|
||||||
|
|
||||||
|
|
||||||
class OllamaAPIAdapter(LLMInterface):
|
class OllamaAPIAdapter(GenericAPIAdapter):
|
||||||
"""
|
"""
|
||||||
Adapter for a Generic API LLM provider using instructor with an OpenAI backend.
|
Adapter for a Generic API LLM provider using instructor with an OpenAI backend.
|
||||||
|
|
||||||
|
|
@ -46,18 +47,20 @@ class OllamaAPIAdapter(LLMInterface):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
endpoint: str,
|
|
||||||
api_key: str,
|
api_key: str,
|
||||||
model: str,
|
model: str,
|
||||||
name: str,
|
name: str,
|
||||||
max_completion_tokens: int,
|
max_completion_tokens: int,
|
||||||
|
endpoint: str,
|
||||||
instructor_mode: str = None,
|
instructor_mode: str = None,
|
||||||
):
|
):
|
||||||
self.name = name
|
super().__init__(
|
||||||
self.model = model
|
api_key=api_key,
|
||||||
self.api_key = api_key
|
model=model,
|
||||||
self.endpoint = endpoint
|
max_completion_tokens=max_completion_tokens,
|
||||||
self.max_completion_tokens = max_completion_tokens
|
name="Ollama",
|
||||||
|
endpoint=endpoint,
|
||||||
|
)
|
||||||
|
|
||||||
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
|
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
|
||||||
|
|
||||||
|
|
@ -66,6 +69,7 @@ class OllamaAPIAdapter(LLMInterface):
|
||||||
mode=instructor.Mode(self.instructor_mode),
|
mode=instructor.Mode(self.instructor_mode),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@observe(as_type="generation")
|
||||||
@retry(
|
@retry(
|
||||||
stop=stop_after_delay(128),
|
stop=stop_after_delay(128),
|
||||||
wait=wait_exponential_jitter(2, 128),
|
wait=wait_exponential_jitter(2, 128),
|
||||||
|
|
@ -113,95 +117,3 @@ class OllamaAPIAdapter(LLMInterface):
|
||||||
)
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
@retry(
|
|
||||||
stop=stop_after_delay(128),
|
|
||||||
wait=wait_exponential_jitter(2, 128),
|
|
||||||
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
|
||||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
|
||||||
reraise=True,
|
|
||||||
)
|
|
||||||
async def create_transcript(self, input_file: str) -> str:
|
|
||||||
"""
|
|
||||||
Generate an audio transcript from a user query.
|
|
||||||
|
|
||||||
This synchronous method takes an input audio file and returns its transcription. Raises
|
|
||||||
a FileNotFoundError if the input file does not exist, and raises a ValueError if
|
|
||||||
transcription fails or returns no text.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
-----------
|
|
||||||
|
|
||||||
- input_file (str): The path to the audio file to be transcribed.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
--------
|
|
||||||
|
|
||||||
- str: The transcription of the audio as a string.
|
|
||||||
"""
|
|
||||||
|
|
||||||
async with open_data_file(input_file, mode="rb") as audio_file:
|
|
||||||
transcription = self.aclient.audio.transcriptions.create(
|
|
||||||
model="whisper-1", # Ensure the correct model for transcription
|
|
||||||
file=audio_file,
|
|
||||||
language="en",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Ensure the response contains a valid transcript
|
|
||||||
if not hasattr(transcription, "text"):
|
|
||||||
raise ValueError("Transcription failed. No text returned.")
|
|
||||||
|
|
||||||
return transcription.text
|
|
||||||
|
|
||||||
@retry(
|
|
||||||
stop=stop_after_delay(128),
|
|
||||||
wait=wait_exponential_jitter(2, 128),
|
|
||||||
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
|
||||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
|
||||||
reraise=True,
|
|
||||||
)
|
|
||||||
async def transcribe_image(self, input_file: str) -> str:
|
|
||||||
"""
|
|
||||||
Transcribe content from an image using base64 encoding.
|
|
||||||
|
|
||||||
This synchronous method takes an input image file, encodes it as base64, and returns the
|
|
||||||
transcription of its content. Raises a FileNotFoundError if the input file does not
|
|
||||||
exist, and raises a ValueError if the transcription fails or no valid response is
|
|
||||||
received.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
-----------
|
|
||||||
|
|
||||||
- input_file (str): The path to the image file to be transcribed.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
--------
|
|
||||||
|
|
||||||
- str: The transcription of the image's content as a string.
|
|
||||||
"""
|
|
||||||
|
|
||||||
async with open_data_file(input_file, mode="rb") as image_file:
|
|
||||||
encoded_image = base64.b64encode(image_file.read()).decode("utf-8")
|
|
||||||
|
|
||||||
response = self.aclient.chat.completions.create(
|
|
||||||
model=self.model,
|
|
||||||
messages=[
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{"type": "text", "text": "What's in this image?"},
|
|
||||||
{
|
|
||||||
"type": "image_url",
|
|
||||||
"image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"},
|
|
||||||
},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
],
|
|
||||||
max_completion_tokens=300,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Ensure response is valid before accessing .choices[0].message.content
|
|
||||||
if not hasattr(response, "choices") or not response.choices:
|
|
||||||
raise ValueError("Image transcription failed. No response received.")
|
|
||||||
|
|
||||||
return response.choices[0].message.content
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
import base64
|
|
||||||
import litellm
|
import litellm
|
||||||
import instructor
|
import instructor
|
||||||
from typing import Type
|
from typing import Type
|
||||||
|
|
@ -16,8 +15,8 @@ from tenacity import (
|
||||||
before_sleep_log,
|
before_sleep_log,
|
||||||
)
|
)
|
||||||
|
|
||||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
|
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.generic_llm_api.adapter import (
|
||||||
LLMInterface,
|
GenericAPIAdapter,
|
||||||
)
|
)
|
||||||
from cognee.infrastructure.llm.exceptions import (
|
from cognee.infrastructure.llm.exceptions import (
|
||||||
ContentPolicyFilterError,
|
ContentPolicyFilterError,
|
||||||
|
|
@ -31,7 +30,7 @@ logger = get_logger()
|
||||||
observe = get_observe()
|
observe = get_observe()
|
||||||
|
|
||||||
|
|
||||||
class OpenAIAdapter(LLMInterface):
|
class OpenAIAdapter(GenericAPIAdapter):
|
||||||
"""
|
"""
|
||||||
Adapter for OpenAI's GPT-3, GPT-4 API.
|
Adapter for OpenAI's GPT-3, GPT-4 API.
|
||||||
|
|
||||||
|
|
@ -52,12 +51,7 @@ class OpenAIAdapter(LLMInterface):
|
||||||
- MAX_RETRIES
|
- MAX_RETRIES
|
||||||
"""
|
"""
|
||||||
|
|
||||||
name = "OpenAI"
|
|
||||||
model: str
|
|
||||||
api_key: str
|
|
||||||
api_version: str
|
|
||||||
default_instructor_mode = "json_schema_mode"
|
default_instructor_mode = "json_schema_mode"
|
||||||
|
|
||||||
MAX_RETRIES = 5
|
MAX_RETRIES = 5
|
||||||
|
|
||||||
"""Adapter for OpenAI's GPT-3, GPT=4 API"""
|
"""Adapter for OpenAI's GPT-3, GPT=4 API"""
|
||||||
|
|
@ -65,17 +59,29 @@ class OpenAIAdapter(LLMInterface):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
api_key: str,
|
api_key: str,
|
||||||
endpoint: str,
|
|
||||||
api_version: str,
|
|
||||||
model: str,
|
model: str,
|
||||||
transcription_model: str,
|
|
||||||
max_completion_tokens: int,
|
max_completion_tokens: int,
|
||||||
|
endpoint: str = None,
|
||||||
|
api_version: str = None,
|
||||||
|
transcription_model: str = None,
|
||||||
instructor_mode: str = None,
|
instructor_mode: str = None,
|
||||||
streaming: bool = False,
|
streaming: bool = False,
|
||||||
fallback_model: str = None,
|
fallback_model: str = None,
|
||||||
fallback_api_key: str = None,
|
fallback_api_key: str = None,
|
||||||
fallback_endpoint: str = None,
|
fallback_endpoint: str = None,
|
||||||
):
|
):
|
||||||
|
super().__init__(
|
||||||
|
api_key=api_key,
|
||||||
|
model=model,
|
||||||
|
max_completion_tokens=max_completion_tokens,
|
||||||
|
name="OpenAI",
|
||||||
|
endpoint=endpoint,
|
||||||
|
api_version=api_version,
|
||||||
|
transcription_model=transcription_model,
|
||||||
|
fallback_model=fallback_model,
|
||||||
|
fallback_api_key=fallback_api_key,
|
||||||
|
fallback_endpoint=fallback_endpoint,
|
||||||
|
)
|
||||||
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
|
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
|
||||||
# TODO: With gpt5 series models OpenAI expects JSON_SCHEMA as a mode for structured outputs.
|
# TODO: With gpt5 series models OpenAI expects JSON_SCHEMA as a mode for structured outputs.
|
||||||
# Make sure all new gpt models will work with this mode as well.
|
# Make sure all new gpt models will work with this mode as well.
|
||||||
|
|
@ -90,18 +96,8 @@ class OpenAIAdapter(LLMInterface):
|
||||||
self.aclient = instructor.from_litellm(litellm.acompletion)
|
self.aclient = instructor.from_litellm(litellm.acompletion)
|
||||||
self.client = instructor.from_litellm(litellm.completion)
|
self.client = instructor.from_litellm(litellm.completion)
|
||||||
|
|
||||||
self.transcription_model = transcription_model
|
|
||||||
self.model = model
|
|
||||||
self.api_key = api_key
|
|
||||||
self.endpoint = endpoint
|
|
||||||
self.api_version = api_version
|
|
||||||
self.max_completion_tokens = max_completion_tokens
|
|
||||||
self.streaming = streaming
|
self.streaming = streaming
|
||||||
|
|
||||||
self.fallback_model = fallback_model
|
|
||||||
self.fallback_api_key = fallback_api_key
|
|
||||||
self.fallback_endpoint = fallback_endpoint
|
|
||||||
|
|
||||||
@observe(as_type="generation")
|
@observe(as_type="generation")
|
||||||
@retry(
|
@retry(
|
||||||
stop=stop_after_delay(128),
|
stop=stop_after_delay(128),
|
||||||
|
|
@ -174,7 +170,7 @@ class OpenAIAdapter(LLMInterface):
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
api_key=self.fallback_api_key,
|
api_key=self.fallback_api_key,
|
||||||
# api_base=self.fallback_endpoint,
|
api_base=self.fallback_endpoint,
|
||||||
response_model=response_model,
|
response_model=response_model,
|
||||||
max_retries=self.MAX_RETRIES,
|
max_retries=self.MAX_RETRIES,
|
||||||
)
|
)
|
||||||
|
|
@ -193,57 +189,7 @@ class OpenAIAdapter(LLMInterface):
|
||||||
f"The provided input contains content that is not aligned with our content policy: {text_input}"
|
f"The provided input contains content that is not aligned with our content policy: {text_input}"
|
||||||
) from error
|
) from error
|
||||||
|
|
||||||
@observe
|
@observe(as_type="transcription")
|
||||||
@retry(
|
|
||||||
stop=stop_after_delay(128),
|
|
||||||
wait=wait_exponential_jitter(2, 128),
|
|
||||||
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
|
||||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
|
||||||
reraise=True,
|
|
||||||
)
|
|
||||||
def create_structured_output(
|
|
||||||
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
|
||||||
) -> BaseModel:
|
|
||||||
"""
|
|
||||||
Generate a response from a user query.
|
|
||||||
|
|
||||||
This method creates structured output by sending a synchronous request to the OpenAI API
|
|
||||||
using the provided parameters to generate a completion based on the user input and
|
|
||||||
system prompt.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
-----------
|
|
||||||
|
|
||||||
- text_input (str): The input text provided by the user for generating a response.
|
|
||||||
- system_prompt (str): The system's prompt to guide the model's response.
|
|
||||||
- response_model (Type[BaseModel]): The expected model type for the response.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
--------
|
|
||||||
|
|
||||||
- BaseModel: A structured output generated by the model, returned as an instance of
|
|
||||||
BaseModel.
|
|
||||||
"""
|
|
||||||
|
|
||||||
return self.client.chat.completions.create(
|
|
||||||
model=self.model,
|
|
||||||
messages=[
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": f"""{text_input}""",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "system",
|
|
||||||
"content": system_prompt,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
api_key=self.api_key,
|
|
||||||
api_base=self.endpoint,
|
|
||||||
api_version=self.api_version,
|
|
||||||
response_model=response_model,
|
|
||||||
max_retries=self.MAX_RETRIES,
|
|
||||||
)
|
|
||||||
|
|
||||||
@retry(
|
@retry(
|
||||||
stop=stop_after_delay(128),
|
stop=stop_after_delay(128),
|
||||||
wait=wait_exponential_jitter(2, 128),
|
wait=wait_exponential_jitter(2, 128),
|
||||||
|
|
@ -282,56 +228,4 @@ class OpenAIAdapter(LLMInterface):
|
||||||
|
|
||||||
return transcription
|
return transcription
|
||||||
|
|
||||||
@retry(
|
# transcribe image inherited from GenericAdapter
|
||||||
stop=stop_after_delay(128),
|
|
||||||
wait=wait_exponential_jitter(2, 128),
|
|
||||||
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
|
||||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
|
||||||
reraise=True,
|
|
||||||
)
|
|
||||||
async def transcribe_image(self, input) -> BaseModel:
|
|
||||||
"""
|
|
||||||
Generate a transcription of an image from a user query.
|
|
||||||
|
|
||||||
This method encodes the image and sends a request to the OpenAI API to obtain a
|
|
||||||
description of the contents of the image.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
-----------
|
|
||||||
|
|
||||||
- input: The path to the image file that needs to be transcribed.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
--------
|
|
||||||
|
|
||||||
- BaseModel: A structured output generated by the model, returned as an instance of
|
|
||||||
BaseModel.
|
|
||||||
"""
|
|
||||||
async with open_data_file(input, mode="rb") as image_file:
|
|
||||||
encoded_image = base64.b64encode(image_file.read()).decode("utf-8")
|
|
||||||
|
|
||||||
return litellm.completion(
|
|
||||||
model=self.model,
|
|
||||||
messages=[
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": "What's in this image?",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "image_url",
|
|
||||||
"image_url": {
|
|
||||||
"url": f"data:image/jpeg;base64,{encoded_image}",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
],
|
|
||||||
api_key=self.api_key,
|
|
||||||
api_base=self.endpoint,
|
|
||||||
api_version=self.api_version,
|
|
||||||
max_completion_tokens=300,
|
|
||||||
max_retries=self.MAX_RETRIES,
|
|
||||||
)
|
|
||||||
|
|
|
||||||
2
uv.lock
generated
2
uv.lock
generated
|
|
@ -1,5 +1,5 @@
|
||||||
version = 1
|
version = 1
|
||||||
revision = 2
|
revision = 3
|
||||||
requires-python = ">=3.10, <3.14"
|
requires-python = ">=3.10, <3.14"
|
||||||
resolution-markers = [
|
resolution-markers = [
|
||||||
"python_full_version >= '3.13' and platform_python_implementation != 'PyPy' and sys_platform == 'darwin'",
|
"python_full_version >= '3.13' and platform_python_implementation != 'PyPy' and sys_platform == 'darwin'",
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue