diff --git a/lightrag/api/routers/query_routes.py b/lightrag/api/routers/query_routes.py
index 83df2823..7f7dedd3 100644
--- a/lightrag/api/routers/query_routes.py
+++ b/lightrag/api/routers/query_routes.py
@@ -152,44 +152,42 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
)
async def query_text(request: QueryRequest):
"""
- Handle a POST request at the /query endpoint to process user queries using RAG capabilities.
+ This endpoint performs a RAG query with non-streaming response.
Parameters:
request (QueryRequest): The request object containing the query parameters.
Returns:
QueryResponse: A Pydantic model containing the result of the query processing.
If include_references=True, also includes reference list.
- If a string is returned (e.g., cache hit), it's directly returned.
- Otherwise, an async generator may be used to build the response.
Raises:
HTTPException: Raised when an error occurs during the request handling process,
with status code 500 and detail containing the exception message.
"""
try:
- param = request.to_query_params(False)
- response = await rag.aquery(request.query, param=param)
+ param = request.to_query_params(
+ False
+ ) # Ensure stream=False for non-streaming endpoint
+ # Force stream=False for /query endpoint regardless of include_references setting
+ param.stream = False
- # Get reference list if requested
- reference_list = None
+ # Unified approach: always use aquery_llm for both cases
+ result = await rag.aquery_llm(request.query, param=param)
+
+ # Extract LLM response and references from unified result
+ llm_response = result.get("llm_response", {})
+ references = result.get("data", {}).get("references", [])
+
+ # Get the non-streaming response content
+ response_content = llm_response.get("content", "")
+ if not response_content:
+ response_content = "No relevant context found for the query."
+
+ # Return response with or without references based on request
if request.include_references:
- try:
- # Use aquery_data to get reference list independently
- data_result = await rag.aquery_data(request.query, param=param)
- if isinstance(data_result, dict) and "data" in data_result:
- reference_list = data_result["data"].get("references", [])
- except Exception as e:
- logging.warning(f"Failed to get reference list: {str(e)}")
- reference_list = []
-
- # Process response and return with optional references
- if isinstance(response, str):
- return QueryResponse(response=response, references=reference_list)
- elif isinstance(response, dict):
- result = json.dumps(response, indent=2)
- return QueryResponse(response=result, references=reference_list)
+ return QueryResponse(response=response_content, references=references)
else:
- return QueryResponse(response=str(response), references=reference_list)
+ return QueryResponse(response=response_content, references=None)
except Exception as e:
trace_exception(e)
raise HTTPException(status_code=500, detail=str(e))
@@ -197,7 +195,8 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
@router.post("/query/stream", dependencies=[Depends(combined_auth)])
async def query_text_stream(request: QueryRequest):
"""
- This endpoint performs a retrieval-augmented generation (RAG) query and streams the response.
+ This endpoint performs RAG query with streaming response.
+ Streaming can be turn off by setting stream=False in QueryRequest.
The streaming response includes:
1. Reference list (sent first as a single message, if include_references=True)
@@ -213,49 +212,42 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
- Error messages: {"error": "..."} - If any errors occur
"""
try:
- param = request.to_query_params(True)
- response = await rag.aquery(request.query, param=param)
+ param = request.to_query_params(
+ True
+ ) # Ensure stream=True for streaming endpoint
from fastapi.responses import StreamingResponse
async def stream_generator():
- # Get reference list if requested (default is True for backward compatibility)
- reference_list = []
- if request.include_references:
- try:
- # Use aquery_data to get reference list independently
- data_param = request.to_query_params(
- False
- ) # Non-streaming for data
- data_result = await rag.aquery_data(
- request.query, param=data_param
- )
- if isinstance(data_result, dict) and "data" in data_result:
- reference_list = data_result["data"].get("references", [])
- except Exception as e:
- logging.warning(f"Failed to get reference list: {str(e)}")
- reference_list = []
+ # Unified approach: always use aquery_llm for all cases
+ result = await rag.aquery_llm(request.query, param=param)
- # Send reference list first (if requested)
- if request.include_references:
- yield f"{json.dumps({'references': reference_list})}\n"
+ # Extract references and LLM response from unified result
+ references = result.get("data", {}).get("references", [])
+ llm_response = result.get("llm_response", {})
- # Then stream the response content
- if isinstance(response, str):
- # If it's a string, send it all at once
- yield f"{json.dumps({'response': response})}\n"
- elif response is None:
- # Handle None response (e.g., when only_need_context=True but no context found)
- yield f"{json.dumps({'response': 'No relevant context found for the query.'})}\n"
+ # Send reference list first if requested
+ if request.include_references:
+ yield f"{json.dumps({'references': references})}\n"
+
+ # Then stream the LLM response content
+ if llm_response.get("is_streaming"):
+ response_stream = llm_response.get("response_iterator")
+ if response_stream:
+ try:
+ async for chunk in response_stream:
+ if chunk: # Only send non-empty content
+ yield f"{json.dumps({'response': chunk})}\n"
+ except Exception as e:
+ logging.error(f"Streaming error: {str(e)}")
+ yield f"{json.dumps({'error': str(e)})}\n"
else:
- # If it's an async generator, send chunks one by one
- try:
- async for chunk in response:
- if chunk: # Only send non-empty content
- yield f"{json.dumps({'response': chunk})}\n"
- except Exception as e:
- logging.error(f"Streaming error: {str(e)}")
- yield f"{json.dumps({'error': str(e)})}\n"
+ # Non-streaming response (fallback)
+ response_content = llm_response.get("content", "")
+ if response_content:
+ yield f"{json.dumps({'response': response_content})}\n"
+ else:
+ yield f"{json.dumps({'response': 'No relevant context found for the query.'})}\n"
return StreamingResponse(
stream_generator(),
diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py
index afc0bc5f..6e4e3b04 100644
--- a/lightrag/lightrag.py
+++ b/lightrag/lightrag.py
@@ -2062,74 +2062,53 @@ class LightRAG:
system_prompt: str | None = None,
) -> str | AsyncIterator[str]:
"""
- Perform a async query.
+ Perform a async query (backward compatibility wrapper).
+
+ This function is now a wrapper around aquery_llm that maintains backward compatibility
+ by returning only the LLM response content in the original format.
Args:
query (str): The query to be executed.
param (QueryParam): Configuration parameters for query execution.
If param.model_func is provided, it will be used instead of the global model.
- prompt (Optional[str]): Custom prompts for fine-tuned control over the system's behavior. Defaults to None, which uses PROMPTS["rag_response"].
+ system_prompt (Optional[str]): Custom prompts for fine-tuned control over the system's behavior. Defaults to None, which uses PROMPTS["rag_response"].
Returns:
- str: The result of the query execution.
+ str | AsyncIterator[str]: The LLM response content.
+ - Non-streaming: Returns str
+ - Streaming: Returns AsyncIterator[str]
"""
- # If a custom model is provided in param, temporarily update global config
- global_config = asdict(self)
+ # Call the new aquery_llm function to get complete results
+ result = await self.aquery_llm(query, param, system_prompt)
- query_result = None
+ # Extract and return only the LLM response for backward compatibility
+ llm_response = result.get("llm_response", {})
- if param.mode in ["local", "global", "hybrid", "mix"]:
- query_result = await kg_query(
- query.strip(),
- self.chunk_entity_relation_graph,
- self.entities_vdb,
- self.relationships_vdb,
- self.text_chunks,
- param,
- global_config,
- hashing_kv=self.llm_response_cache,
- system_prompt=system_prompt,
- chunks_vdb=self.chunks_vdb,
- )
- elif param.mode == "naive":
- query_result = await naive_query(
- query.strip(),
- self.chunks_vdb,
- param,
- global_config,
- hashing_kv=self.llm_response_cache,
- system_prompt=system_prompt,
- )
- elif param.mode == "bypass":
- # Bypass mode: directly use LLM without knowledge retrieval
- use_llm_func = param.model_func or global_config["llm_model_func"]
- # Apply higher priority (8) to entity/relation summary tasks
- use_llm_func = partial(use_llm_func, _priority=8)
-
- param.stream = True if param.stream is None else param.stream
- response = await use_llm_func(
- query.strip(),
- system_prompt=system_prompt,
- history_messages=param.conversation_history,
- enable_cot=True,
- stream=param.stream,
- )
- # Create QueryResult for bypass mode
- query_result = QueryResult(
- content=response if not param.stream else None,
- response_iterator=response if param.stream else None,
- is_streaming=param.stream,
- )
+ if llm_response.get("is_streaming"):
+ return llm_response.get("response_iterator")
else:
- raise ValueError(f"Unknown mode {param.mode}")
+ return llm_response.get("content", "")
- await self._query_done()
+ def query_data(
+ self,
+ query: str,
+ param: QueryParam = QueryParam(),
+ ) -> dict[str, Any]:
+ """
+ Synchronous data retrieval API: returns structured retrieval results without LLM generation.
- # Return appropriate response based on streaming mode
- if query_result.is_streaming:
- return query_result.response_iterator
- else:
- return query_result.content
+ This function is the synchronous version of aquery_data, providing the same functionality
+ for users who prefer synchronous interfaces.
+
+ Args:
+ query: Query text for retrieval.
+ param: Query parameters controlling retrieval behavior (same as aquery).
+
+ Returns:
+ dict[str, Any]: Same structured data result as aquery_data.
+ """
+ loop = always_get_an_event_loop()
+ return loop.run_until_complete(self.aquery_data(query, param))
async def aquery_data(
self,
@@ -2323,6 +2302,162 @@ class LightRAG:
await self._query_done()
return final_data
+ async def aquery_llm(
+ self,
+ query: str,
+ param: QueryParam = QueryParam(),
+ system_prompt: str | None = None,
+ ) -> dict[str, Any]:
+ """
+ Asynchronous complete query API: returns structured retrieval results with LLM generation.
+
+ This function performs a single query operation and returns both structured data and LLM response,
+ based on the original aquery logic to avoid duplicate calls.
+
+ Args:
+ query: Query text for retrieval and LLM generation.
+ param: Query parameters controlling retrieval and LLM behavior.
+ system_prompt: Optional custom system prompt for LLM generation.
+
+ Returns:
+ dict[str, Any]: Complete response with structured data and LLM response.
+ """
+ global_config = asdict(self)
+
+ try:
+ query_result = None
+
+ if param.mode in ["local", "global", "hybrid", "mix"]:
+ query_result = await kg_query(
+ query.strip(),
+ self.chunk_entity_relation_graph,
+ self.entities_vdb,
+ self.relationships_vdb,
+ self.text_chunks,
+ param,
+ global_config,
+ hashing_kv=self.llm_response_cache,
+ system_prompt=system_prompt,
+ chunks_vdb=self.chunks_vdb,
+ )
+ elif param.mode == "naive":
+ query_result = await naive_query(
+ query.strip(),
+ self.chunks_vdb,
+ param,
+ global_config,
+ hashing_kv=self.llm_response_cache,
+ system_prompt=system_prompt,
+ )
+ elif param.mode == "bypass":
+ # Bypass mode: directly use LLM without knowledge retrieval
+ use_llm_func = param.model_func or global_config["llm_model_func"]
+ # Apply higher priority (8) to entity/relation summary tasks
+ use_llm_func = partial(use_llm_func, _priority=8)
+
+ param.stream = True if param.stream is None else param.stream
+ response = await use_llm_func(
+ query.strip(),
+ system_prompt=system_prompt,
+ history_messages=param.conversation_history,
+ enable_cot=True,
+ stream=param.stream,
+ )
+ if type(response) is str:
+ return {
+ "status": "success",
+ "message": "Bypass mode LLM non streaming response",
+ "data": {},
+ "metadata": {},
+ "llm_response": {
+ "content": response,
+ "response_iterator": None,
+ "is_streaming": False,
+ },
+ }
+ else:
+ return {
+ "status": "success",
+ "message": "Bypass mode LLM streaming response",
+ "data": {},
+ "metadata": {},
+ "llm_response": {
+ "content": None,
+ "response_iterator": response,
+ "is_streaming": True,
+ },
+ }
+ else:
+ raise ValueError(f"Unknown mode {param.mode}")
+
+ await self._query_done()
+
+ # Check if query_result is None
+ if query_result is None:
+ return {
+ "status": "failure",
+ "message": "Query returned no results",
+ "data": {},
+ "metadata": {},
+ "llm_response": {
+ "content": None,
+ "response_iterator": None,
+ "is_streaming": False,
+ },
+ }
+
+ # Extract structured data from query result
+ raw_data = query_result.raw_data if query_result else {}
+ raw_data["llm_response"] = {
+ "content": query_result.content
+ if not query_result.is_streaming
+ else None,
+ "response_iterator": query_result.response_iterator
+ if query_result.is_streaming
+ else None,
+ "is_streaming": query_result.is_streaming,
+ }
+
+ return raw_data
+
+ except Exception as e:
+ logger.error(f"Query failed: {e}")
+ # Return error response
+ return {
+ "status": "failure",
+ "message": f"Query failed: {str(e)}",
+ "data": {},
+ "metadata": {},
+ "llm_response": {
+ "content": None,
+ "response_iterator": None,
+ "is_streaming": False,
+ },
+ }
+
+ def query_llm(
+ self,
+ query: str,
+ param: QueryParam = QueryParam(),
+ system_prompt: str | None = None,
+ ) -> dict[str, Any]:
+ """
+ Synchronous complete query API: returns structured retrieval results with LLM generation.
+
+ This function is the synchronous version of aquery_llm, providing the same functionality
+ for users who prefer synchronous interfaces.
+
+ Args:
+ query: Query text for retrieval and LLM generation.
+ param: Query parameters controlling retrieval and LLM behavior.
+ system_prompt: Optional custom system prompt for LLM generation.
+
+ Returns:
+ dict[str, Any]: Same complete response format as aquery_llm.
+ """
+ loop = always_get_an_event_loop()
+ return loop.run_until_complete(self.aquery_llm(query, param, system_prompt))
+
async def _query_done(self):
await self.llm_response_cache.index_done_callback()
diff --git a/lightrag/operate.py b/lightrag/operate.py
index 0f9e8edf..f2930f3d 100644
--- a/lightrag/operate.py
+++ b/lightrag/operate.py
@@ -2236,38 +2236,6 @@ async def extract_entities(
return chunk_results
-@overload
-async def kg_query(
- query: str,
- knowledge_graph_inst: BaseGraphStorage,
- entities_vdb: BaseVectorStorage,
- relationships_vdb: BaseVectorStorage,
- text_chunks_db: BaseKVStorage,
- query_param: QueryParam,
- global_config: dict[str, str],
- hashing_kv: BaseKVStorage | None = None,
- system_prompt: str | None = None,
- chunks_vdb: BaseVectorStorage = None,
- return_raw_data: Literal[True] = False,
-) -> dict[str, Any]: ...
-
-
-@overload
-async def kg_query(
- query: str,
- knowledge_graph_inst: BaseGraphStorage,
- entities_vdb: BaseVectorStorage,
- relationships_vdb: BaseVectorStorage,
- text_chunks_db: BaseKVStorage,
- query_param: QueryParam,
- global_config: dict[str, str],
- hashing_kv: BaseKVStorage | None = None,
- system_prompt: str | None = None,
- chunks_vdb: BaseVectorStorage = None,
- return_raw_data: Literal[False] = False,
-) -> str | AsyncIterator[str]: ...
-
-
async def kg_query(
query: str,
knowledge_graph_inst: BaseGraphStorage,
@@ -2308,7 +2276,6 @@ async def kg_query(
- stream=True: response_iterator contains streaming response, raw_data contains complete data
- default: content contains LLM response text, raw_data contains complete data
"""
-
if not query:
return QueryResult(content=PROMPTS["fail_response"])
@@ -2319,32 +2286,6 @@ async def kg_query(
# Apply higher priority (5) to query relation LLM function
use_model_func = partial(use_model_func, _priority=5)
- # Handle cache
- args_hash = compute_args_hash(
- query_param.mode,
- query,
- query_param.response_type,
- query_param.top_k,
- query_param.chunk_top_k,
- query_param.max_entity_tokens,
- query_param.max_relation_tokens,
- query_param.max_total_tokens,
- query_param.hl_keywords or [],
- query_param.ll_keywords or [],
- query_param.user_prompt or "",
- query_param.enable_rerank,
- )
- cached_result = await handle_cache(
- hashing_kv, args_hash, query, query_param.mode, cache_type="query"
- )
- if (
- cached_result is not None
- and not query_param.only_need_context
- and not query_param.only_need_prompt
- ):
- cached_response, _ = cached_result # Extract content, ignore timestamp
- return QueryResult(content=cached_response)
-
hl_keywords, ll_keywords = await get_keywords_from_query(
query, query_param, global_config, hashing_kv
)
@@ -2417,29 +2358,41 @@ async def kg_query(
f"[kg_query] Sending to LLM: {len_of_prompts:,} tokens (Query: {len(tokenizer.encode(query))}, System: {len(tokenizer.encode(sys_prompt))})"
)
- response = await use_model_func(
- user_query,
- system_prompt=sys_prompt,
- history_messages=query_param.conversation_history,
- enable_cot=True,
- stream=query_param.stream,
+ # Handle cache
+ args_hash = compute_args_hash(
+ query_param.mode,
+ query,
+ query_param.response_type,
+ query_param.top_k,
+ query_param.chunk_top_k,
+ query_param.max_entity_tokens,
+ query_param.max_relation_tokens,
+ query_param.max_total_tokens,
+ query_param.hl_keywords or [],
+ query_param.ll_keywords or [],
+ query_param.user_prompt or "",
+ query_param.enable_rerank,
)
- # Return unified result based on actual response type
- if isinstance(response, str):
- # Non-streaming response (string)
- if len(response) > len(sys_prompt):
- response = (
- response.replace(sys_prompt, "")
- .replace("user", "")
- .replace("model", "")
- .replace(query, "")
- .replace("", "")
- .replace("", "")
- .strip()
- )
+ cached_result = await handle_cache(
+ hashing_kv, args_hash, user_query, query_param.mode, cache_type="query"
+ )
+
+ if cached_result is not None:
+ cached_response, _ = cached_result # Extract content, ignore timestamp
+ logger.info(
+ " == LLM cache == Query cache hit, using cached response as query result"
+ )
+ response = cached_response
+ else:
+ response = await use_model_func(
+ user_query,
+ system_prompt=sys_prompt,
+ history_messages=query_param.conversation_history,
+ enable_cot=True,
+ stream=query_param.stream,
+ )
- # Cache response
if hashing_kv and hashing_kv.global_config.get("enable_llm_cache"):
queryparam_dict = {
"mode": query_param.mode,
@@ -2466,6 +2419,20 @@ async def kg_query(
),
)
+ # Return unified result based on actual response type
+ if isinstance(response, str):
+ # Non-streaming response (string)
+ if len(response) > len(sys_prompt):
+ response = (
+ response.replace(sys_prompt, "")
+ .replace("user", "")
+ .replace("model", "")
+ .replace(query, "")
+ .replace("", "")
+ .replace("", "")
+ .strip()
+ )
+
return QueryResult(content=response, raw_data=context_result.raw_data)
else:
# Streaming response (AsyncIterator)
@@ -4052,29 +4019,6 @@ async def naive_query(
# Apply higher priority (5) to query relation LLM function
use_model_func = partial(use_model_func, _priority=5)
- # Handle cache
- args_hash = compute_args_hash(
- query_param.mode,
- query,
- query_param.response_type,
- query_param.top_k,
- query_param.chunk_top_k,
- query_param.max_entity_tokens,
- query_param.max_relation_tokens,
- query_param.max_total_tokens,
- query_param.hl_keywords or [],
- query_param.ll_keywords or [],
- query_param.user_prompt or "",
- query_param.enable_rerank,
- )
- cached_result = await handle_cache(
- hashing_kv, args_hash, query, query_param.mode, cache_type="query"
- )
- if cached_result is not None:
- cached_response, _ = cached_result # Extract content, ignore timestamp
- if not query_param.only_need_context and not query_param.only_need_prompt:
- return QueryResult(content=cached_response)
-
tokenizer: Tokenizer = global_config["tokenizer"]
if not tokenizer:
logger.error("Tokenizer not found in global configuration.")
@@ -4211,35 +4155,39 @@ async def naive_query(
prompt_content = "\n\n".join([sys_prompt, "---User Query---", user_query])
return QueryResult(content=prompt_content, raw_data=raw_data)
- len_of_prompts = len(tokenizer.encode(query + sys_prompt))
- logger.debug(
- f"[naive_query] Sending to LLM: {len_of_prompts:,} tokens (Query: {len(tokenizer.encode(query))}, System: {len(tokenizer.encode(sys_prompt))})"
+ # Handle cache
+ args_hash = compute_args_hash(
+ query_param.mode,
+ query,
+ query_param.response_type,
+ query_param.top_k,
+ query_param.chunk_top_k,
+ query_param.max_entity_tokens,
+ query_param.max_relation_tokens,
+ query_param.max_total_tokens,
+ query_param.hl_keywords or [],
+ query_param.ll_keywords or [],
+ query_param.user_prompt or "",
+ query_param.enable_rerank,
)
-
- response = await use_model_func(
- user_query,
- system_prompt=sys_prompt,
- history_messages=query_param.conversation_history,
- enable_cot=True,
- stream=query_param.stream,
+ cached_result = await handle_cache(
+ hashing_kv, args_hash, user_query, query_param.mode, cache_type="query"
)
+ if cached_result is not None:
+ cached_response, _ = cached_result # Extract content, ignore timestamp
+ logger.info(
+ " == LLM cache == Query cache hit, using cached response as query result"
+ )
+ response = cached_response
+ else:
+ response = await use_model_func(
+ user_query,
+ system_prompt=sys_prompt,
+ history_messages=query_param.conversation_history,
+ enable_cot=True,
+ stream=query_param.stream,
+ )
- # Return unified result based on actual response type
- if isinstance(response, str):
- # Non-streaming response (string)
- if len(response) > len(sys_prompt):
- response = (
- response[len(sys_prompt) :]
- .replace(sys_prompt, "")
- .replace("user", "")
- .replace("model", "")
- .replace(query, "")
- .replace("", "")
- .replace("", "")
- .strip()
- )
-
- # Cache response
if hashing_kv and hashing_kv.global_config.get("enable_llm_cache"):
queryparam_dict = {
"mode": query_param.mode,
@@ -4266,6 +4214,21 @@ async def naive_query(
),
)
+ # Return unified result based on actual response type
+ if isinstance(response, str):
+ # Non-streaming response (string)
+ if len(response) > len(sys_prompt):
+ response = (
+ response[len(sys_prompt) :]
+ .replace(sys_prompt, "")
+ .replace("user", "")
+ .replace("model", "")
+ .replace(query, "")
+ .replace("", "")
+ .replace("", "")
+ .strip()
+ )
+
return QueryResult(content=response, raw_data=raw_data)
else:
# Streaming response (AsyncIterator)
diff --git a/lightrag/utils.py b/lightrag/utils.py
index 1bd9ca8e..c72b6fd2 100644
--- a/lightrag/utils.py
+++ b/lightrag/utils.py
@@ -1102,7 +1102,9 @@ async def save_to_cache(hashing_kv, cache_data: CacheData):
if existing_cache:
existing_content = existing_cache.get("return")
if existing_content == cache_data.content:
- logger.info(f"Cache content unchanged for {flattened_key}, skipping update")
+ logger.warning(
+ f"Cache duplication detected for {flattened_key}, skipping update"
+ )
return
# Create cache entry with flattened structure