fix: Fix based on PR comments

This commit is contained in:
Andrej Milicevic 2025-11-14 11:05:39 +01:00
parent 2337d36f7b
commit 205f5a9e0c
7 changed files with 40 additions and 43 deletions

View file

@ -30,17 +30,14 @@ class AnthropicAdapter(LLMInterface):
model: str
default_instructor_mode = "anthropic_tools"
def __init__(self, max_completion_tokens: int, model: str = None):
def __init__(self, max_completion_tokens: int, model: str = None, instructor_mode: str = None):
import anthropic
config_instructor_mode = get_llm_config().llm_instructor_mode
instructor_mode = (
config_instructor_mode if config_instructor_mode else self.default_instructor_mode
)
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
self.aclient = instructor.patch(
create=anthropic.AsyncAnthropic(api_key=get_llm_config().llm_api_key).messages.create,
mode=instructor.Mode(instructor_mode),
mode=instructor.Mode(self.instructor_mode),
)
self.model = model

View file

@ -50,6 +50,7 @@ class GeminiAdapter(LLMInterface):
model: str,
api_version: str,
max_completion_tokens: int,
instructor_mode: str = None,
fallback_model: str = None,
fallback_api_key: str = None,
fallback_endpoint: str = None,
@ -64,15 +65,10 @@ class GeminiAdapter(LLMInterface):
self.fallback_api_key = fallback_api_key
self.fallback_endpoint = fallback_endpoint
from cognee.infrastructure.llm.config import get_llm_config
config_instructor_mode = get_llm_config().llm_instructor_mode
instructor_mode = (
config_instructor_mode if config_instructor_mode else self.default_instructor_mode
)
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
self.aclient = instructor.from_litellm(
litellm.acompletion, mode=instructor.Mode(instructor_mode)
litellm.acompletion, mode=instructor.Mode(self.instructor_mode)
)
@retry(

View file

@ -50,6 +50,7 @@ class GenericAPIAdapter(LLMInterface):
model: str,
name: str,
max_completion_tokens: int,
instructor_mode: str = None,
fallback_model: str = None,
fallback_api_key: str = None,
fallback_endpoint: str = None,
@ -64,15 +65,10 @@ class GenericAPIAdapter(LLMInterface):
self.fallback_api_key = fallback_api_key
self.fallback_endpoint = fallback_endpoint
from cognee.infrastructure.llm.config import get_llm_config
config_instructor_mode = get_llm_config().llm_instructor_mode
instructor_mode = (
config_instructor_mode if config_instructor_mode else self.default_instructor_mode
)
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
self.aclient = instructor.from_litellm(
litellm.acompletion, mode=instructor.Mode(instructor_mode)
litellm.acompletion, mode=instructor.Mode(self.instructor_mode)
)
@retry(

View file

@ -81,6 +81,7 @@ def get_llm_client(raise_api_key_error: bool = True):
model=llm_config.llm_model,
transcription_model=llm_config.transcription_model,
max_completion_tokens=max_completion_tokens,
instructor_mode=llm_config.llm_instructor_mode,
streaming=llm_config.llm_streaming,
fallback_api_key=llm_config.fallback_api_key,
fallback_endpoint=llm_config.fallback_endpoint,
@ -101,6 +102,7 @@ def get_llm_client(raise_api_key_error: bool = True):
llm_config.llm_model,
"Ollama",
max_completion_tokens=max_completion_tokens,
instructor_mode=llm_config.llm_instructor_mode,
)
elif provider == LLMProvider.ANTHROPIC:
@ -109,7 +111,9 @@ def get_llm_client(raise_api_key_error: bool = True):
)
return AnthropicAdapter(
max_completion_tokens=max_completion_tokens, model=llm_config.llm_model
max_completion_tokens=max_completion_tokens,
model=llm_config.llm_model,
instructor_mode=llm_config.llm_instructor_mode,
)
elif provider == LLMProvider.CUSTOM:
@ -126,6 +130,7 @@ def get_llm_client(raise_api_key_error: bool = True):
llm_config.llm_model,
"Custom",
max_completion_tokens=max_completion_tokens,
instructor_mode=llm_config.llm_instructor_mode,
fallback_api_key=llm_config.fallback_api_key,
fallback_endpoint=llm_config.fallback_endpoint,
fallback_model=llm_config.fallback_model,
@ -145,6 +150,7 @@ def get_llm_client(raise_api_key_error: bool = True):
max_completion_tokens=max_completion_tokens,
endpoint=llm_config.llm_endpoint,
api_version=llm_config.llm_api_version,
instructor_mode=llm_config.llm_instructor_mode,
)
elif provider == LLMProvider.MISTRAL:
@ -160,6 +166,7 @@ def get_llm_client(raise_api_key_error: bool = True):
model=llm_config.llm_model,
max_completion_tokens=max_completion_tokens,
endpoint=llm_config.llm_endpoint,
instructor_mode=llm_config.llm_instructor_mode,
)
elif provider == LLMProvider.MISTRAL:

View file

@ -39,20 +39,24 @@ class MistralAdapter(LLMInterface):
max_completion_tokens: int
default_instructor_mode = "mistral_tools"
def __init__(self, api_key: str, model: str, max_completion_tokens: int, endpoint: str = None):
def __init__(
self,
api_key: str,
model: str,
max_completion_tokens: int,
endpoint: str = None,
instructor_mode: str = None,
):
from mistralai import Mistral
self.model = model
self.max_completion_tokens = max_completion_tokens
config_instructor_mode = get_llm_config().llm_instructor_mode
instructor_mode = (
config_instructor_mode if config_instructor_mode else self.default_instructor_mode
)
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
self.aclient = instructor.from_litellm(
litellm.acompletion,
mode=instructor.Mode(instructor_mode),
mode=instructor.Mode(self.instructor_mode),
api_key=get_llm_config().llm_api_key,
)

View file

@ -45,7 +45,13 @@ class OllamaAPIAdapter(LLMInterface):
default_instructor_mode = "json_mode"
def __init__(
self, endpoint: str, api_key: str, model: str, name: str, max_completion_tokens: int
self,
endpoint: str,
api_key: str,
model: str,
name: str,
max_completion_tokens: int,
instructor_mode: str = None,
):
self.name = name
self.model = model
@ -53,16 +59,11 @@ class OllamaAPIAdapter(LLMInterface):
self.endpoint = endpoint
self.max_completion_tokens = max_completion_tokens
from cognee.infrastructure.llm.config import get_llm_config
config_instructor_mode = get_llm_config().llm_instructor_mode
instructor_mode = (
config_instructor_mode if config_instructor_mode else self.default_instructor_mode
)
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
self.aclient = instructor.from_openai(
OpenAI(base_url=self.endpoint, api_key=self.api_key),
mode=instructor.Mode(instructor_mode),
mode=instructor.Mode(self.instructor_mode),
)
@retry(

View file

@ -70,25 +70,21 @@ class OpenAIAdapter(LLMInterface):
model: str,
transcription_model: str,
max_completion_tokens: int,
instructor_mode: str = None,
streaming: bool = False,
fallback_model: str = None,
fallback_api_key: str = None,
fallback_endpoint: str = None,
):
from cognee.infrastructure.llm.config import get_llm_config
config_instructor_mode = get_llm_config().llm_instructor_mode
instructor_mode = (
config_instructor_mode if config_instructor_mode else self.default_instructor_mode
)
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
# TODO: With gpt5 series models OpenAI expects JSON_SCHEMA as a mode for structured outputs.
# Make sure all new gpt models will work with this mode as well.
if "gpt-5" in model:
self.aclient = instructor.from_litellm(
litellm.acompletion, mode=instructor.Mode(instructor_mode)
litellm.acompletion, mode=instructor.Mode(self.instructor_mode)
)
self.client = instructor.from_litellm(
litellm.completion, mode=instructor.Mode(instructor_mode)
litellm.completion, mode=instructor.Mode(self.instructor_mode)
)
else:
self.aclient = instructor.from_litellm(litellm.acompletion)