Merge branch 'main' into tongda/main
This commit is contained in:
commit
6774058670
12 changed files with 222 additions and 23 deletions
|
|
@ -2108,6 +2108,7 @@ class LightRAG:
|
|||
query.strip(),
|
||||
system_prompt=system_prompt,
|
||||
history_messages=param.conversation_history,
|
||||
enable_cot=True,
|
||||
stream=param.stream,
|
||||
)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -59,12 +59,17 @@ async def anthropic_complete_if_cache(
|
|||
prompt: str,
|
||||
system_prompt: str | None = None,
|
||||
history_messages: list[dict[str, Any]] | None = None,
|
||||
enable_cot: bool = False,
|
||||
base_url: str | None = None,
|
||||
api_key: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[str, AsyncIterator[str]]:
|
||||
if history_messages is None:
|
||||
history_messages = []
|
||||
if enable_cot:
|
||||
logger.debug(
|
||||
"enable_cot=True is not supported for the Anthropic API and will be ignored."
|
||||
)
|
||||
if not api_key:
|
||||
api_key = os.environ.get("ANTHROPIC_API_KEY")
|
||||
|
||||
|
|
@ -150,6 +155,7 @@ async def anthropic_complete(
|
|||
prompt: str,
|
||||
system_prompt: str | None = None,
|
||||
history_messages: list[dict[str, Any]] | None = None,
|
||||
enable_cot: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Union[str, AsyncIterator[str]]:
|
||||
if history_messages is None:
|
||||
|
|
@ -160,6 +166,7 @@ async def anthropic_complete(
|
|||
prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
enable_cot=enable_cot,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
|
@ -169,6 +176,7 @@ async def claude_3_opus_complete(
|
|||
prompt: str,
|
||||
system_prompt: str | None = None,
|
||||
history_messages: list[dict[str, Any]] | None = None,
|
||||
enable_cot: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Union[str, AsyncIterator[str]]:
|
||||
if history_messages is None:
|
||||
|
|
@ -178,6 +186,7 @@ async def claude_3_opus_complete(
|
|||
prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
enable_cot=enable_cot,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
|
@ -187,6 +196,7 @@ async def claude_3_sonnet_complete(
|
|||
prompt: str,
|
||||
system_prompt: str | None = None,
|
||||
history_messages: list[dict[str, Any]] | None = None,
|
||||
enable_cot: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Union[str, AsyncIterator[str]]:
|
||||
if history_messages is None:
|
||||
|
|
@ -196,6 +206,7 @@ async def claude_3_sonnet_complete(
|
|||
prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
enable_cot=enable_cot,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
|
@ -205,6 +216,7 @@ async def claude_3_haiku_complete(
|
|||
prompt: str,
|
||||
system_prompt: str | None = None,
|
||||
history_messages: list[dict[str, Any]] | None = None,
|
||||
enable_cot: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Union[str, AsyncIterator[str]]:
|
||||
if history_messages is None:
|
||||
|
|
@ -214,6 +226,7 @@ async def claude_3_haiku_complete(
|
|||
prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
enable_cot=enable_cot,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ from tenacity import (
|
|||
from lightrag.utils import (
|
||||
wrap_embedding_func_with_attrs,
|
||||
safe_unicode_decode,
|
||||
logger,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
|
|
@ -41,11 +42,16 @@ async def azure_openai_complete_if_cache(
|
|||
prompt,
|
||||
system_prompt: str | None = None,
|
||||
history_messages: Iterable[ChatCompletionMessageParam] | None = None,
|
||||
enable_cot: bool = False,
|
||||
base_url: str | None = None,
|
||||
api_key: str | None = None,
|
||||
api_version: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
if enable_cot:
|
||||
logger.debug(
|
||||
"enable_cot=True is not supported for the Azure OpenAI API and will be ignored."
|
||||
)
|
||||
deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT") or model or os.getenv("LLM_MODEL")
|
||||
base_url = (
|
||||
base_url or os.getenv("AZURE_OPENAI_ENDPOINT") or os.getenv("LLM_BINDING_HOST")
|
||||
|
|
|
|||
|
|
@ -44,11 +44,18 @@ async def bedrock_complete_if_cache(
|
|||
prompt,
|
||||
system_prompt=None,
|
||||
history_messages=[],
|
||||
enable_cot: bool = False,
|
||||
aws_access_key_id=None,
|
||||
aws_secret_access_key=None,
|
||||
aws_session_token=None,
|
||||
**kwargs,
|
||||
) -> Union[str, AsyncIterator[str]]:
|
||||
if enable_cot:
|
||||
import logging
|
||||
|
||||
logging.debug(
|
||||
"enable_cot=True is not supported for Bedrock and will be ignored."
|
||||
)
|
||||
# Respect existing env; only set if a non-empty value is available
|
||||
access_key = os.environ.get("AWS_ACCESS_KEY_ID") or aws_access_key_id
|
||||
secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY") or aws_secret_access_key
|
||||
|
|
|
|||
|
|
@ -56,8 +56,15 @@ async def hf_model_if_cache(
|
|||
prompt,
|
||||
system_prompt=None,
|
||||
history_messages=[],
|
||||
enable_cot: bool = False,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
if enable_cot:
|
||||
from lightrag.utils import logger
|
||||
|
||||
logger.debug(
|
||||
"enable_cot=True is not supported for Hugging Face local models and will be ignored."
|
||||
)
|
||||
model_name = model
|
||||
hf_model, hf_tokenizer = initialize_hf_model(model_name)
|
||||
messages = []
|
||||
|
|
@ -114,7 +121,12 @@ async def hf_model_if_cache(
|
|||
|
||||
|
||||
async def hf_model_complete(
|
||||
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
||||
prompt,
|
||||
system_prompt=None,
|
||||
history_messages=[],
|
||||
keyword_extraction=False,
|
||||
enable_cot: bool = False,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
kwargs.pop("keyword_extraction", None)
|
||||
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
|
||||
|
|
@ -123,6 +135,7 @@ async def hf_model_complete(
|
|||
prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
enable_cot=enable_cot,
|
||||
**kwargs,
|
||||
)
|
||||
return result
|
||||
|
|
|
|||
|
|
@ -94,9 +94,14 @@ async def llama_index_complete_if_cache(
|
|||
prompt: str,
|
||||
system_prompt: Optional[str] = None,
|
||||
history_messages: List[dict] = [],
|
||||
enable_cot: bool = False,
|
||||
chat_kwargs={},
|
||||
) -> str:
|
||||
"""Complete the prompt using LlamaIndex."""
|
||||
if enable_cot:
|
||||
logger.debug(
|
||||
"enable_cot=True is not supported for LlamaIndex implementation and will be ignored."
|
||||
)
|
||||
try:
|
||||
# Format messages for chat
|
||||
formatted_messages = []
|
||||
|
|
@ -138,6 +143,7 @@ async def llama_index_complete(
|
|||
prompt,
|
||||
system_prompt=None,
|
||||
history_messages=None,
|
||||
enable_cot: bool = False,
|
||||
keyword_extraction=False,
|
||||
settings: LlamaIndexSettings = None,
|
||||
**kwargs,
|
||||
|
|
@ -162,6 +168,7 @@ async def llama_index_complete(
|
|||
prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
enable_cot=enable_cot,
|
||||
**kwargs,
|
||||
)
|
||||
return result
|
||||
|
|
|
|||
|
|
@ -56,6 +56,7 @@ async def lmdeploy_model_if_cache(
|
|||
prompt,
|
||||
system_prompt=None,
|
||||
history_messages=[],
|
||||
enable_cot: bool = False,
|
||||
chat_template=None,
|
||||
model_format="hf",
|
||||
quant_policy=0,
|
||||
|
|
@ -89,6 +90,12 @@ async def lmdeploy_model_if_cache(
|
|||
do_sample (bool): Whether or not to use sampling, use greedy decoding otherwise.
|
||||
Default to be False, which means greedy decoding will be applied.
|
||||
"""
|
||||
if enable_cot:
|
||||
from lightrag.utils import logger
|
||||
|
||||
logger.debug(
|
||||
"enable_cot=True is not supported for lmdeploy and will be ignored."
|
||||
)
|
||||
try:
|
||||
import lmdeploy
|
||||
from lmdeploy import version_info, GenerationConfig
|
||||
|
|
|
|||
|
|
@ -39,10 +39,15 @@ async def lollms_model_if_cache(
|
|||
prompt,
|
||||
system_prompt=None,
|
||||
history_messages=[],
|
||||
enable_cot: bool = False,
|
||||
base_url="http://localhost:9600",
|
||||
**kwargs,
|
||||
) -> Union[str, AsyncIterator[str]]:
|
||||
"""Client implementation for lollms generation."""
|
||||
if enable_cot:
|
||||
from lightrag.utils import logger
|
||||
|
||||
logger.debug("enable_cot=True is not supported for lollms and will be ignored.")
|
||||
|
||||
stream = True if kwargs.get("stream") else False
|
||||
api_key = kwargs.pop("api_key", None)
|
||||
|
|
@ -98,7 +103,12 @@ async def lollms_model_if_cache(
|
|||
|
||||
|
||||
async def lollms_model_complete(
|
||||
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
||||
prompt,
|
||||
system_prompt=None,
|
||||
history_messages=[],
|
||||
enable_cot: bool = False,
|
||||
keyword_extraction=False,
|
||||
**kwargs,
|
||||
) -> Union[str, AsyncIterator[str]]:
|
||||
"""Complete function for lollms model generation."""
|
||||
|
||||
|
|
@ -119,6 +129,7 @@ async def lollms_model_complete(
|
|||
prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
enable_cot=enable_cot,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -43,8 +43,11 @@ async def _ollama_model_if_cache(
|
|||
prompt,
|
||||
system_prompt=None,
|
||||
history_messages=[],
|
||||
enable_cot: bool = False,
|
||||
**kwargs,
|
||||
) -> Union[str, AsyncIterator[str]]:
|
||||
if enable_cot:
|
||||
logger.debug("enable_cot=True is not supported for ollama and will be ignored.")
|
||||
stream = True if kwargs.get("stream") else False
|
||||
|
||||
kwargs.pop("max_tokens", None)
|
||||
|
|
@ -123,7 +126,12 @@ async def _ollama_model_if_cache(
|
|||
|
||||
|
||||
async def ollama_model_complete(
|
||||
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
||||
prompt,
|
||||
system_prompt=None,
|
||||
history_messages=[],
|
||||
enable_cot: bool = False,
|
||||
keyword_extraction=False,
|
||||
**kwargs,
|
||||
) -> Union[str, AsyncIterator[str]]:
|
||||
keyword_extraction = kwargs.pop("keyword_extraction", None)
|
||||
if keyword_extraction:
|
||||
|
|
@ -134,6 +142,7 @@ async def ollama_model_complete(
|
|||
prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
enable_cot=enable_cot,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -111,12 +111,29 @@ async def openai_complete_if_cache(
|
|||
prompt: str,
|
||||
system_prompt: str | None = None,
|
||||
history_messages: list[dict[str, Any]] | None = None,
|
||||
enable_cot: bool = False,
|
||||
base_url: str | None = None,
|
||||
api_key: str | None = None,
|
||||
token_tracker: Any | None = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Complete a prompt using OpenAI's API with caching support.
|
||||
"""Complete a prompt using OpenAI's API with caching support and Chain of Thought (COT) integration.
|
||||
|
||||
This function supports automatic integration of reasoning content (思维链) from models that provide
|
||||
Chain of Thought capabilities. The reasoning content is seamlessly integrated into the response
|
||||
using <think>...</think> tags.
|
||||
|
||||
Note on `reasoning_content`: This feature relies on a Deepseek Style `reasoning_content`
|
||||
in the API response, which may be provided by OpenAI-compatible endpoints that support
|
||||
Chain of Thought.
|
||||
|
||||
COT Integration Rules:
|
||||
1. COT content is accepted only when regular content is empty and `reasoning_content` has content.
|
||||
2. COT processing stops when regular content becomes available.
|
||||
3. If both `content` and `reasoning_content` are present simultaneously, reasoning is ignored.
|
||||
4. If both fields have content from the start, COT is never activated.
|
||||
5. For streaming: COT content is inserted into the content stream with <think> tags.
|
||||
6. For non-streaming: COT content is prepended to regular content with <think> tags.
|
||||
|
||||
Args:
|
||||
model: The OpenAI model to use.
|
||||
|
|
@ -125,6 +142,8 @@ async def openai_complete_if_cache(
|
|||
history_messages: Optional list of previous messages in the conversation.
|
||||
base_url: Optional base URL for the OpenAI API.
|
||||
api_key: Optional OpenAI API key. If None, uses the OPENAI_API_KEY environment variable.
|
||||
token_tracker: Optional token usage tracker for monitoring API usage.
|
||||
enable_cot: Whether to enable Chain of Thought (COT) processing. Default is False.
|
||||
**kwargs: Additional keyword arguments to pass to the OpenAI API.
|
||||
Special kwargs:
|
||||
- openai_client_configs: Dict of configuration options for the AsyncOpenAI client.
|
||||
|
|
@ -134,7 +153,8 @@ async def openai_complete_if_cache(
|
|||
- keyword_extraction: Will be removed from kwargs before passing to OpenAI.
|
||||
|
||||
Returns:
|
||||
The completed text or an async iterator of text chunks if streaming.
|
||||
The completed text (with integrated COT content if available) or an async iterator
|
||||
of text chunks if streaming. COT content is wrapped in <think>...</think> tags.
|
||||
|
||||
Raises:
|
||||
InvalidResponseError: If the response from OpenAI is invalid or empty.
|
||||
|
|
@ -217,6 +237,11 @@ async def openai_complete_if_cache(
|
|||
iteration_started = False
|
||||
final_chunk_usage = None
|
||||
|
||||
# COT (Chain of Thought) state tracking
|
||||
cot_active = False
|
||||
cot_started = False
|
||||
initial_content_seen = False
|
||||
|
||||
try:
|
||||
iteration_started = True
|
||||
async for chunk in response:
|
||||
|
|
@ -232,20 +257,65 @@ async def openai_complete_if_cache(
|
|||
logger.warning(f"Received chunk without choices: {chunk}")
|
||||
continue
|
||||
|
||||
# Check if delta exists and has content
|
||||
if not hasattr(chunk.choices[0], "delta") or not hasattr(
|
||||
chunk.choices[0].delta, "content"
|
||||
):
|
||||
# Check if delta exists
|
||||
if not hasattr(chunk.choices[0], "delta"):
|
||||
# This might be the final chunk, continue to check for usage
|
||||
continue
|
||||
|
||||
content = chunk.choices[0].delta.content
|
||||
if content is None:
|
||||
continue
|
||||
if r"\u" in content:
|
||||
content = safe_unicode_decode(content.encode("utf-8"))
|
||||
delta = chunk.choices[0].delta
|
||||
content = getattr(delta, "content", None)
|
||||
reasoning_content = getattr(delta, "reasoning_content", None)
|
||||
|
||||
yield content
|
||||
# Handle COT logic for streaming (only if enabled)
|
||||
if enable_cot:
|
||||
if content is not None and content != "":
|
||||
# Regular content is present
|
||||
if not initial_content_seen:
|
||||
initial_content_seen = True
|
||||
# If both content and reasoning_content are present initially, don't start COT
|
||||
if (
|
||||
reasoning_content is not None
|
||||
and reasoning_content != ""
|
||||
):
|
||||
cot_active = False
|
||||
cot_started = False
|
||||
|
||||
# If COT was active, end it
|
||||
if cot_active:
|
||||
yield "</think>"
|
||||
cot_active = False
|
||||
|
||||
# Process regular content
|
||||
if r"\u" in content:
|
||||
content = safe_unicode_decode(content.encode("utf-8"))
|
||||
yield content
|
||||
|
||||
elif reasoning_content is not None and reasoning_content != "":
|
||||
# Only reasoning content is present
|
||||
if not initial_content_seen and not cot_started:
|
||||
# Start COT if we haven't seen initial content yet
|
||||
if not cot_active:
|
||||
yield "<think>"
|
||||
cot_active = True
|
||||
cot_started = True
|
||||
|
||||
# Process reasoning content if COT is active
|
||||
if cot_active:
|
||||
if r"\u" in reasoning_content:
|
||||
reasoning_content = safe_unicode_decode(
|
||||
reasoning_content.encode("utf-8")
|
||||
)
|
||||
yield reasoning_content
|
||||
else:
|
||||
# COT disabled, only process regular content
|
||||
if content is not None and content != "":
|
||||
if r"\u" in content:
|
||||
content = safe_unicode_decode(content.encode("utf-8"))
|
||||
yield content
|
||||
|
||||
# If neither content nor reasoning_content, continue to next chunk
|
||||
if content is None and reasoning_content is None:
|
||||
continue
|
||||
|
||||
# After streaming is complete, track token usage
|
||||
if token_tracker and final_chunk_usage:
|
||||
|
|
@ -313,21 +383,56 @@ async def openai_complete_if_cache(
|
|||
not response
|
||||
or not response.choices
|
||||
or not hasattr(response.choices[0], "message")
|
||||
or not hasattr(response.choices[0].message, "content")
|
||||
):
|
||||
logger.error("Invalid response from OpenAI API")
|
||||
await openai_async_client.close() # Ensure client is closed
|
||||
raise InvalidResponseError("Invalid response from OpenAI API")
|
||||
|
||||
content = response.choices[0].message.content
|
||||
message = response.choices[0].message
|
||||
content = getattr(message, "content", None)
|
||||
reasoning_content = getattr(message, "reasoning_content", None)
|
||||
|
||||
if not content or content.strip() == "":
|
||||
# Handle COT logic for non-streaming responses (only if enabled)
|
||||
final_content = ""
|
||||
|
||||
if enable_cot:
|
||||
# Check if we should include reasoning content
|
||||
should_include_reasoning = False
|
||||
if reasoning_content and reasoning_content.strip():
|
||||
if not content or content.strip() == "":
|
||||
# Case 1: Only reasoning content, should include COT
|
||||
should_include_reasoning = True
|
||||
final_content = (
|
||||
content or ""
|
||||
) # Use empty string if content is None
|
||||
else:
|
||||
# Case 3: Both content and reasoning_content present, ignore reasoning
|
||||
should_include_reasoning = False
|
||||
final_content = content
|
||||
else:
|
||||
# No reasoning content, use regular content
|
||||
final_content = content or ""
|
||||
|
||||
# Apply COT wrapping if needed
|
||||
if should_include_reasoning:
|
||||
if r"\u" in reasoning_content:
|
||||
reasoning_content = safe_unicode_decode(
|
||||
reasoning_content.encode("utf-8")
|
||||
)
|
||||
final_content = f"<think>{reasoning_content}</think>{final_content}"
|
||||
else:
|
||||
# COT disabled, only use regular content
|
||||
final_content = content or ""
|
||||
|
||||
# Validate final content
|
||||
if not final_content or final_content.strip() == "":
|
||||
logger.error("Received empty content from OpenAI API")
|
||||
await openai_async_client.close() # Ensure client is closed
|
||||
raise InvalidResponseError("Received empty content from OpenAI API")
|
||||
|
||||
if r"\u" in content:
|
||||
content = safe_unicode_decode(content.encode("utf-8"))
|
||||
# Apply Unicode decoding to final content if needed
|
||||
if r"\u" in final_content:
|
||||
final_content = safe_unicode_decode(final_content.encode("utf-8"))
|
||||
|
||||
if token_tracker and hasattr(response, "usage"):
|
||||
token_counts = {
|
||||
|
|
@ -339,10 +444,10 @@ async def openai_complete_if_cache(
|
|||
}
|
||||
token_tracker.add_usage(token_counts)
|
||||
|
||||
logger.debug(f"Response content len: {len(content)}")
|
||||
logger.debug(f"Response content len: {len(final_content)}")
|
||||
verbose_debug(f"Response: {response}")
|
||||
|
||||
return content
|
||||
return final_content
|
||||
finally:
|
||||
# Ensure client is closed in all cases for non-streaming responses
|
||||
await openai_async_client.close()
|
||||
|
|
@ -374,6 +479,7 @@ async def gpt_4o_complete(
|
|||
prompt,
|
||||
system_prompt=None,
|
||||
history_messages=None,
|
||||
enable_cot: bool = False,
|
||||
keyword_extraction=False,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
|
|
@ -387,6 +493,7 @@ async def gpt_4o_complete(
|
|||
prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
enable_cot=enable_cot,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
|
@ -395,6 +502,7 @@ async def gpt_4o_mini_complete(
|
|||
prompt,
|
||||
system_prompt=None,
|
||||
history_messages=None,
|
||||
enable_cot: bool = False,
|
||||
keyword_extraction=False,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
|
|
@ -408,6 +516,7 @@ async def gpt_4o_mini_complete(
|
|||
prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
enable_cot=enable_cot,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
|
@ -416,6 +525,7 @@ async def nvidia_openai_complete(
|
|||
prompt,
|
||||
system_prompt=None,
|
||||
history_messages=None,
|
||||
enable_cot: bool = False,
|
||||
keyword_extraction=False,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
|
|
@ -427,6 +537,7 @@ async def nvidia_openai_complete(
|
|||
prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
enable_cot=enable_cot,
|
||||
base_url="https://integrate.api.nvidia.com/v1",
|
||||
**kwargs,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -49,8 +49,13 @@ async def zhipu_complete_if_cache(
|
|||
api_key: Optional[str] = None,
|
||||
system_prompt: Optional[str] = None,
|
||||
history_messages: List[Dict[str, str]] = [],
|
||||
enable_cot: bool = False,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
if enable_cot:
|
||||
logger.debug(
|
||||
"enable_cot=True is not supported for ZhipuAI and will be ignored."
|
||||
)
|
||||
# dynamically load ZhipuAI
|
||||
try:
|
||||
from zhipuai import ZhipuAI
|
||||
|
|
@ -91,7 +96,12 @@ async def zhipu_complete_if_cache(
|
|||
|
||||
|
||||
async def zhipu_complete(
|
||||
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
||||
prompt,
|
||||
system_prompt=None,
|
||||
history_messages=[],
|
||||
keyword_extraction=False,
|
||||
enable_cot: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
# Pop keyword_extraction from kwargs to avoid passing it to zhipu_complete_if_cache
|
||||
keyword_extraction = kwargs.pop("keyword_extraction", None)
|
||||
|
|
@ -122,6 +132,7 @@ async def zhipu_complete(
|
|||
prompt=prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
enable_cot=enable_cot,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
|
@ -163,6 +174,7 @@ async def zhipu_complete(
|
|||
prompt=prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
enable_cot=enable_cot,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -2213,6 +2213,7 @@ async def kg_query(
|
|||
query,
|
||||
system_prompt=sys_prompt,
|
||||
stream=query_param.stream,
|
||||
enable_cot=True,
|
||||
)
|
||||
if isinstance(response, str) and len(response) > len(sys_prompt):
|
||||
response = (
|
||||
|
|
@ -4070,6 +4071,7 @@ async def naive_query(
|
|||
query,
|
||||
system_prompt=sys_prompt,
|
||||
stream=query_param.stream,
|
||||
enable_cot=True,
|
||||
)
|
||||
|
||||
if isinstance(response, str) and len(response) > len(sys_prompt):
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue