Merge branch 'feature/cog-3532-empower-test_search-db-retrievers-tests-reorg-3' into feature/cog-3532-empower-test_search-db-retrievers-tests-reorg-4
This commit is contained in:
commit
623126eec1
10 changed files with 280 additions and 213 deletions
|
|
@ -37,19 +37,6 @@ class LLMGateway:
|
|||
**kwargs,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_structured_output(
|
||||
text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
||||
) -> BaseModel:
|
||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.get_llm_client import (
|
||||
get_llm_client,
|
||||
)
|
||||
|
||||
llm_client = get_llm_client()
|
||||
return llm_client.create_structured_output(
|
||||
text_input=text_input, system_prompt=system_prompt, response_model=response_model
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_transcript(input) -> Coroutine:
|
||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.get_llm_client import (
|
||||
|
|
|
|||
|
|
@ -3,7 +3,9 @@ from typing import Type
|
|||
from pydantic import BaseModel
|
||||
import litellm
|
||||
import instructor
|
||||
import anthropic
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.observability.get_observe import get_observe
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_delay,
|
||||
|
|
@ -12,38 +14,41 @@ from tenacity import (
|
|||
before_sleep_log,
|
||||
)
|
||||
|
||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
|
||||
LLMInterface,
|
||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.generic_llm_api.adapter import (
|
||||
GenericAPIAdapter,
|
||||
)
|
||||
from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
|
||||
from cognee.infrastructure.llm.config import get_llm_config
|
||||
|
||||
logger = get_logger()
|
||||
observe = get_observe()
|
||||
|
||||
|
||||
class AnthropicAdapter(LLMInterface):
|
||||
class AnthropicAdapter(GenericAPIAdapter):
|
||||
"""
|
||||
Adapter for interfacing with the Anthropic API, enabling structured output generation
|
||||
and prompt display.
|
||||
"""
|
||||
|
||||
name = "Anthropic"
|
||||
model: str
|
||||
default_instructor_mode = "anthropic_tools"
|
||||
|
||||
def __init__(self, max_completion_tokens: int, model: str = None, instructor_mode: str = None):
|
||||
import anthropic
|
||||
|
||||
def __init__(
|
||||
self, api_key: str, model: str, max_completion_tokens: int, instructor_mode: str = None
|
||||
):
|
||||
super().__init__(
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
name="Anthropic",
|
||||
)
|
||||
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
|
||||
|
||||
self.aclient = instructor.patch(
|
||||
create=anthropic.AsyncAnthropic(api_key=get_llm_config().llm_api_key).messages.create,
|
||||
create=anthropic.AsyncAnthropic(api_key=self.api_key).messages.create,
|
||||
mode=instructor.Mode(self.instructor_mode),
|
||||
)
|
||||
|
||||
self.model = model
|
||||
self.max_completion_tokens = max_completion_tokens
|
||||
|
||||
@observe(as_type="generation")
|
||||
@retry(
|
||||
stop=stop_after_delay(128),
|
||||
wait=wait_exponential_jitter(8, 128),
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
"""Adapter for Generic API LLM provider API"""
|
||||
"""Adapter for Gemini API LLM provider"""
|
||||
|
||||
import litellm
|
||||
import instructor
|
||||
|
|
@ -8,13 +8,9 @@ from openai import ContentFilterFinishReasonError
|
|||
from litellm.exceptions import ContentPolicyViolationError
|
||||
from instructor.core import InstructorRetryException
|
||||
|
||||
from cognee.infrastructure.llm.exceptions import ContentPolicyFilterError
|
||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
|
||||
LLMInterface,
|
||||
)
|
||||
import logging
|
||||
from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_delay,
|
||||
|
|
@ -23,55 +19,65 @@ from tenacity import (
|
|||
before_sleep_log,
|
||||
)
|
||||
|
||||
from cognee.infrastructure.llm.exceptions import ContentPolicyFilterError
|
||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.generic_llm_api.adapter import (
|
||||
GenericAPIAdapter,
|
||||
)
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.observability.get_observe import get_observe
|
||||
|
||||
logger = get_logger()
|
||||
observe = get_observe()
|
||||
|
||||
|
||||
class GeminiAdapter(LLMInterface):
|
||||
class GeminiAdapter(GenericAPIAdapter):
|
||||
"""
|
||||
Adapter for Gemini API LLM provider.
|
||||
|
||||
This class initializes the API adapter with necessary credentials and configurations for
|
||||
interacting with the gemini LLM models. It provides methods for creating structured outputs
|
||||
based on user input and system prompts.
|
||||
based on user input and system prompts, as well as multimodal processing capabilities.
|
||||
|
||||
Public methods:
|
||||
- acreate_structured_output(text_input: str, system_prompt: str, response_model:
|
||||
Type[BaseModel]) -> BaseModel
|
||||
- acreate_structured_output(text_input: str, system_prompt: str, response_model: Type[BaseModel]) -> BaseModel
|
||||
- create_transcript(input) -> BaseModel: Transcribe audio files to text
|
||||
- transcribe_image(input) -> BaseModel: Inherited from GenericAPIAdapter
|
||||
"""
|
||||
|
||||
name: str
|
||||
model: str
|
||||
api_key: str
|
||||
default_instructor_mode = "json_mode"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
endpoint,
|
||||
api_key: str,
|
||||
model: str,
|
||||
api_version: str,
|
||||
max_completion_tokens: int,
|
||||
endpoint: str = None,
|
||||
api_version: str = None,
|
||||
transcription_model: str = None,
|
||||
instructor_mode: str = None,
|
||||
fallback_model: str = None,
|
||||
fallback_api_key: str = None,
|
||||
fallback_endpoint: str = None,
|
||||
):
|
||||
self.model = model
|
||||
self.api_key = api_key
|
||||
self.endpoint = endpoint
|
||||
self.api_version = api_version
|
||||
self.max_completion_tokens = max_completion_tokens
|
||||
|
||||
self.fallback_model = fallback_model
|
||||
self.fallback_api_key = fallback_api_key
|
||||
self.fallback_endpoint = fallback_endpoint
|
||||
|
||||
super().__init__(
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
name="Gemini",
|
||||
endpoint=endpoint,
|
||||
api_version=api_version,
|
||||
transcription_model=transcription_model,
|
||||
fallback_model=fallback_model,
|
||||
fallback_api_key=fallback_api_key,
|
||||
fallback_endpoint=fallback_endpoint,
|
||||
)
|
||||
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
|
||||
|
||||
self.aclient = instructor.from_litellm(
|
||||
litellm.acompletion, mode=instructor.Mode(self.instructor_mode)
|
||||
)
|
||||
|
||||
@observe(as_type="generation")
|
||||
@retry(
|
||||
stop=stop_after_delay(128),
|
||||
wait=wait_exponential_jitter(8, 128),
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
"""Adapter for Generic API LLM provider API"""
|
||||
|
||||
import base64
|
||||
import mimetypes
|
||||
import litellm
|
||||
import instructor
|
||||
from typing import Type
|
||||
from typing import Type, Optional
|
||||
from pydantic import BaseModel
|
||||
from openai import ContentFilterFinishReasonError
|
||||
from litellm.exceptions import ContentPolicyViolationError
|
||||
|
|
@ -12,6 +14,8 @@ from cognee.infrastructure.llm.exceptions import ContentPolicyFilterError
|
|||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
|
||||
LLMInterface,
|
||||
)
|
||||
from cognee.infrastructure.files.utils.open_data_file import open_data_file
|
||||
from cognee.modules.observability.get_observe import get_observe
|
||||
import logging
|
||||
from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
|
@ -23,7 +27,12 @@ from tenacity import (
|
|||
before_sleep_log,
|
||||
)
|
||||
|
||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.types import (
|
||||
TranscriptionReturnType,
|
||||
)
|
||||
|
||||
logger = get_logger()
|
||||
observe = get_observe()
|
||||
|
||||
|
||||
class GenericAPIAdapter(LLMInterface):
|
||||
|
|
@ -39,18 +48,19 @@ class GenericAPIAdapter(LLMInterface):
|
|||
Type[BaseModel]) -> BaseModel
|
||||
"""
|
||||
|
||||
name: str
|
||||
model: str
|
||||
api_key: str
|
||||
MAX_RETRIES = 5
|
||||
default_instructor_mode = "json_mode"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
endpoint,
|
||||
api_key: str,
|
||||
model: str,
|
||||
name: str,
|
||||
max_completion_tokens: int,
|
||||
name: str,
|
||||
endpoint: str = None,
|
||||
api_version: str = None,
|
||||
transcription_model: str = None,
|
||||
image_transcribe_model: str = None,
|
||||
instructor_mode: str = None,
|
||||
fallback_model: str = None,
|
||||
fallback_api_key: str = None,
|
||||
|
|
@ -59,9 +69,11 @@ class GenericAPIAdapter(LLMInterface):
|
|||
self.name = name
|
||||
self.model = model
|
||||
self.api_key = api_key
|
||||
self.api_version = api_version
|
||||
self.endpoint = endpoint
|
||||
self.max_completion_tokens = max_completion_tokens
|
||||
|
||||
self.transcription_model = transcription_model or model
|
||||
self.image_transcribe_model = image_transcribe_model or model
|
||||
self.fallback_model = fallback_model
|
||||
self.fallback_api_key = fallback_api_key
|
||||
self.fallback_endpoint = fallback_endpoint
|
||||
|
|
@ -72,6 +84,7 @@ class GenericAPIAdapter(LLMInterface):
|
|||
litellm.acompletion, mode=instructor.Mode(self.instructor_mode)
|
||||
)
|
||||
|
||||
@observe(as_type="generation")
|
||||
@retry(
|
||||
stop=stop_after_delay(128),
|
||||
wait=wait_exponential_jitter(8, 128),
|
||||
|
|
@ -173,3 +186,115 @@ class GenericAPIAdapter(LLMInterface):
|
|||
raise ContentPolicyFilterError(
|
||||
f"The provided input contains content that is not aligned with our content policy: {text_input}"
|
||||
) from error
|
||||
|
||||
@observe(as_type="transcription")
|
||||
@retry(
|
||||
stop=stop_after_delay(128),
|
||||
wait=wait_exponential_jitter(2, 128),
|
||||
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||
reraise=True,
|
||||
)
|
||||
async def create_transcript(self, input) -> TranscriptionReturnType:
|
||||
"""
|
||||
Generate an audio transcript from a user query.
|
||||
|
||||
This method creates a transcript from the specified audio file, raising a
|
||||
FileNotFoundError if the file does not exist. The audio file is processed and the
|
||||
transcription is retrieved from the API.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
- input: The path to the audio file that needs to be transcribed.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
The generated transcription of the audio file.
|
||||
"""
|
||||
async with open_data_file(input, mode="rb") as audio_file:
|
||||
encoded_string = base64.b64encode(audio_file.read()).decode("utf-8")
|
||||
mime_type, _ = mimetypes.guess_type(input)
|
||||
if not mime_type or not mime_type.startswith("audio/"):
|
||||
raise ValueError(
|
||||
f"Could not determine MIME type for audio file: {input}. Is the extension correct?"
|
||||
)
|
||||
response = await litellm.acompletion(
|
||||
model=self.transcription_model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "file",
|
||||
"file": {"file_data": f"data:{mime_type};base64,{encoded_string}"},
|
||||
},
|
||||
{"type": "text", "text": "Transcribe the following audio precisely."},
|
||||
],
|
||||
}
|
||||
],
|
||||
api_key=self.api_key,
|
||||
api_version=self.api_version,
|
||||
max_completion_tokens=self.max_completion_tokens,
|
||||
api_base=self.endpoint,
|
||||
max_retries=self.MAX_RETRIES,
|
||||
)
|
||||
|
||||
return TranscriptionReturnType(response.choices[0].message.content, response)
|
||||
|
||||
@observe(as_type="transcribe_image")
|
||||
@retry(
|
||||
stop=stop_after_delay(128),
|
||||
wait=wait_exponential_jitter(2, 128),
|
||||
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||
reraise=True,
|
||||
)
|
||||
async def transcribe_image(self, input) -> BaseModel:
|
||||
"""
|
||||
Generate a transcription of an image from a user query.
|
||||
|
||||
This method encodes the image and sends a request to the API to obtain a
|
||||
description of the contents of the image.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
- input: The path to the image file that needs to be transcribed.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
- BaseModel: A structured output generated by the model, returned as an instance of
|
||||
BaseModel.
|
||||
"""
|
||||
async with open_data_file(input, mode="rb") as image_file:
|
||||
encoded_image = base64.b64encode(image_file.read()).decode("utf-8")
|
||||
mime_type, _ = mimetypes.guess_type(input)
|
||||
if not mime_type or not mime_type.startswith("image/"):
|
||||
raise ValueError(
|
||||
f"Could not determine MIME type for image file: {input}. Is the extension correct?"
|
||||
)
|
||||
response = await litellm.acompletion(
|
||||
model=self.image_transcribe_model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What's in this image?",
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:{mime_type};base64,{encoded_image}",
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
api_key=self.api_key,
|
||||
api_base=self.endpoint,
|
||||
api_version=self.api_version,
|
||||
max_completion_tokens=300,
|
||||
max_retries=self.MAX_RETRIES,
|
||||
)
|
||||
return response
|
||||
|
|
|
|||
|
|
@ -103,7 +103,7 @@ def get_llm_client(raise_api_key_error: bool = True):
|
|||
llm_config.llm_api_key,
|
||||
llm_config.llm_model,
|
||||
"Ollama",
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
max_completion_tokens,
|
||||
instructor_mode=llm_config.llm_instructor_mode.lower(),
|
||||
)
|
||||
|
||||
|
|
@ -113,8 +113,9 @@ def get_llm_client(raise_api_key_error: bool = True):
|
|||
)
|
||||
|
||||
return AnthropicAdapter(
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
model=llm_config.llm_model,
|
||||
llm_config.llm_api_key,
|
||||
llm_config.llm_model,
|
||||
max_completion_tokens,
|
||||
instructor_mode=llm_config.llm_instructor_mode.lower(),
|
||||
)
|
||||
|
||||
|
|
@ -127,11 +128,10 @@ def get_llm_client(raise_api_key_error: bool = True):
|
|||
)
|
||||
|
||||
return GenericAPIAdapter(
|
||||
llm_config.llm_endpoint,
|
||||
llm_config.llm_api_key,
|
||||
llm_config.llm_model,
|
||||
max_completion_tokens,
|
||||
"Custom",
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
instructor_mode=llm_config.llm_instructor_mode.lower(),
|
||||
fallback_api_key=llm_config.fallback_api_key,
|
||||
fallback_endpoint=llm_config.fallback_endpoint,
|
||||
|
|
|
|||
|
|
@ -3,18 +3,14 @@
|
|||
from typing import Type, Protocol
|
||||
from abc import abstractmethod
|
||||
from pydantic import BaseModel
|
||||
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
||||
|
||||
|
||||
class LLMInterface(Protocol):
|
||||
"""
|
||||
Define an interface for LLM models with methods for structured output and prompt
|
||||
display.
|
||||
Define an interface for LLM models with methods for structured output, multimodal processing, and prompt display.
|
||||
|
||||
Methods:
|
||||
- acreate_structured_output(text_input: str, system_prompt: str, response_model:
|
||||
Type[BaseModel])
|
||||
- show_prompt(text_input: str, system_prompt: str)
|
||||
- acreate_structured_output(text_input: str, system_prompt: str, response_model: Type[BaseModel])
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
|
|
|
|||
|
|
@ -1,13 +1,13 @@
|
|||
import litellm
|
||||
import instructor
|
||||
from pydantic import BaseModel
|
||||
from typing import Type
|
||||
from typing import Type, Optional
|
||||
from litellm import JSONSchemaValidationError
|
||||
|
||||
from cognee.infrastructure.files.utils.open_data_file import open_data_file
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.observability.get_observe import get_observe
|
||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
|
||||
LLMInterface,
|
||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.generic_llm_api.adapter import (
|
||||
GenericAPIAdapter,
|
||||
)
|
||||
from cognee.infrastructure.llm.config import get_llm_config
|
||||
from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
|
||||
|
|
@ -20,12 +20,14 @@ from tenacity import (
|
|||
retry_if_not_exception_type,
|
||||
before_sleep_log,
|
||||
)
|
||||
from ..types import TranscriptionReturnType
|
||||
from mistralai import Mistral
|
||||
|
||||
logger = get_logger()
|
||||
observe = get_observe()
|
||||
|
||||
|
||||
class MistralAdapter(LLMInterface):
|
||||
class MistralAdapter(GenericAPIAdapter):
|
||||
"""
|
||||
Adapter for Mistral AI API, for structured output generation and prompt display.
|
||||
|
||||
|
|
@ -34,10 +36,6 @@ class MistralAdapter(LLMInterface):
|
|||
- show_prompt
|
||||
"""
|
||||
|
||||
name = "Mistral"
|
||||
model: str
|
||||
api_key: str
|
||||
max_completion_tokens: int
|
||||
default_instructor_mode = "mistral_tools"
|
||||
|
||||
def __init__(
|
||||
|
|
@ -46,12 +44,19 @@ class MistralAdapter(LLMInterface):
|
|||
model: str,
|
||||
max_completion_tokens: int,
|
||||
endpoint: str = None,
|
||||
transcription_model: str = None,
|
||||
image_transcribe_model: str = None,
|
||||
instructor_mode: str = None,
|
||||
):
|
||||
from mistralai import Mistral
|
||||
|
||||
self.model = model
|
||||
self.max_completion_tokens = max_completion_tokens
|
||||
super().__init__(
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
name="Mistral",
|
||||
endpoint=endpoint,
|
||||
transcription_model=transcription_model,
|
||||
image_transcribe_model=image_transcribe_model,
|
||||
)
|
||||
|
||||
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
|
||||
|
||||
|
|
@ -60,7 +65,9 @@ class MistralAdapter(LLMInterface):
|
|||
mode=instructor.Mode(self.instructor_mode),
|
||||
api_key=get_llm_config().llm_api_key,
|
||||
)
|
||||
self.mistral_client = Mistral(api_key=self.api_key)
|
||||
|
||||
@observe(as_type="generation")
|
||||
@retry(
|
||||
stop=stop_after_delay(128),
|
||||
wait=wait_exponential_jitter(8, 128),
|
||||
|
|
@ -119,3 +126,41 @@ class MistralAdapter(LLMInterface):
|
|||
logger.error(f"Schema validation failed: {str(e)}")
|
||||
logger.debug(f"Raw response: {e.raw_response}")
|
||||
raise ValueError(f"Response failed schema validation: {str(e)}")
|
||||
|
||||
@observe(as_type="transcription")
|
||||
@retry(
|
||||
stop=stop_after_delay(128),
|
||||
wait=wait_exponential_jitter(2, 128),
|
||||
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||
reraise=True,
|
||||
)
|
||||
async def create_transcript(self, input) -> Optional[TranscriptionReturnType]:
|
||||
"""
|
||||
Generate an audio transcript from a user query.
|
||||
|
||||
This method creates a transcript from the specified audio file.
|
||||
The audio file is processed and the transcription is retrieved from the API.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
- input: The path to the audio file that needs to be transcribed.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
The generated transcription of the audio file.
|
||||
"""
|
||||
transcription_model = self.transcription_model
|
||||
if self.transcription_model.startswith("mistral"):
|
||||
transcription_model = self.transcription_model.split("/")[-1]
|
||||
file_name = input.split("/")[-1]
|
||||
async with open_data_file(input, mode="rb") as f:
|
||||
transcription_response = self.mistral_client.audio.transcriptions.complete(
|
||||
model=transcription_model,
|
||||
file={
|
||||
"content": f,
|
||||
"file_name": file_name,
|
||||
},
|
||||
)
|
||||
|
||||
return TranscriptionReturnType(transcription_response.text, transcription_response)
|
||||
|
|
|
|||
|
|
@ -12,7 +12,6 @@ from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.ll
|
|||
from cognee.infrastructure.files.utils.open_data_file import open_data_file
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
|
||||
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_delay,
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
import base64
|
||||
import litellm
|
||||
import instructor
|
||||
from typing import Type
|
||||
|
|
@ -16,8 +15,8 @@ from tenacity import (
|
|||
before_sleep_log,
|
||||
)
|
||||
|
||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
|
||||
LLMInterface,
|
||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.generic_llm_api.adapter import (
|
||||
GenericAPIAdapter,
|
||||
)
|
||||
from cognee.infrastructure.llm.exceptions import (
|
||||
ContentPolicyFilterError,
|
||||
|
|
@ -26,13 +25,16 @@ from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
|
|||
from cognee.infrastructure.files.utils.open_data_file import open_data_file
|
||||
from cognee.modules.observability.get_observe import get_observe
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.types import (
|
||||
TranscriptionReturnType,
|
||||
)
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
observe = get_observe()
|
||||
|
||||
|
||||
class OpenAIAdapter(LLMInterface):
|
||||
class OpenAIAdapter(GenericAPIAdapter):
|
||||
"""
|
||||
Adapter for OpenAI's GPT-3, GPT-4 API.
|
||||
|
||||
|
|
@ -53,12 +55,7 @@ class OpenAIAdapter(LLMInterface):
|
|||
- MAX_RETRIES
|
||||
"""
|
||||
|
||||
name = "OpenAI"
|
||||
model: str
|
||||
api_key: str
|
||||
api_version: str
|
||||
default_instructor_mode = "json_schema_mode"
|
||||
|
||||
MAX_RETRIES = 5
|
||||
|
||||
"""Adapter for OpenAI's GPT-3, GPT=4 API"""
|
||||
|
|
@ -66,17 +63,29 @@ class OpenAIAdapter(LLMInterface):
|
|||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
endpoint: str,
|
||||
api_version: str,
|
||||
model: str,
|
||||
transcription_model: str,
|
||||
max_completion_tokens: int,
|
||||
endpoint: str = None,
|
||||
api_version: str = None,
|
||||
transcription_model: str = None,
|
||||
instructor_mode: str = None,
|
||||
streaming: bool = False,
|
||||
fallback_model: str = None,
|
||||
fallback_api_key: str = None,
|
||||
fallback_endpoint: str = None,
|
||||
):
|
||||
super().__init__(
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
name="OpenAI",
|
||||
endpoint=endpoint,
|
||||
api_version=api_version,
|
||||
transcription_model=transcription_model,
|
||||
fallback_model=fallback_model,
|
||||
fallback_api_key=fallback_api_key,
|
||||
fallback_endpoint=fallback_endpoint,
|
||||
)
|
||||
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
|
||||
# TODO: With gpt5 series models OpenAI expects JSON_SCHEMA as a mode for structured outputs.
|
||||
# Make sure all new gpt models will work with this mode as well.
|
||||
|
|
@ -91,18 +100,8 @@ class OpenAIAdapter(LLMInterface):
|
|||
self.aclient = instructor.from_litellm(litellm.acompletion)
|
||||
self.client = instructor.from_litellm(litellm.completion)
|
||||
|
||||
self.transcription_model = transcription_model
|
||||
self.model = model
|
||||
self.api_key = api_key
|
||||
self.endpoint = endpoint
|
||||
self.api_version = api_version
|
||||
self.max_completion_tokens = max_completion_tokens
|
||||
self.streaming = streaming
|
||||
|
||||
self.fallback_model = fallback_model
|
||||
self.fallback_api_key = fallback_api_key
|
||||
self.fallback_endpoint = fallback_endpoint
|
||||
|
||||
@observe(as_type="generation")
|
||||
@retry(
|
||||
stop=stop_after_delay(128),
|
||||
|
|
@ -198,7 +197,7 @@ class OpenAIAdapter(LLMInterface):
|
|||
f"The provided input contains content that is not aligned with our content policy: {text_input}"
|
||||
) from error
|
||||
|
||||
@observe
|
||||
@observe(as_type="transcription")
|
||||
@retry(
|
||||
stop=stop_after_delay(128),
|
||||
wait=wait_exponential_jitter(2, 128),
|
||||
|
|
@ -206,58 +205,7 @@ class OpenAIAdapter(LLMInterface):
|
|||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||
reraise=True,
|
||||
)
|
||||
def create_structured_output(
|
||||
self, text_input: str, system_prompt: str, response_model: Type[BaseModel], **kwargs
|
||||
) -> BaseModel:
|
||||
"""
|
||||
Generate a response from a user query.
|
||||
|
||||
This method creates structured output by sending a synchronous request to the OpenAI API
|
||||
using the provided parameters to generate a completion based on the user input and
|
||||
system prompt.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- text_input (str): The input text provided by the user for generating a response.
|
||||
- system_prompt (str): The system's prompt to guide the model's response.
|
||||
- response_model (Type[BaseModel]): The expected model type for the response.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- BaseModel: A structured output generated by the model, returned as an instance of
|
||||
BaseModel.
|
||||
"""
|
||||
|
||||
return self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""{text_input}""",
|
||||
},
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
],
|
||||
api_key=self.api_key,
|
||||
api_base=self.endpoint,
|
||||
api_version=self.api_version,
|
||||
response_model=response_model,
|
||||
max_retries=self.MAX_RETRIES,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@retry(
|
||||
stop=stop_after_delay(128),
|
||||
wait=wait_exponential_jitter(2, 128),
|
||||
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||
reraise=True,
|
||||
)
|
||||
async def create_transcript(self, input, **kwargs):
|
||||
async def create_transcript(self, input, **kwargs) -> TranscriptionReturnType:
|
||||
"""
|
||||
Generate an audio transcript from a user query.
|
||||
|
||||
|
|
@ -286,60 +234,6 @@ class OpenAIAdapter(LLMInterface):
|
|||
max_retries=self.MAX_RETRIES,
|
||||
**kwargs,
|
||||
)
|
||||
return TranscriptionReturnType(transcription.text, transcription)
|
||||
|
||||
return transcription
|
||||
|
||||
@retry(
|
||||
stop=stop_after_delay(128),
|
||||
wait=wait_exponential_jitter(2, 128),
|
||||
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||
reraise=True,
|
||||
)
|
||||
async def transcribe_image(self, input, **kwargs) -> BaseModel:
|
||||
"""
|
||||
Generate a transcription of an image from a user query.
|
||||
|
||||
This method encodes the image and sends a request to the OpenAI API to obtain a
|
||||
description of the contents of the image.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- input: The path to the image file that needs to be transcribed.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- BaseModel: A structured output generated by the model, returned as an instance of
|
||||
BaseModel.
|
||||
"""
|
||||
async with open_data_file(input, mode="rb") as image_file:
|
||||
encoded_image = base64.b64encode(image_file.read()).decode("utf-8")
|
||||
|
||||
return litellm.completion(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What's in this image?",
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{encoded_image}",
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
api_key=self.api_key,
|
||||
api_base=self.endpoint,
|
||||
api_version=self.api_version,
|
||||
max_completion_tokens=300,
|
||||
max_retries=self.MAX_RETRIES,
|
||||
**kwargs,
|
||||
)
|
||||
# transcribe_image is inherited from GenericAPIAdapter
|
||||
|
|
|
|||
|
|
@ -0,0 +1,10 @@
|
|||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class TranscriptionReturnType:
|
||||
text: str
|
||||
payload: BaseModel
|
||||
|
||||
def __init__(self, text: str, payload: BaseModel):
|
||||
self.text = text
|
||||
self.payload = payload
|
||||
Loading…
Add table
Reference in a new issue