diff --git a/lightrag/llm/openai.py b/lightrag/llm/openai.py index c87ddbe1..66c3bfe4 100644 --- a/lightrag/llm/openai.py +++ b/lightrag/llm/openai.py @@ -1,29 +1,15 @@ from ..utils import verbose_debug, VERBOSE_DEBUG -import sys import os import logging -if sys.version_info < (3, 9): - from typing import AsyncIterator -else: - from collections.abc import AsyncIterator -import pipmaster as pm # Pipmaster for dynamic library install +from collections.abc import AsyncIterator + +import pipmaster as pm # install specific modules if not pm.is_installed("openai"): pm.install("openai") -# Try to import Langfuse for LLM observability (optional) -# Falls back to standard OpenAI client if not available -try: - from langfuse.openai import AsyncOpenAI - LANGFUSE_ENABLED = True - logger.info("Langfuse observability enabled for OpenAI client") -except ImportError: - from openai import AsyncOpenAI - LANGFUSE_ENABLED = False - logger.debug("Langfuse not available, using standard OpenAI client") - from openai import ( APIConnectionError, RateLimitError, @@ -40,6 +26,7 @@ from lightrag.utils import ( safe_unicode_decode, logger, ) + from lightrag.types import GPTKeywordExtractionFormat from lightrag.api import __api_version__ @@ -49,6 +36,32 @@ from typing import Any, Union from dotenv import load_dotenv +# Try to import Langfuse for LLM observability (optional) +# Falls back to standard OpenAI client if not available +# Langfuse requires proper configuration to work correctly +LANGFUSE_ENABLED = False +try: + # Check if required Langfuse environment variables are set + langfuse_public_key = os.environ.get("LANGFUSE_PUBLIC_KEY") + langfuse_secret_key = os.environ.get("LANGFUSE_SECRET_KEY") + + # Only enable Langfuse if both keys are configured + if langfuse_public_key and langfuse_secret_key: + from langfuse.openai import AsyncOpenAI + + LANGFUSE_ENABLED = True + logger.info("Langfuse observability enabled for OpenAI client") + else: + from openai import AsyncOpenAI + + logger.debug( + "Langfuse environment variables not configured, using standard OpenAI client" + ) +except ImportError: + from openai import AsyncOpenAI + + logger.debug("Langfuse not available, using standard OpenAI client") + # use the .env that is inside the current folder # allows to use different .env file for each lightrag instance # the OS environment variables take precedence over the .env file @@ -64,7 +77,7 @@ class InvalidResponseError(Exception): def create_openai_async_client( api_key: str | None = None, base_url: str | None = None, - client_configs: dict[str, Any] = None, + client_configs: dict[str, Any] | None = None, ) -> AsyncOpenAI: """Create an AsyncOpenAI client with the given configuration. @@ -106,57 +119,6 @@ def create_openai_async_client( return AsyncOpenAI(**merged_configs) -def _normalize_openai_kwargs_for_model(model: str, kwargs: dict[str, Any]) -> None: - """ - Normalize OpenAI API parameters based on the model being used. - - This function handles model-specific parameter requirements: - - gpt-5-nano uses 'max_completion_tokens' instead of 'max_tokens' - - gpt-5-nano uses reasoning tokens which consume from the token budget - - gpt-5-nano doesn't support custom temperature values - - Other models support both parameters - - Args: - model: The model name (e.g., 'gpt-5-nano', 'gpt-4o', 'gpt-4o-mini') - kwargs: The API parameters dict to normalize (modified in-place) - """ - # Handle max_tokens vs max_completion_tokens conversion for gpt-5 models - if model.startswith("gpt-5"): - # gpt-5-nano and variants use max_completion_tokens - if "max_tokens" in kwargs and "max_completion_tokens" not in kwargs: - # If only max_tokens is set, move it to max_completion_tokens - max_tokens = kwargs.pop("max_tokens") - # For gpt-5-nano, we need to account for reasoning tokens - # Increase buffer to ensure actual content is generated - # Reasoning typically uses 1.5-2x the actual content tokens needed - kwargs["max_completion_tokens"] = int(max(max_tokens * 2.5, 300)) - else: - # If both are set, remove max_tokens (it's not supported) - max_tokens = kwargs.pop("max_tokens", None) - if max_tokens and "max_completion_tokens" in kwargs: - # If max_completion_tokens is already set and seems too small, increase it - if kwargs["max_completion_tokens"] < 300: - kwargs["max_completion_tokens"] = int(max(kwargs["max_completion_tokens"] * 2.5, 300)) - - # Ensure a minimum token budget for gpt-5-nano due to reasoning overhead - if "max_completion_tokens" in kwargs: - if kwargs["max_completion_tokens"] < 300: - # Minimum 300 tokens to account for reasoning (reasoning can be expensive) - original = kwargs["max_completion_tokens"] - kwargs["max_completion_tokens"] = 300 - logger.debug(f"Increased max_completion_tokens from {original} to 300 for {model} (reasoning overhead)") - - # Handle temperature constraint for gpt-5 models - if model.startswith("gpt-5"): - # gpt-5-nano requires default temperature (doesn't support custom values) - # Remove any custom temperature setting - if "temperature" in kwargs: - kwargs.pop("temperature") - logger.debug(f"Removed custom temperature for {model}: uses default") - - logger.debug(f"Normalized parameters for {model}: {kwargs}") - - @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10), @@ -180,7 +142,7 @@ async def openai_complete_if_cache( ) -> str: """Complete a prompt using OpenAI's API with caching support and Chain of Thought (COT) integration. - This function supports automatic integration of reasoning content (思维链) from models that provide + This function supports automatic integration of reasoning content from models that provide Chain of Thought capabilities. The reasoning content is seamlessly integrated into the response using ... tags. @@ -262,9 +224,6 @@ async def openai_complete_if_cache( messages = kwargs.pop("messages", messages) - # Normalize API parameters based on model requirements - _normalize_openai_kwargs_for_model(model, kwargs) - try: # Don't use async with context manager, use client directly if "response_format" in kwargs: @@ -328,19 +287,16 @@ async def openai_complete_if_cache( delta = chunk.choices[0].delta content = getattr(delta, "content", None) - reasoning_content = getattr(delta, "reasoning_content", None) + reasoning_content = getattr(delta, "reasoning_content", "") # Handle COT logic for streaming (only if enabled) if enable_cot: - if content is not None and content != "": + if content: # Regular content is present if not initial_content_seen: initial_content_seen = True # If both content and reasoning_content are present initially, don't start COT - if ( - reasoning_content is not None - and reasoning_content != "" - ): + if reasoning_content: cot_active = False cot_started = False @@ -354,7 +310,7 @@ async def openai_complete_if_cache( content = safe_unicode_decode(content.encode("utf-8")) yield content - elif reasoning_content is not None and reasoning_content != "": + elif reasoning_content: # Only reasoning content is present if not initial_content_seen and not cot_started: # Start COT if we haven't seen initial content yet @@ -372,7 +328,7 @@ async def openai_complete_if_cache( yield reasoning_content else: # COT disabled, only process regular content - if content is not None and content != "": + if content: if r"\u" in content: content = safe_unicode_decode(content.encode("utf-8")) yield content @@ -440,18 +396,23 @@ async def openai_complete_if_cache( ) # Ensure resources are released even if no exception occurs - if ( - iteration_started - and hasattr(response, "aclose") - and callable(getattr(response, "aclose", None)) - ): - try: - await response.aclose() - logger.debug("Successfully closed stream response") - except Exception as close_error: - logger.warning( - f"Failed to close stream response in finally block: {close_error}" - ) + # Note: Some wrapped clients (e.g., Langfuse) may not implement aclose() properly + if iteration_started and hasattr(response, "aclose"): + aclose_method = getattr(response, "aclose", None) + if callable(aclose_method): + try: + await response.aclose() + logger.debug("Successfully closed stream response") + except (AttributeError, TypeError) as close_error: + # Some wrapper objects may report hasattr(aclose) but fail when called + # This is expected behavior for certain client wrappers + logger.debug( + f"Stream response cleanup not supported by client wrapper: {close_error}" + ) + except Exception as close_error: + logger.warning( + f"Unexpected error during stream response cleanup: {close_error}" + ) # This prevents resource leaks since the caller doesn't handle closing try: @@ -479,7 +440,7 @@ async def openai_complete_if_cache( message = response.choices[0].message content = getattr(message, "content", None) - reasoning_content = getattr(message, "reasoning_content", None) + reasoning_content = getattr(message, "reasoning_content", "") # Handle COT logic for non-streaming responses (only if enabled) final_content = "" @@ -646,9 +607,10 @@ async def nvidia_openai_complete( async def openai_embed( texts: list[str], model: str = "text-embedding-3-small", - base_url: str = None, - api_key: str = None, - client_configs: dict[str, Any] = None, + base_url: str | None = None, + api_key: str | None = None, + client_configs: dict[str, Any] | None = None, + token_tracker: Any | None = None, ) -> np.ndarray: """Generate embeddings for a list of texts using OpenAI's API. @@ -660,6 +622,7 @@ async def openai_embed( client_configs: Additional configuration options for the AsyncOpenAI client. These will override any default configurations but will be overridden by explicit parameters (api_key, base_url). + token_tracker: Optional token usage tracker for monitoring API usage. Returns: A numpy array of embeddings, one per input text. @@ -678,6 +641,14 @@ async def openai_embed( response = await openai_async_client.embeddings.create( model=model, input=texts, encoding_format="base64" ) + + if token_tracker and hasattr(response, "usage"): + token_counts = { + "prompt_tokens": getattr(response.usage, "prompt_tokens", 0), + "total_tokens": getattr(response.usage, "total_tokens", 0), + } + token_tracker.add_usage(token_counts) + return np.array( [ np.array(dp.embedding, dtype=np.float32)