From 8cd4139cbf888410743600e3d195f295884eccde Mon Sep 17 00:00:00 2001 From: yangdx Date: Fri, 26 Sep 2025 19:05:03 +0800 Subject: [PATCH] refactor: fix double query problem by add aquery_llm function for consistent response handling - Add new aquery_llm/query_llm methods providing structured responses - Consolidate /query and /query/stream endpoints to use unified aquery_llm - Optimize cache handling by moving cache checks before LLM calls --- lightrag/api/routers/query_routes.py | 112 ++++++------ lightrag/lightrag.py | 245 +++++++++++++++++++++------ lightrag/operate.py | 219 ++++++++++-------------- lightrag/utils.py | 4 +- 4 files changed, 336 insertions(+), 244 deletions(-) diff --git a/lightrag/api/routers/query_routes.py b/lightrag/api/routers/query_routes.py index 83df2823..7f7dedd3 100644 --- a/lightrag/api/routers/query_routes.py +++ b/lightrag/api/routers/query_routes.py @@ -152,44 +152,42 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60): ) async def query_text(request: QueryRequest): """ - Handle a POST request at the /query endpoint to process user queries using RAG capabilities. + This endpoint performs a RAG query with non-streaming response. Parameters: request (QueryRequest): The request object containing the query parameters. Returns: QueryResponse: A Pydantic model containing the result of the query processing. If include_references=True, also includes reference list. - If a string is returned (e.g., cache hit), it's directly returned. - Otherwise, an async generator may be used to build the response. Raises: HTTPException: Raised when an error occurs during the request handling process, with status code 500 and detail containing the exception message. """ try: - param = request.to_query_params(False) - response = await rag.aquery(request.query, param=param) + param = request.to_query_params( + False + ) # Ensure stream=False for non-streaming endpoint + # Force stream=False for /query endpoint regardless of include_references setting + param.stream = False - # Get reference list if requested - reference_list = None + # Unified approach: always use aquery_llm for both cases + result = await rag.aquery_llm(request.query, param=param) + + # Extract LLM response and references from unified result + llm_response = result.get("llm_response", {}) + references = result.get("data", {}).get("references", []) + + # Get the non-streaming response content + response_content = llm_response.get("content", "") + if not response_content: + response_content = "No relevant context found for the query." + + # Return response with or without references based on request if request.include_references: - try: - # Use aquery_data to get reference list independently - data_result = await rag.aquery_data(request.query, param=param) - if isinstance(data_result, dict) and "data" in data_result: - reference_list = data_result["data"].get("references", []) - except Exception as e: - logging.warning(f"Failed to get reference list: {str(e)}") - reference_list = [] - - # Process response and return with optional references - if isinstance(response, str): - return QueryResponse(response=response, references=reference_list) - elif isinstance(response, dict): - result = json.dumps(response, indent=2) - return QueryResponse(response=result, references=reference_list) + return QueryResponse(response=response_content, references=references) else: - return QueryResponse(response=str(response), references=reference_list) + return QueryResponse(response=response_content, references=None) except Exception as e: trace_exception(e) raise HTTPException(status_code=500, detail=str(e)) @@ -197,7 +195,8 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60): @router.post("/query/stream", dependencies=[Depends(combined_auth)]) async def query_text_stream(request: QueryRequest): """ - This endpoint performs a retrieval-augmented generation (RAG) query and streams the response. + This endpoint performs RAG query with streaming response. + Streaming can be turn off by setting stream=False in QueryRequest. The streaming response includes: 1. Reference list (sent first as a single message, if include_references=True) @@ -213,49 +212,42 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60): - Error messages: {"error": "..."} - If any errors occur """ try: - param = request.to_query_params(True) - response = await rag.aquery(request.query, param=param) + param = request.to_query_params( + True + ) # Ensure stream=True for streaming endpoint from fastapi.responses import StreamingResponse async def stream_generator(): - # Get reference list if requested (default is True for backward compatibility) - reference_list = [] - if request.include_references: - try: - # Use aquery_data to get reference list independently - data_param = request.to_query_params( - False - ) # Non-streaming for data - data_result = await rag.aquery_data( - request.query, param=data_param - ) - if isinstance(data_result, dict) and "data" in data_result: - reference_list = data_result["data"].get("references", []) - except Exception as e: - logging.warning(f"Failed to get reference list: {str(e)}") - reference_list = [] + # Unified approach: always use aquery_llm for all cases + result = await rag.aquery_llm(request.query, param=param) - # Send reference list first (if requested) - if request.include_references: - yield f"{json.dumps({'references': reference_list})}\n" + # Extract references and LLM response from unified result + references = result.get("data", {}).get("references", []) + llm_response = result.get("llm_response", {}) - # Then stream the response content - if isinstance(response, str): - # If it's a string, send it all at once - yield f"{json.dumps({'response': response})}\n" - elif response is None: - # Handle None response (e.g., when only_need_context=True but no context found) - yield f"{json.dumps({'response': 'No relevant context found for the query.'})}\n" + # Send reference list first if requested + if request.include_references: + yield f"{json.dumps({'references': references})}\n" + + # Then stream the LLM response content + if llm_response.get("is_streaming"): + response_stream = llm_response.get("response_iterator") + if response_stream: + try: + async for chunk in response_stream: + if chunk: # Only send non-empty content + yield f"{json.dumps({'response': chunk})}\n" + except Exception as e: + logging.error(f"Streaming error: {str(e)}") + yield f"{json.dumps({'error': str(e)})}\n" else: - # If it's an async generator, send chunks one by one - try: - async for chunk in response: - if chunk: # Only send non-empty content - yield f"{json.dumps({'response': chunk})}\n" - except Exception as e: - logging.error(f"Streaming error: {str(e)}") - yield f"{json.dumps({'error': str(e)})}\n" + # Non-streaming response (fallback) + response_content = llm_response.get("content", "") + if response_content: + yield f"{json.dumps({'response': response_content})}\n" + else: + yield f"{json.dumps({'response': 'No relevant context found for the query.'})}\n" return StreamingResponse( stream_generator(), diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index afc0bc5f..6e4e3b04 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -2062,74 +2062,53 @@ class LightRAG: system_prompt: str | None = None, ) -> str | AsyncIterator[str]: """ - Perform a async query. + Perform a async query (backward compatibility wrapper). + + This function is now a wrapper around aquery_llm that maintains backward compatibility + by returning only the LLM response content in the original format. Args: query (str): The query to be executed. param (QueryParam): Configuration parameters for query execution. If param.model_func is provided, it will be used instead of the global model. - prompt (Optional[str]): Custom prompts for fine-tuned control over the system's behavior. Defaults to None, which uses PROMPTS["rag_response"]. + system_prompt (Optional[str]): Custom prompts for fine-tuned control over the system's behavior. Defaults to None, which uses PROMPTS["rag_response"]. Returns: - str: The result of the query execution. + str | AsyncIterator[str]: The LLM response content. + - Non-streaming: Returns str + - Streaming: Returns AsyncIterator[str] """ - # If a custom model is provided in param, temporarily update global config - global_config = asdict(self) + # Call the new aquery_llm function to get complete results + result = await self.aquery_llm(query, param, system_prompt) - query_result = None + # Extract and return only the LLM response for backward compatibility + llm_response = result.get("llm_response", {}) - if param.mode in ["local", "global", "hybrid", "mix"]: - query_result = await kg_query( - query.strip(), - self.chunk_entity_relation_graph, - self.entities_vdb, - self.relationships_vdb, - self.text_chunks, - param, - global_config, - hashing_kv=self.llm_response_cache, - system_prompt=system_prompt, - chunks_vdb=self.chunks_vdb, - ) - elif param.mode == "naive": - query_result = await naive_query( - query.strip(), - self.chunks_vdb, - param, - global_config, - hashing_kv=self.llm_response_cache, - system_prompt=system_prompt, - ) - elif param.mode == "bypass": - # Bypass mode: directly use LLM without knowledge retrieval - use_llm_func = param.model_func or global_config["llm_model_func"] - # Apply higher priority (8) to entity/relation summary tasks - use_llm_func = partial(use_llm_func, _priority=8) - - param.stream = True if param.stream is None else param.stream - response = await use_llm_func( - query.strip(), - system_prompt=system_prompt, - history_messages=param.conversation_history, - enable_cot=True, - stream=param.stream, - ) - # Create QueryResult for bypass mode - query_result = QueryResult( - content=response if not param.stream else None, - response_iterator=response if param.stream else None, - is_streaming=param.stream, - ) + if llm_response.get("is_streaming"): + return llm_response.get("response_iterator") else: - raise ValueError(f"Unknown mode {param.mode}") + return llm_response.get("content", "") - await self._query_done() + def query_data( + self, + query: str, + param: QueryParam = QueryParam(), + ) -> dict[str, Any]: + """ + Synchronous data retrieval API: returns structured retrieval results without LLM generation. - # Return appropriate response based on streaming mode - if query_result.is_streaming: - return query_result.response_iterator - else: - return query_result.content + This function is the synchronous version of aquery_data, providing the same functionality + for users who prefer synchronous interfaces. + + Args: + query: Query text for retrieval. + param: Query parameters controlling retrieval behavior (same as aquery). + + Returns: + dict[str, Any]: Same structured data result as aquery_data. + """ + loop = always_get_an_event_loop() + return loop.run_until_complete(self.aquery_data(query, param)) async def aquery_data( self, @@ -2323,6 +2302,162 @@ class LightRAG: await self._query_done() return final_data + async def aquery_llm( + self, + query: str, + param: QueryParam = QueryParam(), + system_prompt: str | None = None, + ) -> dict[str, Any]: + """ + Asynchronous complete query API: returns structured retrieval results with LLM generation. + + This function performs a single query operation and returns both structured data and LLM response, + based on the original aquery logic to avoid duplicate calls. + + Args: + query: Query text for retrieval and LLM generation. + param: Query parameters controlling retrieval and LLM behavior. + system_prompt: Optional custom system prompt for LLM generation. + + Returns: + dict[str, Any]: Complete response with structured data and LLM response. + """ + global_config = asdict(self) + + try: + query_result = None + + if param.mode in ["local", "global", "hybrid", "mix"]: + query_result = await kg_query( + query.strip(), + self.chunk_entity_relation_graph, + self.entities_vdb, + self.relationships_vdb, + self.text_chunks, + param, + global_config, + hashing_kv=self.llm_response_cache, + system_prompt=system_prompt, + chunks_vdb=self.chunks_vdb, + ) + elif param.mode == "naive": + query_result = await naive_query( + query.strip(), + self.chunks_vdb, + param, + global_config, + hashing_kv=self.llm_response_cache, + system_prompt=system_prompt, + ) + elif param.mode == "bypass": + # Bypass mode: directly use LLM without knowledge retrieval + use_llm_func = param.model_func or global_config["llm_model_func"] + # Apply higher priority (8) to entity/relation summary tasks + use_llm_func = partial(use_llm_func, _priority=8) + + param.stream = True if param.stream is None else param.stream + response = await use_llm_func( + query.strip(), + system_prompt=system_prompt, + history_messages=param.conversation_history, + enable_cot=True, + stream=param.stream, + ) + if type(response) is str: + return { + "status": "success", + "message": "Bypass mode LLM non streaming response", + "data": {}, + "metadata": {}, + "llm_response": { + "content": response, + "response_iterator": None, + "is_streaming": False, + }, + } + else: + return { + "status": "success", + "message": "Bypass mode LLM streaming response", + "data": {}, + "metadata": {}, + "llm_response": { + "content": None, + "response_iterator": response, + "is_streaming": True, + }, + } + else: + raise ValueError(f"Unknown mode {param.mode}") + + await self._query_done() + + # Check if query_result is None + if query_result is None: + return { + "status": "failure", + "message": "Query returned no results", + "data": {}, + "metadata": {}, + "llm_response": { + "content": None, + "response_iterator": None, + "is_streaming": False, + }, + } + + # Extract structured data from query result + raw_data = query_result.raw_data if query_result else {} + raw_data["llm_response"] = { + "content": query_result.content + if not query_result.is_streaming + else None, + "response_iterator": query_result.response_iterator + if query_result.is_streaming + else None, + "is_streaming": query_result.is_streaming, + } + + return raw_data + + except Exception as e: + logger.error(f"Query failed: {e}") + # Return error response + return { + "status": "failure", + "message": f"Query failed: {str(e)}", + "data": {}, + "metadata": {}, + "llm_response": { + "content": None, + "response_iterator": None, + "is_streaming": False, + }, + } + + def query_llm( + self, + query: str, + param: QueryParam = QueryParam(), + system_prompt: str | None = None, + ) -> dict[str, Any]: + """ + Synchronous complete query API: returns structured retrieval results with LLM generation. + + This function is the synchronous version of aquery_llm, providing the same functionality + for users who prefer synchronous interfaces. + + Args: + query: Query text for retrieval and LLM generation. + param: Query parameters controlling retrieval and LLM behavior. + system_prompt: Optional custom system prompt for LLM generation. + + Returns: + dict[str, Any]: Same complete response format as aquery_llm. + """ + loop = always_get_an_event_loop() + return loop.run_until_complete(self.aquery_llm(query, param, system_prompt)) + async def _query_done(self): await self.llm_response_cache.index_done_callback() diff --git a/lightrag/operate.py b/lightrag/operate.py index 0f9e8edf..f2930f3d 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -2236,38 +2236,6 @@ async def extract_entities( return chunk_results -@overload -async def kg_query( - query: str, - knowledge_graph_inst: BaseGraphStorage, - entities_vdb: BaseVectorStorage, - relationships_vdb: BaseVectorStorage, - text_chunks_db: BaseKVStorage, - query_param: QueryParam, - global_config: dict[str, str], - hashing_kv: BaseKVStorage | None = None, - system_prompt: str | None = None, - chunks_vdb: BaseVectorStorage = None, - return_raw_data: Literal[True] = False, -) -> dict[str, Any]: ... - - -@overload -async def kg_query( - query: str, - knowledge_graph_inst: BaseGraphStorage, - entities_vdb: BaseVectorStorage, - relationships_vdb: BaseVectorStorage, - text_chunks_db: BaseKVStorage, - query_param: QueryParam, - global_config: dict[str, str], - hashing_kv: BaseKVStorage | None = None, - system_prompt: str | None = None, - chunks_vdb: BaseVectorStorage = None, - return_raw_data: Literal[False] = False, -) -> str | AsyncIterator[str]: ... - - async def kg_query( query: str, knowledge_graph_inst: BaseGraphStorage, @@ -2308,7 +2276,6 @@ async def kg_query( - stream=True: response_iterator contains streaming response, raw_data contains complete data - default: content contains LLM response text, raw_data contains complete data """ - if not query: return QueryResult(content=PROMPTS["fail_response"]) @@ -2319,32 +2286,6 @@ async def kg_query( # Apply higher priority (5) to query relation LLM function use_model_func = partial(use_model_func, _priority=5) - # Handle cache - args_hash = compute_args_hash( - query_param.mode, - query, - query_param.response_type, - query_param.top_k, - query_param.chunk_top_k, - query_param.max_entity_tokens, - query_param.max_relation_tokens, - query_param.max_total_tokens, - query_param.hl_keywords or [], - query_param.ll_keywords or [], - query_param.user_prompt or "", - query_param.enable_rerank, - ) - cached_result = await handle_cache( - hashing_kv, args_hash, query, query_param.mode, cache_type="query" - ) - if ( - cached_result is not None - and not query_param.only_need_context - and not query_param.only_need_prompt - ): - cached_response, _ = cached_result # Extract content, ignore timestamp - return QueryResult(content=cached_response) - hl_keywords, ll_keywords = await get_keywords_from_query( query, query_param, global_config, hashing_kv ) @@ -2417,29 +2358,41 @@ async def kg_query( f"[kg_query] Sending to LLM: {len_of_prompts:,} tokens (Query: {len(tokenizer.encode(query))}, System: {len(tokenizer.encode(sys_prompt))})" ) - response = await use_model_func( - user_query, - system_prompt=sys_prompt, - history_messages=query_param.conversation_history, - enable_cot=True, - stream=query_param.stream, + # Handle cache + args_hash = compute_args_hash( + query_param.mode, + query, + query_param.response_type, + query_param.top_k, + query_param.chunk_top_k, + query_param.max_entity_tokens, + query_param.max_relation_tokens, + query_param.max_total_tokens, + query_param.hl_keywords or [], + query_param.ll_keywords or [], + query_param.user_prompt or "", + query_param.enable_rerank, ) - # Return unified result based on actual response type - if isinstance(response, str): - # Non-streaming response (string) - if len(response) > len(sys_prompt): - response = ( - response.replace(sys_prompt, "") - .replace("user", "") - .replace("model", "") - .replace(query, "") - .replace("", "") - .replace("", "") - .strip() - ) + cached_result = await handle_cache( + hashing_kv, args_hash, user_query, query_param.mode, cache_type="query" + ) + + if cached_result is not None: + cached_response, _ = cached_result # Extract content, ignore timestamp + logger.info( + " == LLM cache == Query cache hit, using cached response as query result" + ) + response = cached_response + else: + response = await use_model_func( + user_query, + system_prompt=sys_prompt, + history_messages=query_param.conversation_history, + enable_cot=True, + stream=query_param.stream, + ) - # Cache response if hashing_kv and hashing_kv.global_config.get("enable_llm_cache"): queryparam_dict = { "mode": query_param.mode, @@ -2466,6 +2419,20 @@ async def kg_query( ), ) + # Return unified result based on actual response type + if isinstance(response, str): + # Non-streaming response (string) + if len(response) > len(sys_prompt): + response = ( + response.replace(sys_prompt, "") + .replace("user", "") + .replace("model", "") + .replace(query, "") + .replace("", "") + .replace("", "") + .strip() + ) + return QueryResult(content=response, raw_data=context_result.raw_data) else: # Streaming response (AsyncIterator) @@ -4052,29 +4019,6 @@ async def naive_query( # Apply higher priority (5) to query relation LLM function use_model_func = partial(use_model_func, _priority=5) - # Handle cache - args_hash = compute_args_hash( - query_param.mode, - query, - query_param.response_type, - query_param.top_k, - query_param.chunk_top_k, - query_param.max_entity_tokens, - query_param.max_relation_tokens, - query_param.max_total_tokens, - query_param.hl_keywords or [], - query_param.ll_keywords or [], - query_param.user_prompt or "", - query_param.enable_rerank, - ) - cached_result = await handle_cache( - hashing_kv, args_hash, query, query_param.mode, cache_type="query" - ) - if cached_result is not None: - cached_response, _ = cached_result # Extract content, ignore timestamp - if not query_param.only_need_context and not query_param.only_need_prompt: - return QueryResult(content=cached_response) - tokenizer: Tokenizer = global_config["tokenizer"] if not tokenizer: logger.error("Tokenizer not found in global configuration.") @@ -4211,35 +4155,39 @@ async def naive_query( prompt_content = "\n\n".join([sys_prompt, "---User Query---", user_query]) return QueryResult(content=prompt_content, raw_data=raw_data) - len_of_prompts = len(tokenizer.encode(query + sys_prompt)) - logger.debug( - f"[naive_query] Sending to LLM: {len_of_prompts:,} tokens (Query: {len(tokenizer.encode(query))}, System: {len(tokenizer.encode(sys_prompt))})" + # Handle cache + args_hash = compute_args_hash( + query_param.mode, + query, + query_param.response_type, + query_param.top_k, + query_param.chunk_top_k, + query_param.max_entity_tokens, + query_param.max_relation_tokens, + query_param.max_total_tokens, + query_param.hl_keywords or [], + query_param.ll_keywords or [], + query_param.user_prompt or "", + query_param.enable_rerank, ) - - response = await use_model_func( - user_query, - system_prompt=sys_prompt, - history_messages=query_param.conversation_history, - enable_cot=True, - stream=query_param.stream, + cached_result = await handle_cache( + hashing_kv, args_hash, user_query, query_param.mode, cache_type="query" ) + if cached_result is not None: + cached_response, _ = cached_result # Extract content, ignore timestamp + logger.info( + " == LLM cache == Query cache hit, using cached response as query result" + ) + response = cached_response + else: + response = await use_model_func( + user_query, + system_prompt=sys_prompt, + history_messages=query_param.conversation_history, + enable_cot=True, + stream=query_param.stream, + ) - # Return unified result based on actual response type - if isinstance(response, str): - # Non-streaming response (string) - if len(response) > len(sys_prompt): - response = ( - response[len(sys_prompt) :] - .replace(sys_prompt, "") - .replace("user", "") - .replace("model", "") - .replace(query, "") - .replace("", "") - .replace("", "") - .strip() - ) - - # Cache response if hashing_kv and hashing_kv.global_config.get("enable_llm_cache"): queryparam_dict = { "mode": query_param.mode, @@ -4266,6 +4214,21 @@ async def naive_query( ), ) + # Return unified result based on actual response type + if isinstance(response, str): + # Non-streaming response (string) + if len(response) > len(sys_prompt): + response = ( + response[len(sys_prompt) :] + .replace(sys_prompt, "") + .replace("user", "") + .replace("model", "") + .replace(query, "") + .replace("", "") + .replace("", "") + .strip() + ) + return QueryResult(content=response, raw_data=raw_data) else: # Streaming response (AsyncIterator) diff --git a/lightrag/utils.py b/lightrag/utils.py index 1bd9ca8e..c72b6fd2 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -1102,7 +1102,9 @@ async def save_to_cache(hashing_kv, cache_data: CacheData): if existing_cache: existing_content = existing_cache.get("return") if existing_content == cache_data.content: - logger.info(f"Cache content unchanged for {flattened_key}, skipping update") + logger.warning( + f"Cache duplication detected for {flattened_key}, skipping update" + ) return # Create cache entry with flattened structure