Improve Azure OpenAI wrapper functions with full parameter support
• Add missing parameters to wrappers • Update docstrings for clarity • Ensure API consistency • Fix parameter forwarding • Maintain backward compatibility
This commit is contained in:
parent
45f4f82392
commit
ac9f2574a5
1 changed files with 48 additions and 13 deletions
|
|
@ -205,23 +205,33 @@ async def openai_complete_if_cache(
|
|||
6. For non-streaming: COT content is prepended to regular content with <think> tags.
|
||||
|
||||
Args:
|
||||
model: The OpenAI model to use.
|
||||
model: The OpenAI model to use. For Azure, this can be the deployment name.
|
||||
prompt: The prompt to complete.
|
||||
system_prompt: Optional system prompt to include.
|
||||
history_messages: Optional list of previous messages in the conversation.
|
||||
base_url: Optional base URL for the OpenAI API.
|
||||
api_key: Optional OpenAI API key. If None, uses the OPENAI_API_KEY environment variable.
|
||||
token_tracker: Optional token usage tracker for monitoring API usage.
|
||||
enable_cot: Whether to enable Chain of Thought (COT) processing. Default is False.
|
||||
base_url: Optional base URL for the OpenAI API. For Azure, this should be the
|
||||
Azure OpenAI endpoint (e.g., https://your-resource.openai.azure.com/).
|
||||
api_key: Optional API key. For standard OpenAI, uses OPENAI_API_KEY environment
|
||||
variable if None. For Azure, uses AZURE_OPENAI_API_KEY if None.
|
||||
token_tracker: Optional token usage tracker for monitoring API usage.
|
||||
stream: Whether to stream the response. Default is False.
|
||||
timeout: Request timeout in seconds. Default is None.
|
||||
keyword_extraction: Whether to enable keyword extraction mode. When True, triggers
|
||||
special response formatting for keyword extraction. Default is False.
|
||||
use_azure: Whether to use Azure OpenAI service instead of standard OpenAI.
|
||||
When True, creates an AsyncAzureOpenAI client. Default is False.
|
||||
azure_deployment: Azure OpenAI deployment name. Only used when use_azure=True.
|
||||
If not specified, falls back to AZURE_OPENAI_DEPLOYMENT environment variable.
|
||||
api_version: Azure OpenAI API version (e.g., "2024-02-15-preview"). Only used
|
||||
when use_azure=True. If not specified, falls back to AZURE_OPENAI_API_VERSION
|
||||
environment variable.
|
||||
**kwargs: Additional keyword arguments to pass to the OpenAI API.
|
||||
Special kwargs:
|
||||
- openai_client_configs: Dict of configuration options for the AsyncOpenAI client.
|
||||
These will be passed to the client constructor but will be overridden by
|
||||
explicit parameters (api_key, base_url).
|
||||
explicit parameters (api_key, base_url). Supports proxy configuration,
|
||||
custom headers, retry policies, etc.
|
||||
|
||||
Returns:
|
||||
The completed text (with integrated COT content if available) or an async iterator
|
||||
|
|
@ -684,21 +694,34 @@ async def openai_embed(
|
|||
) -> np.ndarray:
|
||||
"""Generate embeddings for a list of texts using OpenAI's API.
|
||||
|
||||
This function supports both standard OpenAI and Azure OpenAI services.
|
||||
|
||||
Args:
|
||||
texts: List of texts to embed.
|
||||
model: The OpenAI embedding model to use.
|
||||
base_url: Optional base URL for the OpenAI API.
|
||||
api_key: Optional OpenAI API key. If None, uses the OPENAI_API_KEY environment variable.
|
||||
model: The embedding model to use. For standard OpenAI (e.g., "text-embedding-3-small").
|
||||
For Azure, this can be the deployment name.
|
||||
base_url: Optional base URL for the API. For standard OpenAI, uses default OpenAI endpoint.
|
||||
For Azure, this should be the Azure OpenAI endpoint (e.g., https://your-resource.openai.azure.com/).
|
||||
api_key: Optional API key. For standard OpenAI, uses OPENAI_API_KEY environment variable if None.
|
||||
For Azure, uses AZURE_EMBEDDING_API_KEY environment variable if None.
|
||||
embedding_dim: Optional embedding dimension for dynamic dimension reduction.
|
||||
**IMPORTANT**: This parameter is automatically injected by the EmbeddingFunc wrapper.
|
||||
Do NOT manually pass this parameter when calling the function directly.
|
||||
The dimension is controlled by the @wrap_embedding_func_with_attrs decorator.
|
||||
Manually passing a different value will trigger a warning and be ignored.
|
||||
When provided (by EmbeddingFunc), it will be passed to the OpenAI API for dimension reduction.
|
||||
client_configs: Additional configuration options for the AsyncOpenAI client.
|
||||
client_configs: Additional configuration options for the AsyncOpenAI/AsyncAzureOpenAI client.
|
||||
These will override any default configurations but will be overridden by
|
||||
explicit parameters (api_key, base_url).
|
||||
explicit parameters (api_key, base_url). Supports proxy configuration,
|
||||
custom headers, retry policies, etc.
|
||||
token_tracker: Optional token usage tracker for monitoring API usage.
|
||||
use_azure: Whether to use Azure OpenAI service instead of standard OpenAI.
|
||||
When True, creates an AsyncAzureOpenAI client. Default is False.
|
||||
azure_deployment: Azure OpenAI deployment name. Only used when use_azure=True.
|
||||
If not specified, falls back to AZURE_EMBEDDING_DEPLOYMENT environment variable.
|
||||
api_version: Azure OpenAI API version (e.g., "2024-02-15-preview"). Only used
|
||||
when use_azure=True. If not specified, falls back to AZURE_EMBEDDING_API_VERSION
|
||||
environment variable.
|
||||
|
||||
Returns:
|
||||
A numpy array of embeddings, one per input text.
|
||||
|
|
@ -759,6 +782,9 @@ async def azure_openai_complete_if_cache(
|
|||
enable_cot: bool = False,
|
||||
base_url: str | None = None,
|
||||
api_key: str | None = None,
|
||||
token_tracker: Any | None = None,
|
||||
stream: bool | None = None,
|
||||
timeout: int | None = None,
|
||||
api_version: str | None = None,
|
||||
keyword_extraction: bool = False,
|
||||
**kwargs,
|
||||
|
|
@ -767,6 +793,9 @@ async def azure_openai_complete_if_cache(
|
|||
|
||||
This function provides backward compatibility by wrapping the unified
|
||||
openai_complete_if_cache implementation with Azure-specific parameter handling.
|
||||
|
||||
All parameters from the underlying openai_complete_if_cache are exposed to ensure
|
||||
full feature parity and API consistency.
|
||||
"""
|
||||
# Handle Azure-specific environment variables and parameters
|
||||
deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT") or model or os.getenv("LLM_MODEL")
|
||||
|
|
@ -782,9 +811,6 @@ async def azure_openai_complete_if_cache(
|
|||
or os.getenv("OPENAI_API_VERSION")
|
||||
)
|
||||
|
||||
# Pop timeout from kwargs if present (will be handled by openai_complete_if_cache)
|
||||
timeout = kwargs.pop("timeout", None)
|
||||
|
||||
# Call the unified implementation with Azure-specific parameters
|
||||
return await openai_complete_if_cache(
|
||||
model=model,
|
||||
|
|
@ -794,6 +820,8 @@ async def azure_openai_complete_if_cache(
|
|||
enable_cot=enable_cot,
|
||||
base_url=base_url,
|
||||
api_key=api_key,
|
||||
token_tracker=token_tracker,
|
||||
stream=stream,
|
||||
timeout=timeout,
|
||||
use_azure=True,
|
||||
azure_deployment=deployment,
|
||||
|
|
@ -833,6 +861,8 @@ async def azure_openai_embed(
|
|||
model: str | None = None,
|
||||
base_url: str | None = None,
|
||||
api_key: str | None = None,
|
||||
token_tracker: Any | None = None,
|
||||
client_configs: dict[str, Any] | None = None,
|
||||
api_version: str | None = None,
|
||||
) -> np.ndarray:
|
||||
"""Azure OpenAI embedding wrapper function.
|
||||
|
|
@ -840,6 +870,9 @@ async def azure_openai_embed(
|
|||
This function provides backward compatibility by wrapping the unified
|
||||
openai_embed implementation with Azure-specific parameter handling.
|
||||
|
||||
All parameters from the underlying openai_embed are exposed to ensure
|
||||
full feature parity and API consistency.
|
||||
|
||||
IMPORTANT - Decorator Usage:
|
||||
|
||||
1. This function is decorated with @wrap_embedding_func_with_attrs to provide
|
||||
|
|
@ -898,6 +931,8 @@ async def azure_openai_embed(
|
|||
model=model or deployment,
|
||||
base_url=base_url,
|
||||
api_key=api_key,
|
||||
token_tracker=token_tracker,
|
||||
client_configs=client_configs,
|
||||
use_azure=True,
|
||||
azure_deployment=deployment,
|
||||
api_version=api_version,
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue