diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py
index 08f0e6d9..62aaf17d 100644
--- a/lightrag/lightrag.py
+++ b/lightrag/lightrag.py
@@ -2108,6 +2108,7 @@ class LightRAG:
query.strip(),
system_prompt=system_prompt,
history_messages=param.conversation_history,
+ enable_cot=True,
stream=param.stream,
)
else:
diff --git a/lightrag/llm/anthropic.py b/lightrag/llm/anthropic.py
index 98a997d5..fe18300c 100644
--- a/lightrag/llm/anthropic.py
+++ b/lightrag/llm/anthropic.py
@@ -59,12 +59,17 @@ async def anthropic_complete_if_cache(
prompt: str,
system_prompt: str | None = None,
history_messages: list[dict[str, Any]] | None = None,
+ enable_cot: bool = False,
base_url: str | None = None,
api_key: str | None = None,
**kwargs: Any,
) -> Union[str, AsyncIterator[str]]:
if history_messages is None:
history_messages = []
+ if enable_cot:
+ logger.debug(
+ "enable_cot=True is not supported for the Anthropic API and will be ignored."
+ )
if not api_key:
api_key = os.environ.get("ANTHROPIC_API_KEY")
@@ -150,6 +155,7 @@ async def anthropic_complete(
prompt: str,
system_prompt: str | None = None,
history_messages: list[dict[str, Any]] | None = None,
+ enable_cot: bool = False,
**kwargs: Any,
) -> Union[str, AsyncIterator[str]]:
if history_messages is None:
@@ -160,6 +166,7 @@ async def anthropic_complete(
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
+ enable_cot=enable_cot,
**kwargs,
)
@@ -169,6 +176,7 @@ async def claude_3_opus_complete(
prompt: str,
system_prompt: str | None = None,
history_messages: list[dict[str, Any]] | None = None,
+ enable_cot: bool = False,
**kwargs: Any,
) -> Union[str, AsyncIterator[str]]:
if history_messages is None:
@@ -178,6 +186,7 @@ async def claude_3_opus_complete(
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
+ enable_cot=enable_cot,
**kwargs,
)
@@ -187,6 +196,7 @@ async def claude_3_sonnet_complete(
prompt: str,
system_prompt: str | None = None,
history_messages: list[dict[str, Any]] | None = None,
+ enable_cot: bool = False,
**kwargs: Any,
) -> Union[str, AsyncIterator[str]]:
if history_messages is None:
@@ -196,6 +206,7 @@ async def claude_3_sonnet_complete(
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
+ enable_cot=enable_cot,
**kwargs,
)
@@ -205,6 +216,7 @@ async def claude_3_haiku_complete(
prompt: str,
system_prompt: str | None = None,
history_messages: list[dict[str, Any]] | None = None,
+ enable_cot: bool = False,
**kwargs: Any,
) -> Union[str, AsyncIterator[str]]:
if history_messages is None:
@@ -214,6 +226,7 @@ async def claude_3_haiku_complete(
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
+ enable_cot=enable_cot,
**kwargs,
)
diff --git a/lightrag/llm/azure_openai.py b/lightrag/llm/azure_openai.py
index 0ede0824..824ff088 100644
--- a/lightrag/llm/azure_openai.py
+++ b/lightrag/llm/azure_openai.py
@@ -24,6 +24,7 @@ from tenacity import (
from lightrag.utils import (
wrap_embedding_func_with_attrs,
safe_unicode_decode,
+ logger,
)
import numpy as np
@@ -41,11 +42,16 @@ async def azure_openai_complete_if_cache(
prompt,
system_prompt: str | None = None,
history_messages: Iterable[ChatCompletionMessageParam] | None = None,
+ enable_cot: bool = False,
base_url: str | None = None,
api_key: str | None = None,
api_version: str | None = None,
**kwargs,
):
+ if enable_cot:
+ logger.debug(
+ "enable_cot=True is not supported for the Azure OpenAI API and will be ignored."
+ )
deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT") or model or os.getenv("LLM_MODEL")
base_url = (
base_url or os.getenv("AZURE_OPENAI_ENDPOINT") or os.getenv("LLM_BINDING_HOST")
diff --git a/lightrag/llm/bedrock.py b/lightrag/llm/bedrock.py
index 69d00e2d..16737341 100644
--- a/lightrag/llm/bedrock.py
+++ b/lightrag/llm/bedrock.py
@@ -44,11 +44,18 @@ async def bedrock_complete_if_cache(
prompt,
system_prompt=None,
history_messages=[],
+ enable_cot: bool = False,
aws_access_key_id=None,
aws_secret_access_key=None,
aws_session_token=None,
**kwargs,
) -> Union[str, AsyncIterator[str]]:
+ if enable_cot:
+ import logging
+
+ logging.debug(
+ "enable_cot=True is not supported for Bedrock and will be ignored."
+ )
# Respect existing env; only set if a non-empty value is available
access_key = os.environ.get("AWS_ACCESS_KEY_ID") or aws_access_key_id
secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY") or aws_secret_access_key
diff --git a/lightrag/llm/hf.py b/lightrag/llm/hf.py
index 7adf1570..c33b1c7f 100644
--- a/lightrag/llm/hf.py
+++ b/lightrag/llm/hf.py
@@ -56,8 +56,15 @@ async def hf_model_if_cache(
prompt,
system_prompt=None,
history_messages=[],
+ enable_cot: bool = False,
**kwargs,
) -> str:
+ if enable_cot:
+ from lightrag.utils import logger
+
+ logger.debug(
+ "enable_cot=True is not supported for Hugging Face local models and will be ignored."
+ )
model_name = model
hf_model, hf_tokenizer = initialize_hf_model(model_name)
messages = []
@@ -114,7 +121,12 @@ async def hf_model_if_cache(
async def hf_model_complete(
- prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
+ prompt,
+ system_prompt=None,
+ history_messages=[],
+ keyword_extraction=False,
+ enable_cot: bool = False,
+ **kwargs,
) -> str:
kwargs.pop("keyword_extraction", None)
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
@@ -123,6 +135,7 @@ async def hf_model_complete(
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
+ enable_cot=enable_cot,
**kwargs,
)
return result
diff --git a/lightrag/llm/llama_index_impl.py b/lightrag/llm/llama_index_impl.py
index 2e50e33e..38ec7cd1 100644
--- a/lightrag/llm/llama_index_impl.py
+++ b/lightrag/llm/llama_index_impl.py
@@ -94,9 +94,14 @@ async def llama_index_complete_if_cache(
prompt: str,
system_prompt: Optional[str] = None,
history_messages: List[dict] = [],
+ enable_cot: bool = False,
chat_kwargs={},
) -> str:
"""Complete the prompt using LlamaIndex."""
+ if enable_cot:
+ logger.debug(
+ "enable_cot=True is not supported for LlamaIndex implementation and will be ignored."
+ )
try:
# Format messages for chat
formatted_messages = []
@@ -138,6 +143,7 @@ async def llama_index_complete(
prompt,
system_prompt=None,
history_messages=None,
+ enable_cot: bool = False,
keyword_extraction=False,
settings: LlamaIndexSettings = None,
**kwargs,
@@ -162,6 +168,7 @@ async def llama_index_complete(
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
+ enable_cot=enable_cot,
**kwargs,
)
return result
diff --git a/lightrag/llm/lmdeploy.py b/lightrag/llm/lmdeploy.py
index 50e60cd6..8916b0fd 100644
--- a/lightrag/llm/lmdeploy.py
+++ b/lightrag/llm/lmdeploy.py
@@ -56,6 +56,7 @@ async def lmdeploy_model_if_cache(
prompt,
system_prompt=None,
history_messages=[],
+ enable_cot: bool = False,
chat_template=None,
model_format="hf",
quant_policy=0,
@@ -89,6 +90,12 @@ async def lmdeploy_model_if_cache(
do_sample (bool): Whether or not to use sampling, use greedy decoding otherwise.
Default to be False, which means greedy decoding will be applied.
"""
+ if enable_cot:
+ from lightrag.utils import logger
+
+ logger.debug(
+ "enable_cot=True is not supported for lmdeploy and will be ignored."
+ )
try:
import lmdeploy
from lmdeploy import version_info, GenerationConfig
diff --git a/lightrag/llm/lollms.py b/lightrag/llm/lollms.py
index 39b64ce3..9274dbfc 100644
--- a/lightrag/llm/lollms.py
+++ b/lightrag/llm/lollms.py
@@ -39,10 +39,15 @@ async def lollms_model_if_cache(
prompt,
system_prompt=None,
history_messages=[],
+ enable_cot: bool = False,
base_url="http://localhost:9600",
**kwargs,
) -> Union[str, AsyncIterator[str]]:
"""Client implementation for lollms generation."""
+ if enable_cot:
+ from lightrag.utils import logger
+
+ logger.debug("enable_cot=True is not supported for lollms and will be ignored.")
stream = True if kwargs.get("stream") else False
api_key = kwargs.pop("api_key", None)
@@ -98,7 +103,12 @@ async def lollms_model_if_cache(
async def lollms_model_complete(
- prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
+ prompt,
+ system_prompt=None,
+ history_messages=[],
+ enable_cot: bool = False,
+ keyword_extraction=False,
+ **kwargs,
) -> Union[str, AsyncIterator[str]]:
"""Complete function for lollms model generation."""
@@ -119,6 +129,7 @@ async def lollms_model_complete(
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
+ enable_cot=enable_cot,
**kwargs,
)
diff --git a/lightrag/llm/ollama.py b/lightrag/llm/ollama.py
index 6423fa90..c6fb46fc 100644
--- a/lightrag/llm/ollama.py
+++ b/lightrag/llm/ollama.py
@@ -43,8 +43,11 @@ async def _ollama_model_if_cache(
prompt,
system_prompt=None,
history_messages=[],
+ enable_cot: bool = False,
**kwargs,
) -> Union[str, AsyncIterator[str]]:
+ if enable_cot:
+ logger.debug("enable_cot=True is not supported for ollama and will be ignored.")
stream = True if kwargs.get("stream") else False
kwargs.pop("max_tokens", None)
@@ -123,7 +126,12 @@ async def _ollama_model_if_cache(
async def ollama_model_complete(
- prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
+ prompt,
+ system_prompt=None,
+ history_messages=[],
+ enable_cot: bool = False,
+ keyword_extraction=False,
+ **kwargs,
) -> Union[str, AsyncIterator[str]]:
keyword_extraction = kwargs.pop("keyword_extraction", None)
if keyword_extraction:
@@ -134,6 +142,7 @@ async def ollama_model_complete(
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
+ enable_cot=enable_cot,
**kwargs,
)
diff --git a/lightrag/llm/openai.py b/lightrag/llm/openai.py
index dd377862..6d486afc 100644
--- a/lightrag/llm/openai.py
+++ b/lightrag/llm/openai.py
@@ -111,12 +111,29 @@ async def openai_complete_if_cache(
prompt: str,
system_prompt: str | None = None,
history_messages: list[dict[str, Any]] | None = None,
+ enable_cot: bool = False,
base_url: str | None = None,
api_key: str | None = None,
token_tracker: Any | None = None,
**kwargs: Any,
) -> str:
- """Complete a prompt using OpenAI's API with caching support.
+ """Complete a prompt using OpenAI's API with caching support and Chain of Thought (COT) integration.
+
+ This function supports automatic integration of reasoning content (思维链) from models that provide
+ Chain of Thought capabilities. The reasoning content is seamlessly integrated into the response
+ using ... tags.
+
+ Note on `reasoning_content`: This feature relies on a Deepseek Style `reasoning_content`
+ in the API response, which may be provided by OpenAI-compatible endpoints that support
+ Chain of Thought.
+
+ COT Integration Rules:
+ 1. COT content is accepted only when regular content is empty and `reasoning_content` has content.
+ 2. COT processing stops when regular content becomes available.
+ 3. If both `content` and `reasoning_content` are present simultaneously, reasoning is ignored.
+ 4. If both fields have content from the start, COT is never activated.
+ 5. For streaming: COT content is inserted into the content stream with tags.
+ 6. For non-streaming: COT content is prepended to regular content with tags.
Args:
model: The OpenAI model to use.
@@ -125,6 +142,8 @@ async def openai_complete_if_cache(
history_messages: Optional list of previous messages in the conversation.
base_url: Optional base URL for the OpenAI API.
api_key: Optional OpenAI API key. If None, uses the OPENAI_API_KEY environment variable.
+ token_tracker: Optional token usage tracker for monitoring API usage.
+ enable_cot: Whether to enable Chain of Thought (COT) processing. Default is False.
**kwargs: Additional keyword arguments to pass to the OpenAI API.
Special kwargs:
- openai_client_configs: Dict of configuration options for the AsyncOpenAI client.
@@ -134,7 +153,8 @@ async def openai_complete_if_cache(
- keyword_extraction: Will be removed from kwargs before passing to OpenAI.
Returns:
- The completed text or an async iterator of text chunks if streaming.
+ The completed text (with integrated COT content if available) or an async iterator
+ of text chunks if streaming. COT content is wrapped in ... tags.
Raises:
InvalidResponseError: If the response from OpenAI is invalid or empty.
@@ -217,6 +237,11 @@ async def openai_complete_if_cache(
iteration_started = False
final_chunk_usage = None
+ # COT (Chain of Thought) state tracking
+ cot_active = False
+ cot_started = False
+ initial_content_seen = False
+
try:
iteration_started = True
async for chunk in response:
@@ -232,20 +257,65 @@ async def openai_complete_if_cache(
logger.warning(f"Received chunk without choices: {chunk}")
continue
- # Check if delta exists and has content
- if not hasattr(chunk.choices[0], "delta") or not hasattr(
- chunk.choices[0].delta, "content"
- ):
+ # Check if delta exists
+ if not hasattr(chunk.choices[0], "delta"):
# This might be the final chunk, continue to check for usage
continue
- content = chunk.choices[0].delta.content
- if content is None:
- continue
- if r"\u" in content:
- content = safe_unicode_decode(content.encode("utf-8"))
+ delta = chunk.choices[0].delta
+ content = getattr(delta, "content", None)
+ reasoning_content = getattr(delta, "reasoning_content", None)
- yield content
+ # Handle COT logic for streaming (only if enabled)
+ if enable_cot:
+ if content is not None and content != "":
+ # Regular content is present
+ if not initial_content_seen:
+ initial_content_seen = True
+ # If both content and reasoning_content are present initially, don't start COT
+ if (
+ reasoning_content is not None
+ and reasoning_content != ""
+ ):
+ cot_active = False
+ cot_started = False
+
+ # If COT was active, end it
+ if cot_active:
+ yield ""
+ cot_active = False
+
+ # Process regular content
+ if r"\u" in content:
+ content = safe_unicode_decode(content.encode("utf-8"))
+ yield content
+
+ elif reasoning_content is not None and reasoning_content != "":
+ # Only reasoning content is present
+ if not initial_content_seen and not cot_started:
+ # Start COT if we haven't seen initial content yet
+ if not cot_active:
+ yield ""
+ cot_active = True
+ cot_started = True
+
+ # Process reasoning content if COT is active
+ if cot_active:
+ if r"\u" in reasoning_content:
+ reasoning_content = safe_unicode_decode(
+ reasoning_content.encode("utf-8")
+ )
+ yield reasoning_content
+ else:
+ # COT disabled, only process regular content
+ if content is not None and content != "":
+ if r"\u" in content:
+ content = safe_unicode_decode(content.encode("utf-8"))
+ yield content
+
+ # If neither content nor reasoning_content, continue to next chunk
+ if content is None and reasoning_content is None:
+ continue
# After streaming is complete, track token usage
if token_tracker and final_chunk_usage:
@@ -313,21 +383,56 @@ async def openai_complete_if_cache(
not response
or not response.choices
or not hasattr(response.choices[0], "message")
- or not hasattr(response.choices[0].message, "content")
):
logger.error("Invalid response from OpenAI API")
await openai_async_client.close() # Ensure client is closed
raise InvalidResponseError("Invalid response from OpenAI API")
- content = response.choices[0].message.content
+ message = response.choices[0].message
+ content = getattr(message, "content", None)
+ reasoning_content = getattr(message, "reasoning_content", None)
- if not content or content.strip() == "":
+ # Handle COT logic for non-streaming responses (only if enabled)
+ final_content = ""
+
+ if enable_cot:
+ # Check if we should include reasoning content
+ should_include_reasoning = False
+ if reasoning_content and reasoning_content.strip():
+ if not content or content.strip() == "":
+ # Case 1: Only reasoning content, should include COT
+ should_include_reasoning = True
+ final_content = (
+ content or ""
+ ) # Use empty string if content is None
+ else:
+ # Case 3: Both content and reasoning_content present, ignore reasoning
+ should_include_reasoning = False
+ final_content = content
+ else:
+ # No reasoning content, use regular content
+ final_content = content or ""
+
+ # Apply COT wrapping if needed
+ if should_include_reasoning:
+ if r"\u" in reasoning_content:
+ reasoning_content = safe_unicode_decode(
+ reasoning_content.encode("utf-8")
+ )
+ final_content = f"{reasoning_content}{final_content}"
+ else:
+ # COT disabled, only use regular content
+ final_content = content or ""
+
+ # Validate final content
+ if not final_content or final_content.strip() == "":
logger.error("Received empty content from OpenAI API")
await openai_async_client.close() # Ensure client is closed
raise InvalidResponseError("Received empty content from OpenAI API")
- if r"\u" in content:
- content = safe_unicode_decode(content.encode("utf-8"))
+ # Apply Unicode decoding to final content if needed
+ if r"\u" in final_content:
+ final_content = safe_unicode_decode(final_content.encode("utf-8"))
if token_tracker and hasattr(response, "usage"):
token_counts = {
@@ -339,10 +444,10 @@ async def openai_complete_if_cache(
}
token_tracker.add_usage(token_counts)
- logger.debug(f"Response content len: {len(content)}")
+ logger.debug(f"Response content len: {len(final_content)}")
verbose_debug(f"Response: {response}")
- return content
+ return final_content
finally:
# Ensure client is closed in all cases for non-streaming responses
await openai_async_client.close()
@@ -374,6 +479,7 @@ async def gpt_4o_complete(
prompt,
system_prompt=None,
history_messages=None,
+ enable_cot: bool = False,
keyword_extraction=False,
**kwargs,
) -> str:
@@ -387,6 +493,7 @@ async def gpt_4o_complete(
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
+ enable_cot=enable_cot,
**kwargs,
)
@@ -395,6 +502,7 @@ async def gpt_4o_mini_complete(
prompt,
system_prompt=None,
history_messages=None,
+ enable_cot: bool = False,
keyword_extraction=False,
**kwargs,
) -> str:
@@ -408,6 +516,7 @@ async def gpt_4o_mini_complete(
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
+ enable_cot=enable_cot,
**kwargs,
)
@@ -416,6 +525,7 @@ async def nvidia_openai_complete(
prompt,
system_prompt=None,
history_messages=None,
+ enable_cot: bool = False,
keyword_extraction=False,
**kwargs,
) -> str:
@@ -427,6 +537,7 @@ async def nvidia_openai_complete(
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
+ enable_cot=enable_cot,
base_url="https://integrate.api.nvidia.com/v1",
**kwargs,
)
diff --git a/lightrag/llm/zhipu.py b/lightrag/llm/zhipu.py
index c9d1253e..d90f3cc1 100644
--- a/lightrag/llm/zhipu.py
+++ b/lightrag/llm/zhipu.py
@@ -49,8 +49,13 @@ async def zhipu_complete_if_cache(
api_key: Optional[str] = None,
system_prompt: Optional[str] = None,
history_messages: List[Dict[str, str]] = [],
+ enable_cot: bool = False,
**kwargs,
) -> str:
+ if enable_cot:
+ logger.debug(
+ "enable_cot=True is not supported for ZhipuAI and will be ignored."
+ )
# dynamically load ZhipuAI
try:
from zhipuai import ZhipuAI
@@ -91,7 +96,12 @@ async def zhipu_complete_if_cache(
async def zhipu_complete(
- prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
+ prompt,
+ system_prompt=None,
+ history_messages=[],
+ keyword_extraction=False,
+ enable_cot: bool = False,
+ **kwargs,
):
# Pop keyword_extraction from kwargs to avoid passing it to zhipu_complete_if_cache
keyword_extraction = kwargs.pop("keyword_extraction", None)
@@ -122,6 +132,7 @@ async def zhipu_complete(
prompt=prompt,
system_prompt=system_prompt,
history_messages=history_messages,
+ enable_cot=enable_cot,
**kwargs,
)
@@ -163,6 +174,7 @@ async def zhipu_complete(
prompt=prompt,
system_prompt=system_prompt,
history_messages=history_messages,
+ enable_cot=enable_cot,
**kwargs,
)
diff --git a/lightrag/operate.py b/lightrag/operate.py
index 162a9bc4..6635ef94 100644
--- a/lightrag/operate.py
+++ b/lightrag/operate.py
@@ -2213,6 +2213,7 @@ async def kg_query(
query,
system_prompt=sys_prompt,
stream=query_param.stream,
+ enable_cot=True,
)
if isinstance(response, str) and len(response) > len(sys_prompt):
response = (
@@ -4070,6 +4071,7 @@ async def naive_query(
query,
system_prompt=sys_prompt,
stream=query_param.stream,
+ enable_cot=True,
)
if isinstance(response, str) and len(response) > len(sys_prompt):