From b709f8f869c528f045416fc9c584bbc1661bf450 Mon Sep 17 00:00:00 2001 From: yangdx Date: Fri, 21 Nov 2025 17:12:33 +0800 Subject: [PATCH] Consolidate Azure OpenAI implementation into main OpenAI module MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit • Unified OpenAI/Azure client creation • Azure module now re-exports functions • Backward compatibility maintained • Reduced code duplication --- lightrag/llm/azure_openai.py | 205 +++----------------------------- lightrag/llm/openai.py | 223 ++++++++++++++++++++++++++++++----- 2 files changed, 213 insertions(+), 215 deletions(-) diff --git a/lightrag/llm/azure_openai.py b/lightrag/llm/azure_openai.py index cb8d68df..1fc6feef 100644 --- a/lightrag/llm/azure_openai.py +++ b/lightrag/llm/azure_openai.py @@ -1,193 +1,22 @@ -from collections.abc import Iterable -import os -import pipmaster as pm # Pipmaster for dynamic library install +""" +Azure OpenAI compatibility layer. -# install specific modules -if not pm.is_installed("openai"): - pm.install("openai") +This module provides backward compatibility by re-exporting Azure OpenAI functions +from the main openai module where the actual implementation resides. -from openai import ( - AsyncAzureOpenAI, - APIConnectionError, - RateLimitError, - APITimeoutError, -) -from openai.types.chat import ChatCompletionMessageParam +All core logic for both OpenAI and Azure OpenAI now lives in lightrag.llm.openai, +with this module serving as a thin compatibility wrapper for existing code that +imports from lightrag.llm.azure_openai. +""" -from tenacity import ( - retry, - stop_after_attempt, - wait_exponential, - retry_if_exception_type, +from lightrag.llm.openai import ( + azure_openai_complete_if_cache, + azure_openai_complete, + azure_openai_embed, ) -from lightrag.utils import ( - wrap_embedding_func_with_attrs, - safe_unicode_decode, - logger, -) -from lightrag.types import GPTKeywordExtractionFormat - -import numpy as np - - -@retry( - stop=stop_after_attempt(3), - wait=wait_exponential(multiplier=1, min=4, max=10), - retry=retry_if_exception_type( - (RateLimitError, APIConnectionError, APIConnectionError) - ), -) -async def azure_openai_complete_if_cache( - model, - prompt, - system_prompt: str | None = None, - history_messages: Iterable[ChatCompletionMessageParam] | None = None, - enable_cot: bool = False, - base_url: str | None = None, - api_key: str | None = None, - api_version: str | None = None, - keyword_extraction: bool = False, - **kwargs, -): - if enable_cot: - logger.debug( - "enable_cot=True is not supported for the Azure OpenAI API and will be ignored." - ) - deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT") or model or os.getenv("LLM_MODEL") - base_url = ( - base_url or os.getenv("AZURE_OPENAI_ENDPOINT") or os.getenv("LLM_BINDING_HOST") - ) - api_key = ( - api_key or os.getenv("AZURE_OPENAI_API_KEY") or os.getenv("LLM_BINDING_API_KEY") - ) - api_version = ( - api_version - or os.getenv("AZURE_OPENAI_API_VERSION") - or os.getenv("OPENAI_API_VERSION") - ) - - kwargs.pop("hashing_kv", None) - timeout = kwargs.pop("timeout", None) - - # Handle keyword extraction mode - if keyword_extraction: - kwargs["response_format"] = GPTKeywordExtractionFormat - - openai_async_client = AsyncAzureOpenAI( - azure_endpoint=base_url, - azure_deployment=deployment, - api_key=api_key, - api_version=api_version, - timeout=timeout, - ) - messages = [] - if system_prompt: - messages.append({"role": "system", "content": system_prompt}) - if history_messages: - messages.extend(history_messages) - if prompt is not None: - messages.append({"role": "user", "content": prompt}) - - if "response_format" in kwargs: - response = await openai_async_client.chat.completions.parse( - model=model, messages=messages, **kwargs - ) - else: - response = await openai_async_client.chat.completions.create( - model=model, messages=messages, **kwargs - ) - - if hasattr(response, "__aiter__"): - - async def inner(): - async for chunk in response: - if len(chunk.choices) == 0: - continue - content = chunk.choices[0].delta.content - if content is None: - continue - if r"\u" in content: - content = safe_unicode_decode(content.encode("utf-8")) - yield content - - return inner() - else: - message = response.choices[0].message - - # Handle parsed responses (structured output via response_format) - # When using beta.chat.completions.parse(), the response is in message.parsed - if hasattr(message, "parsed") and message.parsed is not None: - # Serialize the parsed structured response to JSON - content = message.parsed.model_dump_json() - logger.debug("Using parsed structured response from API") - else: - # Handle regular content responses - content = message.content - if content and r"\u" in content: - content = safe_unicode_decode(content.encode("utf-8")) - - return content - - -async def azure_openai_complete( - prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs -) -> str: - result = await azure_openai_complete_if_cache( - os.getenv("LLM_MODEL", "gpt-4o-mini"), - prompt, - system_prompt=system_prompt, - history_messages=history_messages, - keyword_extraction=keyword_extraction, - **kwargs, - ) - return result - - -@wrap_embedding_func_with_attrs(embedding_dim=1536) -@retry( - stop=stop_after_attempt(3), - wait=wait_exponential(multiplier=1, min=4, max=10), - retry=retry_if_exception_type( - (RateLimitError, APIConnectionError, APITimeoutError) - ), -) -async def azure_openai_embed( - texts: list[str], - model: str | None = None, - base_url: str | None = None, - api_key: str | None = None, - api_version: str | None = None, -) -> np.ndarray: - deployment = ( - os.getenv("AZURE_EMBEDDING_DEPLOYMENT") - or model - or os.getenv("EMBEDDING_MODEL", "text-embedding-3-small") - ) - base_url = ( - base_url - or os.getenv("AZURE_EMBEDDING_ENDPOINT") - or os.getenv("EMBEDDING_BINDING_HOST") - ) - api_key = ( - api_key - or os.getenv("AZURE_EMBEDDING_API_KEY") - or os.getenv("EMBEDDING_BINDING_API_KEY") - ) - api_version = ( - api_version - or os.getenv("AZURE_EMBEDDING_API_VERSION") - or os.getenv("OPENAI_API_VERSION") - ) - - openai_async_client = AsyncAzureOpenAI( - azure_endpoint=base_url, - azure_deployment=deployment, - api_key=api_key, - api_version=api_version, - ) - - response = await openai_async_client.embeddings.create( - model=model or deployment, input=texts, encoding_format="float" - ) - return np.array([dp.embedding for dp in response.data]) +__all__ = [ + "azure_openai_complete_if_cache", + "azure_openai_complete", + "azure_openai_embed", +] diff --git a/lightrag/llm/openai.py b/lightrag/llm/openai.py index 6da79c2c..829cf736 100644 --- a/lightrag/llm/openai.py +++ b/lightrag/llm/openai.py @@ -77,46 +77,73 @@ class InvalidResponseError(Exception): def create_openai_async_client( api_key: str | None = None, base_url: str | None = None, + use_azure: bool = False, + azure_deployment: str | None = None, + api_version: str | None = None, + timeout: int | None = None, client_configs: dict[str, Any] | None = None, ) -> AsyncOpenAI: - """Create an AsyncOpenAI client with the given configuration. + """Create an AsyncOpenAI or AsyncAzureOpenAI client with the given configuration. Args: api_key: OpenAI API key. If None, uses the OPENAI_API_KEY environment variable. base_url: Base URL for the OpenAI API. If None, uses the default OpenAI API URL. + use_azure: Whether to create an Azure OpenAI client. Default is False. + azure_deployment: Azure OpenAI deployment name (only used when use_azure=True). + api_version: Azure OpenAI API version (only used when use_azure=True). + timeout: Request timeout in seconds. client_configs: Additional configuration options for the AsyncOpenAI client. These will override any default configurations but will be overridden by explicit parameters (api_key, base_url). Returns: - An AsyncOpenAI client instance. + An AsyncOpenAI or AsyncAzureOpenAI client instance. """ - if not api_key: - api_key = os.environ["OPENAI_API_KEY"] + if use_azure: + from openai import AsyncAzureOpenAI - default_headers = { - "User-Agent": f"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}", - "Content-Type": "application/json", - } + if not api_key: + api_key = os.environ.get("AZURE_OPENAI_API_KEY") or os.environ.get( + "LLM_BINDING_API_KEY" + ) - if client_configs is None: - client_configs = {} - - # Create a merged config dict with precedence: explicit params > client_configs > defaults - merged_configs = { - **client_configs, - "default_headers": default_headers, - "api_key": api_key, - } - - if base_url is not None: - merged_configs["base_url"] = base_url - else: - merged_configs["base_url"] = os.environ.get( - "OPENAI_API_BASE", "https://api.openai.com/v1" + return AsyncAzureOpenAI( + azure_endpoint=base_url, + azure_deployment=azure_deployment, + api_key=api_key, + api_version=api_version, + timeout=timeout, ) + else: + if not api_key: + api_key = os.environ["OPENAI_API_KEY"] - return AsyncOpenAI(**merged_configs) + default_headers = { + "User-Agent": f"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}", + "Content-Type": "application/json", + } + + if client_configs is None: + client_configs = {} + + # Create a merged config dict with precedence: explicit params > client_configs > defaults + merged_configs = { + **client_configs, + "default_headers": default_headers, + "api_key": api_key, + } + + if base_url is not None: + merged_configs["base_url"] = base_url + else: + merged_configs["base_url"] = os.environ.get( + "OPENAI_API_BASE", "https://api.openai.com/v1" + ) + + if timeout is not None: + merged_configs["timeout"] = timeout + + return AsyncOpenAI(**merged_configs) @retry( @@ -141,6 +168,9 @@ async def openai_complete_if_cache( stream: bool | None = None, timeout: int | None = None, keyword_extraction: bool = False, + use_azure: bool = False, + azure_deployment: str | None = None, + api_version: str | None = None, **kwargs: Any, ) -> str: """Complete a prompt using OpenAI's API with caching support and Chain of Thought (COT) integration. @@ -207,10 +237,14 @@ async def openai_complete_if_cache( if keyword_extraction: kwargs["response_format"] = GPTKeywordExtractionFormat - # Create the OpenAI client + # Create the OpenAI client (supports both OpenAI and Azure) openai_async_client = create_openai_async_client( api_key=api_key, base_url=base_url, + use_azure=use_azure, + azure_deployment=azure_deployment, + api_version=api_version, + timeout=timeout, client_configs=client_configs, ) @@ -631,6 +665,9 @@ async def openai_embed( embedding_dim: int | None = None, client_configs: dict[str, Any] | None = None, token_tracker: Any | None = None, + use_azure: bool = False, + azure_deployment: str | None = None, + api_version: str | None = None, ) -> np.ndarray: """Generate embeddings for a list of texts using OpenAI's API. @@ -658,9 +695,14 @@ async def openai_embed( RateLimitError: If the OpenAI API rate limit is exceeded. APITimeoutError: If the OpenAI API request times out. """ - # Create the OpenAI client + # Create the OpenAI client (supports both OpenAI and Azure) openai_async_client = create_openai_async_client( - api_key=api_key, base_url=base_url, client_configs=client_configs + api_key=api_key, + base_url=base_url, + use_azure=use_azure, + azure_deployment=azure_deployment, + api_version=api_version, + client_configs=client_configs, ) async with openai_async_client: @@ -693,3 +735,130 @@ async def openai_embed( for dp in response.data ] ) + + +# Azure OpenAI wrapper functions for backward compatibility +async def azure_openai_complete_if_cache( + model, + prompt, + system_prompt: str | None = None, + history_messages: list[dict[str, Any]] | None = None, + enable_cot: bool = False, + base_url: str | None = None, + api_key: str | None = None, + api_version: str | None = None, + keyword_extraction: bool = False, + **kwargs, +): + """Azure OpenAI completion wrapper function. + + This function provides backward compatibility by wrapping the unified + openai_complete_if_cache implementation with Azure-specific parameter handling. + """ + # Handle Azure-specific environment variables and parameters + deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT") or model or os.getenv("LLM_MODEL") + base_url = ( + base_url or os.getenv("AZURE_OPENAI_ENDPOINT") or os.getenv("LLM_BINDING_HOST") + ) + api_key = ( + api_key or os.getenv("AZURE_OPENAI_API_KEY") or os.getenv("LLM_BINDING_API_KEY") + ) + api_version = ( + api_version + or os.getenv("AZURE_OPENAI_API_VERSION") + or os.getenv("OPENAI_API_VERSION") + ) + + # Pop timeout from kwargs if present (will be handled by openai_complete_if_cache) + timeout = kwargs.pop("timeout", None) + + # Call the unified implementation with Azure-specific parameters + return await openai_complete_if_cache( + model=model, + prompt=prompt, + system_prompt=system_prompt, + history_messages=history_messages, + enable_cot=enable_cot, + base_url=base_url, + api_key=api_key, + timeout=timeout, + use_azure=True, + azure_deployment=deployment, + api_version=api_version, + keyword_extraction=keyword_extraction, + **kwargs, + ) + + +async def azure_openai_complete( + prompt, system_prompt=None, history_messages=None, keyword_extraction=False, **kwargs +) -> str: + """Azure OpenAI complete wrapper function. + + Provides backward compatibility for azure_openai_complete calls. + """ + if history_messages is None: + history_messages = [] + result = await azure_openai_complete_if_cache( + os.getenv("LLM_MODEL", "gpt-4o-mini"), + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + keyword_extraction=keyword_extraction, + **kwargs, + ) + return result + + +@wrap_embedding_func_with_attrs(embedding_dim=1536) +@retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type( + (RateLimitError, APIConnectionError, APITimeoutError) + ), +) +async def azure_openai_embed( + texts: list[str], + model: str | None = None, + base_url: str | None = None, + api_key: str | None = None, + api_version: str | None = None, +) -> np.ndarray: + """Azure OpenAI embedding wrapper function. + + This function provides backward compatibility by wrapping the unified + openai_embed implementation with Azure-specific parameter handling. + """ + # Handle Azure-specific environment variables and parameters + deployment = ( + os.getenv("AZURE_EMBEDDING_DEPLOYMENT") + or model + or os.getenv("EMBEDDING_MODEL", "text-embedding-3-small") + ) + base_url = ( + base_url + or os.getenv("AZURE_EMBEDDING_ENDPOINT") + or os.getenv("EMBEDDING_BINDING_HOST") + ) + api_key = ( + api_key + or os.getenv("AZURE_EMBEDDING_API_KEY") + or os.getenv("EMBEDDING_BINDING_API_KEY") + ) + api_version = ( + api_version + or os.getenv("AZURE_EMBEDDING_API_VERSION") + or os.getenv("OPENAI_API_VERSION") + ) + + # Call the unified implementation with Azure-specific parameters + return await openai_embed( + texts=texts, + model=model or deployment, + base_url=base_url, + api_key=api_key, + use_azure=True, + azure_deployment=deployment, + api_version=api_version, + )