From 65d26c6476b4928c29d55b322fb25914513f5d45 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20MANSUY?= Date: Thu, 4 Dec 2025 19:14:32 +0800 Subject: [PATCH] cherry-pick ac9f2574 --- lightrag/llm/openai.py | 322 ++++++++++++++++++++++++++++++++++++----- 1 file changed, 284 insertions(+), 38 deletions(-) diff --git a/lightrag/llm/openai.py b/lightrag/llm/openai.py index cea85b04..a314d597 100644 --- a/lightrag/llm/openai.py +++ b/lightrag/llm/openai.py @@ -77,46 +77,86 @@ class InvalidResponseError(Exception): def create_openai_async_client( api_key: str | None = None, base_url: str | None = None, + use_azure: bool = False, + azure_deployment: str | None = None, + api_version: str | None = None, + timeout: int | None = None, client_configs: dict[str, Any] | None = None, ) -> AsyncOpenAI: - """Create an AsyncOpenAI client with the given configuration. + """Create an AsyncOpenAI or AsyncAzureOpenAI client with the given configuration. Args: api_key: OpenAI API key. If None, uses the OPENAI_API_KEY environment variable. base_url: Base URL for the OpenAI API. If None, uses the default OpenAI API URL. + use_azure: Whether to create an Azure OpenAI client. Default is False. + azure_deployment: Azure OpenAI deployment name (only used when use_azure=True). + api_version: Azure OpenAI API version (only used when use_azure=True). + timeout: Request timeout in seconds. client_configs: Additional configuration options for the AsyncOpenAI client. These will override any default configurations but will be overridden by explicit parameters (api_key, base_url). Returns: - An AsyncOpenAI client instance. + An AsyncOpenAI or AsyncAzureOpenAI client instance. """ - if not api_key: - api_key = os.environ["OPENAI_API_KEY"] + if use_azure: + from openai import AsyncAzureOpenAI - default_headers = { - "User-Agent": f"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}", - "Content-Type": "application/json", - } + if not api_key: + api_key = os.environ.get("AZURE_OPENAI_API_KEY") or os.environ.get( + "LLM_BINDING_API_KEY" + ) - if client_configs is None: - client_configs = {} + if client_configs is None: + client_configs = {} - # Create a merged config dict with precedence: explicit params > client_configs > defaults - merged_configs = { - **client_configs, - "default_headers": default_headers, - "api_key": api_key, - } + # Create a merged config dict with precedence: explicit params > client_configs + merged_configs = { + **client_configs, + "api_key": api_key, + } - if base_url is not None: - merged_configs["base_url"] = base_url + # Add explicit parameters (override client_configs) + if base_url is not None: + merged_configs["azure_endpoint"] = base_url + if azure_deployment is not None: + merged_configs["azure_deployment"] = azure_deployment + if api_version is not None: + merged_configs["api_version"] = api_version + if timeout is not None: + merged_configs["timeout"] = timeout + + return AsyncAzureOpenAI(**merged_configs) else: - merged_configs["base_url"] = os.environ.get( - "OPENAI_API_BASE", "https://api.openai.com/v1" - ) + if not api_key: + api_key = os.environ["OPENAI_API_KEY"] - return AsyncOpenAI(**merged_configs) + default_headers = { + "User-Agent": f"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}", + "Content-Type": "application/json", + } + + if client_configs is None: + client_configs = {} + + # Create a merged config dict with precedence: explicit params > client_configs > defaults + merged_configs = { + **client_configs, + "default_headers": default_headers, + "api_key": api_key, + } + + if base_url is not None: + merged_configs["base_url"] = base_url + else: + merged_configs["base_url"] = os.environ.get( + "OPENAI_API_BASE", "https://api.openai.com/v1" + ) + + if timeout is not None: + merged_configs["timeout"] = timeout + + return AsyncOpenAI(**merged_configs) @retry( @@ -141,6 +181,9 @@ async def openai_complete_if_cache( stream: bool | None = None, timeout: int | None = None, keyword_extraction: bool = False, + use_azure: bool = False, + azure_deployment: str | None = None, + api_version: str | None = None, **kwargs: Any, ) -> str: """Complete a prompt using OpenAI's API with caching support and Chain of Thought (COT) integration. @@ -162,23 +205,33 @@ async def openai_complete_if_cache( 6. For non-streaming: COT content is prepended to regular content with tags. Args: - model: The OpenAI model to use. + model: The OpenAI model to use. For Azure, this can be the deployment name. prompt: The prompt to complete. system_prompt: Optional system prompt to include. history_messages: Optional list of previous messages in the conversation. - base_url: Optional base URL for the OpenAI API. - api_key: Optional OpenAI API key. If None, uses the OPENAI_API_KEY environment variable. - token_tracker: Optional token usage tracker for monitoring API usage. enable_cot: Whether to enable Chain of Thought (COT) processing. Default is False. + base_url: Optional base URL for the OpenAI API. For Azure, this should be the + Azure OpenAI endpoint (e.g., https://your-resource.openai.azure.com/). + api_key: Optional API key. For standard OpenAI, uses OPENAI_API_KEY environment + variable if None. For Azure, uses AZURE_OPENAI_API_KEY if None. + token_tracker: Optional token usage tracker for monitoring API usage. stream: Whether to stream the response. Default is False. timeout: Request timeout in seconds. Default is None. keyword_extraction: Whether to enable keyword extraction mode. When True, triggers special response formatting for keyword extraction. Default is False. + use_azure: Whether to use Azure OpenAI service instead of standard OpenAI. + When True, creates an AsyncAzureOpenAI client. Default is False. + azure_deployment: Azure OpenAI deployment name. Only used when use_azure=True. + If not specified, falls back to AZURE_OPENAI_DEPLOYMENT environment variable. + api_version: Azure OpenAI API version (e.g., "2024-02-15-preview"). Only used + when use_azure=True. If not specified, falls back to AZURE_OPENAI_API_VERSION + environment variable. **kwargs: Additional keyword arguments to pass to the OpenAI API. Special kwargs: - openai_client_configs: Dict of configuration options for the AsyncOpenAI client. These will be passed to the client constructor but will be overridden by - explicit parameters (api_key, base_url). + explicit parameters (api_key, base_url). Supports proxy configuration, + custom headers, retry policies, etc. Returns: The completed text (with integrated COT content if available) or an async iterator @@ -207,10 +260,14 @@ async def openai_complete_if_cache( if keyword_extraction: kwargs["response_format"] = GPTKeywordExtractionFormat - # Create the OpenAI client + # Create the OpenAI client (supports both OpenAI and Azure) openai_async_client = create_openai_async_client( api_key=api_key, base_url=base_url, + use_azure=use_azure, + azure_deployment=azure_deployment, + api_version=api_version, + timeout=timeout, client_configs=client_configs, ) @@ -241,7 +298,7 @@ async def openai_complete_if_cache( try: # Don't use async with context manager, use client directly if "response_format" in kwargs: - response = await openai_async_client.beta.chat.completions.parse( + response = await openai_async_client.chat.completions.parse( model=model, messages=messages, **kwargs ) else: @@ -453,7 +510,7 @@ async def openai_complete_if_cache( raise InvalidResponseError("Invalid response from OpenAI API") message = response.choices[0].message - + # Handle parsed responses (structured output via response_format) # When using beta.chat.completions.parse(), the response is in message.parsed if hasattr(message, "parsed") and message.parsed is not None: @@ -492,7 +549,9 @@ async def openai_complete_if_cache( reasoning_content = safe_unicode_decode( reasoning_content.encode("utf-8") ) - final_content = f"{reasoning_content}{final_content}" + final_content = ( + f"{reasoning_content}{final_content}" + ) else: # COT disabled, only use regular content final_content = content or "" @@ -629,24 +688,40 @@ async def openai_embed( embedding_dim: int | None = None, client_configs: dict[str, Any] | None = None, token_tracker: Any | None = None, + use_azure: bool = False, + azure_deployment: str | None = None, + api_version: str | None = None, ) -> np.ndarray: """Generate embeddings for a list of texts using OpenAI's API. + This function supports both standard OpenAI and Azure OpenAI services. + Args: texts: List of texts to embed. - model: The OpenAI embedding model to use. - base_url: Optional base URL for the OpenAI API. - api_key: Optional OpenAI API key. If None, uses the OPENAI_API_KEY environment variable. + model: The embedding model to use. For standard OpenAI (e.g., "text-embedding-3-small"). + For Azure, this can be the deployment name. + base_url: Optional base URL for the API. For standard OpenAI, uses default OpenAI endpoint. + For Azure, this should be the Azure OpenAI endpoint (e.g., https://your-resource.openai.azure.com/). + api_key: Optional API key. For standard OpenAI, uses OPENAI_API_KEY environment variable if None. + For Azure, uses AZURE_EMBEDDING_API_KEY environment variable if None. embedding_dim: Optional embedding dimension for dynamic dimension reduction. **IMPORTANT**: This parameter is automatically injected by the EmbeddingFunc wrapper. Do NOT manually pass this parameter when calling the function directly. The dimension is controlled by the @wrap_embedding_func_with_attrs decorator. Manually passing a different value will trigger a warning and be ignored. When provided (by EmbeddingFunc), it will be passed to the OpenAI API for dimension reduction. - client_configs: Additional configuration options for the AsyncOpenAI client. + client_configs: Additional configuration options for the AsyncOpenAI/AsyncAzureOpenAI client. These will override any default configurations but will be overridden by - explicit parameters (api_key, base_url). + explicit parameters (api_key, base_url). Supports proxy configuration, + custom headers, retry policies, etc. token_tracker: Optional token usage tracker for monitoring API usage. + use_azure: Whether to use Azure OpenAI service instead of standard OpenAI. + When True, creates an AsyncAzureOpenAI client. Default is False. + azure_deployment: Azure OpenAI deployment name. Only used when use_azure=True. + If not specified, falls back to AZURE_EMBEDDING_DEPLOYMENT environment variable. + api_version: Azure OpenAI API version (e.g., "2024-02-15-preview"). Only used + when use_azure=True. If not specified, falls back to AZURE_EMBEDDING_API_VERSION + environment variable. Returns: A numpy array of embeddings, one per input text. @@ -656,9 +731,14 @@ async def openai_embed( RateLimitError: If the OpenAI API rate limit is exceeded. APITimeoutError: If the OpenAI API request times out. """ - # Create the OpenAI client + # Create the OpenAI client (supports both OpenAI and Azure) openai_async_client = create_openai_async_client( - api_key=api_key, base_url=base_url, client_configs=client_configs + api_key=api_key, + base_url=base_url, + use_azure=use_azure, + azure_deployment=azure_deployment, + api_version=api_version, + client_configs=client_configs, ) async with openai_async_client: @@ -691,3 +771,169 @@ async def openai_embed( for dp in response.data ] ) + + +# Azure OpenAI wrapper functions for backward compatibility +async def azure_openai_complete_if_cache( + model, + prompt, + system_prompt: str | None = None, + history_messages: list[dict[str, Any]] | None = None, + enable_cot: bool = False, + base_url: str | None = None, + api_key: str | None = None, + token_tracker: Any | None = None, + stream: bool | None = None, + timeout: int | None = None, + api_version: str | None = None, + keyword_extraction: bool = False, + **kwargs, +): + """Azure OpenAI completion wrapper function. + + This function provides backward compatibility by wrapping the unified + openai_complete_if_cache implementation with Azure-specific parameter handling. + + All parameters from the underlying openai_complete_if_cache are exposed to ensure + full feature parity and API consistency. + """ + # Handle Azure-specific environment variables and parameters + deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT") or model or os.getenv("LLM_MODEL") + base_url = ( + base_url or os.getenv("AZURE_OPENAI_ENDPOINT") or os.getenv("LLM_BINDING_HOST") + ) + api_key = ( + api_key or os.getenv("AZURE_OPENAI_API_KEY") or os.getenv("LLM_BINDING_API_KEY") + ) + api_version = ( + api_version + or os.getenv("AZURE_OPENAI_API_VERSION") + or os.getenv("OPENAI_API_VERSION") + ) + + # Call the unified implementation with Azure-specific parameters + return await openai_complete_if_cache( + model=model, + prompt=prompt, + system_prompt=system_prompt, + history_messages=history_messages, + enable_cot=enable_cot, + base_url=base_url, + api_key=api_key, + token_tracker=token_tracker, + stream=stream, + timeout=timeout, + use_azure=True, + azure_deployment=deployment, + api_version=api_version, + keyword_extraction=keyword_extraction, + **kwargs, + ) + + +async def azure_openai_complete( + prompt, + system_prompt=None, + history_messages=None, + keyword_extraction=False, + **kwargs, +) -> str: + """Azure OpenAI complete wrapper function. + + Provides backward compatibility for azure_openai_complete calls. + """ + if history_messages is None: + history_messages = [] + result = await azure_openai_complete_if_cache( + os.getenv("LLM_MODEL", "gpt-4o-mini"), + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + keyword_extraction=keyword_extraction, + **kwargs, + ) + return result + + +@wrap_embedding_func_with_attrs(embedding_dim=1536) +async def azure_openai_embed( + texts: list[str], + model: str | None = None, + base_url: str | None = None, + api_key: str | None = None, + token_tracker: Any | None = None, + client_configs: dict[str, Any] | None = None, + api_version: str | None = None, +) -> np.ndarray: + """Azure OpenAI embedding wrapper function. + + This function provides backward compatibility by wrapping the unified + openai_embed implementation with Azure-specific parameter handling. + + All parameters from the underlying openai_embed are exposed to ensure + full feature parity and API consistency. + + IMPORTANT - Decorator Usage: + + 1. This function is decorated with @wrap_embedding_func_with_attrs to provide + the EmbeddingFunc interface for users who need to access embedding_dim + and other attributes. + + 2. This function does NOT use @retry decorator to avoid double-wrapping, + since the underlying openai_embed.func already has retry logic. + + 3. This function calls openai_embed.func (the unwrapped function) instead of + openai_embed (the EmbeddingFunc instance) to avoid double decoration issues: + + ✅ Correct: await openai_embed.func(...) # Calls unwrapped function with retry + ❌ Wrong: await openai_embed(...) # Would cause double EmbeddingFunc wrapping + + Double decoration causes: + - Double injection of embedding_dim parameter + - Incorrect parameter passing to the underlying implementation + - Runtime errors due to parameter conflicts + + The call chain with correct implementation: + azure_openai_embed(texts) + → EmbeddingFunc.__call__(texts) # azure's decorator + → azure_openai_embed_impl(texts, embedding_dim=1536) + → openai_embed.func(texts, ...) + → @retry_wrapper(texts, ...) # openai's retry (only one layer) + → openai_embed_impl(texts, ...) + → actual embedding computation + """ + # Handle Azure-specific environment variables and parameters + deployment = ( + os.getenv("AZURE_EMBEDDING_DEPLOYMENT") + or model + or os.getenv("EMBEDDING_MODEL", "text-embedding-3-small") + ) + base_url = ( + base_url + or os.getenv("AZURE_EMBEDDING_ENDPOINT") + or os.getenv("EMBEDDING_BINDING_HOST") + ) + api_key = ( + api_key + or os.getenv("AZURE_EMBEDDING_API_KEY") + or os.getenv("EMBEDDING_BINDING_API_KEY") + ) + api_version = ( + api_version + or os.getenv("AZURE_EMBEDDING_API_VERSION") + or os.getenv("OPENAI_API_VERSION") + ) + + # CRITICAL: Call openai_embed.func (unwrapped) to avoid double decoration + # openai_embed is an EmbeddingFunc instance, .func accesses the underlying function + return await openai_embed.func( + texts=texts, + model=model or deployment, + base_url=base_url, + api_key=api_key, + token_tracker=token_tracker, + client_configs=client_configs, + use_azure=True, + azure_deployment=deployment, + api_version=api_version, + )