From 79698f6faee658478fe53b47f3931dba40e20f93 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20MANSUY?= Date: Thu, 4 Dec 2025 19:19:21 +0800 Subject: [PATCH] cherry-pick b46c1523 --- lightrag/llm/openai.py | 435 +++++++++++++++++++++++++++++++---------- 1 file changed, 331 insertions(+), 104 deletions(-) diff --git a/lightrag/llm/openai.py b/lightrag/llm/openai.py index 704edc8c..1e840d08 100644 --- a/lightrag/llm/openai.py +++ b/lightrag/llm/openai.py @@ -11,7 +11,6 @@ if not pm.is_installed("openai"): pm.install("openai") from openai import ( - AsyncOpenAI, APIConnectionError, RateLimitError, APITimeoutError, @@ -28,18 +27,6 @@ from lightrag.utils import ( logger, ) -# Try to import Langfuse for LLM observability (optional) -# Falls back to standard OpenAI client if not available -try: - from langfuse.openai import AsyncOpenAI - - LANGFUSE_ENABLED = True - logger.info("Langfuse observability enabled for OpenAI client") -except ImportError: - from openai import AsyncOpenAI - - LANGFUSE_ENABLED = False - logger.debug("Langfuse not available, using standard OpenAI client") from lightrag.types import GPTKeywordExtractionFormat from lightrag.api import __api_version__ @@ -49,6 +36,32 @@ from typing import Any, Union from dotenv import load_dotenv +# Try to import Langfuse for LLM observability (optional) +# Falls back to standard OpenAI client if not available +# Langfuse requires proper configuration to work correctly +LANGFUSE_ENABLED = False +try: + # Check if required Langfuse environment variables are set + langfuse_public_key = os.environ.get("LANGFUSE_PUBLIC_KEY") + langfuse_secret_key = os.environ.get("LANGFUSE_SECRET_KEY") + + # Only enable Langfuse if both keys are configured + if langfuse_public_key and langfuse_secret_key: + from langfuse.openai import AsyncOpenAI # type: ignore[import-untyped] + + LANGFUSE_ENABLED = True + logger.info("Langfuse observability enabled for OpenAI client") + else: + from openai import AsyncOpenAI + + logger.debug( + "Langfuse environment variables not configured, using standard OpenAI client" + ) +except ImportError: + from openai import AsyncOpenAI + + logger.debug("Langfuse not available, using standard OpenAI client") + # use the .env that is inside the current folder # allows to use different .env file for each lightrag instance # the OS environment variables take precedence over the .env file @@ -64,46 +77,73 @@ class InvalidResponseError(Exception): def create_openai_async_client( api_key: str | None = None, base_url: str | None = None, + use_azure: bool = False, + azure_deployment: str | None = None, + api_version: str | None = None, + timeout: int | None = None, client_configs: dict[str, Any] | None = None, ) -> AsyncOpenAI: - """Create an AsyncOpenAI client with the given configuration. + """Create an AsyncOpenAI or AsyncAzureOpenAI client with the given configuration. Args: api_key: OpenAI API key. If None, uses the OPENAI_API_KEY environment variable. base_url: Base URL for the OpenAI API. If None, uses the default OpenAI API URL. + use_azure: Whether to create an Azure OpenAI client. Default is False. + azure_deployment: Azure OpenAI deployment name (only used when use_azure=True). + api_version: Azure OpenAI API version (only used when use_azure=True). + timeout: Request timeout in seconds. client_configs: Additional configuration options for the AsyncOpenAI client. These will override any default configurations but will be overridden by explicit parameters (api_key, base_url). Returns: - An AsyncOpenAI client instance. + An AsyncOpenAI or AsyncAzureOpenAI client instance. """ - if not api_key: - api_key = os.environ["OPENAI_API_KEY"] + if use_azure: + from openai import AsyncAzureOpenAI - default_headers = { - "User-Agent": f"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}", - "Content-Type": "application/json", - } + if not api_key: + api_key = os.environ.get("AZURE_OPENAI_API_KEY") or os.environ.get( + "LLM_BINDING_API_KEY" + ) - if client_configs is None: - client_configs = {} - - # Create a merged config dict with precedence: explicit params > client_configs > defaults - merged_configs = { - **client_configs, - "default_headers": default_headers, - "api_key": api_key, - } - - if base_url is not None: - merged_configs["base_url"] = base_url - else: - merged_configs["base_url"] = os.environ.get( - "OPENAI_API_BASE", "https://api.openai.com/v1" + return AsyncAzureOpenAI( + azure_endpoint=base_url, + azure_deployment=azure_deployment, + api_key=api_key, + api_version=api_version, + timeout=timeout, ) + else: + if not api_key: + api_key = os.environ["OPENAI_API_KEY"] - return AsyncOpenAI(**merged_configs) + default_headers = { + "User-Agent": f"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}", + "Content-Type": "application/json", + } + + if client_configs is None: + client_configs = {} + + # Create a merged config dict with precedence: explicit params > client_configs > defaults + merged_configs = { + **client_configs, + "default_headers": default_headers, + "api_key": api_key, + } + + if base_url is not None: + merged_configs["base_url"] = base_url + else: + merged_configs["base_url"] = os.environ.get( + "OPENAI_API_BASE", "https://api.openai.com/v1" + ) + + if timeout is not None: + merged_configs["timeout"] = timeout + + return AsyncOpenAI(**merged_configs) @retry( @@ -125,6 +165,12 @@ async def openai_complete_if_cache( base_url: str | None = None, api_key: str | None = None, token_tracker: Any | None = None, + stream: bool | None = None, + timeout: int | None = None, + keyword_extraction: bool = False, + use_azure: bool = False, + azure_deployment: str | None = None, + api_version: str | None = None, **kwargs: Any, ) -> str: """Complete a prompt using OpenAI's API with caching support and Chain of Thought (COT) integration. @@ -154,13 +200,15 @@ async def openai_complete_if_cache( api_key: Optional OpenAI API key. If None, uses the OPENAI_API_KEY environment variable. token_tracker: Optional token usage tracker for monitoring API usage. enable_cot: Whether to enable Chain of Thought (COT) processing. Default is False. + stream: Whether to stream the response. Default is False. + timeout: Request timeout in seconds. Default is None. + keyword_extraction: Whether to enable keyword extraction mode. When True, triggers + special response formatting for keyword extraction. Default is False. **kwargs: Additional keyword arguments to pass to the OpenAI API. Special kwargs: - openai_client_configs: Dict of configuration options for the AsyncOpenAI client. These will be passed to the client constructor but will be overridden by explicit parameters (api_key, base_url). - - hashing_kv: Will be removed from kwargs before passing to OpenAI. - - keyword_extraction: Will be removed from kwargs before passing to OpenAI. Returns: The completed text (with integrated COT content if available) or an async iterator @@ -181,15 +229,22 @@ async def openai_complete_if_cache( # Remove special kwargs that shouldn't be passed to OpenAI kwargs.pop("hashing_kv", None) - kwargs.pop("keyword_extraction", None) # Extract client configuration options client_configs = kwargs.pop("openai_client_configs", {}) - # Create the OpenAI client + # Handle keyword extraction mode + if keyword_extraction: + kwargs["response_format"] = GPTKeywordExtractionFormat + + # Create the OpenAI client (supports both OpenAI and Azure) openai_async_client = create_openai_async_client( api_key=api_key, base_url=base_url, + use_azure=use_azure, + azure_deployment=azure_deployment, + api_version=api_version, + timeout=timeout, client_configs=client_configs, ) @@ -211,10 +266,16 @@ async def openai_complete_if_cache( messages = kwargs.pop("messages", messages) + # Add explicit parameters back to kwargs so they're passed to OpenAI API + if stream is not None: + kwargs["stream"] = stream + if timeout is not None: + kwargs["timeout"] = timeout + try: # Don't use async with context manager, use client directly if "response_format" in kwargs: - response = await openai_async_client.beta.chat.completions.parse( + response = await openai_async_client.chat.completions.parse( model=model, messages=messages, **kwargs ) else: @@ -383,18 +444,23 @@ async def openai_complete_if_cache( ) # Ensure resources are released even if no exception occurs - if ( - iteration_started - and hasattr(response, "aclose") - and callable(getattr(response, "aclose", None)) - ): - try: - await response.aclose() - logger.debug("Successfully closed stream response") - except Exception as close_error: - logger.warning( - f"Failed to close stream response in finally block: {close_error}" - ) + # Note: Some wrapped clients (e.g., Langfuse) may not implement aclose() properly + if iteration_started and hasattr(response, "aclose"): + aclose_method = getattr(response, "aclose", None) + if callable(aclose_method): + try: + await response.aclose() + logger.debug("Successfully closed stream response") + except (AttributeError, TypeError) as close_error: + # Some wrapper objects may report hasattr(aclose) but fail when called + # This is expected behavior for certain client wrappers + logger.debug( + f"Stream response cleanup not supported by client wrapper: {close_error}" + ) + except Exception as close_error: + logger.warning( + f"Unexpected error during stream response cleanup: {close_error}" + ) # This prevents resource leaks since the caller doesn't handle closing try: @@ -421,46 +487,57 @@ async def openai_complete_if_cache( raise InvalidResponseError("Invalid response from OpenAI API") message = response.choices[0].message - content = getattr(message, "content", None) - reasoning_content = getattr(message, "reasoning_content", "") - # Handle COT logic for non-streaming responses (only if enabled) - final_content = "" + # Handle parsed responses (structured output via response_format) + # When using beta.chat.completions.parse(), the response is in message.parsed + if hasattr(message, "parsed") and message.parsed is not None: + # Serialize the parsed structured response to JSON + final_content = message.parsed.model_dump_json() + logger.debug("Using parsed structured response from API") + else: + # Handle regular content responses + content = getattr(message, "content", None) + reasoning_content = getattr(message, "reasoning_content", "") - if enable_cot: - # Check if we should include reasoning content - should_include_reasoning = False - if reasoning_content and reasoning_content.strip(): - if not content or content.strip() == "": - # Case 1: Only reasoning content, should include COT - should_include_reasoning = True - final_content = ( - content or "" - ) # Use empty string if content is None + # Handle COT logic for non-streaming responses (only if enabled) + final_content = "" + + if enable_cot: + # Check if we should include reasoning content + should_include_reasoning = False + if reasoning_content and reasoning_content.strip(): + if not content or content.strip() == "": + # Case 1: Only reasoning content, should include COT + should_include_reasoning = True + final_content = ( + content or "" + ) # Use empty string if content is None + else: + # Case 3: Both content and reasoning_content present, ignore reasoning + should_include_reasoning = False + final_content = content else: - # Case 3: Both content and reasoning_content present, ignore reasoning - should_include_reasoning = False - final_content = content + # No reasoning content, use regular content + final_content = content or "" + + # Apply COT wrapping if needed + if should_include_reasoning: + if r"\u" in reasoning_content: + reasoning_content = safe_unicode_decode( + reasoning_content.encode("utf-8") + ) + final_content = ( + f"{reasoning_content}{final_content}" + ) else: - # No reasoning content, use regular content + # COT disabled, only use regular content final_content = content or "" - # Apply COT wrapping if needed - if should_include_reasoning: - if r"\u" in reasoning_content: - reasoning_content = safe_unicode_decode( - reasoning_content.encode("utf-8") - ) - final_content = f"{reasoning_content}{final_content}" - else: - # COT disabled, only use regular content - final_content = content or "" - - # Validate final content - if not final_content or final_content.strip() == "": - logger.error("Received empty content from OpenAI API") - await openai_async_client.close() # Ensure client is closed - raise InvalidResponseError("Received empty content from OpenAI API") + # Validate final content + if not final_content or final_content.strip() == "": + logger.error("Received empty content from OpenAI API") + await openai_async_client.close() # Ensure client is closed + raise InvalidResponseError("Received empty content from OpenAI API") # Apply Unicode decoding to final content if needed if r"\u" in final_content: @@ -494,15 +571,13 @@ async def openai_complete( ) -> Union[str, AsyncIterator[str]]: if history_messages is None: history_messages = [] - keyword_extraction = kwargs.pop("keyword_extraction", None) - if keyword_extraction: - kwargs["response_format"] = "json" model_name = kwargs["hashing_kv"].global_config["llm_model_name"] return await openai_complete_if_cache( model_name, prompt, system_prompt=system_prompt, history_messages=history_messages, + keyword_extraction=keyword_extraction, **kwargs, ) @@ -517,15 +592,13 @@ async def gpt_4o_complete( ) -> str: if history_messages is None: history_messages = [] - keyword_extraction = kwargs.pop("keyword_extraction", None) - if keyword_extraction: - kwargs["response_format"] = GPTKeywordExtractionFormat return await openai_complete_if_cache( "gpt-4o", prompt, system_prompt=system_prompt, history_messages=history_messages, enable_cot=enable_cot, + keyword_extraction=keyword_extraction, **kwargs, ) @@ -540,15 +613,13 @@ async def gpt_4o_mini_complete( ) -> str: if history_messages is None: history_messages = [] - keyword_extraction = kwargs.pop("keyword_extraction", None) - if keyword_extraction: - kwargs["response_format"] = GPTKeywordExtractionFormat return await openai_complete_if_cache( "gpt-4o-mini", prompt, system_prompt=system_prompt, history_messages=history_messages, enable_cot=enable_cot, + keyword_extraction=keyword_extraction, **kwargs, ) @@ -563,20 +634,20 @@ async def nvidia_openai_complete( ) -> str: if history_messages is None: history_messages = [] - kwargs.pop("keyword_extraction", None) result = await openai_complete_if_cache( "nvidia/llama-3.1-nemotron-70b-instruct", # context length 128k prompt, system_prompt=system_prompt, history_messages=history_messages, enable_cot=enable_cot, + keyword_extraction=keyword_extraction, base_url="https://integrate.api.nvidia.com/v1", **kwargs, ) return result -@wrap_embedding_func_with_attrs(embedding_dim=1536) +@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192) @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=60), @@ -591,8 +662,12 @@ async def openai_embed( model: str = "text-embedding-3-small", base_url: str | None = None, api_key: str | None = None, + embedding_dim: int | None = None, client_configs: dict[str, Any] | None = None, token_tracker: Any | None = None, + use_azure: bool = False, + azure_deployment: str | None = None, + api_version: str | None = None, ) -> np.ndarray: """Generate embeddings for a list of texts using OpenAI's API. @@ -601,6 +676,12 @@ async def openai_embed( model: The OpenAI embedding model to use. base_url: Optional base URL for the OpenAI API. api_key: Optional OpenAI API key. If None, uses the OPENAI_API_KEY environment variable. + embedding_dim: Optional embedding dimension for dynamic dimension reduction. + **IMPORTANT**: This parameter is automatically injected by the EmbeddingFunc wrapper. + Do NOT manually pass this parameter when calling the function directly. + The dimension is controlled by the @wrap_embedding_func_with_attrs decorator. + Manually passing a different value will trigger a warning and be ignored. + When provided (by EmbeddingFunc), it will be passed to the OpenAI API for dimension reduction. client_configs: Additional configuration options for the AsyncOpenAI client. These will override any default configurations but will be overridden by explicit parameters (api_key, base_url). @@ -614,15 +695,30 @@ async def openai_embed( RateLimitError: If the OpenAI API rate limit is exceeded. APITimeoutError: If the OpenAI API request times out. """ - # Create the OpenAI client + # Create the OpenAI client (supports both OpenAI and Azure) openai_async_client = create_openai_async_client( - api_key=api_key, base_url=base_url, client_configs=client_configs + api_key=api_key, + base_url=base_url, + use_azure=use_azure, + azure_deployment=azure_deployment, + api_version=api_version, + client_configs=client_configs, ) async with openai_async_client: - response = await openai_async_client.embeddings.create( - model=model, input=texts, encoding_format="base64" - ) + # Prepare API call parameters + api_params = { + "model": model, + "input": texts, + "encoding_format": "base64", + } + + # Add dimensions parameter only if embedding_dim is provided + if embedding_dim is not None: + api_params["dimensions"] = embedding_dim + + # Make API call + response = await openai_async_client.embeddings.create(**api_params) if token_tracker and hasattr(response, "usage"): token_counts = { @@ -639,3 +735,134 @@ async def openai_embed( for dp in response.data ] ) + + +# Azure OpenAI wrapper functions for backward compatibility +async def azure_openai_complete_if_cache( + model, + prompt, + system_prompt: str | None = None, + history_messages: list[dict[str, Any]] | None = None, + enable_cot: bool = False, + base_url: str | None = None, + api_key: str | None = None, + api_version: str | None = None, + keyword_extraction: bool = False, + **kwargs, +): + """Azure OpenAI completion wrapper function. + + This function provides backward compatibility by wrapping the unified + openai_complete_if_cache implementation with Azure-specific parameter handling. + """ + # Handle Azure-specific environment variables and parameters + deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT") or model or os.getenv("LLM_MODEL") + base_url = ( + base_url or os.getenv("AZURE_OPENAI_ENDPOINT") or os.getenv("LLM_BINDING_HOST") + ) + api_key = ( + api_key or os.getenv("AZURE_OPENAI_API_KEY") or os.getenv("LLM_BINDING_API_KEY") + ) + api_version = ( + api_version + or os.getenv("AZURE_OPENAI_API_VERSION") + or os.getenv("OPENAI_API_VERSION") + ) + + # Pop timeout from kwargs if present (will be handled by openai_complete_if_cache) + timeout = kwargs.pop("timeout", None) + + # Call the unified implementation with Azure-specific parameters + return await openai_complete_if_cache( + model=model, + prompt=prompt, + system_prompt=system_prompt, + history_messages=history_messages, + enable_cot=enable_cot, + base_url=base_url, + api_key=api_key, + timeout=timeout, + use_azure=True, + azure_deployment=deployment, + api_version=api_version, + keyword_extraction=keyword_extraction, + **kwargs, + ) + + +async def azure_openai_complete( + prompt, + system_prompt=None, + history_messages=None, + keyword_extraction=False, + **kwargs, +) -> str: + """Azure OpenAI complete wrapper function. + + Provides backward compatibility for azure_openai_complete calls. + """ + if history_messages is None: + history_messages = [] + result = await azure_openai_complete_if_cache( + os.getenv("LLM_MODEL", "gpt-4o-mini"), + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + keyword_extraction=keyword_extraction, + **kwargs, + ) + return result + + +@wrap_embedding_func_with_attrs(embedding_dim=1536) +@retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type( + (RateLimitError, APIConnectionError, APITimeoutError) + ), +) +async def azure_openai_embed( + texts: list[str], + model: str | None = None, + base_url: str | None = None, + api_key: str | None = None, + api_version: str | None = None, +) -> np.ndarray: + """Azure OpenAI embedding wrapper function. + + This function provides backward compatibility by wrapping the unified + openai_embed implementation with Azure-specific parameter handling. + """ + # Handle Azure-specific environment variables and parameters + deployment = ( + os.getenv("AZURE_EMBEDDING_DEPLOYMENT") + or model + or os.getenv("EMBEDDING_MODEL", "text-embedding-3-small") + ) + base_url = ( + base_url + or os.getenv("AZURE_EMBEDDING_ENDPOINT") + or os.getenv("EMBEDDING_BINDING_HOST") + ) + api_key = ( + api_key + or os.getenv("AZURE_EMBEDDING_API_KEY") + or os.getenv("EMBEDDING_BINDING_API_KEY") + ) + api_version = ( + api_version + or os.getenv("AZURE_EMBEDDING_API_VERSION") + or os.getenv("OPENAI_API_VERSION") + ) + + # Call the unified implementation with Azure-specific parameters + return await openai_embed( + texts=texts, + model=model or deployment, + base_url=base_url, + api_key=api_key, + use_azure=True, + azure_deployment=deployment, + api_version=api_version, + )