Merge pull request #2086 from danielaskdd/reasoning_content

feat: Add Deepseek Sytle CoT Support for Open AI Compatible LLM Provider
This commit is contained in:
Daniel.y 2025-09-09 22:42:09 +08:00 committed by GitHub
commit 9c9d55b697
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 222 additions and 23 deletions

View file

@ -2107,6 +2107,7 @@ class LightRAG:
query.strip(),
system_prompt=system_prompt,
history_messages=param.conversation_history,
enable_cot=True,
stream=param.stream,
)
else:

View file

@ -59,12 +59,17 @@ async def anthropic_complete_if_cache(
prompt: str,
system_prompt: str | None = None,
history_messages: list[dict[str, Any]] | None = None,
enable_cot: bool = False,
base_url: str | None = None,
api_key: str | None = None,
**kwargs: Any,
) -> Union[str, AsyncIterator[str]]:
if history_messages is None:
history_messages = []
if enable_cot:
logger.debug(
"enable_cot=True is not supported for the Anthropic API and will be ignored."
)
if not api_key:
api_key = os.environ.get("ANTHROPIC_API_KEY")
@ -150,6 +155,7 @@ async def anthropic_complete(
prompt: str,
system_prompt: str | None = None,
history_messages: list[dict[str, Any]] | None = None,
enable_cot: bool = False,
**kwargs: Any,
) -> Union[str, AsyncIterator[str]]:
if history_messages is None:
@ -160,6 +166,7 @@ async def anthropic_complete(
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
enable_cot=enable_cot,
**kwargs,
)
@ -169,6 +176,7 @@ async def claude_3_opus_complete(
prompt: str,
system_prompt: str | None = None,
history_messages: list[dict[str, Any]] | None = None,
enable_cot: bool = False,
**kwargs: Any,
) -> Union[str, AsyncIterator[str]]:
if history_messages is None:
@ -178,6 +186,7 @@ async def claude_3_opus_complete(
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
enable_cot=enable_cot,
**kwargs,
)
@ -187,6 +196,7 @@ async def claude_3_sonnet_complete(
prompt: str,
system_prompt: str | None = None,
history_messages: list[dict[str, Any]] | None = None,
enable_cot: bool = False,
**kwargs: Any,
) -> Union[str, AsyncIterator[str]]:
if history_messages is None:
@ -196,6 +206,7 @@ async def claude_3_sonnet_complete(
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
enable_cot=enable_cot,
**kwargs,
)
@ -205,6 +216,7 @@ async def claude_3_haiku_complete(
prompt: str,
system_prompt: str | None = None,
history_messages: list[dict[str, Any]] | None = None,
enable_cot: bool = False,
**kwargs: Any,
) -> Union[str, AsyncIterator[str]]:
if history_messages is None:
@ -214,6 +226,7 @@ async def claude_3_haiku_complete(
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
enable_cot=enable_cot,
**kwargs,
)

View file

@ -24,6 +24,7 @@ from tenacity import (
from lightrag.utils import (
wrap_embedding_func_with_attrs,
safe_unicode_decode,
logger,
)
import numpy as np
@ -41,11 +42,16 @@ async def azure_openai_complete_if_cache(
prompt,
system_prompt: str | None = None,
history_messages: Iterable[ChatCompletionMessageParam] | None = None,
enable_cot: bool = False,
base_url: str | None = None,
api_key: str | None = None,
api_version: str | None = None,
**kwargs,
):
if enable_cot:
logger.debug(
"enable_cot=True is not supported for the Azure OpenAI API and will be ignored."
)
deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT") or model or os.getenv("LLM_MODEL")
base_url = (
base_url or os.getenv("AZURE_OPENAI_ENDPOINT") or os.getenv("LLM_BINDING_HOST")

View file

@ -44,11 +44,18 @@ async def bedrock_complete_if_cache(
prompt,
system_prompt=None,
history_messages=[],
enable_cot: bool = False,
aws_access_key_id=None,
aws_secret_access_key=None,
aws_session_token=None,
**kwargs,
) -> Union[str, AsyncIterator[str]]:
if enable_cot:
import logging
logging.debug(
"enable_cot=True is not supported for Bedrock and will be ignored."
)
# Respect existing env; only set if a non-empty value is available
access_key = os.environ.get("AWS_ACCESS_KEY_ID") or aws_access_key_id
secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY") or aws_secret_access_key

View file

@ -56,8 +56,15 @@ async def hf_model_if_cache(
prompt,
system_prompt=None,
history_messages=[],
enable_cot: bool = False,
**kwargs,
) -> str:
if enable_cot:
from lightrag.utils import logger
logger.debug(
"enable_cot=True is not supported for Hugging Face local models and will be ignored."
)
model_name = model
hf_model, hf_tokenizer = initialize_hf_model(model_name)
messages = []
@ -114,7 +121,12 @@ async def hf_model_if_cache(
async def hf_model_complete(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
prompt,
system_prompt=None,
history_messages=[],
keyword_extraction=False,
enable_cot: bool = False,
**kwargs,
) -> str:
kwargs.pop("keyword_extraction", None)
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
@ -123,6 +135,7 @@ async def hf_model_complete(
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
enable_cot=enable_cot,
**kwargs,
)
return result

View file

@ -94,9 +94,14 @@ async def llama_index_complete_if_cache(
prompt: str,
system_prompt: Optional[str] = None,
history_messages: List[dict] = [],
enable_cot: bool = False,
chat_kwargs={},
) -> str:
"""Complete the prompt using LlamaIndex."""
if enable_cot:
logger.debug(
"enable_cot=True is not supported for LlamaIndex implementation and will be ignored."
)
try:
# Format messages for chat
formatted_messages = []
@ -138,6 +143,7 @@ async def llama_index_complete(
prompt,
system_prompt=None,
history_messages=None,
enable_cot: bool = False,
keyword_extraction=False,
settings: LlamaIndexSettings = None,
**kwargs,
@ -162,6 +168,7 @@ async def llama_index_complete(
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
enable_cot=enable_cot,
**kwargs,
)
return result

View file

@ -56,6 +56,7 @@ async def lmdeploy_model_if_cache(
prompt,
system_prompt=None,
history_messages=[],
enable_cot: bool = False,
chat_template=None,
model_format="hf",
quant_policy=0,
@ -89,6 +90,12 @@ async def lmdeploy_model_if_cache(
do_sample (bool): Whether or not to use sampling, use greedy decoding otherwise.
Default to be False, which means greedy decoding will be applied.
"""
if enable_cot:
from lightrag.utils import logger
logger.debug(
"enable_cot=True is not supported for lmdeploy and will be ignored."
)
try:
import lmdeploy
from lmdeploy import version_info, GenerationConfig

View file

@ -39,10 +39,15 @@ async def lollms_model_if_cache(
prompt,
system_prompt=None,
history_messages=[],
enable_cot: bool = False,
base_url="http://localhost:9600",
**kwargs,
) -> Union[str, AsyncIterator[str]]:
"""Client implementation for lollms generation."""
if enable_cot:
from lightrag.utils import logger
logger.debug("enable_cot=True is not supported for lollms and will be ignored.")
stream = True if kwargs.get("stream") else False
api_key = kwargs.pop("api_key", None)
@ -98,7 +103,12 @@ async def lollms_model_if_cache(
async def lollms_model_complete(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
prompt,
system_prompt=None,
history_messages=[],
enable_cot: bool = False,
keyword_extraction=False,
**kwargs,
) -> Union[str, AsyncIterator[str]]:
"""Complete function for lollms model generation."""
@ -119,6 +129,7 @@ async def lollms_model_complete(
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
enable_cot=enable_cot,
**kwargs,
)

View file

@ -43,8 +43,11 @@ async def _ollama_model_if_cache(
prompt,
system_prompt=None,
history_messages=[],
enable_cot: bool = False,
**kwargs,
) -> Union[str, AsyncIterator[str]]:
if enable_cot:
logger.debug("enable_cot=True is not supported for ollama and will be ignored.")
stream = True if kwargs.get("stream") else False
kwargs.pop("max_tokens", None)
@ -123,7 +126,12 @@ async def _ollama_model_if_cache(
async def ollama_model_complete(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
prompt,
system_prompt=None,
history_messages=[],
enable_cot: bool = False,
keyword_extraction=False,
**kwargs,
) -> Union[str, AsyncIterator[str]]:
keyword_extraction = kwargs.pop("keyword_extraction", None)
if keyword_extraction:
@ -134,6 +142,7 @@ async def ollama_model_complete(
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
enable_cot=enable_cot,
**kwargs,
)

View file

@ -111,12 +111,29 @@ async def openai_complete_if_cache(
prompt: str,
system_prompt: str | None = None,
history_messages: list[dict[str, Any]] | None = None,
enable_cot: bool = False,
base_url: str | None = None,
api_key: str | None = None,
token_tracker: Any | None = None,
**kwargs: Any,
) -> str:
"""Complete a prompt using OpenAI's API with caching support.
"""Complete a prompt using OpenAI's API with caching support and Chain of Thought (COT) integration.
This function supports automatic integration of reasoning content (思维链) from models that provide
Chain of Thought capabilities. The reasoning content is seamlessly integrated into the response
using <think>...</think> tags.
Note on `reasoning_content`: This feature relies on a Deepseek Style `reasoning_content`
in the API response, which may be provided by OpenAI-compatible endpoints that support
Chain of Thought.
COT Integration Rules:
1. COT content is accepted only when regular content is empty and `reasoning_content` has content.
2. COT processing stops when regular content becomes available.
3. If both `content` and `reasoning_content` are present simultaneously, reasoning is ignored.
4. If both fields have content from the start, COT is never activated.
5. For streaming: COT content is inserted into the content stream with <think> tags.
6. For non-streaming: COT content is prepended to regular content with <think> tags.
Args:
model: The OpenAI model to use.
@ -125,6 +142,8 @@ async def openai_complete_if_cache(
history_messages: Optional list of previous messages in the conversation.
base_url: Optional base URL for the OpenAI API.
api_key: Optional OpenAI API key. If None, uses the OPENAI_API_KEY environment variable.
token_tracker: Optional token usage tracker for monitoring API usage.
enable_cot: Whether to enable Chain of Thought (COT) processing. Default is False.
**kwargs: Additional keyword arguments to pass to the OpenAI API.
Special kwargs:
- openai_client_configs: Dict of configuration options for the AsyncOpenAI client.
@ -134,7 +153,8 @@ async def openai_complete_if_cache(
- keyword_extraction: Will be removed from kwargs before passing to OpenAI.
Returns:
The completed text or an async iterator of text chunks if streaming.
The completed text (with integrated COT content if available) or an async iterator
of text chunks if streaming. COT content is wrapped in <think>...</think> tags.
Raises:
InvalidResponseError: If the response from OpenAI is invalid or empty.
@ -217,6 +237,11 @@ async def openai_complete_if_cache(
iteration_started = False
final_chunk_usage = None
# COT (Chain of Thought) state tracking
cot_active = False
cot_started = False
initial_content_seen = False
try:
iteration_started = True
async for chunk in response:
@ -232,20 +257,65 @@ async def openai_complete_if_cache(
logger.warning(f"Received chunk without choices: {chunk}")
continue
# Check if delta exists and has content
if not hasattr(chunk.choices[0], "delta") or not hasattr(
chunk.choices[0].delta, "content"
):
# Check if delta exists
if not hasattr(chunk.choices[0], "delta"):
# This might be the final chunk, continue to check for usage
continue
content = chunk.choices[0].delta.content
if content is None:
continue
if r"\u" in content:
content = safe_unicode_decode(content.encode("utf-8"))
delta = chunk.choices[0].delta
content = getattr(delta, "content", None)
reasoning_content = getattr(delta, "reasoning_content", None)
yield content
# Handle COT logic for streaming (only if enabled)
if enable_cot:
if content is not None and content != "":
# Regular content is present
if not initial_content_seen:
initial_content_seen = True
# If both content and reasoning_content are present initially, don't start COT
if (
reasoning_content is not None
and reasoning_content != ""
):
cot_active = False
cot_started = False
# If COT was active, end it
if cot_active:
yield "</think>"
cot_active = False
# Process regular content
if r"\u" in content:
content = safe_unicode_decode(content.encode("utf-8"))
yield content
elif reasoning_content is not None and reasoning_content != "":
# Only reasoning content is present
if not initial_content_seen and not cot_started:
# Start COT if we haven't seen initial content yet
if not cot_active:
yield "<think>"
cot_active = True
cot_started = True
# Process reasoning content if COT is active
if cot_active:
if r"\u" in reasoning_content:
reasoning_content = safe_unicode_decode(
reasoning_content.encode("utf-8")
)
yield reasoning_content
else:
# COT disabled, only process regular content
if content is not None and content != "":
if r"\u" in content:
content = safe_unicode_decode(content.encode("utf-8"))
yield content
# If neither content nor reasoning_content, continue to next chunk
if content is None and reasoning_content is None:
continue
# After streaming is complete, track token usage
if token_tracker and final_chunk_usage:
@ -313,21 +383,56 @@ async def openai_complete_if_cache(
not response
or not response.choices
or not hasattr(response.choices[0], "message")
or not hasattr(response.choices[0].message, "content")
):
logger.error("Invalid response from OpenAI API")
await openai_async_client.close() # Ensure client is closed
raise InvalidResponseError("Invalid response from OpenAI API")
content = response.choices[0].message.content
message = response.choices[0].message
content = getattr(message, "content", None)
reasoning_content = getattr(message, "reasoning_content", None)
if not content or content.strip() == "":
# Handle COT logic for non-streaming responses (only if enabled)
final_content = ""
if enable_cot:
# Check if we should include reasoning content
should_include_reasoning = False
if reasoning_content and reasoning_content.strip():
if not content or content.strip() == "":
# Case 1: Only reasoning content, should include COT
should_include_reasoning = True
final_content = (
content or ""
) # Use empty string if content is None
else:
# Case 3: Both content and reasoning_content present, ignore reasoning
should_include_reasoning = False
final_content = content
else:
# No reasoning content, use regular content
final_content = content or ""
# Apply COT wrapping if needed
if should_include_reasoning:
if r"\u" in reasoning_content:
reasoning_content = safe_unicode_decode(
reasoning_content.encode("utf-8")
)
final_content = f"<think>{reasoning_content}</think>{final_content}"
else:
# COT disabled, only use regular content
final_content = content or ""
# Validate final content
if not final_content or final_content.strip() == "":
logger.error("Received empty content from OpenAI API")
await openai_async_client.close() # Ensure client is closed
raise InvalidResponseError("Received empty content from OpenAI API")
if r"\u" in content:
content = safe_unicode_decode(content.encode("utf-8"))
# Apply Unicode decoding to final content if needed
if r"\u" in final_content:
final_content = safe_unicode_decode(final_content.encode("utf-8"))
if token_tracker and hasattr(response, "usage"):
token_counts = {
@ -339,10 +444,10 @@ async def openai_complete_if_cache(
}
token_tracker.add_usage(token_counts)
logger.debug(f"Response content len: {len(content)}")
logger.debug(f"Response content len: {len(final_content)}")
verbose_debug(f"Response: {response}")
return content
return final_content
finally:
# Ensure client is closed in all cases for non-streaming responses
await openai_async_client.close()
@ -374,6 +479,7 @@ async def gpt_4o_complete(
prompt,
system_prompt=None,
history_messages=None,
enable_cot: bool = False,
keyword_extraction=False,
**kwargs,
) -> str:
@ -387,6 +493,7 @@ async def gpt_4o_complete(
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
enable_cot=enable_cot,
**kwargs,
)
@ -395,6 +502,7 @@ async def gpt_4o_mini_complete(
prompt,
system_prompt=None,
history_messages=None,
enable_cot: bool = False,
keyword_extraction=False,
**kwargs,
) -> str:
@ -408,6 +516,7 @@ async def gpt_4o_mini_complete(
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
enable_cot=enable_cot,
**kwargs,
)
@ -416,6 +525,7 @@ async def nvidia_openai_complete(
prompt,
system_prompt=None,
history_messages=None,
enable_cot: bool = False,
keyword_extraction=False,
**kwargs,
) -> str:
@ -427,6 +537,7 @@ async def nvidia_openai_complete(
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
enable_cot=enable_cot,
base_url="https://integrate.api.nvidia.com/v1",
**kwargs,
)

View file

@ -49,8 +49,13 @@ async def zhipu_complete_if_cache(
api_key: Optional[str] = None,
system_prompt: Optional[str] = None,
history_messages: List[Dict[str, str]] = [],
enable_cot: bool = False,
**kwargs,
) -> str:
if enable_cot:
logger.debug(
"enable_cot=True is not supported for ZhipuAI and will be ignored."
)
# dynamically load ZhipuAI
try:
from zhipuai import ZhipuAI
@ -91,7 +96,12 @@ async def zhipu_complete_if_cache(
async def zhipu_complete(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
prompt,
system_prompt=None,
history_messages=[],
keyword_extraction=False,
enable_cot: bool = False,
**kwargs,
):
# Pop keyword_extraction from kwargs to avoid passing it to zhipu_complete_if_cache
keyword_extraction = kwargs.pop("keyword_extraction", None)
@ -122,6 +132,7 @@ async def zhipu_complete(
prompt=prompt,
system_prompt=system_prompt,
history_messages=history_messages,
enable_cot=enable_cot,
**kwargs,
)
@ -163,6 +174,7 @@ async def zhipu_complete(
prompt=prompt,
system_prompt=system_prompt,
history_messages=history_messages,
enable_cot=enable_cot,
**kwargs,
)

View file

@ -2214,6 +2214,7 @@ async def kg_query(
query,
system_prompt=sys_prompt,
stream=query_param.stream,
enable_cot=True,
)
if isinstance(response, str) and len(response) > len(sys_prompt):
response = (
@ -3736,6 +3737,7 @@ async def naive_query(
query,
system_prompt=sys_prompt,
stream=query_param.stream,
enable_cot=True,
)
if isinstance(response, str) and len(response) > len(sys_prompt):