refactor: fix double query problem by add aquery_llm function for consistent response handling

- Add new aquery_llm/query_llm methods providing structured responses
- Consolidate /query and /query/stream endpoints to use unified aquery_llm
- Optimize cache handling by moving cache checks before LLM calls
This commit is contained in:
yangdx 2025-09-26 19:05:03 +08:00
parent 862026905a
commit 8cd4139cbf
4 changed files with 336 additions and 244 deletions

View file

@ -152,44 +152,42 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
)
async def query_text(request: QueryRequest):
"""
Handle a POST request at the /query endpoint to process user queries using RAG capabilities.
This endpoint performs a RAG query with non-streaming response.
Parameters:
request (QueryRequest): The request object containing the query parameters.
Returns:
QueryResponse: A Pydantic model containing the result of the query processing.
If include_references=True, also includes reference list.
If a string is returned (e.g., cache hit), it's directly returned.
Otherwise, an async generator may be used to build the response.
Raises:
HTTPException: Raised when an error occurs during the request handling process,
with status code 500 and detail containing the exception message.
"""
try:
param = request.to_query_params(False)
response = await rag.aquery(request.query, param=param)
param = request.to_query_params(
False
) # Ensure stream=False for non-streaming endpoint
# Force stream=False for /query endpoint regardless of include_references setting
param.stream = False
# Get reference list if requested
reference_list = None
# Unified approach: always use aquery_llm for both cases
result = await rag.aquery_llm(request.query, param=param)
# Extract LLM response and references from unified result
llm_response = result.get("llm_response", {})
references = result.get("data", {}).get("references", [])
# Get the non-streaming response content
response_content = llm_response.get("content", "")
if not response_content:
response_content = "No relevant context found for the query."
# Return response with or without references based on request
if request.include_references:
try:
# Use aquery_data to get reference list independently
data_result = await rag.aquery_data(request.query, param=param)
if isinstance(data_result, dict) and "data" in data_result:
reference_list = data_result["data"].get("references", [])
except Exception as e:
logging.warning(f"Failed to get reference list: {str(e)}")
reference_list = []
# Process response and return with optional references
if isinstance(response, str):
return QueryResponse(response=response, references=reference_list)
elif isinstance(response, dict):
result = json.dumps(response, indent=2)
return QueryResponse(response=result, references=reference_list)
return QueryResponse(response=response_content, references=references)
else:
return QueryResponse(response=str(response), references=reference_list)
return QueryResponse(response=response_content, references=None)
except Exception as e:
trace_exception(e)
raise HTTPException(status_code=500, detail=str(e))
@ -197,7 +195,8 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
@router.post("/query/stream", dependencies=[Depends(combined_auth)])
async def query_text_stream(request: QueryRequest):
"""
This endpoint performs a retrieval-augmented generation (RAG) query and streams the response.
This endpoint performs RAG query with streaming response.
Streaming can be turn off by setting stream=False in QueryRequest.
The streaming response includes:
1. Reference list (sent first as a single message, if include_references=True)
@ -213,49 +212,42 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
- Error messages: {"error": "..."} - If any errors occur
"""
try:
param = request.to_query_params(True)
response = await rag.aquery(request.query, param=param)
param = request.to_query_params(
True
) # Ensure stream=True for streaming endpoint
from fastapi.responses import StreamingResponse
async def stream_generator():
# Get reference list if requested (default is True for backward compatibility)
reference_list = []
if request.include_references:
try:
# Use aquery_data to get reference list independently
data_param = request.to_query_params(
False
) # Non-streaming for data
data_result = await rag.aquery_data(
request.query, param=data_param
)
if isinstance(data_result, dict) and "data" in data_result:
reference_list = data_result["data"].get("references", [])
except Exception as e:
logging.warning(f"Failed to get reference list: {str(e)}")
reference_list = []
# Unified approach: always use aquery_llm for all cases
result = await rag.aquery_llm(request.query, param=param)
# Send reference list first (if requested)
if request.include_references:
yield f"{json.dumps({'references': reference_list})}\n"
# Extract references and LLM response from unified result
references = result.get("data", {}).get("references", [])
llm_response = result.get("llm_response", {})
# Then stream the response content
if isinstance(response, str):
# If it's a string, send it all at once
yield f"{json.dumps({'response': response})}\n"
elif response is None:
# Handle None response (e.g., when only_need_context=True but no context found)
yield f"{json.dumps({'response': 'No relevant context found for the query.'})}\n"
# Send reference list first if requested
if request.include_references:
yield f"{json.dumps({'references': references})}\n"
# Then stream the LLM response content
if llm_response.get("is_streaming"):
response_stream = llm_response.get("response_iterator")
if response_stream:
try:
async for chunk in response_stream:
if chunk: # Only send non-empty content
yield f"{json.dumps({'response': chunk})}\n"
except Exception as e:
logging.error(f"Streaming error: {str(e)}")
yield f"{json.dumps({'error': str(e)})}\n"
else:
# If it's an async generator, send chunks one by one
try:
async for chunk in response:
if chunk: # Only send non-empty content
yield f"{json.dumps({'response': chunk})}\n"
except Exception as e:
logging.error(f"Streaming error: {str(e)}")
yield f"{json.dumps({'error': str(e)})}\n"
# Non-streaming response (fallback)
response_content = llm_response.get("content", "")
if response_content:
yield f"{json.dumps({'response': response_content})}\n"
else:
yield f"{json.dumps({'response': 'No relevant context found for the query.'})}\n"
return StreamingResponse(
stream_generator(),

View file

@ -2062,74 +2062,53 @@ class LightRAG:
system_prompt: str | None = None,
) -> str | AsyncIterator[str]:
"""
Perform a async query.
Perform a async query (backward compatibility wrapper).
This function is now a wrapper around aquery_llm that maintains backward compatibility
by returning only the LLM response content in the original format.
Args:
query (str): The query to be executed.
param (QueryParam): Configuration parameters for query execution.
If param.model_func is provided, it will be used instead of the global model.
prompt (Optional[str]): Custom prompts for fine-tuned control over the system's behavior. Defaults to None, which uses PROMPTS["rag_response"].
system_prompt (Optional[str]): Custom prompts for fine-tuned control over the system's behavior. Defaults to None, which uses PROMPTS["rag_response"].
Returns:
str: The result of the query execution.
str | AsyncIterator[str]: The LLM response content.
- Non-streaming: Returns str
- Streaming: Returns AsyncIterator[str]
"""
# If a custom model is provided in param, temporarily update global config
global_config = asdict(self)
# Call the new aquery_llm function to get complete results
result = await self.aquery_llm(query, param, system_prompt)
query_result = None
# Extract and return only the LLM response for backward compatibility
llm_response = result.get("llm_response", {})
if param.mode in ["local", "global", "hybrid", "mix"]:
query_result = await kg_query(
query.strip(),
self.chunk_entity_relation_graph,
self.entities_vdb,
self.relationships_vdb,
self.text_chunks,
param,
global_config,
hashing_kv=self.llm_response_cache,
system_prompt=system_prompt,
chunks_vdb=self.chunks_vdb,
)
elif param.mode == "naive":
query_result = await naive_query(
query.strip(),
self.chunks_vdb,
param,
global_config,
hashing_kv=self.llm_response_cache,
system_prompt=system_prompt,
)
elif param.mode == "bypass":
# Bypass mode: directly use LLM without knowledge retrieval
use_llm_func = param.model_func or global_config["llm_model_func"]
# Apply higher priority (8) to entity/relation summary tasks
use_llm_func = partial(use_llm_func, _priority=8)
param.stream = True if param.stream is None else param.stream
response = await use_llm_func(
query.strip(),
system_prompt=system_prompt,
history_messages=param.conversation_history,
enable_cot=True,
stream=param.stream,
)
# Create QueryResult for bypass mode
query_result = QueryResult(
content=response if not param.stream else None,
response_iterator=response if param.stream else None,
is_streaming=param.stream,
)
if llm_response.get("is_streaming"):
return llm_response.get("response_iterator")
else:
raise ValueError(f"Unknown mode {param.mode}")
return llm_response.get("content", "")
await self._query_done()
def query_data(
self,
query: str,
param: QueryParam = QueryParam(),
) -> dict[str, Any]:
"""
Synchronous data retrieval API: returns structured retrieval results without LLM generation.
# Return appropriate response based on streaming mode
if query_result.is_streaming:
return query_result.response_iterator
else:
return query_result.content
This function is the synchronous version of aquery_data, providing the same functionality
for users who prefer synchronous interfaces.
Args:
query: Query text for retrieval.
param: Query parameters controlling retrieval behavior (same as aquery).
Returns:
dict[str, Any]: Same structured data result as aquery_data.
"""
loop = always_get_an_event_loop()
return loop.run_until_complete(self.aquery_data(query, param))
async def aquery_data(
self,
@ -2323,6 +2302,162 @@ class LightRAG:
await self._query_done()
return final_data
async def aquery_llm(
self,
query: str,
param: QueryParam = QueryParam(),
system_prompt: str | None = None,
) -> dict[str, Any]:
"""
Asynchronous complete query API: returns structured retrieval results with LLM generation.
This function performs a single query operation and returns both structured data and LLM response,
based on the original aquery logic to avoid duplicate calls.
Args:
query: Query text for retrieval and LLM generation.
param: Query parameters controlling retrieval and LLM behavior.
system_prompt: Optional custom system prompt for LLM generation.
Returns:
dict[str, Any]: Complete response with structured data and LLM response.
"""
global_config = asdict(self)
try:
query_result = None
if param.mode in ["local", "global", "hybrid", "mix"]:
query_result = await kg_query(
query.strip(),
self.chunk_entity_relation_graph,
self.entities_vdb,
self.relationships_vdb,
self.text_chunks,
param,
global_config,
hashing_kv=self.llm_response_cache,
system_prompt=system_prompt,
chunks_vdb=self.chunks_vdb,
)
elif param.mode == "naive":
query_result = await naive_query(
query.strip(),
self.chunks_vdb,
param,
global_config,
hashing_kv=self.llm_response_cache,
system_prompt=system_prompt,
)
elif param.mode == "bypass":
# Bypass mode: directly use LLM without knowledge retrieval
use_llm_func = param.model_func or global_config["llm_model_func"]
# Apply higher priority (8) to entity/relation summary tasks
use_llm_func = partial(use_llm_func, _priority=8)
param.stream = True if param.stream is None else param.stream
response = await use_llm_func(
query.strip(),
system_prompt=system_prompt,
history_messages=param.conversation_history,
enable_cot=True,
stream=param.stream,
)
if type(response) is str:
return {
"status": "success",
"message": "Bypass mode LLM non streaming response",
"data": {},
"metadata": {},
"llm_response": {
"content": response,
"response_iterator": None,
"is_streaming": False,
},
}
else:
return {
"status": "success",
"message": "Bypass mode LLM streaming response",
"data": {},
"metadata": {},
"llm_response": {
"content": None,
"response_iterator": response,
"is_streaming": True,
},
}
else:
raise ValueError(f"Unknown mode {param.mode}")
await self._query_done()
# Check if query_result is None
if query_result is None:
return {
"status": "failure",
"message": "Query returned no results",
"data": {},
"metadata": {},
"llm_response": {
"content": None,
"response_iterator": None,
"is_streaming": False,
},
}
# Extract structured data from query result
raw_data = query_result.raw_data if query_result else {}
raw_data["llm_response"] = {
"content": query_result.content
if not query_result.is_streaming
else None,
"response_iterator": query_result.response_iterator
if query_result.is_streaming
else None,
"is_streaming": query_result.is_streaming,
}
return raw_data
except Exception as e:
logger.error(f"Query failed: {e}")
# Return error response
return {
"status": "failure",
"message": f"Query failed: {str(e)}",
"data": {},
"metadata": {},
"llm_response": {
"content": None,
"response_iterator": None,
"is_streaming": False,
},
}
def query_llm(
self,
query: str,
param: QueryParam = QueryParam(),
system_prompt: str | None = None,
) -> dict[str, Any]:
"""
Synchronous complete query API: returns structured retrieval results with LLM generation.
This function is the synchronous version of aquery_llm, providing the same functionality
for users who prefer synchronous interfaces.
Args:
query: Query text for retrieval and LLM generation.
param: Query parameters controlling retrieval and LLM behavior.
system_prompt: Optional custom system prompt for LLM generation.
Returns:
dict[str, Any]: Same complete response format as aquery_llm.
"""
loop = always_get_an_event_loop()
return loop.run_until_complete(self.aquery_llm(query, param, system_prompt))
async def _query_done(self):
await self.llm_response_cache.index_done_callback()

View file

@ -2236,38 +2236,6 @@ async def extract_entities(
return chunk_results
@overload
async def kg_query(
query: str,
knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage,
query_param: QueryParam,
global_config: dict[str, str],
hashing_kv: BaseKVStorage | None = None,
system_prompt: str | None = None,
chunks_vdb: BaseVectorStorage = None,
return_raw_data: Literal[True] = False,
) -> dict[str, Any]: ...
@overload
async def kg_query(
query: str,
knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage,
query_param: QueryParam,
global_config: dict[str, str],
hashing_kv: BaseKVStorage | None = None,
system_prompt: str | None = None,
chunks_vdb: BaseVectorStorage = None,
return_raw_data: Literal[False] = False,
) -> str | AsyncIterator[str]: ...
async def kg_query(
query: str,
knowledge_graph_inst: BaseGraphStorage,
@ -2308,7 +2276,6 @@ async def kg_query(
- stream=True: response_iterator contains streaming response, raw_data contains complete data
- default: content contains LLM response text, raw_data contains complete data
"""
if not query:
return QueryResult(content=PROMPTS["fail_response"])
@ -2319,32 +2286,6 @@ async def kg_query(
# Apply higher priority (5) to query relation LLM function
use_model_func = partial(use_model_func, _priority=5)
# Handle cache
args_hash = compute_args_hash(
query_param.mode,
query,
query_param.response_type,
query_param.top_k,
query_param.chunk_top_k,
query_param.max_entity_tokens,
query_param.max_relation_tokens,
query_param.max_total_tokens,
query_param.hl_keywords or [],
query_param.ll_keywords or [],
query_param.user_prompt or "",
query_param.enable_rerank,
)
cached_result = await handle_cache(
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
)
if (
cached_result is not None
and not query_param.only_need_context
and not query_param.only_need_prompt
):
cached_response, _ = cached_result # Extract content, ignore timestamp
return QueryResult(content=cached_response)
hl_keywords, ll_keywords = await get_keywords_from_query(
query, query_param, global_config, hashing_kv
)
@ -2417,29 +2358,41 @@ async def kg_query(
f"[kg_query] Sending to LLM: {len_of_prompts:,} tokens (Query: {len(tokenizer.encode(query))}, System: {len(tokenizer.encode(sys_prompt))})"
)
response = await use_model_func(
user_query,
system_prompt=sys_prompt,
history_messages=query_param.conversation_history,
enable_cot=True,
stream=query_param.stream,
# Handle cache
args_hash = compute_args_hash(
query_param.mode,
query,
query_param.response_type,
query_param.top_k,
query_param.chunk_top_k,
query_param.max_entity_tokens,
query_param.max_relation_tokens,
query_param.max_total_tokens,
query_param.hl_keywords or [],
query_param.ll_keywords or [],
query_param.user_prompt or "",
query_param.enable_rerank,
)
# Return unified result based on actual response type
if isinstance(response, str):
# Non-streaming response (string)
if len(response) > len(sys_prompt):
response = (
response.replace(sys_prompt, "")
.replace("user", "")
.replace("model", "")
.replace(query, "")
.replace("<system>", "")
.replace("</system>", "")
.strip()
)
cached_result = await handle_cache(
hashing_kv, args_hash, user_query, query_param.mode, cache_type="query"
)
if cached_result is not None:
cached_response, _ = cached_result # Extract content, ignore timestamp
logger.info(
" == LLM cache == Query cache hit, using cached response as query result"
)
response = cached_response
else:
response = await use_model_func(
user_query,
system_prompt=sys_prompt,
history_messages=query_param.conversation_history,
enable_cot=True,
stream=query_param.stream,
)
# Cache response
if hashing_kv and hashing_kv.global_config.get("enable_llm_cache"):
queryparam_dict = {
"mode": query_param.mode,
@ -2466,6 +2419,20 @@ async def kg_query(
),
)
# Return unified result based on actual response type
if isinstance(response, str):
# Non-streaming response (string)
if len(response) > len(sys_prompt):
response = (
response.replace(sys_prompt, "")
.replace("user", "")
.replace("model", "")
.replace(query, "")
.replace("<system>", "")
.replace("</system>", "")
.strip()
)
return QueryResult(content=response, raw_data=context_result.raw_data)
else:
# Streaming response (AsyncIterator)
@ -4052,29 +4019,6 @@ async def naive_query(
# Apply higher priority (5) to query relation LLM function
use_model_func = partial(use_model_func, _priority=5)
# Handle cache
args_hash = compute_args_hash(
query_param.mode,
query,
query_param.response_type,
query_param.top_k,
query_param.chunk_top_k,
query_param.max_entity_tokens,
query_param.max_relation_tokens,
query_param.max_total_tokens,
query_param.hl_keywords or [],
query_param.ll_keywords or [],
query_param.user_prompt or "",
query_param.enable_rerank,
)
cached_result = await handle_cache(
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
)
if cached_result is not None:
cached_response, _ = cached_result # Extract content, ignore timestamp
if not query_param.only_need_context and not query_param.only_need_prompt:
return QueryResult(content=cached_response)
tokenizer: Tokenizer = global_config["tokenizer"]
if not tokenizer:
logger.error("Tokenizer not found in global configuration.")
@ -4211,35 +4155,39 @@ async def naive_query(
prompt_content = "\n\n".join([sys_prompt, "---User Query---", user_query])
return QueryResult(content=prompt_content, raw_data=raw_data)
len_of_prompts = len(tokenizer.encode(query + sys_prompt))
logger.debug(
f"[naive_query] Sending to LLM: {len_of_prompts:,} tokens (Query: {len(tokenizer.encode(query))}, System: {len(tokenizer.encode(sys_prompt))})"
# Handle cache
args_hash = compute_args_hash(
query_param.mode,
query,
query_param.response_type,
query_param.top_k,
query_param.chunk_top_k,
query_param.max_entity_tokens,
query_param.max_relation_tokens,
query_param.max_total_tokens,
query_param.hl_keywords or [],
query_param.ll_keywords or [],
query_param.user_prompt or "",
query_param.enable_rerank,
)
response = await use_model_func(
user_query,
system_prompt=sys_prompt,
history_messages=query_param.conversation_history,
enable_cot=True,
stream=query_param.stream,
cached_result = await handle_cache(
hashing_kv, args_hash, user_query, query_param.mode, cache_type="query"
)
if cached_result is not None:
cached_response, _ = cached_result # Extract content, ignore timestamp
logger.info(
" == LLM cache == Query cache hit, using cached response as query result"
)
response = cached_response
else:
response = await use_model_func(
user_query,
system_prompt=sys_prompt,
history_messages=query_param.conversation_history,
enable_cot=True,
stream=query_param.stream,
)
# Return unified result based on actual response type
if isinstance(response, str):
# Non-streaming response (string)
if len(response) > len(sys_prompt):
response = (
response[len(sys_prompt) :]
.replace(sys_prompt, "")
.replace("user", "")
.replace("model", "")
.replace(query, "")
.replace("<system>", "")
.replace("</system>", "")
.strip()
)
# Cache response
if hashing_kv and hashing_kv.global_config.get("enable_llm_cache"):
queryparam_dict = {
"mode": query_param.mode,
@ -4266,6 +4214,21 @@ async def naive_query(
),
)
# Return unified result based on actual response type
if isinstance(response, str):
# Non-streaming response (string)
if len(response) > len(sys_prompt):
response = (
response[len(sys_prompt) :]
.replace(sys_prompt, "")
.replace("user", "")
.replace("model", "")
.replace(query, "")
.replace("<system>", "")
.replace("</system>", "")
.strip()
)
return QueryResult(content=response, raw_data=raw_data)
else:
# Streaming response (AsyncIterator)

View file

@ -1102,7 +1102,9 @@ async def save_to_cache(hashing_kv, cache_data: CacheData):
if existing_cache:
existing_content = existing_cache.get("return")
if existing_content == cache_data.content:
logger.info(f"Cache content unchanged for {flattened_key}, skipping update")
logger.warning(
f"Cache duplication detected for {flattened_key}, skipping update"
)
return
# Create cache entry with flattened structure