From 077d9be5d71df610a1c0b58313915c7fb410d8b2 Mon Sep 17 00:00:00 2001 From: yangdx Date: Tue, 9 Sep 2025 22:34:36 +0800 Subject: [PATCH] Add Deepseek Style Chain of Thought (CoT) Support for OpenAI Compatible LLM providers - Add enable_cot parameter to all LLM APIs - Implement CoT for OpenAI with tags - Log warnings for unsupported providers - Enable CoT in query operations - Handle streaming and non-streaming CoT --- lightrag/lightrag.py | 1 + lightrag/llm/anthropic.py | 13 +++ lightrag/llm/azure_openai.py | 6 ++ lightrag/llm/bedrock.py | 7 ++ lightrag/llm/hf.py | 15 +++- lightrag/llm/llama_index_impl.py | 7 ++ lightrag/llm/lmdeploy.py | 7 ++ lightrag/llm/lollms.py | 13 ++- lightrag/llm/ollama.py | 11 ++- lightrag/llm/openai.py | 149 +++++++++++++++++++++++++++---- lightrag/llm/zhipu.py | 14 ++- lightrag/operate.py | 2 + 12 files changed, 222 insertions(+), 23 deletions(-) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 55731d7a..43ef9249 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -2107,6 +2107,7 @@ class LightRAG: query.strip(), system_prompt=system_prompt, history_messages=param.conversation_history, + enable_cot=True, stream=param.stream, ) else: diff --git a/lightrag/llm/anthropic.py b/lightrag/llm/anthropic.py index 98a997d5..fe18300c 100644 --- a/lightrag/llm/anthropic.py +++ b/lightrag/llm/anthropic.py @@ -59,12 +59,17 @@ async def anthropic_complete_if_cache( prompt: str, system_prompt: str | None = None, history_messages: list[dict[str, Any]] | None = None, + enable_cot: bool = False, base_url: str | None = None, api_key: str | None = None, **kwargs: Any, ) -> Union[str, AsyncIterator[str]]: if history_messages is None: history_messages = [] + if enable_cot: + logger.debug( + "enable_cot=True is not supported for the Anthropic API and will be ignored." + ) if not api_key: api_key = os.environ.get("ANTHROPIC_API_KEY") @@ -150,6 +155,7 @@ async def anthropic_complete( prompt: str, system_prompt: str | None = None, history_messages: list[dict[str, Any]] | None = None, + enable_cot: bool = False, **kwargs: Any, ) -> Union[str, AsyncIterator[str]]: if history_messages is None: @@ -160,6 +166,7 @@ async def anthropic_complete( prompt, system_prompt=system_prompt, history_messages=history_messages, + enable_cot=enable_cot, **kwargs, ) @@ -169,6 +176,7 @@ async def claude_3_opus_complete( prompt: str, system_prompt: str | None = None, history_messages: list[dict[str, Any]] | None = None, + enable_cot: bool = False, **kwargs: Any, ) -> Union[str, AsyncIterator[str]]: if history_messages is None: @@ -178,6 +186,7 @@ async def claude_3_opus_complete( prompt, system_prompt=system_prompt, history_messages=history_messages, + enable_cot=enable_cot, **kwargs, ) @@ -187,6 +196,7 @@ async def claude_3_sonnet_complete( prompt: str, system_prompt: str | None = None, history_messages: list[dict[str, Any]] | None = None, + enable_cot: bool = False, **kwargs: Any, ) -> Union[str, AsyncIterator[str]]: if history_messages is None: @@ -196,6 +206,7 @@ async def claude_3_sonnet_complete( prompt, system_prompt=system_prompt, history_messages=history_messages, + enable_cot=enable_cot, **kwargs, ) @@ -205,6 +216,7 @@ async def claude_3_haiku_complete( prompt: str, system_prompt: str | None = None, history_messages: list[dict[str, Any]] | None = None, + enable_cot: bool = False, **kwargs: Any, ) -> Union[str, AsyncIterator[str]]: if history_messages is None: @@ -214,6 +226,7 @@ async def claude_3_haiku_complete( prompt, system_prompt=system_prompt, history_messages=history_messages, + enable_cot=enable_cot, **kwargs, ) diff --git a/lightrag/llm/azure_openai.py b/lightrag/llm/azure_openai.py index 0ede0824..824ff088 100644 --- a/lightrag/llm/azure_openai.py +++ b/lightrag/llm/azure_openai.py @@ -24,6 +24,7 @@ from tenacity import ( from lightrag.utils import ( wrap_embedding_func_with_attrs, safe_unicode_decode, + logger, ) import numpy as np @@ -41,11 +42,16 @@ async def azure_openai_complete_if_cache( prompt, system_prompt: str | None = None, history_messages: Iterable[ChatCompletionMessageParam] | None = None, + enable_cot: bool = False, base_url: str | None = None, api_key: str | None = None, api_version: str | None = None, **kwargs, ): + if enable_cot: + logger.debug( + "enable_cot=True is not supported for the Azure OpenAI API and will be ignored." + ) deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT") or model or os.getenv("LLM_MODEL") base_url = ( base_url or os.getenv("AZURE_OPENAI_ENDPOINT") or os.getenv("LLM_BINDING_HOST") diff --git a/lightrag/llm/bedrock.py b/lightrag/llm/bedrock.py index 69d00e2d..16737341 100644 --- a/lightrag/llm/bedrock.py +++ b/lightrag/llm/bedrock.py @@ -44,11 +44,18 @@ async def bedrock_complete_if_cache( prompt, system_prompt=None, history_messages=[], + enable_cot: bool = False, aws_access_key_id=None, aws_secret_access_key=None, aws_session_token=None, **kwargs, ) -> Union[str, AsyncIterator[str]]: + if enable_cot: + import logging + + logging.debug( + "enable_cot=True is not supported for Bedrock and will be ignored." + ) # Respect existing env; only set if a non-empty value is available access_key = os.environ.get("AWS_ACCESS_KEY_ID") or aws_access_key_id secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY") or aws_secret_access_key diff --git a/lightrag/llm/hf.py b/lightrag/llm/hf.py index 7adf1570..c33b1c7f 100644 --- a/lightrag/llm/hf.py +++ b/lightrag/llm/hf.py @@ -56,8 +56,15 @@ async def hf_model_if_cache( prompt, system_prompt=None, history_messages=[], + enable_cot: bool = False, **kwargs, ) -> str: + if enable_cot: + from lightrag.utils import logger + + logger.debug( + "enable_cot=True is not supported for Hugging Face local models and will be ignored." + ) model_name = model hf_model, hf_tokenizer = initialize_hf_model(model_name) messages = [] @@ -114,7 +121,12 @@ async def hf_model_if_cache( async def hf_model_complete( - prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs + prompt, + system_prompt=None, + history_messages=[], + keyword_extraction=False, + enable_cot: bool = False, + **kwargs, ) -> str: kwargs.pop("keyword_extraction", None) model_name = kwargs["hashing_kv"].global_config["llm_model_name"] @@ -123,6 +135,7 @@ async def hf_model_complete( prompt, system_prompt=system_prompt, history_messages=history_messages, + enable_cot=enable_cot, **kwargs, ) return result diff --git a/lightrag/llm/llama_index_impl.py b/lightrag/llm/llama_index_impl.py index 2e50e33e..38ec7cd1 100644 --- a/lightrag/llm/llama_index_impl.py +++ b/lightrag/llm/llama_index_impl.py @@ -94,9 +94,14 @@ async def llama_index_complete_if_cache( prompt: str, system_prompt: Optional[str] = None, history_messages: List[dict] = [], + enable_cot: bool = False, chat_kwargs={}, ) -> str: """Complete the prompt using LlamaIndex.""" + if enable_cot: + logger.debug( + "enable_cot=True is not supported for LlamaIndex implementation and will be ignored." + ) try: # Format messages for chat formatted_messages = [] @@ -138,6 +143,7 @@ async def llama_index_complete( prompt, system_prompt=None, history_messages=None, + enable_cot: bool = False, keyword_extraction=False, settings: LlamaIndexSettings = None, **kwargs, @@ -162,6 +168,7 @@ async def llama_index_complete( prompt, system_prompt=system_prompt, history_messages=history_messages, + enable_cot=enable_cot, **kwargs, ) return result diff --git a/lightrag/llm/lmdeploy.py b/lightrag/llm/lmdeploy.py index 50e60cd6..8916b0fd 100644 --- a/lightrag/llm/lmdeploy.py +++ b/lightrag/llm/lmdeploy.py @@ -56,6 +56,7 @@ async def lmdeploy_model_if_cache( prompt, system_prompt=None, history_messages=[], + enable_cot: bool = False, chat_template=None, model_format="hf", quant_policy=0, @@ -89,6 +90,12 @@ async def lmdeploy_model_if_cache( do_sample (bool): Whether or not to use sampling, use greedy decoding otherwise. Default to be False, which means greedy decoding will be applied. """ + if enable_cot: + from lightrag.utils import logger + + logger.debug( + "enable_cot=True is not supported for lmdeploy and will be ignored." + ) try: import lmdeploy from lmdeploy import version_info, GenerationConfig diff --git a/lightrag/llm/lollms.py b/lightrag/llm/lollms.py index 39b64ce3..9274dbfc 100644 --- a/lightrag/llm/lollms.py +++ b/lightrag/llm/lollms.py @@ -39,10 +39,15 @@ async def lollms_model_if_cache( prompt, system_prompt=None, history_messages=[], + enable_cot: bool = False, base_url="http://localhost:9600", **kwargs, ) -> Union[str, AsyncIterator[str]]: """Client implementation for lollms generation.""" + if enable_cot: + from lightrag.utils import logger + + logger.debug("enable_cot=True is not supported for lollms and will be ignored.") stream = True if kwargs.get("stream") else False api_key = kwargs.pop("api_key", None) @@ -98,7 +103,12 @@ async def lollms_model_if_cache( async def lollms_model_complete( - prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs + prompt, + system_prompt=None, + history_messages=[], + enable_cot: bool = False, + keyword_extraction=False, + **kwargs, ) -> Union[str, AsyncIterator[str]]: """Complete function for lollms model generation.""" @@ -119,6 +129,7 @@ async def lollms_model_complete( prompt, system_prompt=system_prompt, history_messages=history_messages, + enable_cot=enable_cot, **kwargs, ) diff --git a/lightrag/llm/ollama.py b/lightrag/llm/ollama.py index 6423fa90..c6fb46fc 100644 --- a/lightrag/llm/ollama.py +++ b/lightrag/llm/ollama.py @@ -43,8 +43,11 @@ async def _ollama_model_if_cache( prompt, system_prompt=None, history_messages=[], + enable_cot: bool = False, **kwargs, ) -> Union[str, AsyncIterator[str]]: + if enable_cot: + logger.debug("enable_cot=True is not supported for ollama and will be ignored.") stream = True if kwargs.get("stream") else False kwargs.pop("max_tokens", None) @@ -123,7 +126,12 @@ async def _ollama_model_if_cache( async def ollama_model_complete( - prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs + prompt, + system_prompt=None, + history_messages=[], + enable_cot: bool = False, + keyword_extraction=False, + **kwargs, ) -> Union[str, AsyncIterator[str]]: keyword_extraction = kwargs.pop("keyword_extraction", None) if keyword_extraction: @@ -134,6 +142,7 @@ async def ollama_model_complete( prompt, system_prompt=system_prompt, history_messages=history_messages, + enable_cot=enable_cot, **kwargs, ) diff --git a/lightrag/llm/openai.py b/lightrag/llm/openai.py index dd377862..6d486afc 100644 --- a/lightrag/llm/openai.py +++ b/lightrag/llm/openai.py @@ -111,12 +111,29 @@ async def openai_complete_if_cache( prompt: str, system_prompt: str | None = None, history_messages: list[dict[str, Any]] | None = None, + enable_cot: bool = False, base_url: str | None = None, api_key: str | None = None, token_tracker: Any | None = None, **kwargs: Any, ) -> str: - """Complete a prompt using OpenAI's API with caching support. + """Complete a prompt using OpenAI's API with caching support and Chain of Thought (COT) integration. + + This function supports automatic integration of reasoning content (思维链) from models that provide + Chain of Thought capabilities. The reasoning content is seamlessly integrated into the response + using ... tags. + + Note on `reasoning_content`: This feature relies on a Deepseek Style `reasoning_content` + in the API response, which may be provided by OpenAI-compatible endpoints that support + Chain of Thought. + + COT Integration Rules: + 1. COT content is accepted only when regular content is empty and `reasoning_content` has content. + 2. COT processing stops when regular content becomes available. + 3. If both `content` and `reasoning_content` are present simultaneously, reasoning is ignored. + 4. If both fields have content from the start, COT is never activated. + 5. For streaming: COT content is inserted into the content stream with tags. + 6. For non-streaming: COT content is prepended to regular content with tags. Args: model: The OpenAI model to use. @@ -125,6 +142,8 @@ async def openai_complete_if_cache( history_messages: Optional list of previous messages in the conversation. base_url: Optional base URL for the OpenAI API. api_key: Optional OpenAI API key. If None, uses the OPENAI_API_KEY environment variable. + token_tracker: Optional token usage tracker for monitoring API usage. + enable_cot: Whether to enable Chain of Thought (COT) processing. Default is False. **kwargs: Additional keyword arguments to pass to the OpenAI API. Special kwargs: - openai_client_configs: Dict of configuration options for the AsyncOpenAI client. @@ -134,7 +153,8 @@ async def openai_complete_if_cache( - keyword_extraction: Will be removed from kwargs before passing to OpenAI. Returns: - The completed text or an async iterator of text chunks if streaming. + The completed text (with integrated COT content if available) or an async iterator + of text chunks if streaming. COT content is wrapped in ... tags. Raises: InvalidResponseError: If the response from OpenAI is invalid or empty. @@ -217,6 +237,11 @@ async def openai_complete_if_cache( iteration_started = False final_chunk_usage = None + # COT (Chain of Thought) state tracking + cot_active = False + cot_started = False + initial_content_seen = False + try: iteration_started = True async for chunk in response: @@ -232,20 +257,65 @@ async def openai_complete_if_cache( logger.warning(f"Received chunk without choices: {chunk}") continue - # Check if delta exists and has content - if not hasattr(chunk.choices[0], "delta") or not hasattr( - chunk.choices[0].delta, "content" - ): + # Check if delta exists + if not hasattr(chunk.choices[0], "delta"): # This might be the final chunk, continue to check for usage continue - content = chunk.choices[0].delta.content - if content is None: - continue - if r"\u" in content: - content = safe_unicode_decode(content.encode("utf-8")) + delta = chunk.choices[0].delta + content = getattr(delta, "content", None) + reasoning_content = getattr(delta, "reasoning_content", None) - yield content + # Handle COT logic for streaming (only if enabled) + if enable_cot: + if content is not None and content != "": + # Regular content is present + if not initial_content_seen: + initial_content_seen = True + # If both content and reasoning_content are present initially, don't start COT + if ( + reasoning_content is not None + and reasoning_content != "" + ): + cot_active = False + cot_started = False + + # If COT was active, end it + if cot_active: + yield "" + cot_active = False + + # Process regular content + if r"\u" in content: + content = safe_unicode_decode(content.encode("utf-8")) + yield content + + elif reasoning_content is not None and reasoning_content != "": + # Only reasoning content is present + if not initial_content_seen and not cot_started: + # Start COT if we haven't seen initial content yet + if not cot_active: + yield "" + cot_active = True + cot_started = True + + # Process reasoning content if COT is active + if cot_active: + if r"\u" in reasoning_content: + reasoning_content = safe_unicode_decode( + reasoning_content.encode("utf-8") + ) + yield reasoning_content + else: + # COT disabled, only process regular content + if content is not None and content != "": + if r"\u" in content: + content = safe_unicode_decode(content.encode("utf-8")) + yield content + + # If neither content nor reasoning_content, continue to next chunk + if content is None and reasoning_content is None: + continue # After streaming is complete, track token usage if token_tracker and final_chunk_usage: @@ -313,21 +383,56 @@ async def openai_complete_if_cache( not response or not response.choices or not hasattr(response.choices[0], "message") - or not hasattr(response.choices[0].message, "content") ): logger.error("Invalid response from OpenAI API") await openai_async_client.close() # Ensure client is closed raise InvalidResponseError("Invalid response from OpenAI API") - content = response.choices[0].message.content + message = response.choices[0].message + content = getattr(message, "content", None) + reasoning_content = getattr(message, "reasoning_content", None) - if not content or content.strip() == "": + # Handle COT logic for non-streaming responses (only if enabled) + final_content = "" + + if enable_cot: + # Check if we should include reasoning content + should_include_reasoning = False + if reasoning_content and reasoning_content.strip(): + if not content or content.strip() == "": + # Case 1: Only reasoning content, should include COT + should_include_reasoning = True + final_content = ( + content or "" + ) # Use empty string if content is None + else: + # Case 3: Both content and reasoning_content present, ignore reasoning + should_include_reasoning = False + final_content = content + else: + # No reasoning content, use regular content + final_content = content or "" + + # Apply COT wrapping if needed + if should_include_reasoning: + if r"\u" in reasoning_content: + reasoning_content = safe_unicode_decode( + reasoning_content.encode("utf-8") + ) + final_content = f"{reasoning_content}{final_content}" + else: + # COT disabled, only use regular content + final_content = content or "" + + # Validate final content + if not final_content or final_content.strip() == "": logger.error("Received empty content from OpenAI API") await openai_async_client.close() # Ensure client is closed raise InvalidResponseError("Received empty content from OpenAI API") - if r"\u" in content: - content = safe_unicode_decode(content.encode("utf-8")) + # Apply Unicode decoding to final content if needed + if r"\u" in final_content: + final_content = safe_unicode_decode(final_content.encode("utf-8")) if token_tracker and hasattr(response, "usage"): token_counts = { @@ -339,10 +444,10 @@ async def openai_complete_if_cache( } token_tracker.add_usage(token_counts) - logger.debug(f"Response content len: {len(content)}") + logger.debug(f"Response content len: {len(final_content)}") verbose_debug(f"Response: {response}") - return content + return final_content finally: # Ensure client is closed in all cases for non-streaming responses await openai_async_client.close() @@ -374,6 +479,7 @@ async def gpt_4o_complete( prompt, system_prompt=None, history_messages=None, + enable_cot: bool = False, keyword_extraction=False, **kwargs, ) -> str: @@ -387,6 +493,7 @@ async def gpt_4o_complete( prompt, system_prompt=system_prompt, history_messages=history_messages, + enable_cot=enable_cot, **kwargs, ) @@ -395,6 +502,7 @@ async def gpt_4o_mini_complete( prompt, system_prompt=None, history_messages=None, + enable_cot: bool = False, keyword_extraction=False, **kwargs, ) -> str: @@ -408,6 +516,7 @@ async def gpt_4o_mini_complete( prompt, system_prompt=system_prompt, history_messages=history_messages, + enable_cot=enable_cot, **kwargs, ) @@ -416,6 +525,7 @@ async def nvidia_openai_complete( prompt, system_prompt=None, history_messages=None, + enable_cot: bool = False, keyword_extraction=False, **kwargs, ) -> str: @@ -427,6 +537,7 @@ async def nvidia_openai_complete( prompt, system_prompt=system_prompt, history_messages=history_messages, + enable_cot=enable_cot, base_url="https://integrate.api.nvidia.com/v1", **kwargs, ) diff --git a/lightrag/llm/zhipu.py b/lightrag/llm/zhipu.py index c9d1253e..d90f3cc1 100644 --- a/lightrag/llm/zhipu.py +++ b/lightrag/llm/zhipu.py @@ -49,8 +49,13 @@ async def zhipu_complete_if_cache( api_key: Optional[str] = None, system_prompt: Optional[str] = None, history_messages: List[Dict[str, str]] = [], + enable_cot: bool = False, **kwargs, ) -> str: + if enable_cot: + logger.debug( + "enable_cot=True is not supported for ZhipuAI and will be ignored." + ) # dynamically load ZhipuAI try: from zhipuai import ZhipuAI @@ -91,7 +96,12 @@ async def zhipu_complete_if_cache( async def zhipu_complete( - prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs + prompt, + system_prompt=None, + history_messages=[], + keyword_extraction=False, + enable_cot: bool = False, + **kwargs, ): # Pop keyword_extraction from kwargs to avoid passing it to zhipu_complete_if_cache keyword_extraction = kwargs.pop("keyword_extraction", None) @@ -122,6 +132,7 @@ async def zhipu_complete( prompt=prompt, system_prompt=system_prompt, history_messages=history_messages, + enable_cot=enable_cot, **kwargs, ) @@ -163,6 +174,7 @@ async def zhipu_complete( prompt=prompt, system_prompt=system_prompt, history_messages=history_messages, + enable_cot=enable_cot, **kwargs, ) diff --git a/lightrag/operate.py b/lightrag/operate.py index 35126fa2..cfb4c67f 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -2214,6 +2214,7 @@ async def kg_query( query, system_prompt=sys_prompt, stream=query_param.stream, + enable_cot=True, ) if isinstance(response, str) and len(response) > len(sys_prompt): response = ( @@ -3736,6 +3737,7 @@ async def naive_query( query, system_prompt=sys_prompt, stream=query_param.stream, + enable_cot=True, ) if isinstance(response, str) and len(response) > len(sys_prompt):