Merge branch 'feature/cog-3532-empower-test_search-db-retrievers-tests-reorg-3' into feature/cog-3532-empower-test_search-db-retrievers-tests-reorg-4

This commit is contained in:
hajdul88 2025-12-17 10:07:58 +01:00 committed by GitHub
commit 623126eec1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 280 additions and 213 deletions

View file

@ -37,19 +37,6 @@ class LLMGateway:
**kwargs,
)
@staticmethod
def create_structured_output(
text_input: str, system_prompt: str, response_model: Type[BaseModel]
) -> BaseModel:
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.get_llm_client import (
get_llm_client,
)
llm_client = get_llm_client()
return llm_client.create_structured_output(
text_input=text_input, system_prompt=system_prompt, response_model=response_model
)
@staticmethod
def create_transcript(input) -> Coroutine:
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.get_llm_client import (

View file

@ -3,7 +3,9 @@ from typing import Type
from pydantic import BaseModel
import litellm
import instructor
import anthropic
from cognee.shared.logging_utils import get_logger
from cognee.modules.observability.get_observe import get_observe
from tenacity import (
retry,
stop_after_delay,
@ -12,38 +14,41 @@ from tenacity import (
before_sleep_log,
)
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
LLMInterface,
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.generic_llm_api.adapter import (
GenericAPIAdapter,
)
from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
from cognee.infrastructure.llm.config import get_llm_config
logger = get_logger()
observe = get_observe()
class AnthropicAdapter(LLMInterface):
class AnthropicAdapter(GenericAPIAdapter):
"""
Adapter for interfacing with the Anthropic API, enabling structured output generation
and prompt display.
"""
name = "Anthropic"
model: str
default_instructor_mode = "anthropic_tools"
def __init__(self, max_completion_tokens: int, model: str = None, instructor_mode: str = None):
import anthropic
def __init__(
self, api_key: str, model: str, max_completion_tokens: int, instructor_mode: str = None
):
super().__init__(
api_key=api_key,
model=model,
max_completion_tokens=max_completion_tokens,
name="Anthropic",
)
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
self.aclient = instructor.patch(
create=anthropic.AsyncAnthropic(api_key=get_llm_config().llm_api_key).messages.create,
create=anthropic.AsyncAnthropic(api_key=self.api_key).messages.create,
mode=instructor.Mode(self.instructor_mode),
)
self.model = model
self.max_completion_tokens = max_completion_tokens
@observe(as_type="generation")
@retry(
stop=stop_after_delay(128),
wait=wait_exponential_jitter(8, 128),

View file

@ -1,4 +1,4 @@
"""Adapter for Generic API LLM provider API"""
"""Adapter for Gemini API LLM provider"""
import litellm
import instructor
@ -8,13 +8,9 @@ from openai import ContentFilterFinishReasonError
from litellm.exceptions import ContentPolicyViolationError
from instructor.core import InstructorRetryException
from cognee.infrastructure.llm.exceptions import ContentPolicyFilterError
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
LLMInterface,
)
import logging
from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
from cognee.shared.logging_utils import get_logger
from tenacity import (
retry,
stop_after_delay,
@ -23,55 +19,65 @@ from tenacity import (
before_sleep_log,
)
from cognee.infrastructure.llm.exceptions import ContentPolicyFilterError
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.generic_llm_api.adapter import (
GenericAPIAdapter,
)
from cognee.shared.logging_utils import get_logger
from cognee.modules.observability.get_observe import get_observe
logger = get_logger()
observe = get_observe()
class GeminiAdapter(LLMInterface):
class GeminiAdapter(GenericAPIAdapter):
"""
Adapter for Gemini API LLM provider.
This class initializes the API adapter with necessary credentials and configurations for
interacting with the gemini LLM models. It provides methods for creating structured outputs
based on user input and system prompts.
based on user input and system prompts, as well as multimodal processing capabilities.
Public methods:
- acreate_structured_output(text_input: str, system_prompt: str, response_model:
Type[BaseModel]) -> BaseModel
- acreate_structured_output(text_input: str, system_prompt: str, response_model: Type[BaseModel]) -> BaseModel
- create_transcript(input) -> BaseModel: Transcribe audio files to text
- transcribe_image(input) -> BaseModel: Inherited from GenericAPIAdapter
"""
name: str
model: str
api_key: str
default_instructor_mode = "json_mode"
def __init__(
self,
endpoint,
api_key: str,
model: str,
api_version: str,
max_completion_tokens: int,
endpoint: str = None,
api_version: str = None,
transcription_model: str = None,
instructor_mode: str = None,
fallback_model: str = None,
fallback_api_key: str = None,
fallback_endpoint: str = None,
):
self.model = model
self.api_key = api_key
self.endpoint = endpoint
self.api_version = api_version
self.max_completion_tokens = max_completion_tokens
self.fallback_model = fallback_model
self.fallback_api_key = fallback_api_key
self.fallback_endpoint = fallback_endpoint
super().__init__(
api_key=api_key,
model=model,
max_completion_tokens=max_completion_tokens,
name="Gemini",
endpoint=endpoint,
api_version=api_version,
transcription_model=transcription_model,
fallback_model=fallback_model,
fallback_api_key=fallback_api_key,
fallback_endpoint=fallback_endpoint,
)
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
self.aclient = instructor.from_litellm(
litellm.acompletion, mode=instructor.Mode(self.instructor_mode)
)
@observe(as_type="generation")
@retry(
stop=stop_after_delay(128),
wait=wait_exponential_jitter(8, 128),

View file

@ -1,8 +1,10 @@
"""Adapter for Generic API LLM provider API"""
import base64
import mimetypes
import litellm
import instructor
from typing import Type
from typing import Type, Optional
from pydantic import BaseModel
from openai import ContentFilterFinishReasonError
from litellm.exceptions import ContentPolicyViolationError
@ -12,6 +14,8 @@ from cognee.infrastructure.llm.exceptions import ContentPolicyFilterError
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
LLMInterface,
)
from cognee.infrastructure.files.utils.open_data_file import open_data_file
from cognee.modules.observability.get_observe import get_observe
import logging
from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
from cognee.shared.logging_utils import get_logger
@ -23,7 +27,12 @@ from tenacity import (
before_sleep_log,
)
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.types import (
TranscriptionReturnType,
)
logger = get_logger()
observe = get_observe()
class GenericAPIAdapter(LLMInterface):
@ -39,18 +48,19 @@ class GenericAPIAdapter(LLMInterface):
Type[BaseModel]) -> BaseModel
"""
name: str
model: str
api_key: str
MAX_RETRIES = 5
default_instructor_mode = "json_mode"
def __init__(
self,
endpoint,
api_key: str,
model: str,
name: str,
max_completion_tokens: int,
name: str,
endpoint: str = None,
api_version: str = None,
transcription_model: str = None,
image_transcribe_model: str = None,
instructor_mode: str = None,
fallback_model: str = None,
fallback_api_key: str = None,
@ -59,9 +69,11 @@ class GenericAPIAdapter(LLMInterface):
self.name = name
self.model = model
self.api_key = api_key
self.api_version = api_version
self.endpoint = endpoint
self.max_completion_tokens = max_completion_tokens
self.transcription_model = transcription_model or model
self.image_transcribe_model = image_transcribe_model or model
self.fallback_model = fallback_model
self.fallback_api_key = fallback_api_key
self.fallback_endpoint = fallback_endpoint
@ -72,6 +84,7 @@ class GenericAPIAdapter(LLMInterface):
litellm.acompletion, mode=instructor.Mode(self.instructor_mode)
)
@observe(as_type="generation")
@retry(
stop=stop_after_delay(128),
wait=wait_exponential_jitter(8, 128),
@ -173,3 +186,115 @@ class GenericAPIAdapter(LLMInterface):
raise ContentPolicyFilterError(
f"The provided input contains content that is not aligned with our content policy: {text_input}"
) from error
@observe(as_type="transcription")
@retry(
stop=stop_after_delay(128),
wait=wait_exponential_jitter(2, 128),
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
async def create_transcript(self, input) -> TranscriptionReturnType:
"""
Generate an audio transcript from a user query.
This method creates a transcript from the specified audio file, raising a
FileNotFoundError if the file does not exist. The audio file is processed and the
transcription is retrieved from the API.
Parameters:
-----------
- input: The path to the audio file that needs to be transcribed.
Returns:
--------
The generated transcription of the audio file.
"""
async with open_data_file(input, mode="rb") as audio_file:
encoded_string = base64.b64encode(audio_file.read()).decode("utf-8")
mime_type, _ = mimetypes.guess_type(input)
if not mime_type or not mime_type.startswith("audio/"):
raise ValueError(
f"Could not determine MIME type for audio file: {input}. Is the extension correct?"
)
response = await litellm.acompletion(
model=self.transcription_model,
messages=[
{
"role": "user",
"content": [
{
"type": "file",
"file": {"file_data": f"data:{mime_type};base64,{encoded_string}"},
},
{"type": "text", "text": "Transcribe the following audio precisely."},
],
}
],
api_key=self.api_key,
api_version=self.api_version,
max_completion_tokens=self.max_completion_tokens,
api_base=self.endpoint,
max_retries=self.MAX_RETRIES,
)
return TranscriptionReturnType(response.choices[0].message.content, response)
@observe(as_type="transcribe_image")
@retry(
stop=stop_after_delay(128),
wait=wait_exponential_jitter(2, 128),
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
async def transcribe_image(self, input) -> BaseModel:
"""
Generate a transcription of an image from a user query.
This method encodes the image and sends a request to the API to obtain a
description of the contents of the image.
Parameters:
-----------
- input: The path to the image file that needs to be transcribed.
Returns:
--------
- BaseModel: A structured output generated by the model, returned as an instance of
BaseModel.
"""
async with open_data_file(input, mode="rb") as image_file:
encoded_image = base64.b64encode(image_file.read()).decode("utf-8")
mime_type, _ = mimetypes.guess_type(input)
if not mime_type or not mime_type.startswith("image/"):
raise ValueError(
f"Could not determine MIME type for image file: {input}. Is the extension correct?"
)
response = await litellm.acompletion(
model=self.image_transcribe_model,
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": "What's in this image?",
},
{
"type": "image_url",
"image_url": {
"url": f"data:{mime_type};base64,{encoded_image}",
},
},
],
}
],
api_key=self.api_key,
api_base=self.endpoint,
api_version=self.api_version,
max_completion_tokens=300,
max_retries=self.MAX_RETRIES,
)
return response

View file

@ -103,7 +103,7 @@ def get_llm_client(raise_api_key_error: bool = True):
llm_config.llm_api_key,
llm_config.llm_model,
"Ollama",
max_completion_tokens=max_completion_tokens,
max_completion_tokens,
instructor_mode=llm_config.llm_instructor_mode.lower(),
)
@ -113,8 +113,9 @@ def get_llm_client(raise_api_key_error: bool = True):
)
return AnthropicAdapter(
max_completion_tokens=max_completion_tokens,
model=llm_config.llm_model,
llm_config.llm_api_key,
llm_config.llm_model,
max_completion_tokens,
instructor_mode=llm_config.llm_instructor_mode.lower(),
)
@ -127,11 +128,10 @@ def get_llm_client(raise_api_key_error: bool = True):
)
return GenericAPIAdapter(
llm_config.llm_endpoint,
llm_config.llm_api_key,
llm_config.llm_model,
max_completion_tokens,
"Custom",
max_completion_tokens=max_completion_tokens,
instructor_mode=llm_config.llm_instructor_mode.lower(),
fallback_api_key=llm_config.fallback_api_key,
fallback_endpoint=llm_config.fallback_endpoint,

View file

@ -3,18 +3,14 @@
from typing import Type, Protocol
from abc import abstractmethod
from pydantic import BaseModel
from cognee.infrastructure.llm.LLMGateway import LLMGateway
class LLMInterface(Protocol):
"""
Define an interface for LLM models with methods for structured output and prompt
display.
Define an interface for LLM models with methods for structured output, multimodal processing, and prompt display.
Methods:
- acreate_structured_output(text_input: str, system_prompt: str, response_model:
Type[BaseModel])
- show_prompt(text_input: str, system_prompt: str)
- acreate_structured_output(text_input: str, system_prompt: str, response_model: Type[BaseModel])
"""
@abstractmethod

View file

@ -1,13 +1,13 @@
import litellm
import instructor
from pydantic import BaseModel
from typing import Type
from typing import Type, Optional
from litellm import JSONSchemaValidationError
from cognee.infrastructure.files.utils.open_data_file import open_data_file
from cognee.shared.logging_utils import get_logger
from cognee.modules.observability.get_observe import get_observe
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
LLMInterface,
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.generic_llm_api.adapter import (
GenericAPIAdapter,
)
from cognee.infrastructure.llm.config import get_llm_config
from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
@ -20,12 +20,14 @@ from tenacity import (
retry_if_not_exception_type,
before_sleep_log,
)
from ..types import TranscriptionReturnType
from mistralai import Mistral
logger = get_logger()
observe = get_observe()
class MistralAdapter(LLMInterface):
class MistralAdapter(GenericAPIAdapter):
"""
Adapter for Mistral AI API, for structured output generation and prompt display.
@ -34,10 +36,6 @@ class MistralAdapter(LLMInterface):
- show_prompt
"""
name = "Mistral"
model: str
api_key: str
max_completion_tokens: int
default_instructor_mode = "mistral_tools"
def __init__(
@ -46,12 +44,19 @@ class MistralAdapter(LLMInterface):
model: str,
max_completion_tokens: int,
endpoint: str = None,
transcription_model: str = None,
image_transcribe_model: str = None,
instructor_mode: str = None,
):
from mistralai import Mistral
self.model = model
self.max_completion_tokens = max_completion_tokens
super().__init__(
api_key=api_key,
model=model,
max_completion_tokens=max_completion_tokens,
name="Mistral",
endpoint=endpoint,
transcription_model=transcription_model,
image_transcribe_model=image_transcribe_model,
)
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
@ -60,7 +65,9 @@ class MistralAdapter(LLMInterface):
mode=instructor.Mode(self.instructor_mode),
api_key=get_llm_config().llm_api_key,
)
self.mistral_client = Mistral(api_key=self.api_key)
@observe(as_type="generation")
@retry(
stop=stop_after_delay(128),
wait=wait_exponential_jitter(8, 128),
@ -119,3 +126,41 @@ class MistralAdapter(LLMInterface):
logger.error(f"Schema validation failed: {str(e)}")
logger.debug(f"Raw response: {e.raw_response}")
raise ValueError(f"Response failed schema validation: {str(e)}")
@observe(as_type="transcription")
@retry(
stop=stop_after_delay(128),
wait=wait_exponential_jitter(2, 128),
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
async def create_transcript(self, input) -> Optional[TranscriptionReturnType]:
"""
Generate an audio transcript from a user query.
This method creates a transcript from the specified audio file.
The audio file is processed and the transcription is retrieved from the API.
Parameters:
-----------
- input: The path to the audio file that needs to be transcribed.
Returns:
--------
The generated transcription of the audio file.
"""
transcription_model = self.transcription_model
if self.transcription_model.startswith("mistral"):
transcription_model = self.transcription_model.split("/")[-1]
file_name = input.split("/")[-1]
async with open_data_file(input, mode="rb") as f:
transcription_response = self.mistral_client.audio.transcriptions.complete(
model=transcription_model,
file={
"content": f,
"file_name": file_name,
},
)
return TranscriptionReturnType(transcription_response.text, transcription_response)

View file

@ -12,7 +12,6 @@ from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.ll
from cognee.infrastructure.files.utils.open_data_file import open_data_file
from cognee.shared.logging_utils import get_logger
from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
from tenacity import (
retry,
stop_after_delay,

View file

@ -1,4 +1,3 @@
import base64
import litellm
import instructor
from typing import Type
@ -16,8 +15,8 @@ from tenacity import (
before_sleep_log,
)
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
LLMInterface,
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.generic_llm_api.adapter import (
GenericAPIAdapter,
)
from cognee.infrastructure.llm.exceptions import (
ContentPolicyFilterError,
@ -26,13 +25,16 @@ from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
from cognee.infrastructure.files.utils.open_data_file import open_data_file
from cognee.modules.observability.get_observe import get_observe
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.types import (
TranscriptionReturnType,
)
logger = get_logger()
observe = get_observe()
class OpenAIAdapter(LLMInterface):
class OpenAIAdapter(GenericAPIAdapter):
"""
Adapter for OpenAI's GPT-3, GPT-4 API.
@ -53,12 +55,7 @@ class OpenAIAdapter(LLMInterface):
- MAX_RETRIES
"""
name = "OpenAI"
model: str
api_key: str
api_version: str
default_instructor_mode = "json_schema_mode"
MAX_RETRIES = 5
"""Adapter for OpenAI's GPT-3, GPT=4 API"""
@ -66,17 +63,29 @@ class OpenAIAdapter(LLMInterface):
def __init__(
self,
api_key: str,
endpoint: str,
api_version: str,
model: str,
transcription_model: str,
max_completion_tokens: int,
endpoint: str = None,
api_version: str = None,
transcription_model: str = None,
instructor_mode: str = None,
streaming: bool = False,
fallback_model: str = None,
fallback_api_key: str = None,
fallback_endpoint: str = None,
):
super().__init__(
api_key=api_key,
model=model,
max_completion_tokens=max_completion_tokens,
name="OpenAI",
endpoint=endpoint,
api_version=api_version,
transcription_model=transcription_model,
fallback_model=fallback_model,
fallback_api_key=fallback_api_key,
fallback_endpoint=fallback_endpoint,
)
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
# TODO: With gpt5 series models OpenAI expects JSON_SCHEMA as a mode for structured outputs.
# Make sure all new gpt models will work with this mode as well.
@ -91,18 +100,8 @@ class OpenAIAdapter(LLMInterface):
self.aclient = instructor.from_litellm(litellm.acompletion)
self.client = instructor.from_litellm(litellm.completion)
self.transcription_model = transcription_model
self.model = model
self.api_key = api_key
self.endpoint = endpoint
self.api_version = api_version
self.max_completion_tokens = max_completion_tokens
self.streaming = streaming
self.fallback_model = fallback_model
self.fallback_api_key = fallback_api_key
self.fallback_endpoint = fallback_endpoint
@observe(as_type="generation")
@retry(
stop=stop_after_delay(128),
@ -198,7 +197,7 @@ class OpenAIAdapter(LLMInterface):
f"The provided input contains content that is not aligned with our content policy: {text_input}"
) from error
@observe
@observe(as_type="transcription")
@retry(
stop=stop_after_delay(128),
wait=wait_exponential_jitter(2, 128),
@ -206,58 +205,7 @@ class OpenAIAdapter(LLMInterface):
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
def create_structured_output(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel], **kwargs
) -> BaseModel:
"""
Generate a response from a user query.
This method creates structured output by sending a synchronous request to the OpenAI API
using the provided parameters to generate a completion based on the user input and
system prompt.
Parameters:
-----------
- text_input (str): The input text provided by the user for generating a response.
- system_prompt (str): The system's prompt to guide the model's response.
- response_model (Type[BaseModel]): The expected model type for the response.
Returns:
--------
- BaseModel: A structured output generated by the model, returned as an instance of
BaseModel.
"""
return self.client.chat.completions.create(
model=self.model,
messages=[
{
"role": "user",
"content": f"""{text_input}""",
},
{
"role": "system",
"content": system_prompt,
},
],
api_key=self.api_key,
api_base=self.endpoint,
api_version=self.api_version,
response_model=response_model,
max_retries=self.MAX_RETRIES,
**kwargs,
)
@retry(
stop=stop_after_delay(128),
wait=wait_exponential_jitter(2, 128),
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
async def create_transcript(self, input, **kwargs):
async def create_transcript(self, input, **kwargs) -> TranscriptionReturnType:
"""
Generate an audio transcript from a user query.
@ -286,60 +234,6 @@ class OpenAIAdapter(LLMInterface):
max_retries=self.MAX_RETRIES,
**kwargs,
)
return TranscriptionReturnType(transcription.text, transcription)
return transcription
@retry(
stop=stop_after_delay(128),
wait=wait_exponential_jitter(2, 128),
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
async def transcribe_image(self, input, **kwargs) -> BaseModel:
"""
Generate a transcription of an image from a user query.
This method encodes the image and sends a request to the OpenAI API to obtain a
description of the contents of the image.
Parameters:
-----------
- input: The path to the image file that needs to be transcribed.
Returns:
--------
- BaseModel: A structured output generated by the model, returned as an instance of
BaseModel.
"""
async with open_data_file(input, mode="rb") as image_file:
encoded_image = base64.b64encode(image_file.read()).decode("utf-8")
return litellm.completion(
model=self.model,
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": "What's in this image?",
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{encoded_image}",
},
},
],
}
],
api_key=self.api_key,
api_base=self.endpoint,
api_version=self.api_version,
max_completion_tokens=300,
max_retries=self.MAX_RETRIES,
**kwargs,
)
# transcribe_image is inherited from GenericAPIAdapter

View file

@ -0,0 +1,10 @@
from pydantic import BaseModel
class TranscriptionReturnType:
text: str
payload: BaseModel
def __init__(self, text: str, payload: BaseModel):
self.text = text
self.payload = payload