Enhance MistralAdapter error handling and response validation in chat completion method
This commit is contained in:
parent
617c1f0d71
commit
06dca5bf26
2 changed files with 35 additions and 19 deletions
|
|
@ -66,23 +66,39 @@ class MistralAdapter(LLMInterface):
|
||||||
--------
|
--------
|
||||||
- BaseModel: An instance of BaseModel containing the structured response.
|
- BaseModel: An instance of BaseModel containing the structured response.
|
||||||
"""
|
"""
|
||||||
return await self.aclient.chat.completions.create(
|
try:
|
||||||
model=self.model,
|
messages = [
|
||||||
max_tokens=self.max_completion_tokens,
|
{
|
||||||
max_retries=5,
|
"role": "system",
|
||||||
messages=[
|
"content": system_prompt,
|
||||||
{
|
},
|
||||||
"role": "system",
|
{
|
||||||
"content": system_prompt,
|
"role": "user",
|
||||||
},
|
"content": f"""Use the given format to extract information
|
||||||
{
|
from the following input: {text_input}""",
|
||||||
"role": "user",
|
},
|
||||||
"content": f"""Use the given format to extract information
|
]
|
||||||
from the following input: {text_input}""",
|
try:
|
||||||
},
|
response = await self.aclient.chat.completions.create(
|
||||||
],
|
model=self.model,
|
||||||
response_model=response_model,
|
max_tokens=self.max_completion_tokens,
|
||||||
)
|
max_retries=5,
|
||||||
|
messages=messages,
|
||||||
|
response_model=response_model,
|
||||||
|
)
|
||||||
|
if response.choices and response.choices[0].message.content:
|
||||||
|
content = response.choices[0].message.content
|
||||||
|
return response_model.model_validate_json(content)
|
||||||
|
else:
|
||||||
|
raise ValueError("Failed to get valid response after retries")
|
||||||
|
except litellm.exceptions.BadRequestError as e:
|
||||||
|
logger.error(f"Bad request error: {str(e)}")
|
||||||
|
raise ValueError(f"Invalid request: {str(e)}")
|
||||||
|
|
||||||
|
except JSONSchemaValidationError as e:
|
||||||
|
logger.error(f"Schema validation failed: {str(e)}")
|
||||||
|
logger.debug(f"Raw response: {e.raw_response}")
|
||||||
|
raise ValueError(f"Response failed schema validation: {str(e)}")
|
||||||
|
|
||||||
def show_prompt(self, text_input: str, system_prompt: str) -> str:
|
def show_prompt(self, text_input: str, system_prompt: str) -> str:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -55,7 +55,7 @@ dependencies = [
|
||||||
"uvicorn>=0.34.0,<1.0.0",
|
"uvicorn>=0.34.0,<1.0.0",
|
||||||
"gunicorn>=20.1.0,<24",
|
"gunicorn>=20.1.0,<24",
|
||||||
"websockets>=15.0.1,<16.0.0",
|
"websockets>=15.0.1,<16.0.0",
|
||||||
"mistralai>=1.9.10",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
|
|
@ -92,7 +92,7 @@ langchain = [
|
||||||
llama-index = ["llama-index-core>=0.12.11,<0.13"]
|
llama-index = ["llama-index-core>=0.12.11,<0.13"]
|
||||||
huggingface = ["transformers>=4.46.3,<5"]
|
huggingface = ["transformers>=4.46.3,<5"]
|
||||||
ollama = ["transformers>=4.46.3,<5"]
|
ollama = ["transformers>=4.46.3,<5"]
|
||||||
mistral = ["mistral-common>=1.5.2,<2"]
|
mistral = ["mistral-common>=1.5.2,<2","mistralai>=1.9.10"]
|
||||||
anthropic = ["anthropic>=0.27"]
|
anthropic = ["anthropic>=0.27"]
|
||||||
deepeval = ["deepeval>=3.0.1,<4"]
|
deepeval = ["deepeval>=3.0.1,<4"]
|
||||||
posthog = ["posthog>=3.5.0,<4"]
|
posthog = ["posthog>=3.5.0,<4"]
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue