From fd486c287a6936852975804da890bf03f2d6afbc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20MANSUY?= Date: Thu, 4 Dec 2025 19:14:32 +0800 Subject: [PATCH] cherry-pick b709f8f8 --- lightrag/llm/azure_openai.py | 205 +++-------------------------------- lightrag/llm/openai.py | 145 ++++++------------------- 2 files changed, 52 insertions(+), 298 deletions(-) diff --git a/lightrag/llm/azure_openai.py b/lightrag/llm/azure_openai.py index c67bae10..1fc6feef 100644 --- a/lightrag/llm/azure_openai.py +++ b/lightrag/llm/azure_openai.py @@ -1,193 +1,22 @@ -from collections.abc import Iterable -import os -import pipmaster as pm # Pipmaster for dynamic library install +""" +Azure OpenAI compatibility layer. -# install specific modules -if not pm.is_installed("openai"): - pm.install("openai") +This module provides backward compatibility by re-exporting Azure OpenAI functions +from the main openai module where the actual implementation resides. -from openai import ( - AsyncAzureOpenAI, - APIConnectionError, - RateLimitError, - APITimeoutError, -) -from openai.types.chat import ChatCompletionMessageParam +All core logic for both OpenAI and Azure OpenAI now lives in lightrag.llm.openai, +with this module serving as a thin compatibility wrapper for existing code that +imports from lightrag.llm.azure_openai. +""" -from tenacity import ( - retry, - stop_after_attempt, - wait_exponential, - retry_if_exception_type, +from lightrag.llm.openai import ( + azure_openai_complete_if_cache, + azure_openai_complete, + azure_openai_embed, ) -from lightrag.utils import ( - wrap_embedding_func_with_attrs, - safe_unicode_decode, - logger, -) -from lightrag.types import GPTKeywordExtractionFormat - -import numpy as np - - -@retry( - stop=stop_after_attempt(3), - wait=wait_exponential(multiplier=1, min=4, max=10), - retry=retry_if_exception_type( - (RateLimitError, APIConnectionError, APIConnectionError) - ), -) -async def azure_openai_complete_if_cache( - model, - prompt, - system_prompt: str | None = None, - history_messages: Iterable[ChatCompletionMessageParam] | None = None, - enable_cot: bool = False, - base_url: str | None = None, - api_key: str | None = None, - api_version: str | None = None, - keyword_extraction: bool = False, - **kwargs, -): - if enable_cot: - logger.debug( - "enable_cot=True is not supported for the Azure OpenAI API and will be ignored." - ) - deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT") or model or os.getenv("LLM_MODEL") - base_url = ( - base_url or os.getenv("AZURE_OPENAI_ENDPOINT") or os.getenv("LLM_BINDING_HOST") - ) - api_key = ( - api_key or os.getenv("AZURE_OPENAI_API_KEY") or os.getenv("LLM_BINDING_API_KEY") - ) - api_version = ( - api_version - or os.getenv("AZURE_OPENAI_API_VERSION") - or os.getenv("OPENAI_API_VERSION") - ) - - kwargs.pop("hashing_kv", None) - timeout = kwargs.pop("timeout", None) - - # Handle keyword extraction mode - if keyword_extraction: - kwargs["response_format"] = GPTKeywordExtractionFormat - - openai_async_client = AsyncAzureOpenAI( - azure_endpoint=base_url, - azure_deployment=deployment, - api_key=api_key, - api_version=api_version, - timeout=timeout, - ) - messages = [] - if system_prompt: - messages.append({"role": "system", "content": system_prompt}) - if history_messages: - messages.extend(history_messages) - if prompt is not None: - messages.append({"role": "user", "content": prompt}) - - if "response_format" in kwargs: - response = await openai_async_client.beta.chat.completions.parse( - model=model, messages=messages, **kwargs - ) - else: - response = await openai_async_client.chat.completions.create( - model=model, messages=messages, **kwargs - ) - - if hasattr(response, "__aiter__"): - - async def inner(): - async for chunk in response: - if len(chunk.choices) == 0: - continue - content = chunk.choices[0].delta.content - if content is None: - continue - if r"\u" in content: - content = safe_unicode_decode(content.encode("utf-8")) - yield content - - return inner() - else: - message = response.choices[0].message - - # Handle parsed responses (structured output via response_format) - # When using beta.chat.completions.parse(), the response is in message.parsed - if hasattr(message, "parsed") and message.parsed is not None: - # Serialize the parsed structured response to JSON - content = message.parsed.model_dump_json() - logger.debug("Using parsed structured response from API") - else: - # Handle regular content responses - content = message.content - if content and r"\u" in content: - content = safe_unicode_decode(content.encode("utf-8")) - - return content - - -async def azure_openai_complete( - prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs -) -> str: - result = await azure_openai_complete_if_cache( - os.getenv("LLM_MODEL", "gpt-4o-mini"), - prompt, - system_prompt=system_prompt, - history_messages=history_messages, - keyword_extraction=keyword_extraction, - **kwargs, - ) - return result - - -@wrap_embedding_func_with_attrs(embedding_dim=1536) -@retry( - stop=stop_after_attempt(3), - wait=wait_exponential(multiplier=1, min=4, max=10), - retry=retry_if_exception_type( - (RateLimitError, APIConnectionError, APITimeoutError) - ), -) -async def azure_openai_embed( - texts: list[str], - model: str | None = None, - base_url: str | None = None, - api_key: str | None = None, - api_version: str | None = None, -) -> np.ndarray: - deployment = ( - os.getenv("AZURE_EMBEDDING_DEPLOYMENT") - or model - or os.getenv("EMBEDDING_MODEL", "text-embedding-3-small") - ) - base_url = ( - base_url - or os.getenv("AZURE_EMBEDDING_ENDPOINT") - or os.getenv("EMBEDDING_BINDING_HOST") - ) - api_key = ( - api_key - or os.getenv("AZURE_EMBEDDING_API_KEY") - or os.getenv("EMBEDDING_BINDING_API_KEY") - ) - api_version = ( - api_version - or os.getenv("AZURE_EMBEDDING_API_VERSION") - or os.getenv("OPENAI_API_VERSION") - ) - - openai_async_client = AsyncAzureOpenAI( - azure_endpoint=base_url, - azure_deployment=deployment, - api_key=api_key, - api_version=api_version, - ) - - response = await openai_async_client.embeddings.create( - model=model or deployment, input=texts, encoding_format="float" - ) - return np.array([dp.embedding for dp in response.data]) +__all__ = [ + "azure_openai_complete_if_cache", + "azure_openai_complete", + "azure_openai_embed", +] diff --git a/lightrag/llm/openai.py b/lightrag/llm/openai.py index a314d597..829cf736 100644 --- a/lightrag/llm/openai.py +++ b/lightrag/llm/openai.py @@ -107,26 +107,13 @@ def create_openai_async_client( "LLM_BINDING_API_KEY" ) - if client_configs is None: - client_configs = {} - - # Create a merged config dict with precedence: explicit params > client_configs - merged_configs = { - **client_configs, - "api_key": api_key, - } - - # Add explicit parameters (override client_configs) - if base_url is not None: - merged_configs["azure_endpoint"] = base_url - if azure_deployment is not None: - merged_configs["azure_deployment"] = azure_deployment - if api_version is not None: - merged_configs["api_version"] = api_version - if timeout is not None: - merged_configs["timeout"] = timeout - - return AsyncAzureOpenAI(**merged_configs) + return AsyncAzureOpenAI( + azure_endpoint=base_url, + azure_deployment=azure_deployment, + api_key=api_key, + api_version=api_version, + timeout=timeout, + ) else: if not api_key: api_key = os.environ["OPENAI_API_KEY"] @@ -205,33 +192,23 @@ async def openai_complete_if_cache( 6. For non-streaming: COT content is prepended to regular content with tags. Args: - model: The OpenAI model to use. For Azure, this can be the deployment name. + model: The OpenAI model to use. prompt: The prompt to complete. system_prompt: Optional system prompt to include. history_messages: Optional list of previous messages in the conversation. - enable_cot: Whether to enable Chain of Thought (COT) processing. Default is False. - base_url: Optional base URL for the OpenAI API. For Azure, this should be the - Azure OpenAI endpoint (e.g., https://your-resource.openai.azure.com/). - api_key: Optional API key. For standard OpenAI, uses OPENAI_API_KEY environment - variable if None. For Azure, uses AZURE_OPENAI_API_KEY if None. + base_url: Optional base URL for the OpenAI API. + api_key: Optional OpenAI API key. If None, uses the OPENAI_API_KEY environment variable. token_tracker: Optional token usage tracker for monitoring API usage. + enable_cot: Whether to enable Chain of Thought (COT) processing. Default is False. stream: Whether to stream the response. Default is False. timeout: Request timeout in seconds. Default is None. keyword_extraction: Whether to enable keyword extraction mode. When True, triggers special response formatting for keyword extraction. Default is False. - use_azure: Whether to use Azure OpenAI service instead of standard OpenAI. - When True, creates an AsyncAzureOpenAI client. Default is False. - azure_deployment: Azure OpenAI deployment name. Only used when use_azure=True. - If not specified, falls back to AZURE_OPENAI_DEPLOYMENT environment variable. - api_version: Azure OpenAI API version (e.g., "2024-02-15-preview"). Only used - when use_azure=True. If not specified, falls back to AZURE_OPENAI_API_VERSION - environment variable. **kwargs: Additional keyword arguments to pass to the OpenAI API. Special kwargs: - openai_client_configs: Dict of configuration options for the AsyncOpenAI client. These will be passed to the client constructor but will be overridden by - explicit parameters (api_key, base_url). Supports proxy configuration, - custom headers, retry policies, etc. + explicit parameters (api_key, base_url). Returns: The completed text (with integrated COT content if available) or an async iterator @@ -694,34 +671,21 @@ async def openai_embed( ) -> np.ndarray: """Generate embeddings for a list of texts using OpenAI's API. - This function supports both standard OpenAI and Azure OpenAI services. - Args: texts: List of texts to embed. - model: The embedding model to use. For standard OpenAI (e.g., "text-embedding-3-small"). - For Azure, this can be the deployment name. - base_url: Optional base URL for the API. For standard OpenAI, uses default OpenAI endpoint. - For Azure, this should be the Azure OpenAI endpoint (e.g., https://your-resource.openai.azure.com/). - api_key: Optional API key. For standard OpenAI, uses OPENAI_API_KEY environment variable if None. - For Azure, uses AZURE_EMBEDDING_API_KEY environment variable if None. + model: The OpenAI embedding model to use. + base_url: Optional base URL for the OpenAI API. + api_key: Optional OpenAI API key. If None, uses the OPENAI_API_KEY environment variable. embedding_dim: Optional embedding dimension for dynamic dimension reduction. **IMPORTANT**: This parameter is automatically injected by the EmbeddingFunc wrapper. Do NOT manually pass this parameter when calling the function directly. The dimension is controlled by the @wrap_embedding_func_with_attrs decorator. Manually passing a different value will trigger a warning and be ignored. When provided (by EmbeddingFunc), it will be passed to the OpenAI API for dimension reduction. - client_configs: Additional configuration options for the AsyncOpenAI/AsyncAzureOpenAI client. + client_configs: Additional configuration options for the AsyncOpenAI client. These will override any default configurations but will be overridden by - explicit parameters (api_key, base_url). Supports proxy configuration, - custom headers, retry policies, etc. + explicit parameters (api_key, base_url). token_tracker: Optional token usage tracker for monitoring API usage. - use_azure: Whether to use Azure OpenAI service instead of standard OpenAI. - When True, creates an AsyncAzureOpenAI client. Default is False. - azure_deployment: Azure OpenAI deployment name. Only used when use_azure=True. - If not specified, falls back to AZURE_EMBEDDING_DEPLOYMENT environment variable. - api_version: Azure OpenAI API version (e.g., "2024-02-15-preview"). Only used - when use_azure=True. If not specified, falls back to AZURE_EMBEDDING_API_VERSION - environment variable. Returns: A numpy array of embeddings, one per input text. @@ -782,20 +746,14 @@ async def azure_openai_complete_if_cache( enable_cot: bool = False, base_url: str | None = None, api_key: str | None = None, - token_tracker: Any | None = None, - stream: bool | None = None, - timeout: int | None = None, api_version: str | None = None, keyword_extraction: bool = False, **kwargs, ): """Azure OpenAI completion wrapper function. - + This function provides backward compatibility by wrapping the unified openai_complete_if_cache implementation with Azure-specific parameter handling. - - All parameters from the underlying openai_complete_if_cache are exposed to ensure - full feature parity and API consistency. """ # Handle Azure-specific environment variables and parameters deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT") or model or os.getenv("LLM_MODEL") @@ -810,7 +768,10 @@ async def azure_openai_complete_if_cache( or os.getenv("AZURE_OPENAI_API_VERSION") or os.getenv("OPENAI_API_VERSION") ) - + + # Pop timeout from kwargs if present (will be handled by openai_complete_if_cache) + timeout = kwargs.pop("timeout", None) + # Call the unified implementation with Azure-specific parameters return await openai_complete_if_cache( model=model, @@ -820,8 +781,6 @@ async def azure_openai_complete_if_cache( enable_cot=enable_cot, base_url=base_url, api_key=api_key, - token_tracker=token_tracker, - stream=stream, timeout=timeout, use_azure=True, azure_deployment=deployment, @@ -832,14 +791,10 @@ async def azure_openai_complete_if_cache( async def azure_openai_complete( - prompt, - system_prompt=None, - history_messages=None, - keyword_extraction=False, - **kwargs, + prompt, system_prompt=None, history_messages=None, keyword_extraction=False, **kwargs ) -> str: """Azure OpenAI complete wrapper function. - + Provides backward compatibility for azure_openai_complete calls. """ if history_messages is None: @@ -856,51 +811,24 @@ async def azure_openai_complete( @wrap_embedding_func_with_attrs(embedding_dim=1536) +@retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type( + (RateLimitError, APIConnectionError, APITimeoutError) + ), +) async def azure_openai_embed( texts: list[str], model: str | None = None, base_url: str | None = None, api_key: str | None = None, - token_tracker: Any | None = None, - client_configs: dict[str, Any] | None = None, api_version: str | None = None, ) -> np.ndarray: """Azure OpenAI embedding wrapper function. - + This function provides backward compatibility by wrapping the unified openai_embed implementation with Azure-specific parameter handling. - - All parameters from the underlying openai_embed are exposed to ensure - full feature parity and API consistency. - - IMPORTANT - Decorator Usage: - - 1. This function is decorated with @wrap_embedding_func_with_attrs to provide - the EmbeddingFunc interface for users who need to access embedding_dim - and other attributes. - - 2. This function does NOT use @retry decorator to avoid double-wrapping, - since the underlying openai_embed.func already has retry logic. - - 3. This function calls openai_embed.func (the unwrapped function) instead of - openai_embed (the EmbeddingFunc instance) to avoid double decoration issues: - - ✅ Correct: await openai_embed.func(...) # Calls unwrapped function with retry - ❌ Wrong: await openai_embed(...) # Would cause double EmbeddingFunc wrapping - - Double decoration causes: - - Double injection of embedding_dim parameter - - Incorrect parameter passing to the underlying implementation - - Runtime errors due to parameter conflicts - - The call chain with correct implementation: - azure_openai_embed(texts) - → EmbeddingFunc.__call__(texts) # azure's decorator - → azure_openai_embed_impl(texts, embedding_dim=1536) - → openai_embed.func(texts, ...) - → @retry_wrapper(texts, ...) # openai's retry (only one layer) - → openai_embed_impl(texts, ...) - → actual embedding computation """ # Handle Azure-specific environment variables and parameters deployment = ( @@ -923,16 +851,13 @@ async def azure_openai_embed( or os.getenv("AZURE_EMBEDDING_API_VERSION") or os.getenv("OPENAI_API_VERSION") ) - - # CRITICAL: Call openai_embed.func (unwrapped) to avoid double decoration - # openai_embed is an EmbeddingFunc instance, .func accesses the underlying function - return await openai_embed.func( + + # Call the unified implementation with Azure-specific parameters + return await openai_embed( texts=texts, model=model or deployment, base_url=base_url, api_key=api_key, - token_tracker=token_tracker, - client_configs=client_configs, use_azure=True, azure_deployment=deployment, api_version=api_version,