This commit is contained in:
hzywhite 2025-09-10 11:23:33 +08:00
commit d0709d5416
30 changed files with 673 additions and 308 deletions

View file

@ -48,7 +48,7 @@ jobs:
images: ghcr.io/${{ github.repository }}
tags: |
type=raw,value=${{ steps.get_tag.outputs.tag }}
type=raw,value=latest,enable={{is_default_branch}}
type=raw,value=latest
- name: Build and push Docker image
uses: docker/build-push-action@v5

View file

@ -3,6 +3,7 @@ name: Upload LightRAG-hku Package
on:
release:
types: [published]
workflow_dispatch:
permissions:
contents: read

View file

@ -107,7 +107,7 @@ RERANK_BINDING=null
# RERANK_BINDING_API_KEY=your_rerank_api_key_here
### Default value for Jina AI
# RERANK_MODELjina-reranker-v2-base-multilingual
# RERANK_MODEL=jina-reranker-v2-base-multilingual
# RERANK_BINDING_HOST=https://api.jina.ai/v1/rerank
# RERANK_BINDING_API_KEY=your_rerank_api_key_here
@ -175,8 +175,8 @@ LLM_BINDING_API_KEY=your_api_key
# LLM_BINDING=openai
### OpenAI Specific Parameters
### To mitigate endless output, set the temperature to a highter value
# OPENAI_LLM_TEMPERATURE=0.95
### Set the max_output_tokens to mitigate endless output of some LLM (less than LLM_TIMEOUT * llm_output_tokens/second, i.e. 9000 = 180s * 50 tokens/s)
# OPENAI_LLM_MAX_TOKENS=9000
### OpenRouter Specific Parameters
# OPENAI_LLM_EXTRA_BODY='{"reasoning": {"enabled": false}}'
@ -189,7 +189,8 @@ LLM_BINDING_API_KEY=your_api_key
### Ollama Server Specific Parameters
### OLLAMA_LLM_NUM_CTX must be provided, and should at least larger than MAX_TOTAL_TOKENS + 2000
OLLAMA_LLM_NUM_CTX=32768
# OLLAMA_LLM_TEMPERATURE=1.0
### Set the max_output_tokens to mitigate endless output of some LLM (less than LLM_TIMEOUT * llm_output_tokens/second, i.e. 9000 = 180s * 50 tokens/s)
# OLLAMA_LLM_NUM_PREDICT=9000
### Stop sequences for Ollama LLM
# OLLAMA_LLM_STOP='["</s>", "<|EOT|>"]'
### use the following command to see all support options for Ollama LLM

View file

@ -1 +1 @@
__api_version__ = "0215"
__api_version__ = "0218"

View file

@ -417,7 +417,7 @@ def update_uvicorn_mode_config():
global_args.workers = 1
# Log warning directly here
logging.warning(
f"In uvicorn mode, workers parameter was set to {original_workers}. Forcing workers=1"
f">> Forcing workers=1 in uvicorn mode(Ignoring workers={original_workers})"
)

View file

@ -69,7 +69,7 @@ DEFAULT_EMBEDDING_FUNC_MAX_ASYNC = 8 # Default max async for embedding function
DEFAULT_EMBEDDING_BATCH_NUM = 10 # Default batch size for embedding computations
# Gunicorn worker timeout
DEFAULT_TIMEOUT = 210
DEFAULT_TIMEOUT = 300
# Default llm and embedding timeout
DEFAULT_LLM_TIMEOUT = 180

View file

@ -164,9 +164,7 @@ class ChromaVectorDBStorage(BaseVectorStorage):
logger.error(f"Error during ChromaDB upsert: {str(e)}")
raise
async def query(
self, query: str, top_k: int, ids: list[str] | None = None
) -> list[dict[str, Any]]:
async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
try:
embedding = await self.embedding_func(
[query], _priority=5

View file

@ -787,13 +787,13 @@ class PostgreSQLDB:
FROM information_schema.columns
WHERE table_name = $1 AND column_name = $2
"""
params = {
"table_name": migration["table"].lower(),
"column_name": migration["column"],
}
column_info = await self.query(
check_column_sql,
{
"table_name": migration["table"].lower(),
"column_name": migration["column"],
},
list(params.values()),
)
if not column_info:
@ -1035,10 +1035,8 @@ class PostgreSQLDB:
WHERE table_name = $1
AND table_schema = 'public'
"""
table_exists = await self.query(
check_table_sql, {"table_name": table_name.lower()}
)
params = {"table_name": table_name.lower()}
table_exists = await self.query(check_table_sql, list(params.values()))
if not table_exists:
logger.info(f"Creating table {table_name}")
@ -1121,7 +1119,8 @@ class PostgreSQLDB:
AND indexname = $1
"""
existing = await self.query(check_sql, {"indexname": index["name"]})
params = {"indexname": index["name"]}
existing = await self.query(check_sql, list(params.values()))
if not existing:
logger.info(f"Creating pagination index: {index['description']}")
@ -1217,7 +1216,7 @@ class PostgreSQLDB:
async def query(
self,
sql: str,
params: dict[str, Any] | None = None,
params: list[Any] | None = None,
multirows: bool = False,
with_age: bool = False,
graph_name: str | None = None,
@ -1230,7 +1229,7 @@ class PostgreSQLDB:
try:
if params:
rows = await connection.fetch(sql, *params.values())
rows = await connection.fetch(sql, *params)
else:
rows = await connection.fetch(sql)
@ -1446,7 +1445,7 @@ class PGKVStorage(BaseKVStorage):
params = {"workspace": self.workspace}
try:
results = await self.db.query(sql, params, multirows=True)
results = await self.db.query(sql, list(params.values()), multirows=True)
# Special handling for LLM cache to ensure compatibility with _get_cached_extraction_results
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
@ -1540,7 +1539,7 @@ class PGKVStorage(BaseKVStorage):
"""Get data by id."""
sql = SQL_TEMPLATES["get_by_id_" + self.namespace]
params = {"workspace": self.workspace, "id": id}
response = await self.db.query(sql, params)
response = await self.db.query(sql, list(params.values()))
if response and is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
# Parse llm_cache_list JSON string back to list
@ -1620,7 +1619,7 @@ class PGKVStorage(BaseKVStorage):
ids=",".join([f"'{id}'" for id in ids])
)
params = {"workspace": self.workspace}
results = await self.db.query(sql, params, multirows=True)
results = await self.db.query(sql, list(params.values()), multirows=True)
if results and is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
# Parse llm_cache_list JSON string back to list for each result
@ -1708,7 +1707,7 @@ class PGKVStorage(BaseKVStorage):
)
params = {"workspace": self.workspace}
try:
res = await self.db.query(sql, params, multirows=True)
res = await self.db.query(sql, list(params.values()), multirows=True)
if res:
exist_keys = [key["id"] for key in res]
else:
@ -2023,7 +2022,7 @@ class PGVectorStorage(BaseVectorStorage):
"closer_than_threshold": 1 - self.cosine_better_than_threshold,
"top_k": top_k,
}
results = await self.db.query(sql, params=params, multirows=True)
results = await self.db.query(sql, params=list(params.values()), multirows=True)
return results
async def index_done_callback(self) -> None:
@ -2120,7 +2119,7 @@ class PGVectorStorage(BaseVectorStorage):
params = {"workspace": self.workspace, "id": id}
try:
result = await self.db.query(query, params)
result = await self.db.query(query, list(params.values()))
if result:
return dict(result)
return None
@ -2154,7 +2153,7 @@ class PGVectorStorage(BaseVectorStorage):
params = {"workspace": self.workspace}
try:
results = await self.db.query(query, params, multirows=True)
results = await self.db.query(query, list(params.values()), multirows=True)
return [dict(record) for record in results]
except Exception as e:
logger.error(
@ -2187,7 +2186,7 @@ class PGVectorStorage(BaseVectorStorage):
params = {"workspace": self.workspace}
try:
results = await self.db.query(query, params, multirows=True)
results = await self.db.query(query, list(params.values()), multirows=True)
vectors_dict = {}
for result in results:
@ -2274,7 +2273,7 @@ class PGDocStatusStorage(DocStatusStorage):
)
params = {"workspace": self.workspace}
try:
res = await self.db.query(sql, params, multirows=True)
res = await self.db.query(sql, list(params.values()), multirows=True)
if res:
exist_keys = [key["id"] for key in res]
else:
@ -2292,7 +2291,7 @@ class PGDocStatusStorage(DocStatusStorage):
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and id=$2"
params = {"workspace": self.workspace, "id": id}
result = await self.db.query(sql, params, True)
result = await self.db.query(sql, list(params.values()), True)
if result is None or result == []:
return None
else:
@ -2338,7 +2337,7 @@ class PGDocStatusStorage(DocStatusStorage):
sql = "SELECT * FROM LIGHTRAG_DOC_STATUS WHERE workspace=$1 AND id = ANY($2)"
params = {"workspace": self.workspace, "ids": ids}
results = await self.db.query(sql, params, True)
results = await self.db.query(sql, list(params.values()), True)
if not results:
return []
@ -2389,7 +2388,8 @@ class PGDocStatusStorage(DocStatusStorage):
FROM LIGHTRAG_DOC_STATUS
where workspace=$1 GROUP BY STATUS
"""
result = await self.db.query(sql, {"workspace": self.workspace}, True)
params = {"workspace": self.workspace}
result = await self.db.query(sql, list(params.values()), True)
counts = {}
for doc in result:
counts[doc["status"]] = doc["count"]
@ -2401,7 +2401,7 @@ class PGDocStatusStorage(DocStatusStorage):
"""all documents with a specific status"""
sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and status=$2"
params = {"workspace": self.workspace, "status": status.value}
result = await self.db.query(sql, params, True)
result = await self.db.query(sql, list(params.values()), True)
docs_by_status = {}
for element in result:
@ -2455,7 +2455,7 @@ class PGDocStatusStorage(DocStatusStorage):
"""Get all documents with a specific track_id"""
sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and track_id=$2"
params = {"workspace": self.workspace, "track_id": track_id}
result = await self.db.query(sql, params, True)
result = await self.db.query(sql, list(params.values()), True)
docs_by_track_id = {}
for element in result:
@ -2555,7 +2555,7 @@ class PGDocStatusStorage(DocStatusStorage):
# Query for total count
count_sql = f"SELECT COUNT(*) as total FROM LIGHTRAG_DOC_STATUS {where_clause}"
count_result = await self.db.query(count_sql, params)
count_result = await self.db.query(count_sql, list(params.values()))
total_count = count_result["total"] if count_result else 0
# Query for paginated data
@ -2568,7 +2568,7 @@ class PGDocStatusStorage(DocStatusStorage):
params["limit"] = page_size
params["offset"] = offset
result = await self.db.query(data_sql, params, True)
result = await self.db.query(data_sql, list(params.values()), True)
# Convert to (doc_id, DocProcessingStatus) tuples
documents = []
@ -2625,7 +2625,7 @@ class PGDocStatusStorage(DocStatusStorage):
GROUP BY status
"""
params = {"workspace": self.workspace}
result = await self.db.query(sql, params, True)
result = await self.db.query(sql, list(params.values()), True)
counts = {}
total_count = 0
@ -3071,7 +3071,7 @@ class PGGraphStorage(BaseGraphStorage):
if readonly:
data = await self.db.query(
query,
params,
list(params.values()) if params else None,
multirows=True,
with_age=True,
graph_name=self.graph_name,
@ -3102,114 +3102,92 @@ class PGGraphStorage(BaseGraphStorage):
return result
async def has_node(self, node_id: str) -> bool:
entity_name_label = self._normalize_node_id(node_id)
query = f"""
SELECT EXISTS (
SELECT 1
FROM {self.graph_name}.base
WHERE ag_catalog.agtype_access_operator(
VARIADIC ARRAY[properties, '"entity_id"'::agtype]
) = (to_json($1::text)::text)::agtype
LIMIT 1
) AS node_exists;
"""
query = """SELECT * FROM cypher('%s', $$
MATCH (n:base {entity_id: "%s"})
RETURN count(n) > 0 AS node_exists
$$) AS (node_exists bool)""" % (self.graph_name, entity_name_label)
single_result = (await self._query(query))[0]
return single_result["node_exists"]
params = {"node_id": node_id}
row = (await self._query(query, params=params))[0]
return bool(row["node_exists"])
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
src_label = self._normalize_node_id(source_node_id)
tgt_label = self._normalize_node_id(target_node_id)
query = """SELECT * FROM cypher('%s', $$
MATCH (a:base {entity_id: "%s"})-[r]-(b:base {entity_id: "%s"})
RETURN COUNT(r) > 0 AS edge_exists
$$) AS (edge_exists bool)""" % (
self.graph_name,
src_label,
tgt_label,
)
single_result = (await self._query(query))[0]
return single_result["edge_exists"]
query = f"""
WITH a AS (
SELECT id AS vid
FROM {self.graph_name}.base
WHERE ag_catalog.agtype_access_operator(
VARIADIC ARRAY[properties, '"entity_id"'::agtype]
) = (to_json($1::text)::text)::agtype
),
b AS (
SELECT id AS vid
FROM {self.graph_name}.base
WHERE ag_catalog.agtype_access_operator(
VARIADIC ARRAY[properties, '"entity_id"'::agtype]
) = (to_json($2::text)::text)::agtype
)
SELECT EXISTS (
SELECT 1
FROM {self.graph_name}."DIRECTED" d
JOIN a ON d.start_id = a.vid
JOIN b ON d.end_id = b.vid
LIMIT 1
)
OR EXISTS (
SELECT 1
FROM {self.graph_name}."DIRECTED" d
JOIN a ON d.end_id = a.vid
JOIN b ON d.start_id = b.vid
LIMIT 1
) AS edge_exists;
"""
params = {
"source_node_id": source_node_id,
"target_node_id": target_node_id,
}
row = (await self._query(query, params=params))[0]
return bool(row["edge_exists"])
async def get_node(self, node_id: str) -> dict[str, str] | None:
"""Get node by its label identifier, return only node properties"""
label = self._normalize_node_id(node_id)
query = """SELECT * FROM cypher('%s', $$
MATCH (n:base {entity_id: "%s"})
RETURN n
$$) AS (n agtype)""" % (self.graph_name, label)
record = await self._query(query)
if record:
node = record[0]
node_dict = node["n"]["properties"]
# Process string result, parse it to JSON dictionary
if isinstance(node_dict, str):
try:
node_dict = json.loads(node_dict)
except json.JSONDecodeError:
logger.warning(
f"[{self.workspace}] Failed to parse node string: {node_dict}"
)
return node_dict
result = await self.get_nodes_batch(node_ids=[label])
if result and node_id in result:
return result[node_id]
return None
async def node_degree(self, node_id: str) -> int:
label = self._normalize_node_id(node_id)
query = """SELECT * FROM cypher('%s', $$
MATCH (n:base {entity_id: "%s"})-[r]-()
RETURN count(r) AS total_edge_count
$$) AS (total_edge_count integer)""" % (self.graph_name, label)
record = (await self._query(query))[0]
if record:
edge_count = int(record["total_edge_count"])
return edge_count
result = await self.node_degrees_batch(node_ids=[label])
if result and node_id in result:
return result[node_id]
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
src_degree = await self.node_degree(src_id)
trg_degree = await self.node_degree(tgt_id)
# Convert None to 0 for addition
src_degree = 0 if src_degree is None else src_degree
trg_degree = 0 if trg_degree is None else trg_degree
degrees = int(src_degree) + int(trg_degree)
return degrees
result = await self.edge_degrees_batch(edges=[(src_id, tgt_id)])
if result and (src_id, tgt_id) in result:
return result[(src_id, tgt_id)]
async def get_edge(
self, source_node_id: str, target_node_id: str
) -> dict[str, str] | None:
"""Get edge properties between two nodes"""
src_label = self._normalize_node_id(source_node_id)
tgt_label = self._normalize_node_id(target_node_id)
query = """SELECT * FROM cypher('%s', $$
MATCH (a:base {entity_id: "%s"})-[r]-(b:base {entity_id: "%s"})
RETURN properties(r) as edge_properties
LIMIT 1
$$) AS (edge_properties agtype)""" % (
self.graph_name,
src_label,
tgt_label,
)
record = await self._query(query)
if record and record[0] and record[0]["edge_properties"]:
result = record[0]["edge_properties"]
# Process string result, parse it to JSON dictionary
if isinstance(result, str):
try:
result = json.loads(result)
except json.JSONDecodeError:
logger.warning(
f"[{self.workspace}] Failed to parse edge string: {result}"
)
return result
result = await self.get_edges_batch([{"src": src_label, "tgt": tgt_label}])
if result and (src_label, tgt_label) in result:
return result[(src_label, tgt_label)]
return None
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
"""

View file

@ -1619,6 +1619,15 @@ class LightRAG:
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
# Prevent memory growth: keep only latest 5000 messages when exceeding 10000
if len(pipeline_status["history_messages"]) > 10000:
logger.info(
f"Trimming pipeline history from {len(pipeline_status['history_messages'])} to 5000 messages"
)
pipeline_status["history_messages"] = (
pipeline_status["history_messages"][-5000:]
)
# Get document content from full_docs
content_data = await self.full_docs.get_by_id(doc_id)
if not content_data:
@ -2246,6 +2255,7 @@ class LightRAG:
query.strip(),
system_prompt=system_prompt,
history_messages=param.conversation_history,
enable_cot=True,
stream=param.stream,
)
else:

View file

@ -59,12 +59,17 @@ async def anthropic_complete_if_cache(
prompt: str,
system_prompt: str | None = None,
history_messages: list[dict[str, Any]] | None = None,
enable_cot: bool = False,
base_url: str | None = None,
api_key: str | None = None,
**kwargs: Any,
) -> Union[str, AsyncIterator[str]]:
if history_messages is None:
history_messages = []
if enable_cot:
logger.debug(
"enable_cot=True is not supported for the Anthropic API and will be ignored."
)
if not api_key:
api_key = os.environ.get("ANTHROPIC_API_KEY")
@ -150,6 +155,7 @@ async def anthropic_complete(
prompt: str,
system_prompt: str | None = None,
history_messages: list[dict[str, Any]] | None = None,
enable_cot: bool = False,
**kwargs: Any,
) -> Union[str, AsyncIterator[str]]:
if history_messages is None:
@ -160,6 +166,7 @@ async def anthropic_complete(
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
enable_cot=enable_cot,
**kwargs,
)
@ -169,6 +176,7 @@ async def claude_3_opus_complete(
prompt: str,
system_prompt: str | None = None,
history_messages: list[dict[str, Any]] | None = None,
enable_cot: bool = False,
**kwargs: Any,
) -> Union[str, AsyncIterator[str]]:
if history_messages is None:
@ -178,6 +186,7 @@ async def claude_3_opus_complete(
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
enable_cot=enable_cot,
**kwargs,
)
@ -187,6 +196,7 @@ async def claude_3_sonnet_complete(
prompt: str,
system_prompt: str | None = None,
history_messages: list[dict[str, Any]] | None = None,
enable_cot: bool = False,
**kwargs: Any,
) -> Union[str, AsyncIterator[str]]:
if history_messages is None:
@ -196,6 +206,7 @@ async def claude_3_sonnet_complete(
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
enable_cot=enable_cot,
**kwargs,
)
@ -205,6 +216,7 @@ async def claude_3_haiku_complete(
prompt: str,
system_prompt: str | None = None,
history_messages: list[dict[str, Any]] | None = None,
enable_cot: bool = False,
**kwargs: Any,
) -> Union[str, AsyncIterator[str]]:
if history_messages is None:
@ -214,6 +226,7 @@ async def claude_3_haiku_complete(
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
enable_cot=enable_cot,
**kwargs,
)

View file

@ -24,6 +24,7 @@ from tenacity import (
from lightrag.utils import (
wrap_embedding_func_with_attrs,
safe_unicode_decode,
logger,
)
import numpy as np
@ -41,11 +42,16 @@ async def azure_openai_complete_if_cache(
prompt,
system_prompt: str | None = None,
history_messages: Iterable[ChatCompletionMessageParam] | None = None,
enable_cot: bool = False,
base_url: str | None = None,
api_key: str | None = None,
api_version: str | None = None,
**kwargs,
):
if enable_cot:
logger.debug(
"enable_cot=True is not supported for the Azure OpenAI API and will be ignored."
)
deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT") or model or os.getenv("LLM_MODEL")
base_url = (
base_url or os.getenv("AZURE_OPENAI_ENDPOINT") or os.getenv("LLM_BINDING_HOST")

View file

@ -44,11 +44,18 @@ async def bedrock_complete_if_cache(
prompt,
system_prompt=None,
history_messages=[],
enable_cot: bool = False,
aws_access_key_id=None,
aws_secret_access_key=None,
aws_session_token=None,
**kwargs,
) -> Union[str, AsyncIterator[str]]:
if enable_cot:
import logging
logging.debug(
"enable_cot=True is not supported for Bedrock and will be ignored."
)
# Respect existing env; only set if a non-empty value is available
access_key = os.environ.get("AWS_ACCESS_KEY_ID") or aws_access_key_id
secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY") or aws_secret_access_key

View file

@ -56,8 +56,15 @@ async def hf_model_if_cache(
prompt,
system_prompt=None,
history_messages=[],
enable_cot: bool = False,
**kwargs,
) -> str:
if enable_cot:
from lightrag.utils import logger
logger.debug(
"enable_cot=True is not supported for Hugging Face local models and will be ignored."
)
model_name = model
hf_model, hf_tokenizer = initialize_hf_model(model_name)
messages = []
@ -114,7 +121,12 @@ async def hf_model_if_cache(
async def hf_model_complete(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
prompt,
system_prompt=None,
history_messages=[],
keyword_extraction=False,
enable_cot: bool = False,
**kwargs,
) -> str:
kwargs.pop("keyword_extraction", None)
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
@ -123,6 +135,7 @@ async def hf_model_complete(
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
enable_cot=enable_cot,
**kwargs,
)
return result

View file

@ -94,9 +94,14 @@ async def llama_index_complete_if_cache(
prompt: str,
system_prompt: Optional[str] = None,
history_messages: List[dict] = [],
enable_cot: bool = False,
chat_kwargs={},
) -> str:
"""Complete the prompt using LlamaIndex."""
if enable_cot:
logger.debug(
"enable_cot=True is not supported for LlamaIndex implementation and will be ignored."
)
try:
# Format messages for chat
formatted_messages = []
@ -138,6 +143,7 @@ async def llama_index_complete(
prompt,
system_prompt=None,
history_messages=None,
enable_cot: bool = False,
keyword_extraction=False,
settings: LlamaIndexSettings = None,
**kwargs,
@ -162,6 +168,7 @@ async def llama_index_complete(
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
enable_cot=enable_cot,
**kwargs,
)
return result

View file

@ -56,6 +56,7 @@ async def lmdeploy_model_if_cache(
prompt,
system_prompt=None,
history_messages=[],
enable_cot: bool = False,
chat_template=None,
model_format="hf",
quant_policy=0,
@ -89,6 +90,12 @@ async def lmdeploy_model_if_cache(
do_sample (bool): Whether or not to use sampling, use greedy decoding otherwise.
Default to be False, which means greedy decoding will be applied.
"""
if enable_cot:
from lightrag.utils import logger
logger.debug(
"enable_cot=True is not supported for lmdeploy and will be ignored."
)
try:
import lmdeploy
from lmdeploy import version_info, GenerationConfig

View file

@ -39,10 +39,15 @@ async def lollms_model_if_cache(
prompt,
system_prompt=None,
history_messages=[],
enable_cot: bool = False,
base_url="http://localhost:9600",
**kwargs,
) -> Union[str, AsyncIterator[str]]:
"""Client implementation for lollms generation."""
if enable_cot:
from lightrag.utils import logger
logger.debug("enable_cot=True is not supported for lollms and will be ignored.")
stream = True if kwargs.get("stream") else False
api_key = kwargs.pop("api_key", None)
@ -98,7 +103,12 @@ async def lollms_model_if_cache(
async def lollms_model_complete(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
prompt,
system_prompt=None,
history_messages=[],
enable_cot: bool = False,
keyword_extraction=False,
**kwargs,
) -> Union[str, AsyncIterator[str]]:
"""Complete function for lollms model generation."""
@ -119,6 +129,7 @@ async def lollms_model_complete(
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
enable_cot=enable_cot,
**kwargs,
)

View file

@ -43,8 +43,11 @@ async def _ollama_model_if_cache(
prompt,
system_prompt=None,
history_messages=[],
enable_cot: bool = False,
**kwargs,
) -> Union[str, AsyncIterator[str]]:
if enable_cot:
logger.debug("enable_cot=True is not supported for ollama and will be ignored.")
stream = True if kwargs.get("stream") else False
kwargs.pop("max_tokens", None)
@ -123,7 +126,12 @@ async def _ollama_model_if_cache(
async def ollama_model_complete(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
prompt,
system_prompt=None,
history_messages=[],
enable_cot: bool = False,
keyword_extraction=False,
**kwargs,
) -> Union[str, AsyncIterator[str]]:
keyword_extraction = kwargs.pop("keyword_extraction", None)
if keyword_extraction:
@ -134,6 +142,7 @@ async def ollama_model_complete(
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
enable_cot=enable_cot,
**kwargs,
)

View file

@ -111,12 +111,29 @@ async def openai_complete_if_cache(
prompt: str,
system_prompt: str | None = None,
history_messages: list[dict[str, Any]] | None = None,
enable_cot: bool = False,
base_url: str | None = None,
api_key: str | None = None,
token_tracker: Any | None = None,
**kwargs: Any,
) -> str:
"""Complete a prompt using OpenAI's API with caching support.
"""Complete a prompt using OpenAI's API with caching support and Chain of Thought (COT) integration.
This function supports automatic integration of reasoning content (思维链) from models that provide
Chain of Thought capabilities. The reasoning content is seamlessly integrated into the response
using <think>...</think> tags.
Note on `reasoning_content`: This feature relies on a Deepseek Style `reasoning_content`
in the API response, which may be provided by OpenAI-compatible endpoints that support
Chain of Thought.
COT Integration Rules:
1. COT content is accepted only when regular content is empty and `reasoning_content` has content.
2. COT processing stops when regular content becomes available.
3. If both `content` and `reasoning_content` are present simultaneously, reasoning is ignored.
4. If both fields have content from the start, COT is never activated.
5. For streaming: COT content is inserted into the content stream with <think> tags.
6. For non-streaming: COT content is prepended to regular content with <think> tags.
Args:
model: The OpenAI model to use.
@ -125,6 +142,8 @@ async def openai_complete_if_cache(
history_messages: Optional list of previous messages in the conversation.
base_url: Optional base URL for the OpenAI API.
api_key: Optional OpenAI API key. If None, uses the OPENAI_API_KEY environment variable.
token_tracker: Optional token usage tracker for monitoring API usage.
enable_cot: Whether to enable Chain of Thought (COT) processing. Default is False.
**kwargs: Additional keyword arguments to pass to the OpenAI API.
Special kwargs:
- openai_client_configs: Dict of configuration options for the AsyncOpenAI client.
@ -134,7 +153,8 @@ async def openai_complete_if_cache(
- keyword_extraction: Will be removed from kwargs before passing to OpenAI.
Returns:
The completed text or an async iterator of text chunks if streaming.
The completed text (with integrated COT content if available) or an async iterator
of text chunks if streaming. COT content is wrapped in <think>...</think> tags.
Raises:
InvalidResponseError: If the response from OpenAI is invalid or empty.
@ -172,6 +192,7 @@ async def openai_complete_if_cache(
logger.debug("===== Entering func of LLM =====")
logger.debug(f"Model: {model} Base URL: {base_url}")
logger.debug(f"Client Configs: {client_configs}")
logger.debug(f"Additional kwargs: {kwargs}")
logger.debug(f"Num of history messages: {len(history_messages)}")
verbose_debug(f"System prompt: {system_prompt}")
@ -216,6 +237,11 @@ async def openai_complete_if_cache(
iteration_started = False
final_chunk_usage = None
# COT (Chain of Thought) state tracking
cot_active = False
cot_started = False
initial_content_seen = False
try:
iteration_started = True
async for chunk in response:
@ -231,20 +257,65 @@ async def openai_complete_if_cache(
logger.warning(f"Received chunk without choices: {chunk}")
continue
# Check if delta exists and has content
if not hasattr(chunk.choices[0], "delta") or not hasattr(
chunk.choices[0].delta, "content"
):
# Check if delta exists
if not hasattr(chunk.choices[0], "delta"):
# This might be the final chunk, continue to check for usage
continue
content = chunk.choices[0].delta.content
if content is None:
continue
if r"\u" in content:
content = safe_unicode_decode(content.encode("utf-8"))
delta = chunk.choices[0].delta
content = getattr(delta, "content", None)
reasoning_content = getattr(delta, "reasoning_content", None)
yield content
# Handle COT logic for streaming (only if enabled)
if enable_cot:
if content is not None and content != "":
# Regular content is present
if not initial_content_seen:
initial_content_seen = True
# If both content and reasoning_content are present initially, don't start COT
if (
reasoning_content is not None
and reasoning_content != ""
):
cot_active = False
cot_started = False
# If COT was active, end it
if cot_active:
yield "</think>"
cot_active = False
# Process regular content
if r"\u" in content:
content = safe_unicode_decode(content.encode("utf-8"))
yield content
elif reasoning_content is not None and reasoning_content != "":
# Only reasoning content is present
if not initial_content_seen and not cot_started:
# Start COT if we haven't seen initial content yet
if not cot_active:
yield "<think>"
cot_active = True
cot_started = True
# Process reasoning content if COT is active
if cot_active:
if r"\u" in reasoning_content:
reasoning_content = safe_unicode_decode(
reasoning_content.encode("utf-8")
)
yield reasoning_content
else:
# COT disabled, only process regular content
if content is not None and content != "":
if r"\u" in content:
content = safe_unicode_decode(content.encode("utf-8"))
yield content
# If neither content nor reasoning_content, continue to next chunk
if content is None and reasoning_content is None:
continue
# After streaming is complete, track token usage
if token_tracker and final_chunk_usage:
@ -312,21 +383,56 @@ async def openai_complete_if_cache(
not response
or not response.choices
or not hasattr(response.choices[0], "message")
or not hasattr(response.choices[0].message, "content")
):
logger.error("Invalid response from OpenAI API")
await openai_async_client.close() # Ensure client is closed
raise InvalidResponseError("Invalid response from OpenAI API")
content = response.choices[0].message.content
message = response.choices[0].message
content = getattr(message, "content", None)
reasoning_content = getattr(message, "reasoning_content", None)
if not content or content.strip() == "":
# Handle COT logic for non-streaming responses (only if enabled)
final_content = ""
if enable_cot:
# Check if we should include reasoning content
should_include_reasoning = False
if reasoning_content and reasoning_content.strip():
if not content or content.strip() == "":
# Case 1: Only reasoning content, should include COT
should_include_reasoning = True
final_content = (
content or ""
) # Use empty string if content is None
else:
# Case 3: Both content and reasoning_content present, ignore reasoning
should_include_reasoning = False
final_content = content
else:
# No reasoning content, use regular content
final_content = content or ""
# Apply COT wrapping if needed
if should_include_reasoning:
if r"\u" in reasoning_content:
reasoning_content = safe_unicode_decode(
reasoning_content.encode("utf-8")
)
final_content = f"<think>{reasoning_content}</think>{final_content}"
else:
# COT disabled, only use regular content
final_content = content or ""
# Validate final content
if not final_content or final_content.strip() == "":
logger.error("Received empty content from OpenAI API")
await openai_async_client.close() # Ensure client is closed
raise InvalidResponseError("Received empty content from OpenAI API")
if r"\u" in content:
content = safe_unicode_decode(content.encode("utf-8"))
# Apply Unicode decoding to final content if needed
if r"\u" in final_content:
final_content = safe_unicode_decode(final_content.encode("utf-8"))
if token_tracker and hasattr(response, "usage"):
token_counts = {
@ -338,10 +444,10 @@ async def openai_complete_if_cache(
}
token_tracker.add_usage(token_counts)
logger.debug(f"Response content len: {len(content)}")
logger.debug(f"Response content len: {len(final_content)}")
verbose_debug(f"Response: {response}")
return content
return final_content
finally:
# Ensure client is closed in all cases for non-streaming responses
await openai_async_client.close()
@ -373,6 +479,7 @@ async def gpt_4o_complete(
prompt,
system_prompt=None,
history_messages=None,
enable_cot: bool = False,
keyword_extraction=False,
**kwargs,
) -> str:
@ -386,6 +493,7 @@ async def gpt_4o_complete(
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
enable_cot=enable_cot,
**kwargs,
)
@ -394,6 +502,7 @@ async def gpt_4o_mini_complete(
prompt,
system_prompt=None,
history_messages=None,
enable_cot: bool = False,
keyword_extraction=False,
**kwargs,
) -> str:
@ -407,6 +516,7 @@ async def gpt_4o_mini_complete(
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
enable_cot=enable_cot,
**kwargs,
)
@ -415,6 +525,7 @@ async def nvidia_openai_complete(
prompt,
system_prompt=None,
history_messages=None,
enable_cot: bool = False,
keyword_extraction=False,
**kwargs,
) -> str:
@ -426,6 +537,7 @@ async def nvidia_openai_complete(
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
enable_cot=enable_cot,
base_url="https://integrate.api.nvidia.com/v1",
**kwargs,
)

View file

@ -49,8 +49,13 @@ async def zhipu_complete_if_cache(
api_key: Optional[str] = None,
system_prompt: Optional[str] = None,
history_messages: List[Dict[str, str]] = [],
enable_cot: bool = False,
**kwargs,
) -> str:
if enable_cot:
logger.debug(
"enable_cot=True is not supported for ZhipuAI and will be ignored."
)
# dynamically load ZhipuAI
try:
from zhipuai import ZhipuAI
@ -91,7 +96,12 @@ async def zhipu_complete_if_cache(
async def zhipu_complete(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
prompt,
system_prompt=None,
history_messages=[],
keyword_extraction=False,
enable_cot: bool = False,
**kwargs,
):
# Pop keyword_extraction from kwargs to avoid passing it to zhipu_complete_if_cache
keyword_extraction = kwargs.pop("keyword_extraction", None)
@ -122,6 +132,7 @@ async def zhipu_complete(
prompt=prompt,
system_prompt=system_prompt,
history_messages=history_messages,
enable_cot=enable_cot,
**kwargs,
)
@ -163,6 +174,7 @@ async def zhipu_complete(
prompt=prompt,
system_prompt=system_prompt,
history_messages=history_messages,
enable_cot=enable_cot,
**kwargs,
)

View file

@ -306,7 +306,7 @@ async def _summarize_descriptions(
use_prompt,
use_llm_func,
llm_response_cache=llm_response_cache,
cache_type="extract",
cache_type="summary",
)
return summary
@ -319,9 +319,8 @@ async def _handle_single_entity_extraction(
if len(record_attributes) < 4 or "entity" not in record_attributes[0]:
if len(record_attributes) > 1 and "entity" in record_attributes[0]:
logger.warning(
f"Entity extraction failed in {chunk_key}: expecting 4 fields but got {len(record_attributes)}"
f"{chunk_key}: Entity `{record_attributes[1]}` extraction failed -- expecting 4 fields but got {len(record_attributes)}"
)
logger.warning(f"Entity extracted: {record_attributes[1]}")
return None
try:
@ -389,9 +388,8 @@ async def _handle_single_relationship_extraction(
if len(record_attributes) < 5 or "relationship" not in record_attributes[0]:
if len(record_attributes) > 1 and "relationship" in record_attributes[0]:
logger.warning(
f"Relation extraction failed in {chunk_key}: expecting 5 fields but got {len(record_attributes)}"
f"{chunk_key}: Relation `{record_attributes[1]}` extraction failed -- expecting 5 fields but got {len(record_attributes)}"
)
logger.warning(f"Relation extracted: {record_attributes[1]}")
return None
try:
@ -839,6 +837,11 @@ async def _process_extraction_result(
bracket_pattern = f"[)](\\s*{re.escape(record_delimiter)}\\s*)[(]"
result = re.sub(bracket_pattern, ")\\1(", result)
if completion_delimiter not in result:
logger.warning(
f"{chunk_key}: Complete delimiter can not be found in extraction result"
)
records = split_string_by_multi_markers(
result,
[record_delimiter, completion_delimiter],
@ -1942,7 +1945,6 @@ async def extract_entities(
# add example's format
examples = examples.format(**example_context_base)
entity_extract_prompt = PROMPTS["entity_extraction"]
context_base = dict(
tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"],
record_delimiter=PROMPTS["DEFAULT_RECORD_DELIMITER"],
@ -1952,8 +1954,6 @@ async def extract_entities(
language=language,
)
continue_prompt = PROMPTS["entity_continue_extraction"].format(**context_base)
processed_chunks = 0
total_chunks = len(ordered_chunks)
@ -1976,13 +1976,20 @@ async def extract_entities(
cache_keys_collector = []
# Get initial extraction
hint_prompt = entity_extract_prompt.format(
entity_extraction_system_prompt = PROMPTS[
"entity_extraction_system_prompt"
].format(**{**context_base, "input_text": content})
entity_extraction_user_prompt = PROMPTS["entity_extraction_user_prompt"].format(
**{**context_base, "input_text": content}
)
entity_continue_extraction_user_prompt = PROMPTS[
"entity_continue_extraction_user_prompt"
].format(**{**context_base, "input_text": content})
final_result = await use_llm_func_with_cache(
hint_prompt,
entity_extraction_user_prompt,
use_llm_func,
system_prompt=entity_extraction_system_prompt,
llm_response_cache=llm_response_cache,
cache_type="extract",
chunk_id=chunk_key,
@ -1990,7 +1997,9 @@ async def extract_entities(
)
# Store LLM cache reference in chunk (will be handled by use_llm_func_with_cache)
history = pack_user_ass_to_openai_messages(hint_prompt, final_result)
history = pack_user_ass_to_openai_messages(
entity_extraction_user_prompt, final_result
)
# Process initial extraction with file path
maybe_nodes, maybe_edges = await _process_extraction_result(
@ -2005,8 +2014,9 @@ async def extract_entities(
# Process additional gleaning results
if entity_extract_max_gleaning > 0:
glean_result = await use_llm_func_with_cache(
continue_prompt,
entity_continue_extraction_user_prompt,
use_llm_func,
system_prompt=entity_extraction_system_prompt,
llm_response_cache=llm_response_cache,
history_messages=history,
cache_type="extract",
@ -2014,8 +2024,6 @@ async def extract_entities(
cache_keys_collector=cache_keys_collector,
)
history += pack_user_ass_to_openai_messages(continue_prompt, glean_result)
# Process gleaning result separately with file path
glean_nodes, glean_edges = await _process_extraction_result(
glean_result,
@ -2234,6 +2242,7 @@ async def kg_query(
query,
system_prompt=sys_prompt,
stream=query_param.stream,
enable_cot=True,
)
if isinstance(response, str) and len(response) > len(sys_prompt):
response = (
@ -3031,7 +3040,7 @@ async def _get_node_data(
):
# get similar entities
logger.info(
f"Query nodes: {query}, top_k: {query_param.top_k}, cosine: {entities_vdb.cosine_better_than_threshold}"
f"Query nodes: {query} (top_k:{query_param.top_k}, cosine:{entities_vdb.cosine_better_than_threshold})"
)
results = await entities_vdb.query(query, top_k=query_param.top_k)
@ -3307,7 +3316,7 @@ async def _get_edge_data(
query_param: QueryParam,
):
logger.info(
f"Query edges: {keywords}, top_k: {query_param.top_k}, cosine: {relationships_vdb.cosine_better_than_threshold}"
f"Query edges: {keywords} (top_k:{query_param.top_k}, cosine:{relationships_vdb.cosine_better_than_threshold})"
)
results = await relationships_vdb.query(keywords, top_k=query_param.top_k)
@ -3756,6 +3765,7 @@ async def naive_query(
query,
system_prompt=sys_prompt,
stream=query_param.stream,
enable_cot=True,
)
if isinstance(response, str) and len(response) > len(sys_prompt):

View file

@ -6,61 +6,64 @@ PROMPTS: dict[str, Any] = {}
PROMPTS["DEFAULT_TUPLE_DELIMITER"] = "<|>"
PROMPTS["DEFAULT_RECORD_DELIMITER"] = "##"
# TODO: Deprecated
PROMPTS["DEFAULT_COMPLETION_DELIMITER"] = "<|COMPLETE|>"
PROMPTS["DEFAULT_USER_PROMPT"] = "n/a"
PROMPTS["entity_extraction"] = """---Task---
For a given input text and entity types in the provided real data, extract all entities and their relationships, then return them in the specified language and format described below.
PROMPTS["entity_extraction_system_prompt"] = """---Role---
You are a Knowledge Graph Specialist responsible for extracting entities and relationships from the input text.
---Instructions---
1. Recognizing definitively conceptualized entities in text. For each identified entity, extract the following information:
- entity_name: Name of the entity, use same language as input text. If English, capitalized the name
- entity_type: Categorize the entity using the provided entity types. If a suitable category cannot be determined, classify it as `Other`.
- entity_description: Provide a comprehensive description of the entity's attributes and activities based on the information present in the input text. To ensure clarity and precision, all descriptions must replace pronouns and referential terms (e.g., "this document," "our company," "I," "you," "he/she") with the specific nouns they represent.
2. Format each entity as: (entity{tuple_delimiter}<entity_name>{tuple_delimiter}<entity_type>{tuple_delimiter}<entity_description>)
3. From the entities identified, identify all pairs of (source_entity, target_entity) that are directly and clearly related, and extract the following information:
- source_entity: name of the source entity
- target_entity: name of the target entity
- relationship_keywords: one or more high-level key words that summarize the overarching nature of the relationship, focusing on concepts or themes rather than specific details
- relationship_description: Explain the nature of the relationship between the source and target entities, providing a clear rationale for their connection
4. Format each relationship as: (relationship{tuple_delimiter}<source_entity>{tuple_delimiter}<target_entity>{tuple_delimiter}<relationship_keywords>{tuple_delimiter}<relationship_description>)
5. Use `{tuple_delimiter}` as field delimiter. Use `{record_delimiter}` as the entity or relation list delimiter.
6. Output `{completion_delimiter}` when all the entities and relationships are extracted.
7. Ensure the output language is {language}.
---Quality Guidelines---
- Only extract entities and relationships that are clearly defined and meaningful in the context
- Avoid over-interpretation; stick to what is explicitly stated in the text
- For all output content, explicitly name the subject or object rather than using pronouns
- Include specific numerical data in entity name when relevant
- Ensure entity names are consistent throughout the extraction
1. **Entity Extraction:** Identify clearly defined and meaningful entities in the input text, and extract the following information:
- entity_name: Name of the entity, ensure entity names are consistent throughout the extraction.
- entity_type: Categorize the entity using the following entity types: {entity_types}; if none of the provided types are suitable, classify it as `Other`.
- entity_description: Provide a comprehensive description of the entity's attributes and activities based on the information present in the input text.
2. **Entity Output Format:** (entity{tuple_delimiter}entity_name{tuple_delimiter}entity_type{tuple_delimiter}entity_description)
3. **Relationship Extraction:** Identify direct, clearly-stated and meaningful relationships between extracted entities within the input text, and extract the following information:
- source_entity: name of the source entity.
- target_entity: name of the target entity.
- relationship_keywords: one or more high-level key words that summarize the overarching nature of the relationship, focusing on concepts or themes rather than specific details.
- relationship_description: Explain the nature of the relationship between the source and target entities, providing a clear rationale for their connection.
4. **Relationship Output Format:** (relationship{tuple_delimiter}source_entity{tuple_delimiter}target_entity{tuple_delimiter}relationship_keywords{tuple_delimiter}relationship_description)
5. **Relationship Order:** Prioritize relationships based on their significance to the intended meaning of input text, and output more crucial relationships first.
6. **Avoid Pronouns:** For entity names and all descriptions, explicitly name the subject or object instead of using pronouns; avoid pronouns such as `this document`, `our company`, `I`, `you`, and `he/she`.
7. **Undirectional Relationship:** Treat relationships as undirected; swapping the source and target entities does not constitute a new relationship. Avoid outputting duplicate relationships.
8. **Language:** Output entity names, keywords and descriptions in {language}.
9. **Delimiter:** Use `{record_delimiter}` as the entity or relationship list delimiter; output `{completion_delimiter}` when all the entities and relationships are extracted.
---Examples---
{examples}
---Real Data---
---Real Data to be Processed---
<Input>
Entity_types: [{entity_types}]
Text:
```
{input_text}
```
"""
PROMPTS["entity_extraction_user_prompt"] = """---Task---
Extract entities and relationships from the input text to be Processed.
---Instructions---
1. Output entities and relationships, prioritized by their relevance to the input text's core meaning.
2. Output `{completion_delimiter}` when all the entities and relationships are extracted.
3. Ensure the output language is {language}.
<Output>
"""
PROMPTS["entity_continue_extraction"] = """---Task---
Identify any missed entities or relationships in the last extraction task.
PROMPTS["entity_continue_extraction_user_prompt"] = """---Task---
Identify any missed entities or relationships from the input text to be Processed of last extraction task.
---Instructions---
1. Output the entities and realtionships in the same format as previous extraction task.
2. Do not include entities and relations that have been previously extracted.
3. If the entity doesn't clearly fit in any of entity types provided, classify it as "Other".
2. Do not include entities and relations that have been correctly extracted in last extraction task.
3. If the entity or relation output is truncated or has missing fields in last extraction task, please re-output it in the correct format.
4. Output `{completion_delimiter}` when all the entities and relationships are extracted.
5. Ensure the output language is {language}.
@ -68,11 +71,7 @@ Identify any missed entities or relationships in the last extraction task.
"""
PROMPTS["entity_extraction_examples"] = [
"""[Example 1]
<Input>
Entity_types: [organization,person,location,event,technology,equiment,product,Document,category]
Text:
"""<Input Text>
```
while Alex clenched his jaw, the buzz of frustration dull against the backdrop of Taylor's authoritarian certainty. It was this competitive undercurrent that kept him alert, the sense that his and Jordan's shared commitment to discovery was an unspoken rebellion against Cruz's narrowing vision of control and order.
@ -97,11 +96,7 @@ It was a small transformation, barely perceptible, but one that Alex noted with
{completion_delimiter}
""",
"""[Example 2]
<Input>
Entity_types: [organization,person,location,event,technology,equiment,product,Document,category]
Text:
"""<Input Text>
```
Stock markets faced a sharp downturn today as tech giants saw significant declines, with the Global Tech Index dropping by 3.4% in midday trading. Analysts attribute the selloff to investor concerns over rising interest rates and regulatory uncertainty.
@ -128,11 +123,7 @@ Financial experts are closely watching the Federal Reserve's next move, as specu
{completion_delimiter}
""",
"""[Example 3]
<Input>
Entity_types: [organization,person,location,event,technology,equiment,product,Document,category]
Text:
"""<Input Text>
```
At the World Athletics Championship in Tokyo, Noah Carter broke the 100m sprint record using cutting-edge carbon-fiber spikes.
```
@ -150,29 +141,6 @@ At the World Athletics Championship in Tokyo, Noah Carter broke the 100m sprint
(relationship{tuple_delimiter}Noah Carter{tuple_delimiter}World Athletics Championship{tuple_delimiter}athlete participation, competition{tuple_delimiter}Noah Carter is competing at the World Athletics Championship.){record_delimiter}
{completion_delimiter}
""",
"""[Example 4]
<Input>
Entity_types: [organization,person,location,event,technology,equiment,product,Document,category]
Text:
```
在北京举行的人工智能大会上腾讯公司的首席技术官张伟发布了最新的大语言模型"腾讯智言"该模型在自然语言处理方面取得了重大突破
```
<Output>
(entity{tuple_delimiter}人工智能大会{tuple_delimiter}event{tuple_delimiter}人工智能大会是在北京举行的技术会议专注于人工智能领域的最新发展){record_delimiter}
(entity{tuple_delimiter}北京{tuple_delimiter}location{tuple_delimiter}北京是人工智能大会的举办城市){record_delimiter}
(entity{tuple_delimiter}腾讯公司{tuple_delimiter}organization{tuple_delimiter}腾讯公司是参与人工智能大会的科技企业发布了新的语言模型产品){record_delimiter}
(entity{tuple_delimiter}张伟{tuple_delimiter}person{tuple_delimiter}张伟是腾讯公司的首席技术官在大会上发布了新产品){record_delimiter}
(entity{tuple_delimiter}腾讯智言{tuple_delimiter}product{tuple_delimiter}腾讯智言是腾讯公司发布的大语言模型产品在自然语言处理方面有重大突破){record_delimiter}
(entity{tuple_delimiter}自然语言处理技术{tuple_delimiter}technology{tuple_delimiter}自然语言处理技术是腾讯智言模型取得重大突破的技术领域){record_delimiter}
(relationship{tuple_delimiter}人工智能大会{tuple_delimiter}北京{tuple_delimiter}会议地点, 举办关系{tuple_delimiter}人工智能大会在北京举行){record_delimiter}
(relationship{tuple_delimiter}张伟{tuple_delimiter}腾讯公司{tuple_delimiter}雇佣关系, 高管职位{tuple_delimiter}张伟担任腾讯公司的首席技术官){record_delimiter}
(relationship{tuple_delimiter}张伟{tuple_delimiter}腾讯智言{tuple_delimiter}产品发布, 技术展示{tuple_delimiter}张伟在大会上发布了腾讯智言大语言模型){record_delimiter}
(relationship{tuple_delimiter}腾讯智言{tuple_delimiter}自然语言处理技术{tuple_delimiter}技术应用, 突破创新{tuple_delimiter}腾讯智言在自然语言处理技术方面取得了重大突破){record_delimiter}
{completion_delimiter}
""",
]

View file

@ -473,12 +473,12 @@ def priority_limit_async_func_call(
nonlocal max_execution_timeout, max_task_duration
if max_execution_timeout is None:
max_execution_timeout = (
llm_timeout + 150
) # LLM timeout + 150s buffer for low-level retry
llm_timeout * 2
) # Reserved timeout buffer for low-level retry
if max_task_duration is None:
max_task_duration = (
llm_timeout + 180
) # LLM timeout + 180s buffer for health check phase
llm_timeout * 2 + 15
) # Reserved timeout buffer for health check phase
queue = asyncio.PriorityQueue(maxsize=max_queue_size)
tasks = set()
@ -707,7 +707,7 @@ def priority_limit_async_func_call(
timeout_info.append(f"Health Check: {max_task_duration}s")
timeout_str = (
f" (Timeouts: {', '.join(timeout_info)})" if timeout_info else ""
f"(Timeouts: {', '.join(timeout_info)})" if timeout_info else ""
)
logger.info(
f"{queue_name}: {workers_needed} new workers initialized {timeout_str}"
@ -1034,7 +1034,7 @@ async def handle_cache(
args_hash,
prompt,
mode="default",
cache_type=None,
cache_type="unknown",
) -> str | None:
"""Generic cache handling function with flattened cache keys"""
if hashing_kv is None:
@ -1646,9 +1646,10 @@ def remove_think_tags(text: str) -> str:
async def use_llm_func_with_cache(
input_text: str,
user_prompt: str,
use_llm_func: callable,
llm_response_cache: "BaseKVStorage | None" = None,
system_prompt: str | None = None,
max_tokens: int = None,
history_messages: list[dict[str, str]] = None,
cache_type: str = "extract",
@ -1677,7 +1678,10 @@ async def use_llm_func_with_cache(
LLM response text
"""
# Sanitize input text to prevent UTF-8 encoding errors for all LLM providers
safe_input_text = sanitize_text_for_encoding(input_text)
safe_user_prompt = sanitize_text_for_encoding(user_prompt)
safe_system_prompt = (
sanitize_text_for_encoding(system_prompt) if system_prompt else None
)
# Sanitize history messages if provided
safe_history_messages = None
@ -1688,13 +1692,19 @@ async def use_llm_func_with_cache(
if "content" in safe_msg:
safe_msg["content"] = sanitize_text_for_encoding(safe_msg["content"])
safe_history_messages.append(safe_msg)
history = json.dumps(safe_history_messages, ensure_ascii=False)
else:
history = None
if llm_response_cache:
if safe_history_messages:
history = json.dumps(safe_history_messages, ensure_ascii=False)
_prompt = history + "\n" + safe_input_text
else:
_prompt = safe_input_text
prompt_parts = []
if safe_user_prompt:
prompt_parts.append(safe_user_prompt)
if safe_system_prompt:
prompt_parts.append(safe_system_prompt)
if history:
prompt_parts.append(history)
_prompt = "\n".join(prompt_parts)
arg_hash = compute_args_hash(_prompt)
# Generate cache key for this LLM call
@ -1725,7 +1735,9 @@ async def use_llm_func_with_cache(
if max_tokens is not None:
kwargs["max_tokens"] = max_tokens
res: str = await use_llm_func(safe_input_text, **kwargs)
res: str = await use_llm_func(
safe_user_prompt, system_prompt=safe_system_prompt, **kwargs
)
res = remove_think_tags(res)
@ -1755,7 +1767,9 @@ async def use_llm_func_with_cache(
kwargs["max_tokens"] = max_tokens
try:
res = await use_llm_func(safe_input_text, **kwargs)
res = await use_llm_func(
safe_user_prompt, system_prompt=safe_system_prompt, **kwargs
)
except Exception as e:
# Add [LLM func] prefix to error message
error_msg = f"[LLM func] {str(e)}"

View file

@ -99,6 +99,9 @@ export type QueryMode = 'naive' | 'local' | 'global' | 'hybrid' | 'mix' | 'bypas
export type Message = {
role: 'user' | 'assistant' | 'system'
content: string
thinkingContent?: string
displayContent?: string
thinkingTime?: number | null
}
export type QueryRequest = {

View file

@ -15,12 +15,13 @@ import type { Element } from 'hast'
import { Prism as SyntaxHighlighter } from 'react-syntax-highlighter'
import { oneLight, oneDark } from 'react-syntax-highlighter/dist/cjs/styles/prism'
import { LoaderIcon, CopyIcon } from 'lucide-react'
import { LoaderIcon, CopyIcon, ChevronDownIcon } from 'lucide-react'
import { useTranslation } from 'react-i18next'
export type MessageWithError = Message & {
id: string // Unique identifier for stable React keys
isError?: boolean
isThinking?: boolean // Flag to indicate if the message is in a "thinking" state
/**
* Indicates if the mermaid diagram in this message has been rendered.
* Used to persist the rendering state across updates and prevent flickering.
@ -33,6 +34,26 @@ export const ChatMessage = ({ message }: { message: MessageWithError }) => { //
const { t } = useTranslation()
const { theme } = useTheme()
const [katexPlugin, setKatexPlugin] = useState<any>(null)
const [isThinkingExpanded, setIsThinkingExpanded] = useState<boolean>(false)
// Directly use props passed from the parent.
const { thinkingContent, displayContent, thinkingTime, isThinking } = message
// Reset expansion state when new thinking starts
useEffect(() => {
if (isThinking) {
// When thinking starts, always reset to collapsed state
setIsThinkingExpanded(false)
}
}, [isThinking, message.id])
// The content to display is now non-ambiguous.
const finalThinkingContent = thinkingContent
// For user messages, displayContent will be undefined, so we fall back to content.
// For assistant messages, we prefer displayContent but fallback to content for backward compatibility
const finalDisplayContent = message.role === 'user'
? message.content
: (displayContent !== undefined ? displayContent : (message.content || ''))
// Load KaTeX dynamically
useEffect(() => {
@ -59,6 +80,27 @@ export const ChatMessage = ({ message }: { message: MessageWithError }) => { //
}
}, [message, t]) // Added t to dependency array
const mainMarkdownComponents = useMemo(() => ({
code: (props: any) => (
<CodeHighlight
{...props}
renderAsDiagram={message.mermaidRendered ?? false}
/>
),
p: ({ children }: { children?: ReactNode }) => <p className="my-2">{children}</p>,
h1: ({ children }: { children?: ReactNode }) => <h1 className="text-xl font-bold mt-4 mb-2">{children}</h1>,
h2: ({ children }: { children?: ReactNode }) => <h2 className="text-lg font-bold mt-4 mb-2">{children}</h2>,
h3: ({ children }: { children?: ReactNode }) => <h3 className="text-base font-bold mt-3 mb-2">{children}</h3>,
h4: ({ children }: { children?: ReactNode }) => <h4 className="text-base font-semibold mt-3 mb-2">{children}</h4>,
ul: ({ children }: { children?: ReactNode }) => <ul className="list-disc pl-5 my-2">{children}</ul>,
ol: ({ children }: { children?: ReactNode }) => <ol className="list-decimal pl-5 my-2">{children}</ol>,
li: ({ children }: { children?: ReactNode }) => <li className="my-1">{children}</li>
}), [message.mermaidRendered]);
const thinkingMarkdownComponents = useMemo(() => ({
code: (props: any) => (<CodeHighlight {...props} renderAsDiagram={message.mermaidRendered ?? false} />)
}), [message.mermaidRendered]);
return (
<div
className={`${
@ -69,55 +111,93 @@ export const ChatMessage = ({ message }: { message: MessageWithError }) => { //
: 'w-[95%] bg-muted'
} rounded-lg px-4 py-2`}
>
<div className="relative">
<ReactMarkdown
className="prose dark:prose-invert max-w-none text-sm break-words prose-headings:mt-4 prose-headings:mb-2 prose-p:my-2 prose-ul:my-2 prose-ol:my-2 prose-li:my-1 [&_.katex]:text-current [&_.katex-display]:my-4 [&_.katex-display]:overflow-x-auto"
remarkPlugins={[remarkGfm, remarkMath]}
rehypePlugins={[
...(katexPlugin ? [[
katexPlugin,
{
errorColor: theme === 'dark' ? '#ef4444' : '#dc2626',
throwOnError: false,
displayMode: false
{/* Thinking process display - only for assistant messages */}
{message.role === 'assistant' && (isThinking || thinkingTime !== null) && (
<div className="mb-2">
<div
className="flex items-center text-gray-500 dark:text-gray-400 hover:text-gray-700 dark:hover:text-gray-200 transition-colors duration-200 text-sm cursor-pointer select-none"
onClick={() => {
// Allow expansion when there's thinking content, even during thinking process
if (finalThinkingContent && finalThinkingContent.trim() !== '') {
setIsThinkingExpanded(!isThinkingExpanded)
}
] as any] : []),
rehypeReact
]}
skipHtml={false}
// Memoize the components object to prevent unnecessary re-renders of ReactMarkdown children
components={useMemo(() => ({
code: (props: any) => ( // Add type annotation if needed, e.g., props: CodeProps from 'react-markdown/lib/ast-to-react'
<CodeHighlight
{...props}
renderAsDiagram={message.mermaidRendered ?? false}
/>
),
p: ({ children }: { children?: ReactNode }) => <p className="my-2">{children}</p>,
h1: ({ children }: { children?: ReactNode }) => <h1 className="text-xl font-bold mt-4 mb-2">{children}</h1>,
h2: ({ children }: { children?: ReactNode }) => <h2 className="text-lg font-bold mt-4 mb-2">{children}</h2>,
h3: ({ children }: { children?: ReactNode }) => <h3 className="text-base font-bold mt-3 mb-2">{children}</h3>,
h4: ({ children }: { children?: ReactNode }) => <h4 className="text-base font-semibold mt-3 mb-2">{children}</h4>,
ul: ({ children }: { children?: ReactNode }) => <ul className="list-disc pl-5 my-2">{children}</ul>,
ol: ({ children }: { children?: ReactNode }) => <ol className="list-decimal pl-5 my-2">{children}</ol>,
li: ({ children }: { children?: ReactNode }) => <li className="my-1">{children}</li>
}), [message.mermaidRendered])} // Dependency ensures update if mermaid state changes
>
{message.content}
</ReactMarkdown>
{message.role === 'assistant' && message.content && message.content.length > 0 && ( // Added check for message.content existence
<Button
onClick={handleCopyMarkdown}
className="absolute right-0 bottom-0 size-6 rounded-md opacity-20 transition-opacity hover:opacity-100"
tooltip={t('retrievePanel.chatMessage.copyTooltip')}
variant="default"
size="icon"
}}
>
<CopyIcon className="size-4" /> {/* Explicit size */}
</Button>
)}
</div>
{message.content === '' && <LoaderIcon className="animate-spin duration-2000" />} {/* Check for empty string specifically */}
{isThinking ? (
<>
<LoaderIcon className="mr-2 size-4 animate-spin" />
<span>{t('retrievePanel.chatMessage.thinking')}</span>
</>
) : (
typeof thinkingTime === 'number' && <span>{t('retrievePanel.chatMessage.thinkingTime', { time: thinkingTime })}</span>
)}
{/* Show chevron when there's thinking content, even during thinking process */}
{finalThinkingContent && finalThinkingContent.trim() !== '' && <ChevronDownIcon className={`ml-2 size-4 shrink-0 transition-transform ${isThinkingExpanded ? 'rotate-180' : ''}`} />}
</div>
{/* Show thinking content when expanded and content exists, even during thinking process */}
{isThinkingExpanded && finalThinkingContent && finalThinkingContent.trim() !== '' && (
<div className="mt-2 pl-4 border-l-2 border-primary/20 text-sm prose dark:prose-invert max-w-none break-words prose-p:my-1 prose-headings:my-2">
{isThinking && (
<div className="mb-2 text-xs text-gray-400 dark:text-gray-500 italic">
{t('retrievePanel.chatMessage.thinkingInProgress', 'Thinking in progress...')}
</div>
)}
<ReactMarkdown
remarkPlugins={[remarkGfm, remarkMath]}
rehypePlugins={[
...(katexPlugin ? [[katexPlugin, { errorColor: theme === 'dark' ? '#ef4444' : '#dc2626', throwOnError: false, displayMode: false }] as any] : []),
rehypeReact
]}
skipHtml={false}
components={thinkingMarkdownComponents}
>
{finalThinkingContent}
</ReactMarkdown>
</div>
)}
</div>
)}
{/* Main content display */}
{finalDisplayContent && (
<div className="relative">
<ReactMarkdown
className="prose dark:prose-invert max-w-none text-sm break-words prose-headings:mt-4 prose-headings:mb-2 prose-p:my-2 prose-ul:my-2 prose-ol:my-2 prose-li:my-1 [&_.katex]:text-current [&_.katex-display]:my-4 [&_.katex-display]:overflow-x-auto"
remarkPlugins={[remarkGfm, remarkMath]}
rehypePlugins={[
...(katexPlugin ? [[
katexPlugin,
{
errorColor: theme === 'dark' ? '#ef4444' : '#dc2626',
throwOnError: false,
displayMode: false
}
] as any] : []),
rehypeReact
]}
skipHtml={false}
components={mainMarkdownComponents}
>
{finalDisplayContent}
</ReactMarkdown>
{message.role === 'assistant' && finalDisplayContent && finalDisplayContent.length > 0 && (
<Button
onClick={handleCopyMarkdown}
className="absolute right-0 bottom-0 size-6 rounded-md opacity-20 transition-opacity hover:opacity-100"
tooltip={t('retrievePanel.chatMessage.copyTooltip')}
variant="default"
size="icon"
>
<CopyIcon className="size-4" /> {/* Explicit size */}
</Button>
)}
</div>
)}
{(() => {
// More comprehensive loading state check
const hasVisibleContent = finalDisplayContent && finalDisplayContent.trim() !== '';
const isLoadingState = !hasVisibleContent && !isThinking && !thinkingTime;
return isLoadingState && <LoaderIcon className="animate-spin duration-2000" />;
})()}
</div>
)
}

View file

@ -58,6 +58,8 @@ export default function RetrievalTesting() {
const [inputError, setInputError] = useState('') // Error message for input
// Reference to track if we should follow scroll during streaming (using ref for synchronous updates)
const shouldFollowScrollRef = useRef(true)
const thinkingStartTime = useRef<number | null>(null)
const thinkingProcessed = useRef(false)
// Reference to track if user interaction is from the form area
const isFormInteractionRef = useRef(false)
// Reference to track if scroll was triggered programmatically
@ -67,6 +69,16 @@ export default function RetrievalTesting() {
const messagesEndRef = useRef<HTMLDivElement>(null)
const messagesContainerRef = useRef<HTMLDivElement>(null)
// Add cleanup effect for memory leak prevention
useEffect(() => {
// Component cleanup - reset timer state to prevent memory leaks
return () => {
if (thinkingStartTime.current) {
thinkingStartTime.current = null;
}
};
}, []);
// Scroll to bottom function - restored smooth scrolling with better handling
const scrollToBottom = useCallback(() => {
// Set flag to indicate this is a programmatic scroll
@ -115,6 +127,10 @@ export default function RetrievalTesting() {
// Clear error message
setInputError('')
// Reset thinking timer state for new query to prevent confusion
thinkingStartTime.current = null
thinkingProcessed.current = false
// Create messages
// Save the original input (with prefix if any) in userMessage.content for display
const userMessage: MessageWithError = {
@ -127,7 +143,11 @@ export default function RetrievalTesting() {
id: generateUniqueId(), // Use browser-compatible ID generation
content: '',
role: 'assistant',
mermaidRendered: false
mermaidRendered: false,
thinkingTime: null, // Explicitly initialize to null
thinkingContent: undefined, // Explicitly initialize to undefined
displayContent: undefined, // Explicitly initialize to undefined
isThinking: false // Explicitly initialize to false
}
const prevMessages = [...messages]
@ -153,6 +173,47 @@ export default function RetrievalTesting() {
const updateAssistantMessage = (chunk: string, isError?: boolean) => {
assistantMessage.content += chunk
// Start thinking timer on first sight of think tag
if (assistantMessage.content.includes('<think>') && !thinkingStartTime.current) {
thinkingStartTime.current = Date.now()
}
// Real-time parsing for streaming
const thinkStartTag = '<think>'
const thinkEndTag = '</think>'
const thinkStartIndex = assistantMessage.content.indexOf(thinkStartTag)
const thinkEndIndex = assistantMessage.content.indexOf(thinkEndTag)
if (thinkStartIndex !== -1) {
if (thinkEndIndex !== -1) {
// Thinking has finished for this chunk
assistantMessage.isThinking = false
// Only calculate time and extract thinking content once
if (!thinkingProcessed.current) {
if (thinkingStartTime.current && !assistantMessage.thinkingTime) {
const duration = (Date.now() - thinkingStartTime.current) / 1000
assistantMessage.thinkingTime = parseFloat(duration.toFixed(2))
}
assistantMessage.thinkingContent = assistantMessage.content
.substring(thinkStartIndex + thinkStartTag.length, thinkEndIndex)
.trim()
thinkingProcessed.current = true
}
// Always update display content as content after </think> may grow
assistantMessage.displayContent = assistantMessage.content.substring(thinkEndIndex + thinkEndTag.length).trim()
} else {
// Still thinking - update thinking content in real-time
assistantMessage.isThinking = true
assistantMessage.thinkingContent = assistantMessage.content.substring(thinkStartIndex + thinkStartTag.length)
assistantMessage.displayContent = ''
}
} else {
assistantMessage.isThinking = false
assistantMessage.displayContent = assistantMessage.content
}
// Detect if the assistant message contains a complete mermaid code block
// Simple heuristic: look for ```mermaid ... ```
const mermaidBlockRegex = /```mermaid\s+([\s\S]+?)```/g
@ -167,13 +228,21 @@ export default function RetrievalTesting() {
}
assistantMessage.mermaidRendered = mermaidRendered
// Single unified update to avoid race conditions
setMessages((prev) => {
const newMessages = [...prev]
const lastMessage = newMessages[newMessages.length - 1]
if (lastMessage.role === 'assistant') {
lastMessage.content = assistantMessage.content
lastMessage.isError = isError
lastMessage.mermaidRendered = assistantMessage.mermaidRendered
if (lastMessage && lastMessage.id === assistantMessage.id) {
// Update all properties at once to maintain consistency
Object.assign(lastMessage, {
content: assistantMessage.content,
thinkingContent: assistantMessage.thinkingContent,
displayContent: assistantMessage.displayContent,
isThinking: assistantMessage.isThinking,
isError: isError,
mermaidRendered: assistantMessage.mermaidRendered,
thinkingTime: assistantMessage.thinkingTime
})
}
return newMessages
})
@ -223,9 +292,30 @@ export default function RetrievalTesting() {
// Clear loading and add messages to state
setIsLoading(false)
isReceivingResponseRef.current = false
useSettingsStore
.getState()
.setRetrievalHistory([...prevMessages, userMessage, assistantMessage])
// Enhanced cleanup with error handling to prevent memory leaks
try {
// Final calculation for thinking time, only if not already calculated
if (assistantMessage.thinkingContent && thinkingStartTime.current && !assistantMessage.thinkingTime) {
const duration = (Date.now() - thinkingStartTime.current) / 1000
assistantMessage.thinkingTime = parseFloat(duration.toFixed(2))
}
} catch (error) {
console.error('Error calculating thinking time:', error)
} finally {
// Ensure cleanup happens regardless of errors
assistantMessage.isThinking = false;
thinkingStartTime.current = null;
}
// Save history with error handling
try {
useSettingsStore
.getState()
.setRetrievalHistory([...prevMessages, userMessage, assistantMessage])
} catch (error) {
console.error('Error saving retrieval history:', error)
}
}
},
[inputValue, isLoading, messages, setMessages, t, scrollToBottom]

View file

@ -341,7 +341,10 @@
"retrievePanel": {
"chatMessage": {
"copyTooltip": "نسخ إلى الحافظة",
"copyError": "فشل نسخ النص إلى الحافظة"
"copyError": "فشل نسخ النص إلى الحافظة",
"thinking": "جاري التفكير...",
"thinkingTime": "وقت التفكير {{time}} ثانية",
"thinkingInProgress": "التفكير قيد التقدم..."
},
"retrieval": {
"startPrompt": "ابدأ الاسترجاع بكتابة استفسارك أدناه",

View file

@ -341,7 +341,10 @@
"retrievePanel": {
"chatMessage": {
"copyTooltip": "Copy to clipboard",
"copyError": "Failed to copy text to clipboard"
"copyError": "Failed to copy text to clipboard",
"thinking": "Thinking...",
"thinkingTime": "Thinking time {{time}}s",
"thinkingInProgress": "Thinking in progress..."
},
"retrieval": {
"startPrompt": "Start a retrieval by typing your query below",

View file

@ -341,7 +341,10 @@
"retrievePanel": {
"chatMessage": {
"copyTooltip": "Copier dans le presse-papiers",
"copyError": "Échec de la copie du texte dans le presse-papiers"
"copyError": "Échec de la copie du texte dans le presse-papiers",
"thinking": "Réflexion en cours...",
"thinkingTime": "Temps de réflexion {{time}}s",
"thinkingInProgress": "Réflexion en cours..."
},
"retrieval": {
"startPrompt": "Démarrez une récupération en tapant votre requête ci-dessous",

View file

@ -341,7 +341,10 @@
"retrievePanel": {
"chatMessage": {
"copyTooltip": "复制到剪贴板",
"copyError": "复制文本到剪贴板失败"
"copyError": "复制文本到剪贴板失败",
"thinking": "正在思考...",
"thinkingTime": "思考用时 {{time}} 秒",
"thinkingInProgress": "思考进行中..."
},
"retrieval": {
"startPrompt": "输入查询开始检索",

View file

@ -341,7 +341,10 @@
"retrievePanel": {
"chatMessage": {
"copyTooltip": "複製到剪貼簿",
"copyError": "複製文字到剪貼簿失敗"
"copyError": "複製文字到剪貼簿失敗",
"thinking": "正在思考...",
"thinkingTime": "思考用時 {{time}} 秒",
"thinkingInProgress": "思考進行中..."
},
"retrieval": {
"startPrompt": "輸入查詢開始檢索",