fix: Pr 1449 (#1533)
<!-- .github/pull_request_template.md --> ## Description <!-- Please provide a clear, human-generated description of the changes in this PR. DO NOT use AI-generated descriptions. We want to understand your thought process and reasoning. --> ## Type of Change <!-- Please check the relevant option --> - [x] Bug fix (non-breaking change that fixes an issue) - [x] New feature (non-breaking change that adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Documentation update - [ ] Code refactoring - [ ] Performance improvement - [ ] Other (please specify): ## Screenshots/Videos (if applicable) <!-- Add screenshots or videos to help explain your changes --> ## Pre-submission Checklist <!-- Please check all boxes that apply before submitting your PR --> - [ ] **I have tested my changes thoroughly before submitting this PR** - [ ] **This PR contains minimal changes necessary to address the issue/feature** - [ ] My code follows the project's coding standards and style guidelines - [ ] I have added tests that prove my fix is effective or that my feature works - [ ] I have added necessary documentation (if applicable) - [ ] All new and existing tests pass - [ ] I have searched existing PRs to ensure this change hasn't been submitted already - [ ] I have linked any relevant issues in the description - [ ] My commits have clear and descriptive messages ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin.
This commit is contained in:
commit
fae0e240d3
9 changed files with 5575 additions and 5086 deletions
|
|
@ -170,7 +170,7 @@ async def add(
|
|||
- LLM_API_KEY: API key for your LLM provider (OpenAI, Anthropic, etc.)
|
||||
|
||||
Optional:
|
||||
- LLM_PROVIDER: "openai" (default), "anthropic", "gemini", "ollama"
|
||||
- LLM_PROVIDER: "openai" (default), "anthropic", "gemini", "ollama", "mistral"
|
||||
- LLM_MODEL: Model name (default: "gpt-5-mini")
|
||||
- DEFAULT_USER_EMAIL: Custom default user email
|
||||
- DEFAULT_USER_PASSWORD: Custom default user password
|
||||
|
|
|
|||
|
|
@ -21,7 +21,13 @@ class SettingsDTO(OutDTO):
|
|||
|
||||
|
||||
class LLMConfigInputDTO(InDTO):
|
||||
provider: Union[Literal["openai"], Literal["ollama"], Literal["anthropic"], Literal["gemini"]]
|
||||
provider: Union[
|
||||
Literal["openai"],
|
||||
Literal["ollama"],
|
||||
Literal["anthropic"],
|
||||
Literal["gemini"],
|
||||
Literal["mistral"],
|
||||
]
|
||||
model: str
|
||||
api_key: str
|
||||
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ class LLMProvider(Enum):
|
|||
- ANTHROPIC: Represents the Anthropic provider.
|
||||
- CUSTOM: Represents a custom provider option.
|
||||
- GEMINI: Represents the Gemini provider.
|
||||
- MISTRAL: Represents the Mistral AI provider.
|
||||
"""
|
||||
|
||||
OPENAI = "openai"
|
||||
|
|
@ -30,6 +31,7 @@ class LLMProvider(Enum):
|
|||
ANTHROPIC = "anthropic"
|
||||
CUSTOM = "custom"
|
||||
GEMINI = "gemini"
|
||||
MISTRAL = "mistral"
|
||||
|
||||
|
||||
def get_llm_client(raise_api_key_error: bool = True):
|
||||
|
|
@ -145,5 +147,35 @@ def get_llm_client(raise_api_key_error: bool = True):
|
|||
api_version=llm_config.llm_api_version,
|
||||
)
|
||||
|
||||
elif provider == LLMProvider.MISTRAL:
|
||||
if llm_config.llm_api_key is None:
|
||||
raise LLMAPIKeyNotSetError()
|
||||
|
||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.mistral.adapter import (
|
||||
MistralAdapter,
|
||||
)
|
||||
|
||||
return MistralAdapter(
|
||||
api_key=llm_config.llm_api_key,
|
||||
model=llm_config.llm_model,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
endpoint=llm_config.llm_endpoint,
|
||||
)
|
||||
|
||||
elif provider == LLMProvider.MISTRAL:
|
||||
if llm_config.llm_api_key is None:
|
||||
raise LLMAPIKeyNotSetError()
|
||||
|
||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.mistral.adapter import (
|
||||
MistralAdapter,
|
||||
)
|
||||
|
||||
return MistralAdapter(
|
||||
api_key=llm_config.llm_api_key,
|
||||
model=llm_config.llm_model,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
endpoint=llm_config.llm_endpoint,
|
||||
)
|
||||
|
||||
else:
|
||||
raise UnsupportedLLMProviderError(provider)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,129 @@
|
|||
import litellm
|
||||
import instructor
|
||||
from pydantic import BaseModel
|
||||
from typing import Type, Optional
|
||||
from litellm import acompletion, JSONSchemaValidationError
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.observability.get_observe import get_observe
|
||||
from cognee.infrastructure.llm.exceptions import MissingSystemPromptPathError
|
||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
|
||||
LLMInterface,
|
||||
)
|
||||
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
||||
from cognee.infrastructure.llm.config import get_llm_config
|
||||
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.rate_limiter import (
|
||||
rate_limit_async,
|
||||
sleep_and_retry_async,
|
||||
)
|
||||
|
||||
logger = get_logger()
|
||||
observe = get_observe()
|
||||
|
||||
|
||||
class MistralAdapter(LLMInterface):
|
||||
"""
|
||||
Adapter for Mistral AI API, for structured output generation and prompt display.
|
||||
|
||||
Public methods:
|
||||
- acreate_structured_output
|
||||
- show_prompt
|
||||
"""
|
||||
|
||||
name = "Mistral"
|
||||
model: str
|
||||
api_key: str
|
||||
max_completion_tokens: int
|
||||
|
||||
def __init__(self, api_key: str, model: str, max_completion_tokens: int, endpoint: str = None):
|
||||
from mistralai import Mistral
|
||||
|
||||
self.model = model
|
||||
self.max_completion_tokens = max_completion_tokens
|
||||
|
||||
self.aclient = instructor.from_litellm(
|
||||
litellm.acompletion,
|
||||
mode=instructor.Mode.MISTRAL_TOOLS,
|
||||
api_key=get_llm_config().llm_api_key,
|
||||
)
|
||||
|
||||
@sleep_and_retry_async()
|
||||
@rate_limit_async
|
||||
async def acreate_structured_output(
|
||||
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
||||
) -> BaseModel:
|
||||
"""
|
||||
Generate a response from the user query.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
- text_input (str): The input text from the user to be processed.
|
||||
- system_prompt (str): A prompt that sets the context for the query.
|
||||
- response_model (Type[BaseModel]): The model to structure the response according to
|
||||
its format.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
- BaseModel: An instance of BaseModel containing the structured response.
|
||||
"""
|
||||
try:
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""Use the given format to extract information
|
||||
from the following input: {text_input}""",
|
||||
},
|
||||
]
|
||||
try:
|
||||
response = await self.aclient.chat.completions.create(
|
||||
model=self.model,
|
||||
max_tokens=self.max_completion_tokens,
|
||||
max_retries=5,
|
||||
messages=messages,
|
||||
response_model=response_model,
|
||||
)
|
||||
if response.choices and response.choices[0].message.content:
|
||||
content = response.choices[0].message.content
|
||||
return response_model.model_validate_json(content)
|
||||
else:
|
||||
raise ValueError("Failed to get valid response after retries")
|
||||
except litellm.exceptions.BadRequestError as e:
|
||||
logger.error(f"Bad request error: {str(e)}")
|
||||
raise ValueError(f"Invalid request: {str(e)}")
|
||||
|
||||
except JSONSchemaValidationError as e:
|
||||
logger.error(f"Schema validation failed: {str(e)}")
|
||||
logger.debug(f"Raw response: {e.raw_response}")
|
||||
raise ValueError(f"Response failed schema validation: {str(e)}")
|
||||
|
||||
def show_prompt(self, text_input: str, system_prompt: str) -> str:
|
||||
"""
|
||||
Format and display the prompt for a user query.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
- text_input (str): Input text from the user to be included in the prompt.
|
||||
- system_prompt (str): The system prompt that will be shown alongside the user input.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
- str: The formatted prompt string combining system prompt and user input.
|
||||
"""
|
||||
if not text_input:
|
||||
text_input = "No user input provided."
|
||||
if not system_prompt:
|
||||
raise MissingSystemPromptPathError()
|
||||
|
||||
system_prompt = LLMGateway.read_query_prompt(system_prompt)
|
||||
|
||||
formatted_prompt = (
|
||||
f"""System Prompt:\n{system_prompt}\n\nUser Input:\n{text_input}\n"""
|
||||
if system_prompt
|
||||
else None
|
||||
)
|
||||
|
||||
return formatted_prompt
|
||||
|
|
@ -15,6 +15,7 @@ class ModelName(Enum):
|
|||
ollama = "ollama"
|
||||
anthropic = "anthropic"
|
||||
gemini = "gemini"
|
||||
mistral = "mistral"
|
||||
|
||||
|
||||
class LLMConfig(BaseModel):
|
||||
|
|
@ -72,6 +73,10 @@ def get_settings() -> SettingsDict:
|
|||
"value": "gemini",
|
||||
"label": "Gemini",
|
||||
},
|
||||
{
|
||||
"value": "mistral",
|
||||
"label": "Mistral",
|
||||
},
|
||||
]
|
||||
|
||||
return SettingsDict.model_validate(
|
||||
|
|
@ -134,6 +139,24 @@ def get_settings() -> SettingsDict:
|
|||
"label": "Gemini 2.0 Flash",
|
||||
},
|
||||
],
|
||||
"mistral": [
|
||||
{
|
||||
"value": "mistral-medium-2508",
|
||||
"label": "Mistral Medium 3.1",
|
||||
},
|
||||
{
|
||||
"value": "magistral-medium-2509",
|
||||
"label": "Magistral Medium 1.2",
|
||||
},
|
||||
{
|
||||
"value": "magistral-medium-2507",
|
||||
"label": "Magistral Medium 1.1",
|
||||
},
|
||||
{
|
||||
"value": "mistral-large-2411",
|
||||
"label": "Mistral Large 2.1",
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
vector_db={
|
||||
|
|
|
|||
2202
poetry.lock
generated
2202
poetry.lock
generated
File diff suppressed because it is too large
Load diff
|
|
@ -54,7 +54,8 @@ dependencies = [
|
|||
"networkx>=3.4.2,<4",
|
||||
"uvicorn>=0.34.0,<1.0.0",
|
||||
"gunicorn>=20.1.0,<24",
|
||||
"websockets>=15.0.1,<16.0.0"
|
||||
"websockets>=15.0.1,<16.0.0",
|
||||
"mistralai>=1.9.10",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
|
@ -147,6 +148,7 @@ Homepage = "https://www.cognee.ai"
|
|||
Repository = "https://github.com/topoteretes/cognee"
|
||||
|
||||
[project.scripts]
|
||||
cognee = "cognee.cli._cognee:main"
|
||||
cognee-cli = "cognee.cli._cognee:main"
|
||||
|
||||
[build-system]
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue