Added Mistral support as LLM provider using litellm
This commit is contained in:
parent
a4ab65768b
commit
617c1f0d71
7 changed files with 163 additions and 3 deletions
|
|
@ -170,7 +170,7 @@ async def add(
|
||||||
- LLM_API_KEY: API key for your LLM provider (OpenAI, Anthropic, etc.)
|
- LLM_API_KEY: API key for your LLM provider (OpenAI, Anthropic, etc.)
|
||||||
|
|
||||||
Optional:
|
Optional:
|
||||||
- LLM_PROVIDER: "openai" (default), "anthropic", "gemini", "ollama"
|
- LLM_PROVIDER: "openai" (default), "anthropic", "gemini", "ollama", "mistral"
|
||||||
- LLM_MODEL: Model name (default: "gpt-5-mini")
|
- LLM_MODEL: Model name (default: "gpt-5-mini")
|
||||||
- DEFAULT_USER_EMAIL: Custom default user email
|
- DEFAULT_USER_EMAIL: Custom default user email
|
||||||
- DEFAULT_USER_PASSWORD: Custom default user password
|
- DEFAULT_USER_PASSWORD: Custom default user password
|
||||||
|
|
|
||||||
|
|
@ -21,7 +21,13 @@ class SettingsDTO(OutDTO):
|
||||||
|
|
||||||
|
|
||||||
class LLMConfigInputDTO(InDTO):
|
class LLMConfigInputDTO(InDTO):
|
||||||
provider: Union[Literal["openai"], Literal["ollama"], Literal["anthropic"], Literal["gemini"]]
|
provider: Union[
|
||||||
|
Literal["openai"],
|
||||||
|
Literal["ollama"],
|
||||||
|
Literal["anthropic"],
|
||||||
|
Literal["gemini"],
|
||||||
|
Literal["mistral"],
|
||||||
|
]
|
||||||
model: str
|
model: str
|
||||||
api_key: str
|
api_key: str
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -23,6 +23,7 @@ class LLMProvider(Enum):
|
||||||
- ANTHROPIC: Represents the Anthropic provider.
|
- ANTHROPIC: Represents the Anthropic provider.
|
||||||
- CUSTOM: Represents a custom provider option.
|
- CUSTOM: Represents a custom provider option.
|
||||||
- GEMINI: Represents the Gemini provider.
|
- GEMINI: Represents the Gemini provider.
|
||||||
|
- MISTRAL: Represents the Mistral AI provider.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
OPENAI = "openai"
|
OPENAI = "openai"
|
||||||
|
|
@ -30,6 +31,7 @@ class LLMProvider(Enum):
|
||||||
ANTHROPIC = "anthropic"
|
ANTHROPIC = "anthropic"
|
||||||
CUSTOM = "custom"
|
CUSTOM = "custom"
|
||||||
GEMINI = "gemini"
|
GEMINI = "gemini"
|
||||||
|
MISTRAL = "mistral"
|
||||||
|
|
||||||
|
|
||||||
def get_llm_client(raise_api_key_error: bool = True):
|
def get_llm_client(raise_api_key_error: bool = True):
|
||||||
|
|
@ -145,5 +147,20 @@ def get_llm_client(raise_api_key_error: bool = True):
|
||||||
api_version=llm_config.llm_api_version,
|
api_version=llm_config.llm_api_version,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
elif provider == LLMProvider.MISTRAL:
|
||||||
|
if llm_config.llm_api_key is None:
|
||||||
|
raise LLMAPIKeyNotSetError()
|
||||||
|
|
||||||
|
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.mistral.adapter import (
|
||||||
|
MistralAdapter,
|
||||||
|
)
|
||||||
|
|
||||||
|
return MistralAdapter(
|
||||||
|
api_key=llm_config.llm_api_key,
|
||||||
|
model=llm_config.llm_model,
|
||||||
|
max_completion_tokens=max_completion_tokens,
|
||||||
|
endpoint=llm_config.llm_endpoint,
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise UnsupportedLLMProviderError(provider)
|
raise UnsupportedLLMProviderError(provider)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,113 @@
|
||||||
|
import litellm
|
||||||
|
import instructor
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import Type, Optional
|
||||||
|
from litellm import acompletion, JSONSchemaValidationError
|
||||||
|
|
||||||
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
from cognee.modules.observability.get_observe import get_observe
|
||||||
|
from cognee.infrastructure.llm.exceptions import MissingSystemPromptPathError
|
||||||
|
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
|
||||||
|
LLMInterface,
|
||||||
|
)
|
||||||
|
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
||||||
|
from cognee.infrastructure.llm.config import get_llm_config
|
||||||
|
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.rate_limiter import (
|
||||||
|
rate_limit_async,
|
||||||
|
sleep_and_retry_async,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
observe = get_observe()
|
||||||
|
|
||||||
|
|
||||||
|
class MistralAdapter(LLMInterface):
|
||||||
|
"""
|
||||||
|
Adapter for Mistral AI API, for structured output generation and prompt display.
|
||||||
|
|
||||||
|
Public methods:
|
||||||
|
- acreate_structured_output
|
||||||
|
- show_prompt
|
||||||
|
"""
|
||||||
|
|
||||||
|
name = "Mistral"
|
||||||
|
model: str
|
||||||
|
api_key: str
|
||||||
|
max_completion_tokens: int
|
||||||
|
|
||||||
|
def __init__(self, api_key: str, model: str, max_completion_tokens: int, endpoint: str = None):
|
||||||
|
from mistralai import Mistral
|
||||||
|
|
||||||
|
self.model = model
|
||||||
|
self.max_completion_tokens = max_completion_tokens
|
||||||
|
|
||||||
|
self.aclient = instructor.from_litellm(
|
||||||
|
litellm.acompletion,
|
||||||
|
mode=instructor.Mode.MISTRAL_TOOLS,
|
||||||
|
api_key=get_llm_config().llm_api_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
@sleep_and_retry_async()
|
||||||
|
@rate_limit_async
|
||||||
|
async def acreate_structured_output(
|
||||||
|
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
||||||
|
) -> BaseModel:
|
||||||
|
"""
|
||||||
|
Generate a response from the user query.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
-----------
|
||||||
|
- text_input (str): The input text from the user to be processed.
|
||||||
|
- system_prompt (str): A prompt that sets the context for the query.
|
||||||
|
- response_model (Type[BaseModel]): The model to structure the response according to
|
||||||
|
its format.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
--------
|
||||||
|
- BaseModel: An instance of BaseModel containing the structured response.
|
||||||
|
"""
|
||||||
|
return await self.aclient.chat.completions.create(
|
||||||
|
model=self.model,
|
||||||
|
max_tokens=self.max_completion_tokens,
|
||||||
|
max_retries=5,
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": system_prompt,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": f"""Use the given format to extract information
|
||||||
|
from the following input: {text_input}""",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
response_model=response_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
def show_prompt(self, text_input: str, system_prompt: str) -> str:
|
||||||
|
"""
|
||||||
|
Format and display the prompt for a user query.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
-----------
|
||||||
|
- text_input (str): Input text from the user to be included in the prompt.
|
||||||
|
- system_prompt (str): The system prompt that will be shown alongside the user input.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
--------
|
||||||
|
- str: The formatted prompt string combining system prompt and user input.
|
||||||
|
"""
|
||||||
|
if not text_input:
|
||||||
|
text_input = "No user input provided."
|
||||||
|
if not system_prompt:
|
||||||
|
raise MissingSystemPromptPathError()
|
||||||
|
|
||||||
|
system_prompt = LLMGateway.read_query_prompt(system_prompt)
|
||||||
|
|
||||||
|
formatted_prompt = (
|
||||||
|
f"""System Prompt:\n{system_prompt}\n\nUser Input:\n{text_input}\n"""
|
||||||
|
if system_prompt
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
return formatted_prompt
|
||||||
|
|
@ -15,6 +15,7 @@ class ModelName(Enum):
|
||||||
ollama = "ollama"
|
ollama = "ollama"
|
||||||
anthropic = "anthropic"
|
anthropic = "anthropic"
|
||||||
gemini = "gemini"
|
gemini = "gemini"
|
||||||
|
mistral = "mistral"
|
||||||
|
|
||||||
|
|
||||||
class LLMConfig(BaseModel):
|
class LLMConfig(BaseModel):
|
||||||
|
|
@ -72,6 +73,10 @@ def get_settings() -> SettingsDict:
|
||||||
"value": "gemini",
|
"value": "gemini",
|
||||||
"label": "Gemini",
|
"label": "Gemini",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"value": "mistral",
|
||||||
|
"label": "Mistral",
|
||||||
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
return SettingsDict.model_validate(
|
return SettingsDict.model_validate(
|
||||||
|
|
@ -134,6 +139,24 @@ def get_settings() -> SettingsDict:
|
||||||
"label": "Gemini 2.0 Flash",
|
"label": "Gemini 2.0 Flash",
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
|
"mistral": [
|
||||||
|
{
|
||||||
|
"value": "mistral-medium-2508",
|
||||||
|
"label": "Mistral Medium 3.1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"value": "magistral-medium-2509",
|
||||||
|
"label": "Magistral Medium 1.2",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"value": "magistral-medium-2507",
|
||||||
|
"label": "Magistral Medium 1.1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"value": "mistral-large-2411",
|
||||||
|
"label": "Mistral Large 2.1",
|
||||||
|
},
|
||||||
|
],
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
vector_db={
|
vector_db={
|
||||||
|
|
|
||||||
|
|
@ -54,7 +54,8 @@ dependencies = [
|
||||||
"networkx>=3.4.2,<4",
|
"networkx>=3.4.2,<4",
|
||||||
"uvicorn>=0.34.0,<1.0.0",
|
"uvicorn>=0.34.0,<1.0.0",
|
||||||
"gunicorn>=20.1.0,<24",
|
"gunicorn>=20.1.0,<24",
|
||||||
"websockets>=15.0.1,<16.0.0"
|
"websockets>=15.0.1,<16.0.0",
|
||||||
|
"mistralai>=1.9.10",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue