diff --git a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py index a9cf68d6e..dbf0322b1 100644 --- a/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py +++ b/cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py @@ -66,23 +66,39 @@ class MistralAdapter(LLMInterface): -------- - BaseModel: An instance of BaseModel containing the structured response. """ - return await self.aclient.chat.completions.create( - model=self.model, - max_tokens=self.max_completion_tokens, - max_retries=5, - messages=[ - { - "role": "system", - "content": system_prompt, - }, - { - "role": "user", - "content": f"""Use the given format to extract information - from the following input: {text_input}""", - }, - ], - response_model=response_model, - ) + try: + messages = [ + { + "role": "system", + "content": system_prompt, + }, + { + "role": "user", + "content": f"""Use the given format to extract information + from the following input: {text_input}""", + }, + ] + try: + response = await self.aclient.chat.completions.create( + model=self.model, + max_tokens=self.max_completion_tokens, + max_retries=5, + messages=messages, + response_model=response_model, + ) + if response.choices and response.choices[0].message.content: + content = response.choices[0].message.content + return response_model.model_validate_json(content) + else: + raise ValueError("Failed to get valid response after retries") + except litellm.exceptions.BadRequestError as e: + logger.error(f"Bad request error: {str(e)}") + raise ValueError(f"Invalid request: {str(e)}") + + except JSONSchemaValidationError as e: + logger.error(f"Schema validation failed: {str(e)}") + logger.debug(f"Raw response: {e.raw_response}") + raise ValueError(f"Response failed schema validation: {str(e)}") def show_prompt(self, text_input: str, system_prompt: str) -> str: """ diff --git a/pyproject.toml b/pyproject.toml index aa5e693c1..0e754914b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,7 +55,7 @@ dependencies = [ "uvicorn>=0.34.0,<1.0.0", "gunicorn>=20.1.0,<24", "websockets>=15.0.1,<16.0.0", - "mistralai>=1.9.10", + ] [project.optional-dependencies] @@ -92,7 +92,7 @@ langchain = [ llama-index = ["llama-index-core>=0.12.11,<0.13"] huggingface = ["transformers>=4.46.3,<5"] ollama = ["transformers>=4.46.3,<5"] -mistral = ["mistral-common>=1.5.2,<2"] +mistral = ["mistral-common>=1.5.2,<2","mistralai>=1.9.10"] anthropic = ["anthropic>=0.27"] deepeval = ["deepeval>=3.0.1,<4"] posthog = ["posthog>=3.5.0,<4"]