Enhance MistralAdapter error handling and response validation in chat completion method

This commit is contained in:
Aniruddha Mandal 2025-09-24 10:40:17 +05:30 committed by vasilije
parent 617c1f0d71
commit 06dca5bf26
2 changed files with 35 additions and 19 deletions

View file

@ -66,23 +66,39 @@ class MistralAdapter(LLMInterface):
-------- --------
- BaseModel: An instance of BaseModel containing the structured response. - BaseModel: An instance of BaseModel containing the structured response.
""" """
return await self.aclient.chat.completions.create( try:
model=self.model, messages = [
max_tokens=self.max_completion_tokens, {
max_retries=5, "role": "system",
messages=[ "content": system_prompt,
{ },
"role": "system", {
"content": system_prompt, "role": "user",
}, "content": f"""Use the given format to extract information
{ from the following input: {text_input}""",
"role": "user", },
"content": f"""Use the given format to extract information ]
from the following input: {text_input}""", try:
}, response = await self.aclient.chat.completions.create(
], model=self.model,
response_model=response_model, max_tokens=self.max_completion_tokens,
) max_retries=5,
messages=messages,
response_model=response_model,
)
if response.choices and response.choices[0].message.content:
content = response.choices[0].message.content
return response_model.model_validate_json(content)
else:
raise ValueError("Failed to get valid response after retries")
except litellm.exceptions.BadRequestError as e:
logger.error(f"Bad request error: {str(e)}")
raise ValueError(f"Invalid request: {str(e)}")
except JSONSchemaValidationError as e:
logger.error(f"Schema validation failed: {str(e)}")
logger.debug(f"Raw response: {e.raw_response}")
raise ValueError(f"Response failed schema validation: {str(e)}")
def show_prompt(self, text_input: str, system_prompt: str) -> str: def show_prompt(self, text_input: str, system_prompt: str) -> str:
""" """

View file

@ -55,7 +55,7 @@ dependencies = [
"uvicorn>=0.34.0,<1.0.0", "uvicorn>=0.34.0,<1.0.0",
"gunicorn>=20.1.0,<24", "gunicorn>=20.1.0,<24",
"websockets>=15.0.1,<16.0.0", "websockets>=15.0.1,<16.0.0",
"mistralai>=1.9.10",
] ]
[project.optional-dependencies] [project.optional-dependencies]
@ -92,7 +92,7 @@ langchain = [
llama-index = ["llama-index-core>=0.12.11,<0.13"] llama-index = ["llama-index-core>=0.12.11,<0.13"]
huggingface = ["transformers>=4.46.3,<5"] huggingface = ["transformers>=4.46.3,<5"]
ollama = ["transformers>=4.46.3,<5"] ollama = ["transformers>=4.46.3,<5"]
mistral = ["mistral-common>=1.5.2,<2"] mistral = ["mistral-common>=1.5.2,<2","mistralai>=1.9.10"]
anthropic = ["anthropic>=0.27"] anthropic = ["anthropic>=0.27"]
deepeval = ["deepeval>=3.0.1,<4"] deepeval = ["deepeval>=3.0.1,<4"]
posthog = ["posthog>=3.5.0,<4"] posthog = ["posthog>=3.5.0,<4"]