Test audio image transcription (#1911)
<!-- .github/pull_request_template.md -->
## Description
Run CI/CD for audio/image transcription PR from contributor
@rajeevrajeshuni
## Acceptance Criteria
<!--
* Key requirements to the new feature or modification;
* Proof that the changes work and meet the requirements;
* Include instructions on how to verify the changes. Describe how to
test it locally;
* Proof that it's sufficiently tested.
-->
## Type of Change
<!-- Please check the relevant option -->
- [ ] Bug fix (non-breaking change that fixes an issue)
- [ ] New feature (non-breaking change that adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to change)
- [ ] Documentation update
- [ ] Code refactoring
- [ ] Performance improvement
- [ ] Other (please specify):
## Screenshots/Videos (if applicable)
<!-- Add screenshots or videos to help explain your changes -->
## Pre-submission Checklist
<!-- Please check all boxes that apply before submitting your PR -->
- [ ] **I have tested my changes thoroughly before submitting this PR**
- [ ] **This PR contains minimal changes necessary to address the
issue/feature**
- [ ] My code follows the project's coding standards and style
guidelines
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have added necessary documentation (if applicable)
- [ ] All new and existing tests pass
- [ ] I have searched existing PRs to ensure this change hasn't been
submitted already
- [ ] I have linked any relevant issues in the description
- [ ] My commits have clear and descriptive messages
## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit
## Release Notes
* **New Features**
* Added audio transcription capability across LLM providers.
* Added image transcription and description capability.
* Enhanced observability and monitoring for AI operations.
* **Breaking Changes**
* Removed synchronous structured output method; use asynchronous
alternative instead.
* **Refactor**
* Unified LLM provider architecture for improved consistency and
maintainability.
<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
commit
aeda1d8eba
10 changed files with 280 additions and 213 deletions
|
|
@ -37,19 +37,6 @@ class LLMGateway:
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def create_structured_output(
|
|
||||||
text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
|
||||||
) -> BaseModel:
|
|
||||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.get_llm_client import (
|
|
||||||
get_llm_client,
|
|
||||||
)
|
|
||||||
|
|
||||||
llm_client = get_llm_client()
|
|
||||||
return llm_client.create_structured_output(
|
|
||||||
text_input=text_input, system_prompt=system_prompt, response_model=response_model
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_transcript(input) -> Coroutine:
|
def create_transcript(input) -> Coroutine:
|
||||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.get_llm_client import (
|
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.get_llm_client import (
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,9 @@ from typing import Type
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
import litellm
|
import litellm
|
||||||
import instructor
|
import instructor
|
||||||
|
import anthropic
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
from cognee.modules.observability.get_observe import get_observe
|
||||||
from tenacity import (
|
from tenacity import (
|
||||||
retry,
|
retry,
|
||||||
stop_after_delay,
|
stop_after_delay,
|
||||||
|
|
@ -12,38 +14,41 @@ from tenacity import (
|
||||||
before_sleep_log,
|
before_sleep_log,
|
||||||
)
|
)
|
||||||
|
|
||||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
|
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.generic_llm_api.adapter import (
|
||||||
LLMInterface,
|
GenericAPIAdapter,
|
||||||
)
|
)
|
||||||
from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
|
from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
|
||||||
from cognee.infrastructure.llm.config import get_llm_config
|
from cognee.infrastructure.llm.config import get_llm_config
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
observe = get_observe()
|
||||||
|
|
||||||
|
|
||||||
class AnthropicAdapter(LLMInterface):
|
class AnthropicAdapter(GenericAPIAdapter):
|
||||||
"""
|
"""
|
||||||
Adapter for interfacing with the Anthropic API, enabling structured output generation
|
Adapter for interfacing with the Anthropic API, enabling structured output generation
|
||||||
and prompt display.
|
and prompt display.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
name = "Anthropic"
|
|
||||||
model: str
|
|
||||||
default_instructor_mode = "anthropic_tools"
|
default_instructor_mode = "anthropic_tools"
|
||||||
|
|
||||||
def __init__(self, max_completion_tokens: int, model: str = None, instructor_mode: str = None):
|
def __init__(
|
||||||
import anthropic
|
self, api_key: str, model: str, max_completion_tokens: int, instructor_mode: str = None
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
api_key=api_key,
|
||||||
|
model=model,
|
||||||
|
max_completion_tokens=max_completion_tokens,
|
||||||
|
name="Anthropic",
|
||||||
|
)
|
||||||
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
|
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
|
||||||
|
|
||||||
self.aclient = instructor.patch(
|
self.aclient = instructor.patch(
|
||||||
create=anthropic.AsyncAnthropic(api_key=get_llm_config().llm_api_key).messages.create,
|
create=anthropic.AsyncAnthropic(api_key=self.api_key).messages.create,
|
||||||
mode=instructor.Mode(self.instructor_mode),
|
mode=instructor.Mode(self.instructor_mode),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.model = model
|
@observe(as_type="generation")
|
||||||
self.max_completion_tokens = max_completion_tokens
|
|
||||||
|
|
||||||
@retry(
|
@retry(
|
||||||
stop=stop_after_delay(128),
|
stop=stop_after_delay(128),
|
||||||
wait=wait_exponential_jitter(8, 128),
|
wait=wait_exponential_jitter(8, 128),
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
"""Adapter for Generic API LLM provider API"""
|
"""Adapter for Gemini API LLM provider"""
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
import instructor
|
import instructor
|
||||||
|
|
@ -8,13 +8,9 @@ from openai import ContentFilterFinishReasonError
|
||||||
from litellm.exceptions import ContentPolicyViolationError
|
from litellm.exceptions import ContentPolicyViolationError
|
||||||
from instructor.core import InstructorRetryException
|
from instructor.core import InstructorRetryException
|
||||||
|
|
||||||
from cognee.infrastructure.llm.exceptions import ContentPolicyFilterError
|
|
||||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
|
|
||||||
LLMInterface,
|
|
||||||
)
|
|
||||||
import logging
|
import logging
|
||||||
from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
|
from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
|
||||||
from cognee.shared.logging_utils import get_logger
|
|
||||||
from tenacity import (
|
from tenacity import (
|
||||||
retry,
|
retry,
|
||||||
stop_after_delay,
|
stop_after_delay,
|
||||||
|
|
@ -23,55 +19,65 @@ from tenacity import (
|
||||||
before_sleep_log,
|
before_sleep_log,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from cognee.infrastructure.llm.exceptions import ContentPolicyFilterError
|
||||||
|
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.generic_llm_api.adapter import (
|
||||||
|
GenericAPIAdapter,
|
||||||
|
)
|
||||||
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
from cognee.modules.observability.get_observe import get_observe
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
observe = get_observe()
|
||||||
|
|
||||||
|
|
||||||
class GeminiAdapter(LLMInterface):
|
class GeminiAdapter(GenericAPIAdapter):
|
||||||
"""
|
"""
|
||||||
Adapter for Gemini API LLM provider.
|
Adapter for Gemini API LLM provider.
|
||||||
|
|
||||||
This class initializes the API adapter with necessary credentials and configurations for
|
This class initializes the API adapter with necessary credentials and configurations for
|
||||||
interacting with the gemini LLM models. It provides methods for creating structured outputs
|
interacting with the gemini LLM models. It provides methods for creating structured outputs
|
||||||
based on user input and system prompts.
|
based on user input and system prompts, as well as multimodal processing capabilities.
|
||||||
|
|
||||||
Public methods:
|
Public methods:
|
||||||
- acreate_structured_output(text_input: str, system_prompt: str, response_model:
|
- acreate_structured_output(text_input: str, system_prompt: str, response_model: Type[BaseModel]) -> BaseModel
|
||||||
Type[BaseModel]) -> BaseModel
|
- create_transcript(input) -> BaseModel: Transcribe audio files to text
|
||||||
|
- transcribe_image(input) -> BaseModel: Inherited from GenericAPIAdapter
|
||||||
"""
|
"""
|
||||||
|
|
||||||
name: str
|
|
||||||
model: str
|
|
||||||
api_key: str
|
|
||||||
default_instructor_mode = "json_mode"
|
default_instructor_mode = "json_mode"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
endpoint,
|
|
||||||
api_key: str,
|
api_key: str,
|
||||||
model: str,
|
model: str,
|
||||||
api_version: str,
|
|
||||||
max_completion_tokens: int,
|
max_completion_tokens: int,
|
||||||
|
endpoint: str = None,
|
||||||
|
api_version: str = None,
|
||||||
|
transcription_model: str = None,
|
||||||
instructor_mode: str = None,
|
instructor_mode: str = None,
|
||||||
fallback_model: str = None,
|
fallback_model: str = None,
|
||||||
fallback_api_key: str = None,
|
fallback_api_key: str = None,
|
||||||
fallback_endpoint: str = None,
|
fallback_endpoint: str = None,
|
||||||
):
|
):
|
||||||
self.model = model
|
super().__init__(
|
||||||
self.api_key = api_key
|
api_key=api_key,
|
||||||
self.endpoint = endpoint
|
model=model,
|
||||||
self.api_version = api_version
|
max_completion_tokens=max_completion_tokens,
|
||||||
self.max_completion_tokens = max_completion_tokens
|
name="Gemini",
|
||||||
|
endpoint=endpoint,
|
||||||
self.fallback_model = fallback_model
|
api_version=api_version,
|
||||||
self.fallback_api_key = fallback_api_key
|
transcription_model=transcription_model,
|
||||||
self.fallback_endpoint = fallback_endpoint
|
fallback_model=fallback_model,
|
||||||
|
fallback_api_key=fallback_api_key,
|
||||||
|
fallback_endpoint=fallback_endpoint,
|
||||||
|
)
|
||||||
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
|
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
|
||||||
|
|
||||||
self.aclient = instructor.from_litellm(
|
self.aclient = instructor.from_litellm(
|
||||||
litellm.acompletion, mode=instructor.Mode(self.instructor_mode)
|
litellm.acompletion, mode=instructor.Mode(self.instructor_mode)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@observe(as_type="generation")
|
||||||
@retry(
|
@retry(
|
||||||
stop=stop_after_delay(128),
|
stop=stop_after_delay(128),
|
||||||
wait=wait_exponential_jitter(8, 128),
|
wait=wait_exponential_jitter(8, 128),
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,10 @@
|
||||||
"""Adapter for Generic API LLM provider API"""
|
"""Adapter for Generic API LLM provider API"""
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import mimetypes
|
||||||
import litellm
|
import litellm
|
||||||
import instructor
|
import instructor
|
||||||
from typing import Type
|
from typing import Type, Optional
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from openai import ContentFilterFinishReasonError
|
from openai import ContentFilterFinishReasonError
|
||||||
from litellm.exceptions import ContentPolicyViolationError
|
from litellm.exceptions import ContentPolicyViolationError
|
||||||
|
|
@ -12,6 +14,8 @@ from cognee.infrastructure.llm.exceptions import ContentPolicyFilterError
|
||||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
|
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
|
||||||
LLMInterface,
|
LLMInterface,
|
||||||
)
|
)
|
||||||
|
from cognee.infrastructure.files.utils.open_data_file import open_data_file
|
||||||
|
from cognee.modules.observability.get_observe import get_observe
|
||||||
import logging
|
import logging
|
||||||
from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
|
from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
|
@ -23,7 +27,12 @@ from tenacity import (
|
||||||
before_sleep_log,
|
before_sleep_log,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.types import (
|
||||||
|
TranscriptionReturnType,
|
||||||
|
)
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
observe = get_observe()
|
||||||
|
|
||||||
|
|
||||||
class GenericAPIAdapter(LLMInterface):
|
class GenericAPIAdapter(LLMInterface):
|
||||||
|
|
@ -39,18 +48,19 @@ class GenericAPIAdapter(LLMInterface):
|
||||||
Type[BaseModel]) -> BaseModel
|
Type[BaseModel]) -> BaseModel
|
||||||
"""
|
"""
|
||||||
|
|
||||||
name: str
|
MAX_RETRIES = 5
|
||||||
model: str
|
|
||||||
api_key: str
|
|
||||||
default_instructor_mode = "json_mode"
|
default_instructor_mode = "json_mode"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
endpoint,
|
|
||||||
api_key: str,
|
api_key: str,
|
||||||
model: str,
|
model: str,
|
||||||
name: str,
|
|
||||||
max_completion_tokens: int,
|
max_completion_tokens: int,
|
||||||
|
name: str,
|
||||||
|
endpoint: str = None,
|
||||||
|
api_version: str = None,
|
||||||
|
transcription_model: str = None,
|
||||||
|
image_transcribe_model: str = None,
|
||||||
instructor_mode: str = None,
|
instructor_mode: str = None,
|
||||||
fallback_model: str = None,
|
fallback_model: str = None,
|
||||||
fallback_api_key: str = None,
|
fallback_api_key: str = None,
|
||||||
|
|
@ -59,9 +69,11 @@ class GenericAPIAdapter(LLMInterface):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.model = model
|
self.model = model
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
|
self.api_version = api_version
|
||||||
self.endpoint = endpoint
|
self.endpoint = endpoint
|
||||||
self.max_completion_tokens = max_completion_tokens
|
self.max_completion_tokens = max_completion_tokens
|
||||||
|
self.transcription_model = transcription_model or model
|
||||||
|
self.image_transcribe_model = image_transcribe_model or model
|
||||||
self.fallback_model = fallback_model
|
self.fallback_model = fallback_model
|
||||||
self.fallback_api_key = fallback_api_key
|
self.fallback_api_key = fallback_api_key
|
||||||
self.fallback_endpoint = fallback_endpoint
|
self.fallback_endpoint = fallback_endpoint
|
||||||
|
|
@ -72,6 +84,7 @@ class GenericAPIAdapter(LLMInterface):
|
||||||
litellm.acompletion, mode=instructor.Mode(self.instructor_mode)
|
litellm.acompletion, mode=instructor.Mode(self.instructor_mode)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@observe(as_type="generation")
|
||||||
@retry(
|
@retry(
|
||||||
stop=stop_after_delay(128),
|
stop=stop_after_delay(128),
|
||||||
wait=wait_exponential_jitter(8, 128),
|
wait=wait_exponential_jitter(8, 128),
|
||||||
|
|
@ -173,3 +186,115 @@ class GenericAPIAdapter(LLMInterface):
|
||||||
raise ContentPolicyFilterError(
|
raise ContentPolicyFilterError(
|
||||||
f"The provided input contains content that is not aligned with our content policy: {text_input}"
|
f"The provided input contains content that is not aligned with our content policy: {text_input}"
|
||||||
) from error
|
) from error
|
||||||
|
|
||||||
|
@observe(as_type="transcription")
|
||||||
|
@retry(
|
||||||
|
stop=stop_after_delay(128),
|
||||||
|
wait=wait_exponential_jitter(2, 128),
|
||||||
|
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
||||||
|
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||||
|
reraise=True,
|
||||||
|
)
|
||||||
|
async def create_transcript(self, input) -> TranscriptionReturnType:
|
||||||
|
"""
|
||||||
|
Generate an audio transcript from a user query.
|
||||||
|
|
||||||
|
This method creates a transcript from the specified audio file, raising a
|
||||||
|
FileNotFoundError if the file does not exist. The audio file is processed and the
|
||||||
|
transcription is retrieved from the API.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
-----------
|
||||||
|
- input: The path to the audio file that needs to be transcribed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
--------
|
||||||
|
The generated transcription of the audio file.
|
||||||
|
"""
|
||||||
|
async with open_data_file(input, mode="rb") as audio_file:
|
||||||
|
encoded_string = base64.b64encode(audio_file.read()).decode("utf-8")
|
||||||
|
mime_type, _ = mimetypes.guess_type(input)
|
||||||
|
if not mime_type or not mime_type.startswith("audio/"):
|
||||||
|
raise ValueError(
|
||||||
|
f"Could not determine MIME type for audio file: {input}. Is the extension correct?"
|
||||||
|
)
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model=self.transcription_model,
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "file",
|
||||||
|
"file": {"file_data": f"data:{mime_type};base64,{encoded_string}"},
|
||||||
|
},
|
||||||
|
{"type": "text", "text": "Transcribe the following audio precisely."},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
api_key=self.api_key,
|
||||||
|
api_version=self.api_version,
|
||||||
|
max_completion_tokens=self.max_completion_tokens,
|
||||||
|
api_base=self.endpoint,
|
||||||
|
max_retries=self.MAX_RETRIES,
|
||||||
|
)
|
||||||
|
|
||||||
|
return TranscriptionReturnType(response.choices[0].message.content, response)
|
||||||
|
|
||||||
|
@observe(as_type="transcribe_image")
|
||||||
|
@retry(
|
||||||
|
stop=stop_after_delay(128),
|
||||||
|
wait=wait_exponential_jitter(2, 128),
|
||||||
|
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
||||||
|
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||||
|
reraise=True,
|
||||||
|
)
|
||||||
|
async def transcribe_image(self, input) -> BaseModel:
|
||||||
|
"""
|
||||||
|
Generate a transcription of an image from a user query.
|
||||||
|
|
||||||
|
This method encodes the image and sends a request to the API to obtain a
|
||||||
|
description of the contents of the image.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
-----------
|
||||||
|
- input: The path to the image file that needs to be transcribed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
--------
|
||||||
|
- BaseModel: A structured output generated by the model, returned as an instance of
|
||||||
|
BaseModel.
|
||||||
|
"""
|
||||||
|
async with open_data_file(input, mode="rb") as image_file:
|
||||||
|
encoded_image = base64.b64encode(image_file.read()).decode("utf-8")
|
||||||
|
mime_type, _ = mimetypes.guess_type(input)
|
||||||
|
if not mime_type or not mime_type.startswith("image/"):
|
||||||
|
raise ValueError(
|
||||||
|
f"Could not determine MIME type for image file: {input}. Is the extension correct?"
|
||||||
|
)
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model=self.image_transcribe_model,
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "What's in this image?",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": f"data:{mime_type};base64,{encoded_image}",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
api_key=self.api_key,
|
||||||
|
api_base=self.endpoint,
|
||||||
|
api_version=self.api_version,
|
||||||
|
max_completion_tokens=300,
|
||||||
|
max_retries=self.MAX_RETRIES,
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
|
||||||
|
|
@ -103,7 +103,7 @@ def get_llm_client(raise_api_key_error: bool = True):
|
||||||
llm_config.llm_api_key,
|
llm_config.llm_api_key,
|
||||||
llm_config.llm_model,
|
llm_config.llm_model,
|
||||||
"Ollama",
|
"Ollama",
|
||||||
max_completion_tokens=max_completion_tokens,
|
max_completion_tokens,
|
||||||
instructor_mode=llm_config.llm_instructor_mode.lower(),
|
instructor_mode=llm_config.llm_instructor_mode.lower(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -113,8 +113,9 @@ def get_llm_client(raise_api_key_error: bool = True):
|
||||||
)
|
)
|
||||||
|
|
||||||
return AnthropicAdapter(
|
return AnthropicAdapter(
|
||||||
max_completion_tokens=max_completion_tokens,
|
llm_config.llm_api_key,
|
||||||
model=llm_config.llm_model,
|
llm_config.llm_model,
|
||||||
|
max_completion_tokens,
|
||||||
instructor_mode=llm_config.llm_instructor_mode.lower(),
|
instructor_mode=llm_config.llm_instructor_mode.lower(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -127,11 +128,10 @@ def get_llm_client(raise_api_key_error: bool = True):
|
||||||
)
|
)
|
||||||
|
|
||||||
return GenericAPIAdapter(
|
return GenericAPIAdapter(
|
||||||
llm_config.llm_endpoint,
|
|
||||||
llm_config.llm_api_key,
|
llm_config.llm_api_key,
|
||||||
llm_config.llm_model,
|
llm_config.llm_model,
|
||||||
|
max_completion_tokens,
|
||||||
"Custom",
|
"Custom",
|
||||||
max_completion_tokens=max_completion_tokens,
|
|
||||||
instructor_mode=llm_config.llm_instructor_mode.lower(),
|
instructor_mode=llm_config.llm_instructor_mode.lower(),
|
||||||
fallback_api_key=llm_config.fallback_api_key,
|
fallback_api_key=llm_config.fallback_api_key,
|
||||||
fallback_endpoint=llm_config.fallback_endpoint,
|
fallback_endpoint=llm_config.fallback_endpoint,
|
||||||
|
|
|
||||||
|
|
@ -3,18 +3,14 @@
|
||||||
from typing import Type, Protocol
|
from typing import Type, Protocol
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
|
||||||
|
|
||||||
|
|
||||||
class LLMInterface(Protocol):
|
class LLMInterface(Protocol):
|
||||||
"""
|
"""
|
||||||
Define an interface for LLM models with methods for structured output and prompt
|
Define an interface for LLM models with methods for structured output, multimodal processing, and prompt display.
|
||||||
display.
|
|
||||||
|
|
||||||
Methods:
|
Methods:
|
||||||
- acreate_structured_output(text_input: str, system_prompt: str, response_model:
|
- acreate_structured_output(text_input: str, system_prompt: str, response_model: Type[BaseModel])
|
||||||
Type[BaseModel])
|
|
||||||
- show_prompt(text_input: str, system_prompt: str)
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,13 @@
|
||||||
import litellm
|
import litellm
|
||||||
import instructor
|
import instructor
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import Type
|
from typing import Type, Optional
|
||||||
from litellm import JSONSchemaValidationError
|
from litellm import JSONSchemaValidationError
|
||||||
|
from cognee.infrastructure.files.utils.open_data_file import open_data_file
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
from cognee.modules.observability.get_observe import get_observe
|
from cognee.modules.observability.get_observe import get_observe
|
||||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
|
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.generic_llm_api.adapter import (
|
||||||
LLMInterface,
|
GenericAPIAdapter,
|
||||||
)
|
)
|
||||||
from cognee.infrastructure.llm.config import get_llm_config
|
from cognee.infrastructure.llm.config import get_llm_config
|
||||||
from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
|
from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
|
||||||
|
|
@ -20,12 +20,14 @@ from tenacity import (
|
||||||
retry_if_not_exception_type,
|
retry_if_not_exception_type,
|
||||||
before_sleep_log,
|
before_sleep_log,
|
||||||
)
|
)
|
||||||
|
from ..types import TranscriptionReturnType
|
||||||
|
from mistralai import Mistral
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
observe = get_observe()
|
observe = get_observe()
|
||||||
|
|
||||||
|
|
||||||
class MistralAdapter(LLMInterface):
|
class MistralAdapter(GenericAPIAdapter):
|
||||||
"""
|
"""
|
||||||
Adapter for Mistral AI API, for structured output generation and prompt display.
|
Adapter for Mistral AI API, for structured output generation and prompt display.
|
||||||
|
|
||||||
|
|
@ -34,10 +36,6 @@ class MistralAdapter(LLMInterface):
|
||||||
- show_prompt
|
- show_prompt
|
||||||
"""
|
"""
|
||||||
|
|
||||||
name = "Mistral"
|
|
||||||
model: str
|
|
||||||
api_key: str
|
|
||||||
max_completion_tokens: int
|
|
||||||
default_instructor_mode = "mistral_tools"
|
default_instructor_mode = "mistral_tools"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
@ -46,12 +44,19 @@ class MistralAdapter(LLMInterface):
|
||||||
model: str,
|
model: str,
|
||||||
max_completion_tokens: int,
|
max_completion_tokens: int,
|
||||||
endpoint: str = None,
|
endpoint: str = None,
|
||||||
|
transcription_model: str = None,
|
||||||
|
image_transcribe_model: str = None,
|
||||||
instructor_mode: str = None,
|
instructor_mode: str = None,
|
||||||
):
|
):
|
||||||
from mistralai import Mistral
|
super().__init__(
|
||||||
|
api_key=api_key,
|
||||||
self.model = model
|
model=model,
|
||||||
self.max_completion_tokens = max_completion_tokens
|
max_completion_tokens=max_completion_tokens,
|
||||||
|
name="Mistral",
|
||||||
|
endpoint=endpoint,
|
||||||
|
transcription_model=transcription_model,
|
||||||
|
image_transcribe_model=image_transcribe_model,
|
||||||
|
)
|
||||||
|
|
||||||
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
|
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
|
||||||
|
|
||||||
|
|
@ -60,7 +65,9 @@ class MistralAdapter(LLMInterface):
|
||||||
mode=instructor.Mode(self.instructor_mode),
|
mode=instructor.Mode(self.instructor_mode),
|
||||||
api_key=get_llm_config().llm_api_key,
|
api_key=get_llm_config().llm_api_key,
|
||||||
)
|
)
|
||||||
|
self.mistral_client = Mistral(api_key=self.api_key)
|
||||||
|
|
||||||
|
@observe(as_type="generation")
|
||||||
@retry(
|
@retry(
|
||||||
stop=stop_after_delay(128),
|
stop=stop_after_delay(128),
|
||||||
wait=wait_exponential_jitter(8, 128),
|
wait=wait_exponential_jitter(8, 128),
|
||||||
|
|
@ -119,3 +126,41 @@ class MistralAdapter(LLMInterface):
|
||||||
logger.error(f"Schema validation failed: {str(e)}")
|
logger.error(f"Schema validation failed: {str(e)}")
|
||||||
logger.debug(f"Raw response: {e.raw_response}")
|
logger.debug(f"Raw response: {e.raw_response}")
|
||||||
raise ValueError(f"Response failed schema validation: {str(e)}")
|
raise ValueError(f"Response failed schema validation: {str(e)}")
|
||||||
|
|
||||||
|
@observe(as_type="transcription")
|
||||||
|
@retry(
|
||||||
|
stop=stop_after_delay(128),
|
||||||
|
wait=wait_exponential_jitter(2, 128),
|
||||||
|
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
||||||
|
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||||
|
reraise=True,
|
||||||
|
)
|
||||||
|
async def create_transcript(self, input) -> Optional[TranscriptionReturnType]:
|
||||||
|
"""
|
||||||
|
Generate an audio transcript from a user query.
|
||||||
|
|
||||||
|
This method creates a transcript from the specified audio file.
|
||||||
|
The audio file is processed and the transcription is retrieved from the API.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
-----------
|
||||||
|
- input: The path to the audio file that needs to be transcribed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
--------
|
||||||
|
The generated transcription of the audio file.
|
||||||
|
"""
|
||||||
|
transcription_model = self.transcription_model
|
||||||
|
if self.transcription_model.startswith("mistral"):
|
||||||
|
transcription_model = self.transcription_model.split("/")[-1]
|
||||||
|
file_name = input.split("/")[-1]
|
||||||
|
async with open_data_file(input, mode="rb") as f:
|
||||||
|
transcription_response = self.mistral_client.audio.transcriptions.complete(
|
||||||
|
model=transcription_model,
|
||||||
|
file={
|
||||||
|
"content": f,
|
||||||
|
"file_name": file_name,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return TranscriptionReturnType(transcription_response.text, transcription_response)
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,6 @@ from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.ll
|
||||||
from cognee.infrastructure.files.utils.open_data_file import open_data_file
|
from cognee.infrastructure.files.utils.open_data_file import open_data_file
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
|
from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
|
||||||
|
|
||||||
from tenacity import (
|
from tenacity import (
|
||||||
retry,
|
retry,
|
||||||
stop_after_delay,
|
stop_after_delay,
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
import base64
|
|
||||||
import litellm
|
import litellm
|
||||||
import instructor
|
import instructor
|
||||||
from typing import Type
|
from typing import Type
|
||||||
|
|
@ -16,8 +15,8 @@ from tenacity import (
|
||||||
before_sleep_log,
|
before_sleep_log,
|
||||||
)
|
)
|
||||||
|
|
||||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
|
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.generic_llm_api.adapter import (
|
||||||
LLMInterface,
|
GenericAPIAdapter,
|
||||||
)
|
)
|
||||||
from cognee.infrastructure.llm.exceptions import (
|
from cognee.infrastructure.llm.exceptions import (
|
||||||
ContentPolicyFilterError,
|
ContentPolicyFilterError,
|
||||||
|
|
@ -26,13 +25,16 @@ from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
|
||||||
from cognee.infrastructure.files.utils.open_data_file import open_data_file
|
from cognee.infrastructure.files.utils.open_data_file import open_data_file
|
||||||
from cognee.modules.observability.get_observe import get_observe
|
from cognee.modules.observability.get_observe import get_observe
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.types import (
|
||||||
|
TranscriptionReturnType,
|
||||||
|
)
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
observe = get_observe()
|
observe = get_observe()
|
||||||
|
|
||||||
|
|
||||||
class OpenAIAdapter(LLMInterface):
|
class OpenAIAdapter(GenericAPIAdapter):
|
||||||
"""
|
"""
|
||||||
Adapter for OpenAI's GPT-3, GPT-4 API.
|
Adapter for OpenAI's GPT-3, GPT-4 API.
|
||||||
|
|
||||||
|
|
@ -53,12 +55,7 @@ class OpenAIAdapter(LLMInterface):
|
||||||
- MAX_RETRIES
|
- MAX_RETRIES
|
||||||
"""
|
"""
|
||||||
|
|
||||||
name = "OpenAI"
|
|
||||||
model: str
|
|
||||||
api_key: str
|
|
||||||
api_version: str
|
|
||||||
default_instructor_mode = "json_schema_mode"
|
default_instructor_mode = "json_schema_mode"
|
||||||
|
|
||||||
MAX_RETRIES = 5
|
MAX_RETRIES = 5
|
||||||
|
|
||||||
"""Adapter for OpenAI's GPT-3, GPT=4 API"""
|
"""Adapter for OpenAI's GPT-3, GPT=4 API"""
|
||||||
|
|
@ -66,17 +63,29 @@ class OpenAIAdapter(LLMInterface):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
api_key: str,
|
api_key: str,
|
||||||
endpoint: str,
|
|
||||||
api_version: str,
|
|
||||||
model: str,
|
model: str,
|
||||||
transcription_model: str,
|
|
||||||
max_completion_tokens: int,
|
max_completion_tokens: int,
|
||||||
|
endpoint: str = None,
|
||||||
|
api_version: str = None,
|
||||||
|
transcription_model: str = None,
|
||||||
instructor_mode: str = None,
|
instructor_mode: str = None,
|
||||||
streaming: bool = False,
|
streaming: bool = False,
|
||||||
fallback_model: str = None,
|
fallback_model: str = None,
|
||||||
fallback_api_key: str = None,
|
fallback_api_key: str = None,
|
||||||
fallback_endpoint: str = None,
|
fallback_endpoint: str = None,
|
||||||
):
|
):
|
||||||
|
super().__init__(
|
||||||
|
api_key=api_key,
|
||||||
|
model=model,
|
||||||
|
max_completion_tokens=max_completion_tokens,
|
||||||
|
name="OpenAI",
|
||||||
|
endpoint=endpoint,
|
||||||
|
api_version=api_version,
|
||||||
|
transcription_model=transcription_model,
|
||||||
|
fallback_model=fallback_model,
|
||||||
|
fallback_api_key=fallback_api_key,
|
||||||
|
fallback_endpoint=fallback_endpoint,
|
||||||
|
)
|
||||||
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
|
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
|
||||||
# TODO: With gpt5 series models OpenAI expects JSON_SCHEMA as a mode for structured outputs.
|
# TODO: With gpt5 series models OpenAI expects JSON_SCHEMA as a mode for structured outputs.
|
||||||
# Make sure all new gpt models will work with this mode as well.
|
# Make sure all new gpt models will work with this mode as well.
|
||||||
|
|
@ -91,18 +100,8 @@ class OpenAIAdapter(LLMInterface):
|
||||||
self.aclient = instructor.from_litellm(litellm.acompletion)
|
self.aclient = instructor.from_litellm(litellm.acompletion)
|
||||||
self.client = instructor.from_litellm(litellm.completion)
|
self.client = instructor.from_litellm(litellm.completion)
|
||||||
|
|
||||||
self.transcription_model = transcription_model
|
|
||||||
self.model = model
|
|
||||||
self.api_key = api_key
|
|
||||||
self.endpoint = endpoint
|
|
||||||
self.api_version = api_version
|
|
||||||
self.max_completion_tokens = max_completion_tokens
|
|
||||||
self.streaming = streaming
|
self.streaming = streaming
|
||||||
|
|
||||||
self.fallback_model = fallback_model
|
|
||||||
self.fallback_api_key = fallback_api_key
|
|
||||||
self.fallback_endpoint = fallback_endpoint
|
|
||||||
|
|
||||||
@observe(as_type="generation")
|
@observe(as_type="generation")
|
||||||
@retry(
|
@retry(
|
||||||
stop=stop_after_delay(128),
|
stop=stop_after_delay(128),
|
||||||
|
|
@ -198,7 +197,7 @@ class OpenAIAdapter(LLMInterface):
|
||||||
f"The provided input contains content that is not aligned with our content policy: {text_input}"
|
f"The provided input contains content that is not aligned with our content policy: {text_input}"
|
||||||
) from error
|
) from error
|
||||||
|
|
||||||
@observe
|
@observe(as_type="transcription")
|
||||||
@retry(
|
@retry(
|
||||||
stop=stop_after_delay(128),
|
stop=stop_after_delay(128),
|
||||||
wait=wait_exponential_jitter(2, 128),
|
wait=wait_exponential_jitter(2, 128),
|
||||||
|
|
@ -206,58 +205,7 @@ class OpenAIAdapter(LLMInterface):
|
||||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||||
reraise=True,
|
reraise=True,
|
||||||
)
|
)
|
||||||
def create_structured_output(
|
async def create_transcript(self, input, **kwargs) -> TranscriptionReturnType:
|
||||||
self, text_input: str, system_prompt: str, response_model: Type[BaseModel], **kwargs
|
|
||||||
) -> BaseModel:
|
|
||||||
"""
|
|
||||||
Generate a response from a user query.
|
|
||||||
|
|
||||||
This method creates structured output by sending a synchronous request to the OpenAI API
|
|
||||||
using the provided parameters to generate a completion based on the user input and
|
|
||||||
system prompt.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
-----------
|
|
||||||
|
|
||||||
- text_input (str): The input text provided by the user for generating a response.
|
|
||||||
- system_prompt (str): The system's prompt to guide the model's response.
|
|
||||||
- response_model (Type[BaseModel]): The expected model type for the response.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
--------
|
|
||||||
|
|
||||||
- BaseModel: A structured output generated by the model, returned as an instance of
|
|
||||||
BaseModel.
|
|
||||||
"""
|
|
||||||
|
|
||||||
return self.client.chat.completions.create(
|
|
||||||
model=self.model,
|
|
||||||
messages=[
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": f"""{text_input}""",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "system",
|
|
||||||
"content": system_prompt,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
api_key=self.api_key,
|
|
||||||
api_base=self.endpoint,
|
|
||||||
api_version=self.api_version,
|
|
||||||
response_model=response_model,
|
|
||||||
max_retries=self.MAX_RETRIES,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
@retry(
|
|
||||||
stop=stop_after_delay(128),
|
|
||||||
wait=wait_exponential_jitter(2, 128),
|
|
||||||
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
|
||||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
|
||||||
reraise=True,
|
|
||||||
)
|
|
||||||
async def create_transcript(self, input, **kwargs):
|
|
||||||
"""
|
"""
|
||||||
Generate an audio transcript from a user query.
|
Generate an audio transcript from a user query.
|
||||||
|
|
||||||
|
|
@ -286,60 +234,6 @@ class OpenAIAdapter(LLMInterface):
|
||||||
max_retries=self.MAX_RETRIES,
|
max_retries=self.MAX_RETRIES,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
return TranscriptionReturnType(transcription.text, transcription)
|
||||||
|
|
||||||
return transcription
|
# transcribe_image is inherited from GenericAPIAdapter
|
||||||
|
|
||||||
@retry(
|
|
||||||
stop=stop_after_delay(128),
|
|
||||||
wait=wait_exponential_jitter(2, 128),
|
|
||||||
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
|
||||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
|
||||||
reraise=True,
|
|
||||||
)
|
|
||||||
async def transcribe_image(self, input, **kwargs) -> BaseModel:
|
|
||||||
"""
|
|
||||||
Generate a transcription of an image from a user query.
|
|
||||||
|
|
||||||
This method encodes the image and sends a request to the OpenAI API to obtain a
|
|
||||||
description of the contents of the image.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
-----------
|
|
||||||
|
|
||||||
- input: The path to the image file that needs to be transcribed.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
--------
|
|
||||||
|
|
||||||
- BaseModel: A structured output generated by the model, returned as an instance of
|
|
||||||
BaseModel.
|
|
||||||
"""
|
|
||||||
async with open_data_file(input, mode="rb") as image_file:
|
|
||||||
encoded_image = base64.b64encode(image_file.read()).decode("utf-8")
|
|
||||||
|
|
||||||
return litellm.completion(
|
|
||||||
model=self.model,
|
|
||||||
messages=[
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": "What's in this image?",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "image_url",
|
|
||||||
"image_url": {
|
|
||||||
"url": f"data:image/jpeg;base64,{encoded_image}",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
],
|
|
||||||
api_key=self.api_key,
|
|
||||||
api_base=self.endpoint,
|
|
||||||
api_version=self.api_version,
|
|
||||||
max_completion_tokens=300,
|
|
||||||
max_retries=self.MAX_RETRIES,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,10 @@
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class TranscriptionReturnType:
|
||||||
|
text: str
|
||||||
|
payload: BaseModel
|
||||||
|
|
||||||
|
def __init__(self, text: str, payload: BaseModel):
|
||||||
|
self.text = text
|
||||||
|
self.payload = payload
|
||||||
Loading…
Add table
Reference in a new issue