refactor: fix double query problem by add aquery_llm function for consistent response handling
- Add new aquery_llm/query_llm methods providing structured responses - Consolidate /query and /query/stream endpoints to use unified aquery_llm - Optimize cache handling by moving cache checks before LLM calls
This commit is contained in:
parent
862026905a
commit
8cd4139cbf
4 changed files with 336 additions and 244 deletions
|
|
@ -152,44 +152,42 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
|
|||
)
|
||||
async def query_text(request: QueryRequest):
|
||||
"""
|
||||
Handle a POST request at the /query endpoint to process user queries using RAG capabilities.
|
||||
This endpoint performs a RAG query with non-streaming response.
|
||||
|
||||
Parameters:
|
||||
request (QueryRequest): The request object containing the query parameters.
|
||||
Returns:
|
||||
QueryResponse: A Pydantic model containing the result of the query processing.
|
||||
If include_references=True, also includes reference list.
|
||||
If a string is returned (e.g., cache hit), it's directly returned.
|
||||
Otherwise, an async generator may be used to build the response.
|
||||
|
||||
Raises:
|
||||
HTTPException: Raised when an error occurs during the request handling process,
|
||||
with status code 500 and detail containing the exception message.
|
||||
"""
|
||||
try:
|
||||
param = request.to_query_params(False)
|
||||
response = await rag.aquery(request.query, param=param)
|
||||
param = request.to_query_params(
|
||||
False
|
||||
) # Ensure stream=False for non-streaming endpoint
|
||||
# Force stream=False for /query endpoint regardless of include_references setting
|
||||
param.stream = False
|
||||
|
||||
# Get reference list if requested
|
||||
reference_list = None
|
||||
# Unified approach: always use aquery_llm for both cases
|
||||
result = await rag.aquery_llm(request.query, param=param)
|
||||
|
||||
# Extract LLM response and references from unified result
|
||||
llm_response = result.get("llm_response", {})
|
||||
references = result.get("data", {}).get("references", [])
|
||||
|
||||
# Get the non-streaming response content
|
||||
response_content = llm_response.get("content", "")
|
||||
if not response_content:
|
||||
response_content = "No relevant context found for the query."
|
||||
|
||||
# Return response with or without references based on request
|
||||
if request.include_references:
|
||||
try:
|
||||
# Use aquery_data to get reference list independently
|
||||
data_result = await rag.aquery_data(request.query, param=param)
|
||||
if isinstance(data_result, dict) and "data" in data_result:
|
||||
reference_list = data_result["data"].get("references", [])
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to get reference list: {str(e)}")
|
||||
reference_list = []
|
||||
|
||||
# Process response and return with optional references
|
||||
if isinstance(response, str):
|
||||
return QueryResponse(response=response, references=reference_list)
|
||||
elif isinstance(response, dict):
|
||||
result = json.dumps(response, indent=2)
|
||||
return QueryResponse(response=result, references=reference_list)
|
||||
return QueryResponse(response=response_content, references=references)
|
||||
else:
|
||||
return QueryResponse(response=str(response), references=reference_list)
|
||||
return QueryResponse(response=response_content, references=None)
|
||||
except Exception as e:
|
||||
trace_exception(e)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
|
@ -197,7 +195,8 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
|
|||
@router.post("/query/stream", dependencies=[Depends(combined_auth)])
|
||||
async def query_text_stream(request: QueryRequest):
|
||||
"""
|
||||
This endpoint performs a retrieval-augmented generation (RAG) query and streams the response.
|
||||
This endpoint performs RAG query with streaming response.
|
||||
Streaming can be turn off by setting stream=False in QueryRequest.
|
||||
|
||||
The streaming response includes:
|
||||
1. Reference list (sent first as a single message, if include_references=True)
|
||||
|
|
@ -213,49 +212,42 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
|
|||
- Error messages: {"error": "..."} - If any errors occur
|
||||
"""
|
||||
try:
|
||||
param = request.to_query_params(True)
|
||||
response = await rag.aquery(request.query, param=param)
|
||||
param = request.to_query_params(
|
||||
True
|
||||
) # Ensure stream=True for streaming endpoint
|
||||
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
async def stream_generator():
|
||||
# Get reference list if requested (default is True for backward compatibility)
|
||||
reference_list = []
|
||||
if request.include_references:
|
||||
try:
|
||||
# Use aquery_data to get reference list independently
|
||||
data_param = request.to_query_params(
|
||||
False
|
||||
) # Non-streaming for data
|
||||
data_result = await rag.aquery_data(
|
||||
request.query, param=data_param
|
||||
)
|
||||
if isinstance(data_result, dict) and "data" in data_result:
|
||||
reference_list = data_result["data"].get("references", [])
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to get reference list: {str(e)}")
|
||||
reference_list = []
|
||||
# Unified approach: always use aquery_llm for all cases
|
||||
result = await rag.aquery_llm(request.query, param=param)
|
||||
|
||||
# Send reference list first (if requested)
|
||||
if request.include_references:
|
||||
yield f"{json.dumps({'references': reference_list})}\n"
|
||||
# Extract references and LLM response from unified result
|
||||
references = result.get("data", {}).get("references", [])
|
||||
llm_response = result.get("llm_response", {})
|
||||
|
||||
# Then stream the response content
|
||||
if isinstance(response, str):
|
||||
# If it's a string, send it all at once
|
||||
yield f"{json.dumps({'response': response})}\n"
|
||||
elif response is None:
|
||||
# Handle None response (e.g., when only_need_context=True but no context found)
|
||||
yield f"{json.dumps({'response': 'No relevant context found for the query.'})}\n"
|
||||
# Send reference list first if requested
|
||||
if request.include_references:
|
||||
yield f"{json.dumps({'references': references})}\n"
|
||||
|
||||
# Then stream the LLM response content
|
||||
if llm_response.get("is_streaming"):
|
||||
response_stream = llm_response.get("response_iterator")
|
||||
if response_stream:
|
||||
try:
|
||||
async for chunk in response_stream:
|
||||
if chunk: # Only send non-empty content
|
||||
yield f"{json.dumps({'response': chunk})}\n"
|
||||
except Exception as e:
|
||||
logging.error(f"Streaming error: {str(e)}")
|
||||
yield f"{json.dumps({'error': str(e)})}\n"
|
||||
else:
|
||||
# If it's an async generator, send chunks one by one
|
||||
try:
|
||||
async for chunk in response:
|
||||
if chunk: # Only send non-empty content
|
||||
yield f"{json.dumps({'response': chunk})}\n"
|
||||
except Exception as e:
|
||||
logging.error(f"Streaming error: {str(e)}")
|
||||
yield f"{json.dumps({'error': str(e)})}\n"
|
||||
# Non-streaming response (fallback)
|
||||
response_content = llm_response.get("content", "")
|
||||
if response_content:
|
||||
yield f"{json.dumps({'response': response_content})}\n"
|
||||
else:
|
||||
yield f"{json.dumps({'response': 'No relevant context found for the query.'})}\n"
|
||||
|
||||
return StreamingResponse(
|
||||
stream_generator(),
|
||||
|
|
|
|||
|
|
@ -2062,74 +2062,53 @@ class LightRAG:
|
|||
system_prompt: str | None = None,
|
||||
) -> str | AsyncIterator[str]:
|
||||
"""
|
||||
Perform a async query.
|
||||
Perform a async query (backward compatibility wrapper).
|
||||
|
||||
This function is now a wrapper around aquery_llm that maintains backward compatibility
|
||||
by returning only the LLM response content in the original format.
|
||||
|
||||
Args:
|
||||
query (str): The query to be executed.
|
||||
param (QueryParam): Configuration parameters for query execution.
|
||||
If param.model_func is provided, it will be used instead of the global model.
|
||||
prompt (Optional[str]): Custom prompts for fine-tuned control over the system's behavior. Defaults to None, which uses PROMPTS["rag_response"].
|
||||
system_prompt (Optional[str]): Custom prompts for fine-tuned control over the system's behavior. Defaults to None, which uses PROMPTS["rag_response"].
|
||||
|
||||
Returns:
|
||||
str: The result of the query execution.
|
||||
str | AsyncIterator[str]: The LLM response content.
|
||||
- Non-streaming: Returns str
|
||||
- Streaming: Returns AsyncIterator[str]
|
||||
"""
|
||||
# If a custom model is provided in param, temporarily update global config
|
||||
global_config = asdict(self)
|
||||
# Call the new aquery_llm function to get complete results
|
||||
result = await self.aquery_llm(query, param, system_prompt)
|
||||
|
||||
query_result = None
|
||||
# Extract and return only the LLM response for backward compatibility
|
||||
llm_response = result.get("llm_response", {})
|
||||
|
||||
if param.mode in ["local", "global", "hybrid", "mix"]:
|
||||
query_result = await kg_query(
|
||||
query.strip(),
|
||||
self.chunk_entity_relation_graph,
|
||||
self.entities_vdb,
|
||||
self.relationships_vdb,
|
||||
self.text_chunks,
|
||||
param,
|
||||
global_config,
|
||||
hashing_kv=self.llm_response_cache,
|
||||
system_prompt=system_prompt,
|
||||
chunks_vdb=self.chunks_vdb,
|
||||
)
|
||||
elif param.mode == "naive":
|
||||
query_result = await naive_query(
|
||||
query.strip(),
|
||||
self.chunks_vdb,
|
||||
param,
|
||||
global_config,
|
||||
hashing_kv=self.llm_response_cache,
|
||||
system_prompt=system_prompt,
|
||||
)
|
||||
elif param.mode == "bypass":
|
||||
# Bypass mode: directly use LLM without knowledge retrieval
|
||||
use_llm_func = param.model_func or global_config["llm_model_func"]
|
||||
# Apply higher priority (8) to entity/relation summary tasks
|
||||
use_llm_func = partial(use_llm_func, _priority=8)
|
||||
|
||||
param.stream = True if param.stream is None else param.stream
|
||||
response = await use_llm_func(
|
||||
query.strip(),
|
||||
system_prompt=system_prompt,
|
||||
history_messages=param.conversation_history,
|
||||
enable_cot=True,
|
||||
stream=param.stream,
|
||||
)
|
||||
# Create QueryResult for bypass mode
|
||||
query_result = QueryResult(
|
||||
content=response if not param.stream else None,
|
||||
response_iterator=response if param.stream else None,
|
||||
is_streaming=param.stream,
|
||||
)
|
||||
if llm_response.get("is_streaming"):
|
||||
return llm_response.get("response_iterator")
|
||||
else:
|
||||
raise ValueError(f"Unknown mode {param.mode}")
|
||||
return llm_response.get("content", "")
|
||||
|
||||
await self._query_done()
|
||||
def query_data(
|
||||
self,
|
||||
query: str,
|
||||
param: QueryParam = QueryParam(),
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Synchronous data retrieval API: returns structured retrieval results without LLM generation.
|
||||
|
||||
# Return appropriate response based on streaming mode
|
||||
if query_result.is_streaming:
|
||||
return query_result.response_iterator
|
||||
else:
|
||||
return query_result.content
|
||||
This function is the synchronous version of aquery_data, providing the same functionality
|
||||
for users who prefer synchronous interfaces.
|
||||
|
||||
Args:
|
||||
query: Query text for retrieval.
|
||||
param: Query parameters controlling retrieval behavior (same as aquery).
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: Same structured data result as aquery_data.
|
||||
"""
|
||||
loop = always_get_an_event_loop()
|
||||
return loop.run_until_complete(self.aquery_data(query, param))
|
||||
|
||||
async def aquery_data(
|
||||
self,
|
||||
|
|
@ -2323,6 +2302,162 @@ class LightRAG:
|
|||
await self._query_done()
|
||||
return final_data
|
||||
|
||||
async def aquery_llm(
|
||||
self,
|
||||
query: str,
|
||||
param: QueryParam = QueryParam(),
|
||||
system_prompt: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Asynchronous complete query API: returns structured retrieval results with LLM generation.
|
||||
|
||||
This function performs a single query operation and returns both structured data and LLM response,
|
||||
based on the original aquery logic to avoid duplicate calls.
|
||||
|
||||
Args:
|
||||
query: Query text for retrieval and LLM generation.
|
||||
param: Query parameters controlling retrieval and LLM behavior.
|
||||
system_prompt: Optional custom system prompt for LLM generation.
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: Complete response with structured data and LLM response.
|
||||
"""
|
||||
global_config = asdict(self)
|
||||
|
||||
try:
|
||||
query_result = None
|
||||
|
||||
if param.mode in ["local", "global", "hybrid", "mix"]:
|
||||
query_result = await kg_query(
|
||||
query.strip(),
|
||||
self.chunk_entity_relation_graph,
|
||||
self.entities_vdb,
|
||||
self.relationships_vdb,
|
||||
self.text_chunks,
|
||||
param,
|
||||
global_config,
|
||||
hashing_kv=self.llm_response_cache,
|
||||
system_prompt=system_prompt,
|
||||
chunks_vdb=self.chunks_vdb,
|
||||
)
|
||||
elif param.mode == "naive":
|
||||
query_result = await naive_query(
|
||||
query.strip(),
|
||||
self.chunks_vdb,
|
||||
param,
|
||||
global_config,
|
||||
hashing_kv=self.llm_response_cache,
|
||||
system_prompt=system_prompt,
|
||||
)
|
||||
elif param.mode == "bypass":
|
||||
# Bypass mode: directly use LLM without knowledge retrieval
|
||||
use_llm_func = param.model_func or global_config["llm_model_func"]
|
||||
# Apply higher priority (8) to entity/relation summary tasks
|
||||
use_llm_func = partial(use_llm_func, _priority=8)
|
||||
|
||||
param.stream = True if param.stream is None else param.stream
|
||||
response = await use_llm_func(
|
||||
query.strip(),
|
||||
system_prompt=system_prompt,
|
||||
history_messages=param.conversation_history,
|
||||
enable_cot=True,
|
||||
stream=param.stream,
|
||||
)
|
||||
if type(response) is str:
|
||||
return {
|
||||
"status": "success",
|
||||
"message": "Bypass mode LLM non streaming response",
|
||||
"data": {},
|
||||
"metadata": {},
|
||||
"llm_response": {
|
||||
"content": response,
|
||||
"response_iterator": None,
|
||||
"is_streaming": False,
|
||||
},
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"status": "success",
|
||||
"message": "Bypass mode LLM streaming response",
|
||||
"data": {},
|
||||
"metadata": {},
|
||||
"llm_response": {
|
||||
"content": None,
|
||||
"response_iterator": response,
|
||||
"is_streaming": True,
|
||||
},
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unknown mode {param.mode}")
|
||||
|
||||
await self._query_done()
|
||||
|
||||
# Check if query_result is None
|
||||
if query_result is None:
|
||||
return {
|
||||
"status": "failure",
|
||||
"message": "Query returned no results",
|
||||
"data": {},
|
||||
"metadata": {},
|
||||
"llm_response": {
|
||||
"content": None,
|
||||
"response_iterator": None,
|
||||
"is_streaming": False,
|
||||
},
|
||||
}
|
||||
|
||||
# Extract structured data from query result
|
||||
raw_data = query_result.raw_data if query_result else {}
|
||||
raw_data["llm_response"] = {
|
||||
"content": query_result.content
|
||||
if not query_result.is_streaming
|
||||
else None,
|
||||
"response_iterator": query_result.response_iterator
|
||||
if query_result.is_streaming
|
||||
else None,
|
||||
"is_streaming": query_result.is_streaming,
|
||||
}
|
||||
|
||||
return raw_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Query failed: {e}")
|
||||
# Return error response
|
||||
return {
|
||||
"status": "failure",
|
||||
"message": f"Query failed: {str(e)}",
|
||||
"data": {},
|
||||
"metadata": {},
|
||||
"llm_response": {
|
||||
"content": None,
|
||||
"response_iterator": None,
|
||||
"is_streaming": False,
|
||||
},
|
||||
}
|
||||
|
||||
def query_llm(
|
||||
self,
|
||||
query: str,
|
||||
param: QueryParam = QueryParam(),
|
||||
system_prompt: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Synchronous complete query API: returns structured retrieval results with LLM generation.
|
||||
|
||||
This function is the synchronous version of aquery_llm, providing the same functionality
|
||||
for users who prefer synchronous interfaces.
|
||||
|
||||
Args:
|
||||
query: Query text for retrieval and LLM generation.
|
||||
param: Query parameters controlling retrieval and LLM behavior.
|
||||
system_prompt: Optional custom system prompt for LLM generation.
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: Same complete response format as aquery_llm.
|
||||
"""
|
||||
loop = always_get_an_event_loop()
|
||||
return loop.run_until_complete(self.aquery_llm(query, param, system_prompt))
|
||||
|
||||
async def _query_done(self):
|
||||
await self.llm_response_cache.index_done_callback()
|
||||
|
||||
|
|
|
|||
|
|
@ -2236,38 +2236,6 @@ async def extract_entities(
|
|||
return chunk_results
|
||||
|
||||
|
||||
@overload
|
||||
async def kg_query(
|
||||
query: str,
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
entities_vdb: BaseVectorStorage,
|
||||
relationships_vdb: BaseVectorStorage,
|
||||
text_chunks_db: BaseKVStorage,
|
||||
query_param: QueryParam,
|
||||
global_config: dict[str, str],
|
||||
hashing_kv: BaseKVStorage | None = None,
|
||||
system_prompt: str | None = None,
|
||||
chunks_vdb: BaseVectorStorage = None,
|
||||
return_raw_data: Literal[True] = False,
|
||||
) -> dict[str, Any]: ...
|
||||
|
||||
|
||||
@overload
|
||||
async def kg_query(
|
||||
query: str,
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
entities_vdb: BaseVectorStorage,
|
||||
relationships_vdb: BaseVectorStorage,
|
||||
text_chunks_db: BaseKVStorage,
|
||||
query_param: QueryParam,
|
||||
global_config: dict[str, str],
|
||||
hashing_kv: BaseKVStorage | None = None,
|
||||
system_prompt: str | None = None,
|
||||
chunks_vdb: BaseVectorStorage = None,
|
||||
return_raw_data: Literal[False] = False,
|
||||
) -> str | AsyncIterator[str]: ...
|
||||
|
||||
|
||||
async def kg_query(
|
||||
query: str,
|
||||
knowledge_graph_inst: BaseGraphStorage,
|
||||
|
|
@ -2308,7 +2276,6 @@ async def kg_query(
|
|||
- stream=True: response_iterator contains streaming response, raw_data contains complete data
|
||||
- default: content contains LLM response text, raw_data contains complete data
|
||||
"""
|
||||
|
||||
if not query:
|
||||
return QueryResult(content=PROMPTS["fail_response"])
|
||||
|
||||
|
|
@ -2319,32 +2286,6 @@ async def kg_query(
|
|||
# Apply higher priority (5) to query relation LLM function
|
||||
use_model_func = partial(use_model_func, _priority=5)
|
||||
|
||||
# Handle cache
|
||||
args_hash = compute_args_hash(
|
||||
query_param.mode,
|
||||
query,
|
||||
query_param.response_type,
|
||||
query_param.top_k,
|
||||
query_param.chunk_top_k,
|
||||
query_param.max_entity_tokens,
|
||||
query_param.max_relation_tokens,
|
||||
query_param.max_total_tokens,
|
||||
query_param.hl_keywords or [],
|
||||
query_param.ll_keywords or [],
|
||||
query_param.user_prompt or "",
|
||||
query_param.enable_rerank,
|
||||
)
|
||||
cached_result = await handle_cache(
|
||||
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
|
||||
)
|
||||
if (
|
||||
cached_result is not None
|
||||
and not query_param.only_need_context
|
||||
and not query_param.only_need_prompt
|
||||
):
|
||||
cached_response, _ = cached_result # Extract content, ignore timestamp
|
||||
return QueryResult(content=cached_response)
|
||||
|
||||
hl_keywords, ll_keywords = await get_keywords_from_query(
|
||||
query, query_param, global_config, hashing_kv
|
||||
)
|
||||
|
|
@ -2417,29 +2358,41 @@ async def kg_query(
|
|||
f"[kg_query] Sending to LLM: {len_of_prompts:,} tokens (Query: {len(tokenizer.encode(query))}, System: {len(tokenizer.encode(sys_prompt))})"
|
||||
)
|
||||
|
||||
response = await use_model_func(
|
||||
user_query,
|
||||
system_prompt=sys_prompt,
|
||||
history_messages=query_param.conversation_history,
|
||||
enable_cot=True,
|
||||
stream=query_param.stream,
|
||||
# Handle cache
|
||||
args_hash = compute_args_hash(
|
||||
query_param.mode,
|
||||
query,
|
||||
query_param.response_type,
|
||||
query_param.top_k,
|
||||
query_param.chunk_top_k,
|
||||
query_param.max_entity_tokens,
|
||||
query_param.max_relation_tokens,
|
||||
query_param.max_total_tokens,
|
||||
query_param.hl_keywords or [],
|
||||
query_param.ll_keywords or [],
|
||||
query_param.user_prompt or "",
|
||||
query_param.enable_rerank,
|
||||
)
|
||||
|
||||
# Return unified result based on actual response type
|
||||
if isinstance(response, str):
|
||||
# Non-streaming response (string)
|
||||
if len(response) > len(sys_prompt):
|
||||
response = (
|
||||
response.replace(sys_prompt, "")
|
||||
.replace("user", "")
|
||||
.replace("model", "")
|
||||
.replace(query, "")
|
||||
.replace("<system>", "")
|
||||
.replace("</system>", "")
|
||||
.strip()
|
||||
)
|
||||
cached_result = await handle_cache(
|
||||
hashing_kv, args_hash, user_query, query_param.mode, cache_type="query"
|
||||
)
|
||||
|
||||
if cached_result is not None:
|
||||
cached_response, _ = cached_result # Extract content, ignore timestamp
|
||||
logger.info(
|
||||
" == LLM cache == Query cache hit, using cached response as query result"
|
||||
)
|
||||
response = cached_response
|
||||
else:
|
||||
response = await use_model_func(
|
||||
user_query,
|
||||
system_prompt=sys_prompt,
|
||||
history_messages=query_param.conversation_history,
|
||||
enable_cot=True,
|
||||
stream=query_param.stream,
|
||||
)
|
||||
|
||||
# Cache response
|
||||
if hashing_kv and hashing_kv.global_config.get("enable_llm_cache"):
|
||||
queryparam_dict = {
|
||||
"mode": query_param.mode,
|
||||
|
|
@ -2466,6 +2419,20 @@ async def kg_query(
|
|||
),
|
||||
)
|
||||
|
||||
# Return unified result based on actual response type
|
||||
if isinstance(response, str):
|
||||
# Non-streaming response (string)
|
||||
if len(response) > len(sys_prompt):
|
||||
response = (
|
||||
response.replace(sys_prompt, "")
|
||||
.replace("user", "")
|
||||
.replace("model", "")
|
||||
.replace(query, "")
|
||||
.replace("<system>", "")
|
||||
.replace("</system>", "")
|
||||
.strip()
|
||||
)
|
||||
|
||||
return QueryResult(content=response, raw_data=context_result.raw_data)
|
||||
else:
|
||||
# Streaming response (AsyncIterator)
|
||||
|
|
@ -4052,29 +4019,6 @@ async def naive_query(
|
|||
# Apply higher priority (5) to query relation LLM function
|
||||
use_model_func = partial(use_model_func, _priority=5)
|
||||
|
||||
# Handle cache
|
||||
args_hash = compute_args_hash(
|
||||
query_param.mode,
|
||||
query,
|
||||
query_param.response_type,
|
||||
query_param.top_k,
|
||||
query_param.chunk_top_k,
|
||||
query_param.max_entity_tokens,
|
||||
query_param.max_relation_tokens,
|
||||
query_param.max_total_tokens,
|
||||
query_param.hl_keywords or [],
|
||||
query_param.ll_keywords or [],
|
||||
query_param.user_prompt or "",
|
||||
query_param.enable_rerank,
|
||||
)
|
||||
cached_result = await handle_cache(
|
||||
hashing_kv, args_hash, query, query_param.mode, cache_type="query"
|
||||
)
|
||||
if cached_result is not None:
|
||||
cached_response, _ = cached_result # Extract content, ignore timestamp
|
||||
if not query_param.only_need_context and not query_param.only_need_prompt:
|
||||
return QueryResult(content=cached_response)
|
||||
|
||||
tokenizer: Tokenizer = global_config["tokenizer"]
|
||||
if not tokenizer:
|
||||
logger.error("Tokenizer not found in global configuration.")
|
||||
|
|
@ -4211,35 +4155,39 @@ async def naive_query(
|
|||
prompt_content = "\n\n".join([sys_prompt, "---User Query---", user_query])
|
||||
return QueryResult(content=prompt_content, raw_data=raw_data)
|
||||
|
||||
len_of_prompts = len(tokenizer.encode(query + sys_prompt))
|
||||
logger.debug(
|
||||
f"[naive_query] Sending to LLM: {len_of_prompts:,} tokens (Query: {len(tokenizer.encode(query))}, System: {len(tokenizer.encode(sys_prompt))})"
|
||||
# Handle cache
|
||||
args_hash = compute_args_hash(
|
||||
query_param.mode,
|
||||
query,
|
||||
query_param.response_type,
|
||||
query_param.top_k,
|
||||
query_param.chunk_top_k,
|
||||
query_param.max_entity_tokens,
|
||||
query_param.max_relation_tokens,
|
||||
query_param.max_total_tokens,
|
||||
query_param.hl_keywords or [],
|
||||
query_param.ll_keywords or [],
|
||||
query_param.user_prompt or "",
|
||||
query_param.enable_rerank,
|
||||
)
|
||||
|
||||
response = await use_model_func(
|
||||
user_query,
|
||||
system_prompt=sys_prompt,
|
||||
history_messages=query_param.conversation_history,
|
||||
enable_cot=True,
|
||||
stream=query_param.stream,
|
||||
cached_result = await handle_cache(
|
||||
hashing_kv, args_hash, user_query, query_param.mode, cache_type="query"
|
||||
)
|
||||
if cached_result is not None:
|
||||
cached_response, _ = cached_result # Extract content, ignore timestamp
|
||||
logger.info(
|
||||
" == LLM cache == Query cache hit, using cached response as query result"
|
||||
)
|
||||
response = cached_response
|
||||
else:
|
||||
response = await use_model_func(
|
||||
user_query,
|
||||
system_prompt=sys_prompt,
|
||||
history_messages=query_param.conversation_history,
|
||||
enable_cot=True,
|
||||
stream=query_param.stream,
|
||||
)
|
||||
|
||||
# Return unified result based on actual response type
|
||||
if isinstance(response, str):
|
||||
# Non-streaming response (string)
|
||||
if len(response) > len(sys_prompt):
|
||||
response = (
|
||||
response[len(sys_prompt) :]
|
||||
.replace(sys_prompt, "")
|
||||
.replace("user", "")
|
||||
.replace("model", "")
|
||||
.replace(query, "")
|
||||
.replace("<system>", "")
|
||||
.replace("</system>", "")
|
||||
.strip()
|
||||
)
|
||||
|
||||
# Cache response
|
||||
if hashing_kv and hashing_kv.global_config.get("enable_llm_cache"):
|
||||
queryparam_dict = {
|
||||
"mode": query_param.mode,
|
||||
|
|
@ -4266,6 +4214,21 @@ async def naive_query(
|
|||
),
|
||||
)
|
||||
|
||||
# Return unified result based on actual response type
|
||||
if isinstance(response, str):
|
||||
# Non-streaming response (string)
|
||||
if len(response) > len(sys_prompt):
|
||||
response = (
|
||||
response[len(sys_prompt) :]
|
||||
.replace(sys_prompt, "")
|
||||
.replace("user", "")
|
||||
.replace("model", "")
|
||||
.replace(query, "")
|
||||
.replace("<system>", "")
|
||||
.replace("</system>", "")
|
||||
.strip()
|
||||
)
|
||||
|
||||
return QueryResult(content=response, raw_data=raw_data)
|
||||
else:
|
||||
# Streaming response (AsyncIterator)
|
||||
|
|
|
|||
|
|
@ -1102,7 +1102,9 @@ async def save_to_cache(hashing_kv, cache_data: CacheData):
|
|||
if existing_cache:
|
||||
existing_content = existing_cache.get("return")
|
||||
if existing_content == cache_data.content:
|
||||
logger.info(f"Cache content unchanged for {flattened_key}, skipping update")
|
||||
logger.warning(
|
||||
f"Cache duplication detected for {flattened_key}, skipping update"
|
||||
)
|
||||
return
|
||||
|
||||
# Create cache entry with flattened structure
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue