From ac9f2574a572604262c7bfa40b911f5029b8aa2d Mon Sep 17 00:00:00 2001 From: yangdx Date: Fri, 21 Nov 2025 19:24:32 +0800 Subject: [PATCH] Improve Azure OpenAI wrapper functions with full parameter support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit • Add missing parameters to wrappers • Update docstrings for clarity • Ensure API consistency • Fix parameter forwarding • Maintain backward compatibility --- lightrag/llm/openai.py | 61 +++++++++++++++++++++++++++++++++--------- 1 file changed, 48 insertions(+), 13 deletions(-) diff --git a/lightrag/llm/openai.py b/lightrag/llm/openai.py index 8e265d2c..a314d597 100644 --- a/lightrag/llm/openai.py +++ b/lightrag/llm/openai.py @@ -205,23 +205,33 @@ async def openai_complete_if_cache( 6. For non-streaming: COT content is prepended to regular content with tags. Args: - model: The OpenAI model to use. + model: The OpenAI model to use. For Azure, this can be the deployment name. prompt: The prompt to complete. system_prompt: Optional system prompt to include. history_messages: Optional list of previous messages in the conversation. - base_url: Optional base URL for the OpenAI API. - api_key: Optional OpenAI API key. If None, uses the OPENAI_API_KEY environment variable. - token_tracker: Optional token usage tracker for monitoring API usage. enable_cot: Whether to enable Chain of Thought (COT) processing. Default is False. + base_url: Optional base URL for the OpenAI API. For Azure, this should be the + Azure OpenAI endpoint (e.g., https://your-resource.openai.azure.com/). + api_key: Optional API key. For standard OpenAI, uses OPENAI_API_KEY environment + variable if None. For Azure, uses AZURE_OPENAI_API_KEY if None. + token_tracker: Optional token usage tracker for monitoring API usage. stream: Whether to stream the response. Default is False. timeout: Request timeout in seconds. Default is None. keyword_extraction: Whether to enable keyword extraction mode. When True, triggers special response formatting for keyword extraction. Default is False. + use_azure: Whether to use Azure OpenAI service instead of standard OpenAI. + When True, creates an AsyncAzureOpenAI client. Default is False. + azure_deployment: Azure OpenAI deployment name. Only used when use_azure=True. + If not specified, falls back to AZURE_OPENAI_DEPLOYMENT environment variable. + api_version: Azure OpenAI API version (e.g., "2024-02-15-preview"). Only used + when use_azure=True. If not specified, falls back to AZURE_OPENAI_API_VERSION + environment variable. **kwargs: Additional keyword arguments to pass to the OpenAI API. Special kwargs: - openai_client_configs: Dict of configuration options for the AsyncOpenAI client. These will be passed to the client constructor but will be overridden by - explicit parameters (api_key, base_url). + explicit parameters (api_key, base_url). Supports proxy configuration, + custom headers, retry policies, etc. Returns: The completed text (with integrated COT content if available) or an async iterator @@ -684,21 +694,34 @@ async def openai_embed( ) -> np.ndarray: """Generate embeddings for a list of texts using OpenAI's API. + This function supports both standard OpenAI and Azure OpenAI services. + Args: texts: List of texts to embed. - model: The OpenAI embedding model to use. - base_url: Optional base URL for the OpenAI API. - api_key: Optional OpenAI API key. If None, uses the OPENAI_API_KEY environment variable. + model: The embedding model to use. For standard OpenAI (e.g., "text-embedding-3-small"). + For Azure, this can be the deployment name. + base_url: Optional base URL for the API. For standard OpenAI, uses default OpenAI endpoint. + For Azure, this should be the Azure OpenAI endpoint (e.g., https://your-resource.openai.azure.com/). + api_key: Optional API key. For standard OpenAI, uses OPENAI_API_KEY environment variable if None. + For Azure, uses AZURE_EMBEDDING_API_KEY environment variable if None. embedding_dim: Optional embedding dimension for dynamic dimension reduction. **IMPORTANT**: This parameter is automatically injected by the EmbeddingFunc wrapper. Do NOT manually pass this parameter when calling the function directly. The dimension is controlled by the @wrap_embedding_func_with_attrs decorator. Manually passing a different value will trigger a warning and be ignored. When provided (by EmbeddingFunc), it will be passed to the OpenAI API for dimension reduction. - client_configs: Additional configuration options for the AsyncOpenAI client. + client_configs: Additional configuration options for the AsyncOpenAI/AsyncAzureOpenAI client. These will override any default configurations but will be overridden by - explicit parameters (api_key, base_url). + explicit parameters (api_key, base_url). Supports proxy configuration, + custom headers, retry policies, etc. token_tracker: Optional token usage tracker for monitoring API usage. + use_azure: Whether to use Azure OpenAI service instead of standard OpenAI. + When True, creates an AsyncAzureOpenAI client. Default is False. + azure_deployment: Azure OpenAI deployment name. Only used when use_azure=True. + If not specified, falls back to AZURE_EMBEDDING_DEPLOYMENT environment variable. + api_version: Azure OpenAI API version (e.g., "2024-02-15-preview"). Only used + when use_azure=True. If not specified, falls back to AZURE_EMBEDDING_API_VERSION + environment variable. Returns: A numpy array of embeddings, one per input text. @@ -759,6 +782,9 @@ async def azure_openai_complete_if_cache( enable_cot: bool = False, base_url: str | None = None, api_key: str | None = None, + token_tracker: Any | None = None, + stream: bool | None = None, + timeout: int | None = None, api_version: str | None = None, keyword_extraction: bool = False, **kwargs, @@ -767,6 +793,9 @@ async def azure_openai_complete_if_cache( This function provides backward compatibility by wrapping the unified openai_complete_if_cache implementation with Azure-specific parameter handling. + + All parameters from the underlying openai_complete_if_cache are exposed to ensure + full feature parity and API consistency. """ # Handle Azure-specific environment variables and parameters deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT") or model or os.getenv("LLM_MODEL") @@ -782,9 +811,6 @@ async def azure_openai_complete_if_cache( or os.getenv("OPENAI_API_VERSION") ) - # Pop timeout from kwargs if present (will be handled by openai_complete_if_cache) - timeout = kwargs.pop("timeout", None) - # Call the unified implementation with Azure-specific parameters return await openai_complete_if_cache( model=model, @@ -794,6 +820,8 @@ async def azure_openai_complete_if_cache( enable_cot=enable_cot, base_url=base_url, api_key=api_key, + token_tracker=token_tracker, + stream=stream, timeout=timeout, use_azure=True, azure_deployment=deployment, @@ -833,6 +861,8 @@ async def azure_openai_embed( model: str | None = None, base_url: str | None = None, api_key: str | None = None, + token_tracker: Any | None = None, + client_configs: dict[str, Any] | None = None, api_version: str | None = None, ) -> np.ndarray: """Azure OpenAI embedding wrapper function. @@ -840,6 +870,9 @@ async def azure_openai_embed( This function provides backward compatibility by wrapping the unified openai_embed implementation with Azure-specific parameter handling. + All parameters from the underlying openai_embed are exposed to ensure + full feature parity and API consistency. + IMPORTANT - Decorator Usage: 1. This function is decorated with @wrap_embedding_func_with_attrs to provide @@ -898,6 +931,8 @@ async def azure_openai_embed( model=model or deployment, base_url=base_url, api_key=api_key, + token_tracker=token_tracker, + client_configs=client_configs, use_azure=True, azure_deployment=deployment, api_version=api_version,