diff --git a/lightrag/llm/openai.py b/lightrag/llm/openai.py index 3339ea3a..a2bbfa23 100644 --- a/lightrag/llm/openai.py +++ b/lightrag/llm/openai.py @@ -138,9 +138,9 @@ async def openai_complete_if_cache( base_url: str | None = None, api_key: str | None = None, token_tracker: Any | None = None, - keyword_extraction: bool = False, # Will be removed from kwargs before passing to OpenAI stream: bool | None = None, timeout: int | None = None, + keyword_extraction: bool = False, **kwargs: Any, ) -> str: """Complete a prompt using OpenAI's API with caching support and Chain of Thought (COT) integration. @@ -170,14 +170,15 @@ async def openai_complete_if_cache( api_key: Optional OpenAI API key. If None, uses the OPENAI_API_KEY environment variable. token_tracker: Optional token usage tracker for monitoring API usage. enable_cot: Whether to enable Chain of Thought (COT) processing. Default is False. + stream: Whether to stream the response. Default is False. + timeout: Request timeout in seconds. Default is None. + keyword_extraction: Whether to enable keyword extraction mode. When True, triggers + special response formatting for keyword extraction. Default is False. **kwargs: Additional keyword arguments to pass to the OpenAI API. Special kwargs: - openai_client_configs: Dict of configuration options for the AsyncOpenAI client. These will be passed to the client constructor but will be overridden by explicit parameters (api_key, base_url). - - keyword_extraction: Will be removed from kwargs before passing to OpenAI. - - stream: Whether to stream the response. Default is False. - - timeout: Request timeout in seconds. Default is None. Returns: The completed text (with integrated COT content if available) or an async iterator @@ -198,7 +199,6 @@ async def openai_complete_if_cache( # Remove special kwargs that shouldn't be passed to OpenAI kwargs.pop("hashing_kv", None) - kwargs.pop("keyword_extraction", None) # Extract client configuration options client_configs = kwargs.pop("openai_client_configs", {}) @@ -228,6 +228,12 @@ async def openai_complete_if_cache( messages = kwargs.pop("messages", messages) + # Add explicit parameters back to kwargs so they're passed to OpenAI API + if stream is not None: + kwargs["stream"] = stream + if timeout is not None: + kwargs["timeout"] = timeout + try: # Don't use async with context manager, use client directly if "response_format" in kwargs: @@ -516,7 +522,6 @@ async def openai_complete( ) -> Union[str, AsyncIterator[str]]: if history_messages is None: history_messages = [] - keyword_extraction = kwargs.pop("keyword_extraction", None) if keyword_extraction: kwargs["response_format"] = "json" model_name = kwargs["hashing_kv"].global_config["llm_model_name"] @@ -525,6 +530,7 @@ async def openai_complete( prompt, system_prompt=system_prompt, history_messages=history_messages, + keyword_extraction=keyword_extraction, **kwargs, ) @@ -539,7 +545,6 @@ async def gpt_4o_complete( ) -> str: if history_messages is None: history_messages = [] - keyword_extraction = kwargs.pop("keyword_extraction", None) if keyword_extraction: kwargs["response_format"] = GPTKeywordExtractionFormat return await openai_complete_if_cache( @@ -548,6 +553,7 @@ async def gpt_4o_complete( system_prompt=system_prompt, history_messages=history_messages, enable_cot=enable_cot, + keyword_extraction=keyword_extraction, **kwargs, ) @@ -562,7 +568,6 @@ async def gpt_4o_mini_complete( ) -> str: if history_messages is None: history_messages = [] - keyword_extraction = kwargs.pop("keyword_extraction", None) if keyword_extraction: kwargs["response_format"] = GPTKeywordExtractionFormat return await openai_complete_if_cache( @@ -571,6 +576,7 @@ async def gpt_4o_mini_complete( system_prompt=system_prompt, history_messages=history_messages, enable_cot=enable_cot, + keyword_extraction=keyword_extraction, **kwargs, ) @@ -585,13 +591,13 @@ async def nvidia_openai_complete( ) -> str: if history_messages is None: history_messages = [] - kwargs.pop("keyword_extraction", None) result = await openai_complete_if_cache( "nvidia/llama-3.1-nemotron-70b-instruct", # context length 128k prompt, system_prompt=system_prompt, history_messages=history_messages, enable_cot=enable_cot, + keyword_extraction=keyword_extraction, base_url="https://integrate.api.nvidia.com/v1", **kwargs, )