cherry-pick fafa1791
This commit is contained in:
parent
fd486c287a
commit
f16de69415
1 changed files with 13 additions and 5 deletions
|
|
@ -272,15 +272,19 @@ async def openai_complete_if_cache(
|
||||||
if timeout is not None:
|
if timeout is not None:
|
||||||
kwargs["timeout"] = timeout
|
kwargs["timeout"] = timeout
|
||||||
|
|
||||||
|
# Determine the correct model identifier to use
|
||||||
|
# For Azure OpenAI, we must use the deployment name instead of the model name
|
||||||
|
api_model = azure_deployment if use_azure and azure_deployment else model
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Don't use async with context manager, use client directly
|
# Don't use async with context manager, use client directly
|
||||||
if "response_format" in kwargs:
|
if "response_format" in kwargs:
|
||||||
response = await openai_async_client.chat.completions.parse(
|
response = await openai_async_client.chat.completions.parse(
|
||||||
model=model, messages=messages, **kwargs
|
model=api_model, messages=messages, **kwargs
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
response = await openai_async_client.chat.completions.create(
|
response = await openai_async_client.chat.completions.create(
|
||||||
model=model, messages=messages, **kwargs
|
model=api_model, messages=messages, **kwargs
|
||||||
)
|
)
|
||||||
except APIConnectionError as e:
|
except APIConnectionError as e:
|
||||||
logger.error(f"OpenAI API Connection Error: {e}")
|
logger.error(f"OpenAI API Connection Error: {e}")
|
||||||
|
|
@ -706,9 +710,13 @@ async def openai_embed(
|
||||||
)
|
)
|
||||||
|
|
||||||
async with openai_async_client:
|
async with openai_async_client:
|
||||||
|
# Determine the correct model identifier to use
|
||||||
|
# For Azure OpenAI, we must use the deployment name instead of the model name
|
||||||
|
api_model = azure_deployment if use_azure and azure_deployment else model
|
||||||
|
|
||||||
# Prepare API call parameters
|
# Prepare API call parameters
|
||||||
api_params = {
|
api_params = {
|
||||||
"model": model,
|
"model": api_model,
|
||||||
"input": texts,
|
"input": texts,
|
||||||
"encoding_format": "base64",
|
"encoding_format": "base64",
|
||||||
}
|
}
|
||||||
|
|
@ -774,7 +782,7 @@ async def azure_openai_complete_if_cache(
|
||||||
|
|
||||||
# Call the unified implementation with Azure-specific parameters
|
# Call the unified implementation with Azure-specific parameters
|
||||||
return await openai_complete_if_cache(
|
return await openai_complete_if_cache(
|
||||||
model=model,
|
model=deployment,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
history_messages=history_messages,
|
history_messages=history_messages,
|
||||||
|
|
@ -855,7 +863,7 @@ async def azure_openai_embed(
|
||||||
# Call the unified implementation with Azure-specific parameters
|
# Call the unified implementation with Azure-specific parameters
|
||||||
return await openai_embed(
|
return await openai_embed(
|
||||||
texts=texts,
|
texts=texts,
|
||||||
model=model or deployment,
|
model=deployment,
|
||||||
base_url=base_url,
|
base_url=base_url,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
use_azure=True,
|
use_azure=True,
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue