diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index 9f5457224..039365a19 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -28,7 +28,7 @@ import json_repair import litellm import openai from openai import AsyncOpenAI, OpenAI -from openai.lib.azure import AzureOpenAI +from openai.lib.azure import AzureOpenAI, AsyncAzureOpenAI from strenum import StrEnum from common.token_utils import num_tokens_from_string, total_token_count_from_response @@ -535,6 +535,7 @@ class AzureChat(Base): api_version = json.loads(key).get("api_version", "2024-02-01") super().__init__(key, model_name, base_url, **kwargs) self.client = AzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version) + self.async_client = AsyncOpenAI(api_key=key, base_url=base_url, api_version=api_version) self.model_name = model_name @property