diff --git a/cognee/api/v1/add/add.py b/cognee/api/v1/add/add.py index cbb1c8fd5..65394f1ec 100644 --- a/cognee/api/v1/add/add.py +++ b/cognee/api/v1/add/add.py @@ -170,7 +170,7 @@ async def add( - LLM_API_KEY: API key for your LLM provider (OpenAI, Anthropic, etc.) Optional: - - LLM_PROVIDER: "openai" (default), "anthropic", "gemini", "ollama" + - LLM_PROVIDER: "openai" (default), "anthropic", "gemini", "ollama", "mistral" - LLM_MODEL: Model name (default: "gpt-5-mini") - DEFAULT_USER_EMAIL: Custom default user email - DEFAULT_USER_PASSWORD: Custom default user password diff --git a/cognee/api/v1/settings/routers/get_settings_router.py b/cognee/api/v1/settings/routers/get_settings_router.py index c85352746..c64ce2072 100644 --- a/cognee/api/v1/settings/routers/get_settings_router.py +++ b/cognee/api/v1/settings/routers/get_settings_router.py @@ -21,7 +21,13 @@ class SettingsDTO(OutDTO): class LLMConfigInputDTO(InDTO): - provider: Union[Literal["openai"], Literal["ollama"], Literal["anthropic"], Literal["gemini"]] + provider: Union[ + Literal["openai"], + Literal["ollama"], + Literal["anthropic"], + Literal["gemini"], + Literal["mistral"], + ] model: str api_key: str diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py index 0ae621428..bbdfe49e9 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py @@ -23,6 +23,7 @@ class LLMProvider(Enum): - ANTHROPIC: Represents the Anthropic provider. - CUSTOM: Represents a custom provider option. - GEMINI: Represents the Gemini provider. + - MISTRAL: Represents the Mistral AI provider. """ OPENAI = "openai" @@ -30,6 +31,7 @@ class LLMProvider(Enum): ANTHROPIC = "anthropic" CUSTOM = "custom" GEMINI = "gemini" + MISTRAL = "mistral" def get_llm_client(raise_api_key_error: bool = True): @@ -145,5 +147,20 @@ def get_llm_client(raise_api_key_error: bool = True): api_version=llm_config.llm_api_version, ) + elif provider == LLMProvider.MISTRAL: + if llm_config.llm_api_key is None: + raise LLMAPIKeyNotSetError() + + from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.mistral.adapter import ( + MistralAdapter, + ) + + return MistralAdapter( + api_key=llm_config.llm_api_key, + model=llm_config.llm_model, + max_completion_tokens=max_completion_tokens, + endpoint=llm_config.llm_endpoint, + ) + else: raise UnsupportedLLMProviderError(provider) diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/__init__.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py new file mode 100644 index 000000000..a9cf68d6e --- /dev/null +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py @@ -0,0 +1,113 @@ +import litellm +import instructor +from pydantic import BaseModel +from typing import Type, Optional +from litellm import acompletion, JSONSchemaValidationError + +from cognee.shared.logging_utils import get_logger +from cognee.modules.observability.get_observe import get_observe +from cognee.infrastructure.llm.exceptions import MissingSystemPromptPathError +from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import ( + LLMInterface, +) +from cognee.infrastructure.llm.LLMGateway import LLMGateway +from cognee.infrastructure.llm.config import get_llm_config +from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.rate_limiter import ( + rate_limit_async, + sleep_and_retry_async, +) + +logger = get_logger() +observe = get_observe() + + +class MistralAdapter(LLMInterface): + """ + Adapter for Mistral AI API, for structured output generation and prompt display. + + Public methods: + - acreate_structured_output + - show_prompt + """ + + name = "Mistral" + model: str + api_key: str + max_completion_tokens: int + + def __init__(self, api_key: str, model: str, max_completion_tokens: int, endpoint: str = None): + from mistralai import Mistral + + self.model = model + self.max_completion_tokens = max_completion_tokens + + self.aclient = instructor.from_litellm( + litellm.acompletion, + mode=instructor.Mode.MISTRAL_TOOLS, + api_key=get_llm_config().llm_api_key, + ) + + @sleep_and_retry_async() + @rate_limit_async + async def acreate_structured_output( + self, text_input: str, system_prompt: str, response_model: Type[BaseModel] + ) -> BaseModel: + """ + Generate a response from the user query. + + Parameters: + ----------- + - text_input (str): The input text from the user to be processed. + - system_prompt (str): A prompt that sets the context for the query. + - response_model (Type[BaseModel]): The model to structure the response according to + its format. + + Returns: + -------- + - BaseModel: An instance of BaseModel containing the structured response. + """ + return await self.aclient.chat.completions.create( + model=self.model, + max_tokens=self.max_completion_tokens, + max_retries=5, + messages=[ + { + "role": "system", + "content": system_prompt, + }, + { + "role": "user", + "content": f"""Use the given format to extract information + from the following input: {text_input}""", + }, + ], + response_model=response_model, + ) + + def show_prompt(self, text_input: str, system_prompt: str) -> str: + """ + Format and display the prompt for a user query. + + Parameters: + ----------- + - text_input (str): Input text from the user to be included in the prompt. + - system_prompt (str): The system prompt that will be shown alongside the user input. + + Returns: + -------- + - str: The formatted prompt string combining system prompt and user input. + """ + if not text_input: + text_input = "No user input provided." + if not system_prompt: + raise MissingSystemPromptPathError() + + system_prompt = LLMGateway.read_query_prompt(system_prompt) + + formatted_prompt = ( + f"""System Prompt:\n{system_prompt}\n\nUser Input:\n{text_input}\n""" + if system_prompt + else None + ) + + return formatted_prompt diff --git a/cognee/modules/settings/get_settings.py b/cognee/modules/settings/get_settings.py index fa92c8043..071bcca36 100644 --- a/cognee/modules/settings/get_settings.py +++ b/cognee/modules/settings/get_settings.py @@ -15,6 +15,7 @@ class ModelName(Enum): ollama = "ollama" anthropic = "anthropic" gemini = "gemini" + mistral = "mistral" class LLMConfig(BaseModel): @@ -72,6 +73,10 @@ def get_settings() -> SettingsDict: "value": "gemini", "label": "Gemini", }, + { + "value": "mistral", + "label": "Mistral", + }, ] return SettingsDict.model_validate( @@ -134,6 +139,24 @@ def get_settings() -> SettingsDict: "label": "Gemini 2.0 Flash", }, ], + "mistral": [ + { + "value": "mistral-medium-2508", + "label": "Mistral Medium 3.1", + }, + { + "value": "magistral-medium-2509", + "label": "Magistral Medium 1.2", + }, + { + "value": "magistral-medium-2507", + "label": "Magistral Medium 1.1", + }, + { + "value": "mistral-large-2411", + "label": "Mistral Large 2.1", + }, + ], }, }, vector_db={ diff --git a/pyproject.toml b/pyproject.toml index 4baf40e2f..aa5e693c1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,7 +54,8 @@ dependencies = [ "networkx>=3.4.2,<4", "uvicorn>=0.34.0,<1.0.0", "gunicorn>=20.1.0,<24", - "websockets>=15.0.1,<16.0.0" + "websockets>=15.0.1,<16.0.0", + "mistralai>=1.9.10", ] [project.optional-dependencies]