Added Mistral support as LLM provider using litellm

This commit is contained in:
Aniruddha Mandal 2025-09-23 13:16:44 +05:30 committed by vasilije
parent a4ab65768b
commit 617c1f0d71
7 changed files with 163 additions and 3 deletions

View file

@ -170,7 +170,7 @@ async def add(
- LLM_API_KEY: API key for your LLM provider (OpenAI, Anthropic, etc.)
Optional:
- LLM_PROVIDER: "openai" (default), "anthropic", "gemini", "ollama"
- LLM_PROVIDER: "openai" (default), "anthropic", "gemini", "ollama", "mistral"
- LLM_MODEL: Model name (default: "gpt-5-mini")
- DEFAULT_USER_EMAIL: Custom default user email
- DEFAULT_USER_PASSWORD: Custom default user password

View file

@ -21,7 +21,13 @@ class SettingsDTO(OutDTO):
class LLMConfigInputDTO(InDTO):
provider: Union[Literal["openai"], Literal["ollama"], Literal["anthropic"], Literal["gemini"]]
provider: Union[
Literal["openai"],
Literal["ollama"],
Literal["anthropic"],
Literal["gemini"],
Literal["mistral"],
]
model: str
api_key: str

View file

@ -23,6 +23,7 @@ class LLMProvider(Enum):
- ANTHROPIC: Represents the Anthropic provider.
- CUSTOM: Represents a custom provider option.
- GEMINI: Represents the Gemini provider.
- MISTRAL: Represents the Mistral AI provider.
"""
OPENAI = "openai"
@ -30,6 +31,7 @@ class LLMProvider(Enum):
ANTHROPIC = "anthropic"
CUSTOM = "custom"
GEMINI = "gemini"
MISTRAL = "mistral"
def get_llm_client(raise_api_key_error: bool = True):
@ -145,5 +147,20 @@ def get_llm_client(raise_api_key_error: bool = True):
api_version=llm_config.llm_api_version,
)
elif provider == LLMProvider.MISTRAL:
if llm_config.llm_api_key is None:
raise LLMAPIKeyNotSetError()
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.mistral.adapter import (
MistralAdapter,
)
return MistralAdapter(
api_key=llm_config.llm_api_key,
model=llm_config.llm_model,
max_completion_tokens=max_completion_tokens,
endpoint=llm_config.llm_endpoint,
)
else:
raise UnsupportedLLMProviderError(provider)

View file

@ -0,0 +1,113 @@
import litellm
import instructor
from pydantic import BaseModel
from typing import Type, Optional
from litellm import acompletion, JSONSchemaValidationError
from cognee.shared.logging_utils import get_logger
from cognee.modules.observability.get_observe import get_observe
from cognee.infrastructure.llm.exceptions import MissingSystemPromptPathError
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
LLMInterface,
)
from cognee.infrastructure.llm.LLMGateway import LLMGateway
from cognee.infrastructure.llm.config import get_llm_config
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.rate_limiter import (
rate_limit_async,
sleep_and_retry_async,
)
logger = get_logger()
observe = get_observe()
class MistralAdapter(LLMInterface):
"""
Adapter for Mistral AI API, for structured output generation and prompt display.
Public methods:
- acreate_structured_output
- show_prompt
"""
name = "Mistral"
model: str
api_key: str
max_completion_tokens: int
def __init__(self, api_key: str, model: str, max_completion_tokens: int, endpoint: str = None):
from mistralai import Mistral
self.model = model
self.max_completion_tokens = max_completion_tokens
self.aclient = instructor.from_litellm(
litellm.acompletion,
mode=instructor.Mode.MISTRAL_TOOLS,
api_key=get_llm_config().llm_api_key,
)
@sleep_and_retry_async()
@rate_limit_async
async def acreate_structured_output(
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
) -> BaseModel:
"""
Generate a response from the user query.
Parameters:
-----------
- text_input (str): The input text from the user to be processed.
- system_prompt (str): A prompt that sets the context for the query.
- response_model (Type[BaseModel]): The model to structure the response according to
its format.
Returns:
--------
- BaseModel: An instance of BaseModel containing the structured response.
"""
return await self.aclient.chat.completions.create(
model=self.model,
max_tokens=self.max_completion_tokens,
max_retries=5,
messages=[
{
"role": "system",
"content": system_prompt,
},
{
"role": "user",
"content": f"""Use the given format to extract information
from the following input: {text_input}""",
},
],
response_model=response_model,
)
def show_prompt(self, text_input: str, system_prompt: str) -> str:
"""
Format and display the prompt for a user query.
Parameters:
-----------
- text_input (str): Input text from the user to be included in the prompt.
- system_prompt (str): The system prompt that will be shown alongside the user input.
Returns:
--------
- str: The formatted prompt string combining system prompt and user input.
"""
if not text_input:
text_input = "No user input provided."
if not system_prompt:
raise MissingSystemPromptPathError()
system_prompt = LLMGateway.read_query_prompt(system_prompt)
formatted_prompt = (
f"""System Prompt:\n{system_prompt}\n\nUser Input:\n{text_input}\n"""
if system_prompt
else None
)
return formatted_prompt

View file

@ -15,6 +15,7 @@ class ModelName(Enum):
ollama = "ollama"
anthropic = "anthropic"
gemini = "gemini"
mistral = "mistral"
class LLMConfig(BaseModel):
@ -72,6 +73,10 @@ def get_settings() -> SettingsDict:
"value": "gemini",
"label": "Gemini",
},
{
"value": "mistral",
"label": "Mistral",
},
]
return SettingsDict.model_validate(
@ -134,6 +139,24 @@ def get_settings() -> SettingsDict:
"label": "Gemini 2.0 Flash",
},
],
"mistral": [
{
"value": "mistral-medium-2508",
"label": "Mistral Medium 3.1",
},
{
"value": "magistral-medium-2509",
"label": "Magistral Medium 1.2",
},
{
"value": "magistral-medium-2507",
"label": "Magistral Medium 1.1",
},
{
"value": "mistral-large-2411",
"label": "Mistral Large 2.1",
},
],
},
},
vector_db={

View file

@ -54,7 +54,8 @@ dependencies = [
"networkx>=3.4.2,<4",
"uvicorn>=0.34.0,<1.0.0",
"gunicorn>=20.1.0,<24",
"websockets>=15.0.1,<16.0.0"
"websockets>=15.0.1,<16.0.0",
"mistralai>=1.9.10",
]
[project.optional-dependencies]