Enhance MistralAdapter error handling and response validation in chat completion method
This commit is contained in:
parent
617c1f0d71
commit
06dca5bf26
2 changed files with 35 additions and 19 deletions
|
|
@ -66,23 +66,39 @@ class MistralAdapter(LLMInterface):
|
|||
--------
|
||||
- BaseModel: An instance of BaseModel containing the structured response.
|
||||
"""
|
||||
return await self.aclient.chat.completions.create(
|
||||
model=self.model,
|
||||
max_tokens=self.max_completion_tokens,
|
||||
max_retries=5,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""Use the given format to extract information
|
||||
from the following input: {text_input}""",
|
||||
},
|
||||
],
|
||||
response_model=response_model,
|
||||
)
|
||||
try:
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""Use the given format to extract information
|
||||
from the following input: {text_input}""",
|
||||
},
|
||||
]
|
||||
try:
|
||||
response = await self.aclient.chat.completions.create(
|
||||
model=self.model,
|
||||
max_tokens=self.max_completion_tokens,
|
||||
max_retries=5,
|
||||
messages=messages,
|
||||
response_model=response_model,
|
||||
)
|
||||
if response.choices and response.choices[0].message.content:
|
||||
content = response.choices[0].message.content
|
||||
return response_model.model_validate_json(content)
|
||||
else:
|
||||
raise ValueError("Failed to get valid response after retries")
|
||||
except litellm.exceptions.BadRequestError as e:
|
||||
logger.error(f"Bad request error: {str(e)}")
|
||||
raise ValueError(f"Invalid request: {str(e)}")
|
||||
|
||||
except JSONSchemaValidationError as e:
|
||||
logger.error(f"Schema validation failed: {str(e)}")
|
||||
logger.debug(f"Raw response: {e.raw_response}")
|
||||
raise ValueError(f"Response failed schema validation: {str(e)}")
|
||||
|
||||
def show_prompt(self, text_input: str, system_prompt: str) -> str:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@ dependencies = [
|
|||
"uvicorn>=0.34.0,<1.0.0",
|
||||
"gunicorn>=20.1.0,<24",
|
||||
"websockets>=15.0.1,<16.0.0",
|
||||
"mistralai>=1.9.10",
|
||||
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
|
@ -92,7 +92,7 @@ langchain = [
|
|||
llama-index = ["llama-index-core>=0.12.11,<0.13"]
|
||||
huggingface = ["transformers>=4.46.3,<5"]
|
||||
ollama = ["transformers>=4.46.3,<5"]
|
||||
mistral = ["mistral-common>=1.5.2,<2"]
|
||||
mistral = ["mistral-common>=1.5.2,<2","mistralai>=1.9.10"]
|
||||
anthropic = ["anthropic>=0.27"]
|
||||
deepeval = ["deepeval>=3.0.1,<4"]
|
||||
posthog = ["posthog>=3.5.0,<4"]
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue