Added Mistral support as LLM provider using litellm
This commit is contained in:
parent
a4ab65768b
commit
617c1f0d71
7 changed files with 163 additions and 3 deletions
|
|
@ -170,7 +170,7 @@ async def add(
|
|||
- LLM_API_KEY: API key for your LLM provider (OpenAI, Anthropic, etc.)
|
||||
|
||||
Optional:
|
||||
- LLM_PROVIDER: "openai" (default), "anthropic", "gemini", "ollama"
|
||||
- LLM_PROVIDER: "openai" (default), "anthropic", "gemini", "ollama", "mistral"
|
||||
- LLM_MODEL: Model name (default: "gpt-5-mini")
|
||||
- DEFAULT_USER_EMAIL: Custom default user email
|
||||
- DEFAULT_USER_PASSWORD: Custom default user password
|
||||
|
|
|
|||
|
|
@ -21,7 +21,13 @@ class SettingsDTO(OutDTO):
|
|||
|
||||
|
||||
class LLMConfigInputDTO(InDTO):
|
||||
provider: Union[Literal["openai"], Literal["ollama"], Literal["anthropic"], Literal["gemini"]]
|
||||
provider: Union[
|
||||
Literal["openai"],
|
||||
Literal["ollama"],
|
||||
Literal["anthropic"],
|
||||
Literal["gemini"],
|
||||
Literal["mistral"],
|
||||
]
|
||||
model: str
|
||||
api_key: str
|
||||
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ class LLMProvider(Enum):
|
|||
- ANTHROPIC: Represents the Anthropic provider.
|
||||
- CUSTOM: Represents a custom provider option.
|
||||
- GEMINI: Represents the Gemini provider.
|
||||
- MISTRAL: Represents the Mistral AI provider.
|
||||
"""
|
||||
|
||||
OPENAI = "openai"
|
||||
|
|
@ -30,6 +31,7 @@ class LLMProvider(Enum):
|
|||
ANTHROPIC = "anthropic"
|
||||
CUSTOM = "custom"
|
||||
GEMINI = "gemini"
|
||||
MISTRAL = "mistral"
|
||||
|
||||
|
||||
def get_llm_client(raise_api_key_error: bool = True):
|
||||
|
|
@ -145,5 +147,20 @@ def get_llm_client(raise_api_key_error: bool = True):
|
|||
api_version=llm_config.llm_api_version,
|
||||
)
|
||||
|
||||
elif provider == LLMProvider.MISTRAL:
|
||||
if llm_config.llm_api_key is None:
|
||||
raise LLMAPIKeyNotSetError()
|
||||
|
||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.mistral.adapter import (
|
||||
MistralAdapter,
|
||||
)
|
||||
|
||||
return MistralAdapter(
|
||||
api_key=llm_config.llm_api_key,
|
||||
model=llm_config.llm_model,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
endpoint=llm_config.llm_endpoint,
|
||||
)
|
||||
|
||||
else:
|
||||
raise UnsupportedLLMProviderError(provider)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,113 @@
|
|||
import litellm
|
||||
import instructor
|
||||
from pydantic import BaseModel
|
||||
from typing import Type, Optional
|
||||
from litellm import acompletion, JSONSchemaValidationError
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.observability.get_observe import get_observe
|
||||
from cognee.infrastructure.llm.exceptions import MissingSystemPromptPathError
|
||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
|
||||
LLMInterface,
|
||||
)
|
||||
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
||||
from cognee.infrastructure.llm.config import get_llm_config
|
||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.rate_limiter import (
|
||||
rate_limit_async,
|
||||
sleep_and_retry_async,
|
||||
)
|
||||
|
||||
logger = get_logger()
|
||||
observe = get_observe()
|
||||
|
||||
|
||||
class MistralAdapter(LLMInterface):
|
||||
"""
|
||||
Adapter for Mistral AI API, for structured output generation and prompt display.
|
||||
|
||||
Public methods:
|
||||
- acreate_structured_output
|
||||
- show_prompt
|
||||
"""
|
||||
|
||||
name = "Mistral"
|
||||
model: str
|
||||
api_key: str
|
||||
max_completion_tokens: int
|
||||
|
||||
def __init__(self, api_key: str, model: str, max_completion_tokens: int, endpoint: str = None):
|
||||
from mistralai import Mistral
|
||||
|
||||
self.model = model
|
||||
self.max_completion_tokens = max_completion_tokens
|
||||
|
||||
self.aclient = instructor.from_litellm(
|
||||
litellm.acompletion,
|
||||
mode=instructor.Mode.MISTRAL_TOOLS,
|
||||
api_key=get_llm_config().llm_api_key,
|
||||
)
|
||||
|
||||
@sleep_and_retry_async()
|
||||
@rate_limit_async
|
||||
async def acreate_structured_output(
|
||||
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
||||
) -> BaseModel:
|
||||
"""
|
||||
Generate a response from the user query.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
- text_input (str): The input text from the user to be processed.
|
||||
- system_prompt (str): A prompt that sets the context for the query.
|
||||
- response_model (Type[BaseModel]): The model to structure the response according to
|
||||
its format.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
- BaseModel: An instance of BaseModel containing the structured response.
|
||||
"""
|
||||
return await self.aclient.chat.completions.create(
|
||||
model=self.model,
|
||||
max_tokens=self.max_completion_tokens,
|
||||
max_retries=5,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""Use the given format to extract information
|
||||
from the following input: {text_input}""",
|
||||
},
|
||||
],
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
def show_prompt(self, text_input: str, system_prompt: str) -> str:
|
||||
"""
|
||||
Format and display the prompt for a user query.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
- text_input (str): Input text from the user to be included in the prompt.
|
||||
- system_prompt (str): The system prompt that will be shown alongside the user input.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
- str: The formatted prompt string combining system prompt and user input.
|
||||
"""
|
||||
if not text_input:
|
||||
text_input = "No user input provided."
|
||||
if not system_prompt:
|
||||
raise MissingSystemPromptPathError()
|
||||
|
||||
system_prompt = LLMGateway.read_query_prompt(system_prompt)
|
||||
|
||||
formatted_prompt = (
|
||||
f"""System Prompt:\n{system_prompt}\n\nUser Input:\n{text_input}\n"""
|
||||
if system_prompt
|
||||
else None
|
||||
)
|
||||
|
||||
return formatted_prompt
|
||||
|
|
@ -15,6 +15,7 @@ class ModelName(Enum):
|
|||
ollama = "ollama"
|
||||
anthropic = "anthropic"
|
||||
gemini = "gemini"
|
||||
mistral = "mistral"
|
||||
|
||||
|
||||
class LLMConfig(BaseModel):
|
||||
|
|
@ -72,6 +73,10 @@ def get_settings() -> SettingsDict:
|
|||
"value": "gemini",
|
||||
"label": "Gemini",
|
||||
},
|
||||
{
|
||||
"value": "mistral",
|
||||
"label": "Mistral",
|
||||
},
|
||||
]
|
||||
|
||||
return SettingsDict.model_validate(
|
||||
|
|
@ -134,6 +139,24 @@ def get_settings() -> SettingsDict:
|
|||
"label": "Gemini 2.0 Flash",
|
||||
},
|
||||
],
|
||||
"mistral": [
|
||||
{
|
||||
"value": "mistral-medium-2508",
|
||||
"label": "Mistral Medium 3.1",
|
||||
},
|
||||
{
|
||||
"value": "magistral-medium-2509",
|
||||
"label": "Magistral Medium 1.2",
|
||||
},
|
||||
{
|
||||
"value": "magistral-medium-2507",
|
||||
"label": "Magistral Medium 1.1",
|
||||
},
|
||||
{
|
||||
"value": "mistral-large-2411",
|
||||
"label": "Mistral Large 2.1",
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
vector_db={
|
||||
|
|
|
|||
|
|
@ -54,7 +54,8 @@ dependencies = [
|
|||
"networkx>=3.4.2,<4",
|
||||
"uvicorn>=0.34.0,<1.0.0",
|
||||
"gunicorn>=20.1.0,<24",
|
||||
"websockets>=15.0.1,<16.0.0"
|
||||
"websockets>=15.0.1,<16.0.0",
|
||||
"mistralai>=1.9.10",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue