refactor: fix double query problem by add aquery_llm function for consistent response handling

- Add new aquery_llm/query_llm methods providing structured responses
- Consolidate /query and /query/stream endpoints to use unified aquery_llm
- Optimize cache handling by moving cache checks before LLM calls
This commit is contained in:
yangdx 2025-09-26 19:05:03 +08:00
parent 862026905a
commit 8cd4139cbf
4 changed files with 336 additions and 244 deletions

View file

@ -152,44 +152,42 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
) )
async def query_text(request: QueryRequest): async def query_text(request: QueryRequest):
""" """
Handle a POST request at the /query endpoint to process user queries using RAG capabilities. This endpoint performs a RAG query with non-streaming response.
Parameters: Parameters:
request (QueryRequest): The request object containing the query parameters. request (QueryRequest): The request object containing the query parameters.
Returns: Returns:
QueryResponse: A Pydantic model containing the result of the query processing. QueryResponse: A Pydantic model containing the result of the query processing.
If include_references=True, also includes reference list. If include_references=True, also includes reference list.
If a string is returned (e.g., cache hit), it's directly returned.
Otherwise, an async generator may be used to build the response.
Raises: Raises:
HTTPException: Raised when an error occurs during the request handling process, HTTPException: Raised when an error occurs during the request handling process,
with status code 500 and detail containing the exception message. with status code 500 and detail containing the exception message.
""" """
try: try:
param = request.to_query_params(False) param = request.to_query_params(
response = await rag.aquery(request.query, param=param) False
) # Ensure stream=False for non-streaming endpoint
# Force stream=False for /query endpoint regardless of include_references setting
param.stream = False
# Get reference list if requested # Unified approach: always use aquery_llm for both cases
reference_list = None result = await rag.aquery_llm(request.query, param=param)
# Extract LLM response and references from unified result
llm_response = result.get("llm_response", {})
references = result.get("data", {}).get("references", [])
# Get the non-streaming response content
response_content = llm_response.get("content", "")
if not response_content:
response_content = "No relevant context found for the query."
# Return response with or without references based on request
if request.include_references: if request.include_references:
try: return QueryResponse(response=response_content, references=references)
# Use aquery_data to get reference list independently
data_result = await rag.aquery_data(request.query, param=param)
if isinstance(data_result, dict) and "data" in data_result:
reference_list = data_result["data"].get("references", [])
except Exception as e:
logging.warning(f"Failed to get reference list: {str(e)}")
reference_list = []
# Process response and return with optional references
if isinstance(response, str):
return QueryResponse(response=response, references=reference_list)
elif isinstance(response, dict):
result = json.dumps(response, indent=2)
return QueryResponse(response=result, references=reference_list)
else: else:
return QueryResponse(response=str(response), references=reference_list) return QueryResponse(response=response_content, references=None)
except Exception as e: except Exception as e:
trace_exception(e) trace_exception(e)
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@ -197,7 +195,8 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
@router.post("/query/stream", dependencies=[Depends(combined_auth)]) @router.post("/query/stream", dependencies=[Depends(combined_auth)])
async def query_text_stream(request: QueryRequest): async def query_text_stream(request: QueryRequest):
""" """
This endpoint performs a retrieval-augmented generation (RAG) query and streams the response. This endpoint performs RAG query with streaming response.
Streaming can be turn off by setting stream=False in QueryRequest.
The streaming response includes: The streaming response includes:
1. Reference list (sent first as a single message, if include_references=True) 1. Reference list (sent first as a single message, if include_references=True)
@ -213,49 +212,42 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
- Error messages: {"error": "..."} - If any errors occur - Error messages: {"error": "..."} - If any errors occur
""" """
try: try:
param = request.to_query_params(True) param = request.to_query_params(
response = await rag.aquery(request.query, param=param) True
) # Ensure stream=True for streaming endpoint
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
async def stream_generator(): async def stream_generator():
# Get reference list if requested (default is True for backward compatibility) # Unified approach: always use aquery_llm for all cases
reference_list = [] result = await rag.aquery_llm(request.query, param=param)
if request.include_references:
try:
# Use aquery_data to get reference list independently
data_param = request.to_query_params(
False
) # Non-streaming for data
data_result = await rag.aquery_data(
request.query, param=data_param
)
if isinstance(data_result, dict) and "data" in data_result:
reference_list = data_result["data"].get("references", [])
except Exception as e:
logging.warning(f"Failed to get reference list: {str(e)}")
reference_list = []
# Send reference list first (if requested) # Extract references and LLM response from unified result
if request.include_references: references = result.get("data", {}).get("references", [])
yield f"{json.dumps({'references': reference_list})}\n" llm_response = result.get("llm_response", {})
# Then stream the response content # Send reference list first if requested
if isinstance(response, str): if request.include_references:
# If it's a string, send it all at once yield f"{json.dumps({'references': references})}\n"
yield f"{json.dumps({'response': response})}\n"
elif response is None: # Then stream the LLM response content
# Handle None response (e.g., when only_need_context=True but no context found) if llm_response.get("is_streaming"):
yield f"{json.dumps({'response': 'No relevant context found for the query.'})}\n" response_stream = llm_response.get("response_iterator")
if response_stream:
try:
async for chunk in response_stream:
if chunk: # Only send non-empty content
yield f"{json.dumps({'response': chunk})}\n"
except Exception as e:
logging.error(f"Streaming error: {str(e)}")
yield f"{json.dumps({'error': str(e)})}\n"
else: else:
# If it's an async generator, send chunks one by one # Non-streaming response (fallback)
try: response_content = llm_response.get("content", "")
async for chunk in response: if response_content:
if chunk: # Only send non-empty content yield f"{json.dumps({'response': response_content})}\n"
yield f"{json.dumps({'response': chunk})}\n" else:
except Exception as e: yield f"{json.dumps({'response': 'No relevant context found for the query.'})}\n"
logging.error(f"Streaming error: {str(e)}")
yield f"{json.dumps({'error': str(e)})}\n"
return StreamingResponse( return StreamingResponse(
stream_generator(), stream_generator(),

View file

@ -2062,74 +2062,53 @@ class LightRAG:
system_prompt: str | None = None, system_prompt: str | None = None,
) -> str | AsyncIterator[str]: ) -> str | AsyncIterator[str]:
""" """
Perform a async query. Perform a async query (backward compatibility wrapper).
This function is now a wrapper around aquery_llm that maintains backward compatibility
by returning only the LLM response content in the original format.
Args: Args:
query (str): The query to be executed. query (str): The query to be executed.
param (QueryParam): Configuration parameters for query execution. param (QueryParam): Configuration parameters for query execution.
If param.model_func is provided, it will be used instead of the global model. If param.model_func is provided, it will be used instead of the global model.
prompt (Optional[str]): Custom prompts for fine-tuned control over the system's behavior. Defaults to None, which uses PROMPTS["rag_response"]. system_prompt (Optional[str]): Custom prompts for fine-tuned control over the system's behavior. Defaults to None, which uses PROMPTS["rag_response"].
Returns: Returns:
str: The result of the query execution. str | AsyncIterator[str]: The LLM response content.
- Non-streaming: Returns str
- Streaming: Returns AsyncIterator[str]
""" """
# If a custom model is provided in param, temporarily update global config # Call the new aquery_llm function to get complete results
global_config = asdict(self) result = await self.aquery_llm(query, param, system_prompt)
query_result = None # Extract and return only the LLM response for backward compatibility
llm_response = result.get("llm_response", {})
if param.mode in ["local", "global", "hybrid", "mix"]: if llm_response.get("is_streaming"):
query_result = await kg_query( return llm_response.get("response_iterator")
query.strip(),
self.chunk_entity_relation_graph,
self.entities_vdb,
self.relationships_vdb,
self.text_chunks,
param,
global_config,
hashing_kv=self.llm_response_cache,
system_prompt=system_prompt,
chunks_vdb=self.chunks_vdb,
)
elif param.mode == "naive":
query_result = await naive_query(
query.strip(),
self.chunks_vdb,
param,
global_config,
hashing_kv=self.llm_response_cache,
system_prompt=system_prompt,
)
elif param.mode == "bypass":
# Bypass mode: directly use LLM without knowledge retrieval
use_llm_func = param.model_func or global_config["llm_model_func"]
# Apply higher priority (8) to entity/relation summary tasks
use_llm_func = partial(use_llm_func, _priority=8)
param.stream = True if param.stream is None else param.stream
response = await use_llm_func(
query.strip(),
system_prompt=system_prompt,
history_messages=param.conversation_history,
enable_cot=True,
stream=param.stream,
)
# Create QueryResult for bypass mode
query_result = QueryResult(
content=response if not param.stream else None,
response_iterator=response if param.stream else None,
is_streaming=param.stream,
)
else: else:
raise ValueError(f"Unknown mode {param.mode}") return llm_response.get("content", "")
await self._query_done() def query_data(
self,
query: str,
param: QueryParam = QueryParam(),
) -> dict[str, Any]:
"""
Synchronous data retrieval API: returns structured retrieval results without LLM generation.
# Return appropriate response based on streaming mode This function is the synchronous version of aquery_data, providing the same functionality
if query_result.is_streaming: for users who prefer synchronous interfaces.
return query_result.response_iterator
else: Args:
return query_result.content query: Query text for retrieval.
param: Query parameters controlling retrieval behavior (same as aquery).
Returns:
dict[str, Any]: Same structured data result as aquery_data.
"""
loop = always_get_an_event_loop()
return loop.run_until_complete(self.aquery_data(query, param))
async def aquery_data( async def aquery_data(
self, self,
@ -2323,6 +2302,162 @@ class LightRAG:
await self._query_done() await self._query_done()
return final_data return final_data
async def aquery_llm(
self,
query: str,
param: QueryParam = QueryParam(),
system_prompt: str | None = None,
) -> dict[str, Any]:
"""
Asynchronous complete query API: returns structured retrieval results with LLM generation.
This function performs a single query operation and returns both structured data and LLM response,
based on the original aquery logic to avoid duplicate calls.
Args:
query: Query text for retrieval and LLM generation.
param: Query parameters controlling retrieval and LLM behavior.
system_prompt: Optional custom system prompt for LLM generation.
Returns:
dict[str, Any]: Complete response with structured data and LLM response.
"""
global_config = asdict(self)
try:
query_result = None
if param.mode in ["local", "global", "hybrid", "mix"]:
query_result = await kg_query(
query.strip(),
self.chunk_entity_relation_graph,
self.entities_vdb,
self.relationships_vdb,
self.text_chunks,
param,
global_config,
hashing_kv=self.llm_response_cache,
system_prompt=system_prompt,
chunks_vdb=self.chunks_vdb,
)
elif param.mode == "naive":
query_result = await naive_query(
query.strip(),
self.chunks_vdb,
param,
global_config,
hashing_kv=self.llm_response_cache,
system_prompt=system_prompt,
)
elif param.mode == "bypass":
# Bypass mode: directly use LLM without knowledge retrieval
use_llm_func = param.model_func or global_config["llm_model_func"]
# Apply higher priority (8) to entity/relation summary tasks
use_llm_func = partial(use_llm_func, _priority=8)
param.stream = True if param.stream is None else param.stream
response = await use_llm_func(
query.strip(),
system_prompt=system_prompt,
history_messages=param.conversation_history,
enable_cot=True,
stream=param.stream,
)
if type(response) is str:
return {
"status": "success",
"message": "Bypass mode LLM non streaming response",
"data": {},
"metadata": {},
"llm_response": {
"content": response,
"response_iterator": None,
"is_streaming": False,
},
}
else:
return {
"status": "success",
"message": "Bypass mode LLM streaming response",
"data": {},
"metadata": {},
"llm_response": {
"content": None,
"response_iterator": response,
"is_streaming": True,
},
}
else:
raise ValueError(f"Unknown mode {param.mode}")
await self._query_done()
# Check if query_result is None
if query_result is None:
return {
"status": "failure",
"message": "Query returned no results",
"data": {},
"metadata": {},
"llm_response": {
"content": None,
"response_iterator": None,
"is_streaming": False,
},
}
# Extract structured data from query result
raw_data = query_result.raw_data if query_result else {}
raw_data["llm_response"] = {
"content": query_result.content
if not query_result.is_streaming
else None,
"response_iterator": query_result.response_iterator
if query_result.is_streaming
else None,
"is_streaming": query_result.is_streaming,
}
return raw_data
except Exception as e:
logger.error(f"Query failed: {e}")
# Return error response
return {
"status": "failure",
"message": f"Query failed: {str(e)}",
"data": {},
"metadata": {},
"llm_response": {
"content": None,
"response_iterator": None,
"is_streaming": False,
},
}
def query_llm(
self,
query: str,
param: QueryParam = QueryParam(),
system_prompt: str | None = None,
) -> dict[str, Any]:
"""
Synchronous complete query API: returns structured retrieval results with LLM generation.
This function is the synchronous version of aquery_llm, providing the same functionality
for users who prefer synchronous interfaces.
Args:
query: Query text for retrieval and LLM generation.
param: Query parameters controlling retrieval and LLM behavior.
system_prompt: Optional custom system prompt for LLM generation.
Returns:
dict[str, Any]: Same complete response format as aquery_llm.
"""
loop = always_get_an_event_loop()
return loop.run_until_complete(self.aquery_llm(query, param, system_prompt))
async def _query_done(self): async def _query_done(self):
await self.llm_response_cache.index_done_callback() await self.llm_response_cache.index_done_callback()

View file

@ -2236,38 +2236,6 @@ async def extract_entities(
return chunk_results return chunk_results
@overload
async def kg_query(
query: str,
knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage,
query_param: QueryParam,
global_config: dict[str, str],
hashing_kv: BaseKVStorage | None = None,
system_prompt: str | None = None,
chunks_vdb: BaseVectorStorage = None,
return_raw_data: Literal[True] = False,
) -> dict[str, Any]: ...
@overload
async def kg_query(
query: str,
knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage,
query_param: QueryParam,
global_config: dict[str, str],
hashing_kv: BaseKVStorage | None = None,
system_prompt: str | None = None,
chunks_vdb: BaseVectorStorage = None,
return_raw_data: Literal[False] = False,
) -> str | AsyncIterator[str]: ...
async def kg_query( async def kg_query(
query: str, query: str,
knowledge_graph_inst: BaseGraphStorage, knowledge_graph_inst: BaseGraphStorage,
@ -2308,7 +2276,6 @@ async def kg_query(
- stream=True: response_iterator contains streaming response, raw_data contains complete data - stream=True: response_iterator contains streaming response, raw_data contains complete data
- default: content contains LLM response text, raw_data contains complete data - default: content contains LLM response text, raw_data contains complete data
""" """
if not query: if not query:
return QueryResult(content=PROMPTS["fail_response"]) return QueryResult(content=PROMPTS["fail_response"])
@ -2319,32 +2286,6 @@ async def kg_query(
# Apply higher priority (5) to query relation LLM function # Apply higher priority (5) to query relation LLM function
use_model_func = partial(use_model_func, _priority=5) use_model_func = partial(use_model_func, _priority=5)
# Handle cache
args_hash = compute_args_hash(
query_param.mode,
query,
query_param.response_type,
query_param.top_k,
query_param.chunk_top_k,
query_param.max_entity_tokens,
query_param.max_relation_tokens,
query_param.max_total_tokens,
query_param.hl_keywords or [],
query_param.ll_keywords or [],
query_param.user_prompt or "",
query_param.enable_rerank,
)
cached_result = await handle_cache(
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
)
if (
cached_result is not None
and not query_param.only_need_context
and not query_param.only_need_prompt
):
cached_response, _ = cached_result # Extract content, ignore timestamp
return QueryResult(content=cached_response)
hl_keywords, ll_keywords = await get_keywords_from_query( hl_keywords, ll_keywords = await get_keywords_from_query(
query, query_param, global_config, hashing_kv query, query_param, global_config, hashing_kv
) )
@ -2417,29 +2358,41 @@ async def kg_query(
f"[kg_query] Sending to LLM: {len_of_prompts:,} tokens (Query: {len(tokenizer.encode(query))}, System: {len(tokenizer.encode(sys_prompt))})" f"[kg_query] Sending to LLM: {len_of_prompts:,} tokens (Query: {len(tokenizer.encode(query))}, System: {len(tokenizer.encode(sys_prompt))})"
) )
response = await use_model_func( # Handle cache
user_query, args_hash = compute_args_hash(
system_prompt=sys_prompt, query_param.mode,
history_messages=query_param.conversation_history, query,
enable_cot=True, query_param.response_type,
stream=query_param.stream, query_param.top_k,
query_param.chunk_top_k,
query_param.max_entity_tokens,
query_param.max_relation_tokens,
query_param.max_total_tokens,
query_param.hl_keywords or [],
query_param.ll_keywords or [],
query_param.user_prompt or "",
query_param.enable_rerank,
) )
# Return unified result based on actual response type cached_result = await handle_cache(
if isinstance(response, str): hashing_kv, args_hash, user_query, query_param.mode, cache_type="query"
# Non-streaming response (string) )
if len(response) > len(sys_prompt):
response = ( if cached_result is not None:
response.replace(sys_prompt, "") cached_response, _ = cached_result # Extract content, ignore timestamp
.replace("user", "") logger.info(
.replace("model", "") " == LLM cache == Query cache hit, using cached response as query result"
.replace(query, "") )
.replace("<system>", "") response = cached_response
.replace("</system>", "") else:
.strip() response = await use_model_func(
) user_query,
system_prompt=sys_prompt,
history_messages=query_param.conversation_history,
enable_cot=True,
stream=query_param.stream,
)
# Cache response
if hashing_kv and hashing_kv.global_config.get("enable_llm_cache"): if hashing_kv and hashing_kv.global_config.get("enable_llm_cache"):
queryparam_dict = { queryparam_dict = {
"mode": query_param.mode, "mode": query_param.mode,
@ -2466,6 +2419,20 @@ async def kg_query(
), ),
) )
# Return unified result based on actual response type
if isinstance(response, str):
# Non-streaming response (string)
if len(response) > len(sys_prompt):
response = (
response.replace(sys_prompt, "")
.replace("user", "")
.replace("model", "")
.replace(query, "")
.replace("<system>", "")
.replace("</system>", "")
.strip()
)
return QueryResult(content=response, raw_data=context_result.raw_data) return QueryResult(content=response, raw_data=context_result.raw_data)
else: else:
# Streaming response (AsyncIterator) # Streaming response (AsyncIterator)
@ -4052,29 +4019,6 @@ async def naive_query(
# Apply higher priority (5) to query relation LLM function # Apply higher priority (5) to query relation LLM function
use_model_func = partial(use_model_func, _priority=5) use_model_func = partial(use_model_func, _priority=5)
# Handle cache
args_hash = compute_args_hash(
query_param.mode,
query,
query_param.response_type,
query_param.top_k,
query_param.chunk_top_k,
query_param.max_entity_tokens,
query_param.max_relation_tokens,
query_param.max_total_tokens,
query_param.hl_keywords or [],
query_param.ll_keywords or [],
query_param.user_prompt or "",
query_param.enable_rerank,
)
cached_result = await handle_cache(
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
)
if cached_result is not None:
cached_response, _ = cached_result # Extract content, ignore timestamp
if not query_param.only_need_context and not query_param.only_need_prompt:
return QueryResult(content=cached_response)
tokenizer: Tokenizer = global_config["tokenizer"] tokenizer: Tokenizer = global_config["tokenizer"]
if not tokenizer: if not tokenizer:
logger.error("Tokenizer not found in global configuration.") logger.error("Tokenizer not found in global configuration.")
@ -4211,35 +4155,39 @@ async def naive_query(
prompt_content = "\n\n".join([sys_prompt, "---User Query---", user_query]) prompt_content = "\n\n".join([sys_prompt, "---User Query---", user_query])
return QueryResult(content=prompt_content, raw_data=raw_data) return QueryResult(content=prompt_content, raw_data=raw_data)
len_of_prompts = len(tokenizer.encode(query + sys_prompt)) # Handle cache
logger.debug( args_hash = compute_args_hash(
f"[naive_query] Sending to LLM: {len_of_prompts:,} tokens (Query: {len(tokenizer.encode(query))}, System: {len(tokenizer.encode(sys_prompt))})" query_param.mode,
query,
query_param.response_type,
query_param.top_k,
query_param.chunk_top_k,
query_param.max_entity_tokens,
query_param.max_relation_tokens,
query_param.max_total_tokens,
query_param.hl_keywords or [],
query_param.ll_keywords or [],
query_param.user_prompt or "",
query_param.enable_rerank,
) )
cached_result = await handle_cache(
response = await use_model_func( hashing_kv, args_hash, user_query, query_param.mode, cache_type="query"
user_query,
system_prompt=sys_prompt,
history_messages=query_param.conversation_history,
enable_cot=True,
stream=query_param.stream,
) )
if cached_result is not None:
cached_response, _ = cached_result # Extract content, ignore timestamp
logger.info(
" == LLM cache == Query cache hit, using cached response as query result"
)
response = cached_response
else:
response = await use_model_func(
user_query,
system_prompt=sys_prompt,
history_messages=query_param.conversation_history,
enable_cot=True,
stream=query_param.stream,
)
# Return unified result based on actual response type
if isinstance(response, str):
# Non-streaming response (string)
if len(response) > len(sys_prompt):
response = (
response[len(sys_prompt) :]
.replace(sys_prompt, "")
.replace("user", "")
.replace("model", "")
.replace(query, "")
.replace("<system>", "")
.replace("</system>", "")
.strip()
)
# Cache response
if hashing_kv and hashing_kv.global_config.get("enable_llm_cache"): if hashing_kv and hashing_kv.global_config.get("enable_llm_cache"):
queryparam_dict = { queryparam_dict = {
"mode": query_param.mode, "mode": query_param.mode,
@ -4266,6 +4214,21 @@ async def naive_query(
), ),
) )
# Return unified result based on actual response type
if isinstance(response, str):
# Non-streaming response (string)
if len(response) > len(sys_prompt):
response = (
response[len(sys_prompt) :]
.replace(sys_prompt, "")
.replace("user", "")
.replace("model", "")
.replace(query, "")
.replace("<system>", "")
.replace("</system>", "")
.strip()
)
return QueryResult(content=response, raw_data=raw_data) return QueryResult(content=response, raw_data=raw_data)
else: else:
# Streaming response (AsyncIterator) # Streaming response (AsyncIterator)

View file

@ -1102,7 +1102,9 @@ async def save_to_cache(hashing_kv, cache_data: CacheData):
if existing_cache: if existing_cache:
existing_content = existing_cache.get("return") existing_content = existing_cache.get("return")
if existing_content == cache_data.content: if existing_content == cache_data.content:
logger.info(f"Cache content unchanged for {flattened_key}, skipping update") logger.warning(
f"Cache duplication detected for {flattened_key}, skipping update"
)
return return
# Create cache entry with flattened structure # Create cache entry with flattened structure