Add reference list support to query API endpoints with unified result handling

• Add include_references param to QueryRequest
• Extend QueryResponse with references field
• Create unified QueryResult data structures
• Refactor kg_query and naive_query functions
• Update streaming to send references first
This commit is contained in:
yangdx 2025-09-25 16:21:42 +08:00
parent 64c38864e5
commit b08b8a6a6a
4 changed files with 504 additions and 349 deletions

View file

@ -88,6 +88,11 @@ class QueryRequest(BaseModel):
description="Enable reranking for retrieved text chunks. If True but no rerank model is configured, a warning will be issued. Default is True.",
)
include_references: Optional[bool] = Field(
default=True,
description="If True, includes reference list in responses. Affects /query and /query/stream endpoints. /query/data always includes references.",
)
@field_validator("query", mode="after")
@classmethod
def query_strip_after(cls, query: str) -> str:
@ -122,6 +127,10 @@ class QueryResponse(BaseModel):
response: str = Field(
description="The generated response",
)
references: Optional[List[Dict[str, str]]] = Field(
default=None,
description="Reference list (only included when include_references=True, /query/data always includes references.)",
)
class QueryDataResponse(BaseModel):
@ -149,6 +158,7 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
request (QueryRequest): The request object containing the query parameters.
Returns:
QueryResponse: A Pydantic model containing the result of the query processing.
If include_references=True, also includes reference list.
If a string is returned (e.g., cache hit), it's directly returned.
Otherwise, an async generator may be used to build the response.
@ -160,15 +170,26 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
param = request.to_query_params(False)
response = await rag.aquery(request.query, param=param)
# If response is a string (e.g. cache hit), return directly
if isinstance(response, str):
return QueryResponse(response=response)
# Get reference list if requested
reference_list = None
if request.include_references:
try:
# Use aquery_data to get reference list independently
data_result = await rag.aquery_data(request.query, param=param)
if isinstance(data_result, dict) and "data" in data_result:
reference_list = data_result["data"].get("references", [])
except Exception as e:
logging.warning(f"Failed to get reference list: {str(e)}")
reference_list = []
if isinstance(response, dict):
# Process response and return with optional references
if isinstance(response, str):
return QueryResponse(response=response, references=reference_list)
elif isinstance(response, dict):
result = json.dumps(response, indent=2)
return QueryResponse(response=result)
return QueryResponse(response=result, references=reference_list)
else:
return QueryResponse(response=str(response))
return QueryResponse(response=str(response), references=reference_list)
except Exception as e:
trace_exception(e)
raise HTTPException(status_code=500, detail=str(e))
@ -177,13 +198,19 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
async def query_text_stream(request: QueryRequest):
"""
This endpoint performs a retrieval-augmented generation (RAG) query and streams the response.
The streaming response includes:
1. Reference list (sent first as a single message, if include_references=True)
2. LLM response content (streamed as multiple chunks)
Args:
request (QueryRequest): The request object containing the query parameters.
optional_api_key (Optional[str], optional): An optional API key for authentication. Defaults to None.
Returns:
StreamingResponse: A streaming response containing the RAG query results.
StreamingResponse: A streaming response containing:
- First message: {"references": [...]} - Complete reference list (if requested)
- Subsequent messages: {"response": "..."} - LLM response chunks
- Error messages: {"error": "..."} - If any errors occur
"""
try:
param = request.to_query_params(True)
@ -192,6 +219,24 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
from fastapi.responses import StreamingResponse
async def stream_generator():
# Get reference list if requested (default is True for backward compatibility)
reference_list = []
if request.include_references:
try:
# Use aquery_data to get reference list independently
data_param = request.to_query_params(False) # Non-streaming for data
data_result = await rag.aquery_data(request.query, param=data_param)
if isinstance(data_result, dict) and "data" in data_result:
reference_list = data_result["data"].get("references", [])
except Exception as e:
logging.warning(f"Failed to get reference list: {str(e)}")
reference_list = []
# Send reference list first (if requested)
if request.include_references:
yield f"{json.dumps({'references': reference_list})}\n"
# Then stream the response content
if isinstance(response, str):
# If it's a string, send it all at once
yield f"{json.dumps({'response': response})}\n"

View file

@ -11,6 +11,10 @@ from typing import (
TypedDict,
TypeVar,
Callable,
Optional,
Dict,
List,
AsyncIterator,
)
from .utils import EmbeddingFunc
from .types import KnowledgeGraph
@ -158,6 +162,12 @@ class QueryParam:
Default is True to enable reranking when rerank model is available.
"""
include_references: bool = False
"""If True, includes reference list in the response for supported endpoints.
This parameter controls whether the API response includes a references field
containing citation information for the retrieved content.
"""
@dataclass
class StorageNameSpace(ABC):
@ -814,3 +824,65 @@ class DeletionResult:
message: str
status_code: int = 200
file_path: str | None = None
# Unified Query Result Data Structures for Reference List Support
@dataclass
class QueryResult:
"""
Unified query result data structure for all query modes.
Attributes:
content: Text content for non-streaming responses
response_iterator: Streaming response iterator for streaming responses
raw_data: Complete structured data including references and metadata
is_streaming: Whether this is a streaming result
"""
content: Optional[str] = None
response_iterator: Optional[AsyncIterator[str]] = None
raw_data: Optional[Dict[str, Any]] = None
is_streaming: bool = False
@property
def reference_list(self) -> List[Dict[str, str]]:
"""
Convenient property to extract reference list from raw_data.
Returns:
List[Dict[str, str]]: Reference list in format:
[{"reference_id": "1", "file_path": "/path/to/file.pdf"}, ...]
"""
if self.raw_data:
return self.raw_data.get("data", {}).get("references", [])
return []
@property
def metadata(self) -> Dict[str, Any]:
"""
Convenient property to extract metadata from raw_data.
Returns:
Dict[str, Any]: Query metadata including query_mode, keywords, etc.
"""
if self.raw_data:
return self.raw_data.get("metadata", {})
return {}
@dataclass
class QueryContextResult:
"""
Unified query context result data structure.
Attributes:
context: LLM context string
raw_data: Complete structured data including reference_list
"""
context: str
raw_data: Dict[str, Any]
@property
def reference_list(self) -> List[Dict[str, str]]:
"""Convenient property to extract reference list from raw_data."""
return self.raw_data.get("data", {}).get("references", [])

View file

@ -71,6 +71,7 @@ from lightrag.base import (
StoragesStatus,
DeletionResult,
OllamaServerInfos,
QueryResult,
)
from lightrag.namespace import NameSpace
from lightrag.operate import (
@ -2075,8 +2076,10 @@ class LightRAG:
# If a custom model is provided in param, temporarily update global config
global_config = asdict(self)
query_result = None
if param.mode in ["local", "global", "hybrid", "mix"]:
response = await kg_query(
query_result = await kg_query(
query.strip(),
self.chunk_entity_relation_graph,
self.entities_vdb,
@ -2089,7 +2092,7 @@ class LightRAG:
chunks_vdb=self.chunks_vdb,
)
elif param.mode == "naive":
response = await naive_query(
query_result = await naive_query(
query.strip(),
self.chunks_vdb,
param,
@ -2111,10 +2114,22 @@ class LightRAG:
enable_cot=True,
stream=param.stream,
)
# Create QueryResult for bypass mode
query_result = QueryResult(
content=response if not param.stream else None,
response_iterator=response if param.stream else None,
is_streaming=param.stream
)
else:
raise ValueError(f"Unknown mode {param.mode}")
await self._query_done()
return response
# Return appropriate response based on streaming mode
if query_result.is_streaming:
return query_result.response_iterator
else:
return query_result.content
async def aquery_data(
self,
@ -2229,61 +2244,84 @@ class LightRAG:
"""
global_config = asdict(self)
if param.mode in ["local", "global", "hybrid", "mix"]:
logger.debug(f"[aquery_data] Using kg_query for mode: {param.mode}")
final_data = await kg_query(
# Create a copy of param to avoid modifying the original
data_param = QueryParam(
mode=param.mode,
only_need_context=True, # Skip LLM generation, only get context and data
only_need_prompt=False,
response_type=param.response_type,
stream=False, # Data retrieval doesn't need streaming
top_k=param.top_k,
chunk_top_k=param.chunk_top_k,
max_entity_tokens=param.max_entity_tokens,
max_relation_tokens=param.max_relation_tokens,
max_total_tokens=param.max_total_tokens,
hl_keywords=param.hl_keywords,
ll_keywords=param.ll_keywords,
conversation_history=param.conversation_history,
history_turns=param.history_turns,
model_func=param.model_func,
user_prompt=param.user_prompt,
enable_rerank=param.enable_rerank,
)
query_result = None
if data_param.mode in ["local", "global", "hybrid", "mix"]:
logger.debug(f"[aquery_data] Using kg_query for mode: {data_param.mode}")
query_result = await kg_query(
query.strip(),
self.chunk_entity_relation_graph,
self.entities_vdb,
self.relationships_vdb,
self.text_chunks,
param,
data_param, # Use data_param with only_need_context=True
global_config,
hashing_kv=self.llm_response_cache,
system_prompt=None,
chunks_vdb=self.chunks_vdb,
return_raw_data=True, # Get final processed data
)
elif param.mode == "naive":
logger.debug(f"[aquery_data] Using naive_query for mode: {param.mode}")
final_data = await naive_query(
elif data_param.mode == "naive":
logger.debug(f"[aquery_data] Using naive_query for mode: {data_param.mode}")
query_result = await naive_query(
query.strip(),
self.chunks_vdb,
param,
data_param, # Use data_param with only_need_context=True
global_config,
hashing_kv=self.llm_response_cache,
system_prompt=None,
return_raw_data=True, # Get final processed data
)
elif param.mode == "bypass":
elif data_param.mode == "bypass":
logger.debug("[aquery_data] Using bypass mode")
# bypass mode returns empty data using convert_to_user_format
final_data = convert_to_user_format(
empty_raw_data = convert_to_user_format(
[], # no entities
[], # no relationships
[], # no chunks
[], # no references
"bypass",
)
query_result = QueryResult(
content="",
raw_data=empty_raw_data
)
else:
raise ValueError(f"Unknown mode {param.mode}")
raise ValueError(f"Unknown mode {data_param.mode}")
# Extract raw_data from QueryResult
final_data = query_result.raw_data if query_result else {}
# Log final result counts - adapt to new data format from convert_to_user_format
if isinstance(final_data, dict) and "data" in final_data:
# New format: data is nested under 'data' field
if final_data and "data" in final_data:
data_section = final_data["data"]
entities_count = len(data_section.get("entities", []))
relationships_count = len(data_section.get("relationships", []))
chunks_count = len(data_section.get("chunks", []))
logger.debug(
f"[aquery_data] Final result: {entities_count} entities, {relationships_count} relationships, {chunks_count} chunks"
)
else:
# Fallback for other formats
entities_count = len(final_data.get("entities", []))
relationships_count = len(final_data.get("relationships", []))
chunks_count = len(final_data.get("chunks", []))
logger.debug(
f"[aquery_data] Final result: {entities_count} entities, {relationships_count} relationships, {chunks_count} chunks"
)
logger.warning("[aquery_data] No data section found in query result")
await self._query_done()
return final_data

View file

@ -39,6 +39,8 @@ from .base import (
BaseVectorStorage,
TextChunkSchema,
QueryParam,
QueryResult,
QueryContextResult,
)
from .prompt import PROMPTS
from .constants import (
@ -2277,16 +2279,38 @@ async def kg_query(
hashing_kv: BaseKVStorage | None = None,
system_prompt: str | None = None,
chunks_vdb: BaseVectorStorage = None,
return_raw_data: bool = False,
) -> str | AsyncIterator[str] | dict[str, Any]:
) -> QueryResult:
"""
Execute knowledge graph query and return unified QueryResult object.
Args:
query: Query string
knowledge_graph_inst: Knowledge graph storage instance
entities_vdb: Entity vector database
relationships_vdb: Relationship vector database
text_chunks_db: Text chunks storage
query_param: Query parameters
global_config: Global configuration
hashing_kv: Cache storage
system_prompt: System prompt
chunks_vdb: Document chunks vector database
Returns:
QueryResult: Unified query result object containing:
- content: Non-streaming response text content
- response_iterator: Streaming response iterator
- raw_data: Complete structured data (including references and metadata)
- is_streaming: Whether this is a streaming result
Based on different query_param settings, different fields will be populated:
- only_need_context=True: content contains context string
- only_need_prompt=True: content contains complete prompt
- stream=True: response_iterator contains streaming response, raw_data contains complete data
- default: content contains LLM response text, raw_data contains complete data
"""
if not query:
if return_raw_data:
return {
"status": "failure",
"message": "Query string is empty.",
"data": {},
}
return PROMPTS["fail_response"]
return QueryResult(content=PROMPTS["fail_response"])
if query_param.model_func:
use_model_func = query_param.model_func
@ -2315,12 +2339,11 @@ async def kg_query(
)
if (
cached_result is not None
and not return_raw_data
and not query_param.only_need_context
and not query_param.only_need_prompt
):
cached_response, _ = cached_result # Extract content, ignore timestamp
return cached_response
return QueryResult(content=cached_response)
hl_keywords, ll_keywords = await get_keywords_from_query(
query, query_param, global_config, hashing_kv
@ -2339,53 +2362,13 @@ async def kg_query(
logger.warning(f"Forced low_level_keywords to origin query: {query}")
ll_keywords = [query]
else:
if return_raw_data:
return {
"status": "failure",
"message": "Both high_level_keywords and low_level_keywords are empty",
"data": {},
}
return PROMPTS["fail_response"]
return QueryResult(content=PROMPTS["fail_response"])
ll_keywords_str = ", ".join(ll_keywords) if ll_keywords else ""
hl_keywords_str = ", ".join(hl_keywords) if hl_keywords else ""
# If raw data is requested, get both context and raw data
if return_raw_data:
context_result = await _build_query_context(
query,
ll_keywords_str,
hl_keywords_str,
knowledge_graph_inst,
entities_vdb,
relationships_vdb,
text_chunks_db,
query_param,
chunks_vdb,
return_raw_data=True,
)
if isinstance(context_result, tuple):
context, raw_data = context_result
logger.debug(f"[kg_query] Context length: {len(context) if context else 0}")
logger.debug(
f"[kg_query] Raw data entities: {len(raw_data.get('entities', []))}, relationships: {len(raw_data.get('relationships', []))}, chunks: {len(raw_data.get('chunks', []))}"
)
return raw_data
else:
if not context_result:
return {
"status": "failure",
"message": "Query return empty data set.",
"data": {},
}
else:
raise ValueError(
"Fail to build raw data query result. Invalid return from _build_query_context"
)
# Build context (normal flow)
context = await _build_query_context(
# Build query context (unified interface)
context_result = await _build_query_context(
query,
ll_keywords_str,
hl_keywords_str,
@ -2397,14 +2380,20 @@ async def kg_query(
chunks_vdb,
)
if query_param.only_need_context and not query_param.only_need_prompt:
return context if context is not None else PROMPTS["fail_response"]
if context is None:
return PROMPTS["fail_response"]
if context_result is None:
return QueryResult(content=PROMPTS["fail_response"])
# Return different content based on query parameters
if query_param.only_need_context and not query_param.only_need_prompt:
return QueryResult(
content=context_result.context,
raw_data=context_result.raw_data
)
# Build system prompt
sys_prompt_temp = system_prompt if system_prompt else PROMPTS["rag_response"]
sys_prompt = sys_prompt_temp.format(
context_data=context,
context_data=context_result.context,
response_type=query_param.response_type,
)
@ -2415,8 +2404,13 @@ async def kg_query(
)
if query_param.only_need_prompt:
return "\n\n".join([sys_prompt, "---User Query---", user_query])
prompt_content = "\n\n".join([sys_prompt, "---User Query---", user_query])
return QueryResult(
content=prompt_content,
raw_data=context_result.raw_data
)
# Call LLM
tokenizer: Tokenizer = global_config["tokenizer"]
len_of_prompts = len(tokenizer.encode(query + sys_prompt))
logger.debug(
@ -2430,45 +2424,59 @@ async def kg_query(
enable_cot=True,
stream=query_param.stream,
)
if isinstance(response, str) and len(response) > len(sys_prompt):
response = (
response.replace(sys_prompt, "")
.replace("user", "")
.replace("model", "")
.replace(query, "")
.replace("<system>", "")
.replace("</system>", "")
.strip()
)
if hashing_kv.global_config.get("enable_llm_cache"):
# Save to cache with query parameters
queryparam_dict = {
"mode": query_param.mode,
"response_type": query_param.response_type,
"top_k": query_param.top_k,
"chunk_top_k": query_param.chunk_top_k,
"max_entity_tokens": query_param.max_entity_tokens,
"max_relation_tokens": query_param.max_relation_tokens,
"max_total_tokens": query_param.max_total_tokens,
"hl_keywords": query_param.hl_keywords or [],
"ll_keywords": query_param.ll_keywords or [],
"user_prompt": query_param.user_prompt or "",
"enable_rerank": query_param.enable_rerank,
}
await save_to_cache(
hashing_kv,
CacheData(
args_hash=args_hash,
content=response,
prompt=query,
mode=query_param.mode,
cache_type="query",
queryparam=queryparam_dict,
),
)
# Return unified result based on actual response type
if isinstance(response, str):
# Non-streaming response (string)
if len(response) > len(sys_prompt):
response = (
response.replace(sys_prompt, "")
.replace("user", "")
.replace("model", "")
.replace(query, "")
.replace("<system>", "")
.replace("</system>", "")
.strip()
)
return response
# Cache response
if hashing_kv and hashing_kv.global_config.get("enable_llm_cache"):
queryparam_dict = {
"mode": query_param.mode,
"response_type": query_param.response_type,
"top_k": query_param.top_k,
"chunk_top_k": query_param.chunk_top_k,
"max_entity_tokens": query_param.max_entity_tokens,
"max_relation_tokens": query_param.max_relation_tokens,
"max_total_tokens": query_param.max_total_tokens,
"hl_keywords": query_param.hl_keywords or [],
"ll_keywords": query_param.ll_keywords or [],
"user_prompt": query_param.user_prompt or "",
"enable_rerank": query_param.enable_rerank,
}
await save_to_cache(
hashing_kv,
CacheData(
args_hash=args_hash,
content=response,
prompt=query,
mode=query_param.mode,
cache_type="query",
queryparam=queryparam_dict,
),
)
return QueryResult(
content=response,
raw_data=context_result.raw_data
)
else:
# Streaming response (AsyncIterator)
return QueryResult(
response_iterator=response,
raw_data=context_result.raw_data,
is_streaming=True
)
async def get_keywords_from_query(
@ -3123,10 +3131,9 @@ async def _build_llm_context(
query_param: QueryParam,
global_config: dict[str, str],
chunk_tracking: dict = None,
return_raw_data: bool = False,
entity_id_to_original: dict = None,
relation_id_to_original: dict = None,
) -> str | tuple[str, dict[str, Any]]:
) -> tuple[str, dict[str, Any]]:
"""
Build the final LLM context string with token processing.
This includes dynamic token calculation and final chunk truncation.
@ -3134,22 +3141,17 @@ async def _build_llm_context(
tokenizer = global_config.get("tokenizer")
if not tokenizer:
logger.error("Missing tokenizer, cannot build LLM context")
if return_raw_data:
# Return empty raw data structure when no entities/relations
empty_raw_data = convert_to_user_format(
[],
[],
[],
[],
query_param.mode,
)
empty_raw_data["status"] = "failure"
empty_raw_data["message"] = "Missing tokenizer, cannot build LLM context."
return None, empty_raw_data
else:
logger.error("Tokenizer not found in global configuration.")
return None
# Return empty raw data structure when no tokenizer
empty_raw_data = convert_to_user_format(
[],
[],
[],
[],
query_param.mode,
)
empty_raw_data["status"] = "failure"
empty_raw_data["message"] = "Missing tokenizer, cannot build LLM context."
return "", empty_raw_data
# Get token limits
max_total_tokens = getattr(
@ -3268,20 +3270,17 @@ The reference documents list in Document Chunks(DC) is as follows (reference_id
# not necessary to use LLM to generate a response
if not entities_context and not relations_context:
if return_raw_data:
# Return empty raw data structure when no entities/relations
empty_raw_data = convert_to_user_format(
[],
[],
[],
[],
query_param.mode,
)
empty_raw_data["status"] = "failure"
empty_raw_data["message"] = "Query returned empty dataset."
return None, empty_raw_data
else:
return None
# Return empty raw data structure when no entities/relations
empty_raw_data = convert_to_user_format(
[],
[],
[],
[],
query_param.mode,
)
empty_raw_data["status"] = "failure"
empty_raw_data["message"] = "Query returned empty dataset."
return "", empty_raw_data
# output chunks tracking infomations
# format: <source><frequency>/<order> (e.g., E5/2 R2/1 C1/1)
@ -3342,26 +3341,23 @@ Document Chunks (DC) reference documents : (Each entry begins with [reference_id
"""
# If final data is requested, return both context and complete data structure
if return_raw_data:
logger.debug(
f"[_build_llm_context] Converting to user format: {len(entities_context)} entities, {len(relations_context)} relations, {len(truncated_chunks)} chunks"
)
final_data = convert_to_user_format(
entities_context,
relations_context,
truncated_chunks,
reference_list,
query_param.mode,
entity_id_to_original,
relation_id_to_original,
)
logger.debug(
f"[_build_llm_context] Final data after conversion: {len(final_data.get('entities', []))} entities, {len(final_data.get('relationships', []))} relationships, {len(final_data.get('chunks', []))} chunks"
)
return result, final_data
else:
return result
# Always return both context and complete data structure (unified approach)
logger.debug(
f"[_build_llm_context] Converting to user format: {len(entities_context)} entities, {len(relations_context)} relations, {len(truncated_chunks)} chunks"
)
final_data = convert_to_user_format(
entities_context,
relations_context,
truncated_chunks,
reference_list,
query_param.mode,
entity_id_to_original,
relation_id_to_original,
)
logger.debug(
f"[_build_llm_context] Final data after conversion: {len(final_data.get('entities', []))} entities, {len(final_data.get('relationships', []))} relationships, {len(final_data.get('chunks', []))} chunks"
)
return result, final_data
# Now let's update the old _build_query_context to use the new architecture
@ -3375,16 +3371,17 @@ async def _build_query_context(
text_chunks_db: BaseKVStorage,
query_param: QueryParam,
chunks_vdb: BaseVectorStorage = None,
return_raw_data: bool = False,
) -> str | None | tuple[str, dict[str, Any]]:
) -> QueryContextResult | None:
"""
Main query context building function using the new 4-stage architecture:
1. Search -> 2. Truncate -> 3. Merge chunks -> 4. Build LLM context
Returns unified QueryContextResult containing both context and raw_data.
"""
if not query:
logger.warning("Query is empty, skipping context building")
return ""
return None
# Stage 1: Pure search
search_result = await _perform_kg_search(
@ -3435,71 +3432,56 @@ async def _build_query_context(
return None
# Stage 4: Build final LLM context with dynamic token processing
# _build_llm_context now always returns tuple[str, dict]
context, raw_data = await _build_llm_context(
entities_context=truncation_result["entities_context"],
relations_context=truncation_result["relations_context"],
merged_chunks=merged_chunks,
query=query,
query_param=query_param,
global_config=text_chunks_db.global_config,
chunk_tracking=search_result["chunk_tracking"],
entity_id_to_original=truncation_result["entity_id_to_original"],
relation_id_to_original=truncation_result["relation_id_to_original"],
)
if return_raw_data:
# Convert keywords strings to lists
hl_keywords_list = hl_keywords.split(", ") if hl_keywords else []
ll_keywords_list = ll_keywords.split(", ") if ll_keywords else []
# Convert keywords strings to lists and add complete metadata to raw_data
hl_keywords_list = hl_keywords.split(", ") if hl_keywords else []
ll_keywords_list = ll_keywords.split(", ") if ll_keywords else []
# Get both context and final data - when return_raw_data=True, _build_llm_context always returns tuple
context, raw_data = await _build_llm_context(
entities_context=truncation_result["entities_context"],
relations_context=truncation_result["relations_context"],
merged_chunks=merged_chunks,
query=query,
query_param=query_param,
global_config=text_chunks_db.global_config,
chunk_tracking=search_result["chunk_tracking"],
return_raw_data=True,
entity_id_to_original=truncation_result["entity_id_to_original"],
relation_id_to_original=truncation_result["relation_id_to_original"],
)
# Add complete metadata to raw_data (preserve existing metadata including query_mode)
if "metadata" not in raw_data:
raw_data["metadata"] = {}
# Convert keywords strings to lists and add complete metadata to raw_data
hl_keywords_list = hl_keywords.split(", ") if hl_keywords else []
ll_keywords_list = ll_keywords.split(", ") if ll_keywords else []
# Update keywords while preserving existing metadata
raw_data["metadata"]["keywords"] = {
"high_level": hl_keywords_list,
"low_level": ll_keywords_list,
}
raw_data["metadata"]["processing_info"] = {
"total_entities_found": len(search_result.get("final_entities", [])),
"total_relations_found": len(search_result.get("final_relations", [])),
"entities_after_truncation": len(
truncation_result.get("filtered_entities", [])
),
"relations_after_truncation": len(
truncation_result.get("filtered_relations", [])
),
"merged_chunks_count": len(merged_chunks),
"final_chunks_count": len(raw_data.get("data", {}).get("chunks", [])),
}
# Add complete metadata to raw_data (preserve existing metadata including query_mode)
if "metadata" not in raw_data:
raw_data["metadata"] = {}
# Update keywords while preserving existing metadata
raw_data["metadata"]["keywords"] = {
"high_level": hl_keywords_list,
"low_level": ll_keywords_list,
}
raw_data["metadata"]["processing_info"] = {
"total_entities_found": len(search_result.get("final_entities", [])),
"total_relations_found": len(search_result.get("final_relations", [])),
"entities_after_truncation": len(
truncation_result.get("filtered_entities", [])
),
"relations_after_truncation": len(
truncation_result.get("filtered_relations", [])
),
"merged_chunks_count": len(merged_chunks),
"final_chunks_count": len(raw_data.get("chunks", [])),
}
logger.debug(
f"[_build_query_context] Context length: {len(context) if context else 0}"
)
logger.debug(
f"[_build_query_context] Raw data entities: {len(raw_data.get('entities', []))}, relationships: {len(raw_data.get('relationships', []))}, chunks: {len(raw_data.get('chunks', []))}"
)
return context, raw_data
else:
# Normal context building (existing logic)
context = await _build_llm_context(
entities_context=truncation_result["entities_context"],
relations_context=truncation_result["relations_context"],
merged_chunks=merged_chunks,
query=query,
query_param=query_param,
global_config=text_chunks_db.global_config,
chunk_tracking=search_result["chunk_tracking"],
)
return context
logger.debug(
f"[_build_query_context] Context length: {len(context) if context else 0}"
)
logger.debug(
f"[_build_query_context] Raw data entities: {len(raw_data.get('data', {}).get('entities', []))}, relationships: {len(raw_data.get('data', {}).get('relationships', []))}, chunks: {len(raw_data.get('data', {}).get('chunks', []))}"
)
return QueryContextResult(
context=context,
raw_data=raw_data
)
async def _get_node_data(
@ -4105,19 +4087,28 @@ async def naive_query(
global_config: dict[str, str],
hashing_kv: BaseKVStorage | None = None,
system_prompt: str | None = None,
return_raw_data: bool = False,
) -> str | AsyncIterator[str] | dict[str, Any]:
) -> QueryResult:
"""
Execute naive query and return unified QueryResult object.
Args:
query: Query string
chunks_vdb: Document chunks vector database
query_param: Query parameters
global_config: Global configuration
hashing_kv: Cache storage
system_prompt: System prompt
Returns:
QueryResult: Unified query result object containing:
- content: Non-streaming response text content
- response_iterator: Streaming response iterator
- raw_data: Complete structured data (including references and metadata)
- is_streaming: Whether this is a streaming result
"""
if not query:
if return_raw_data:
# Return empty raw data structure when query is empty
empty_raw_data = {
"status": "failure",
"message": "Query string is empty.",
"data": {},
}
return empty_raw_data
else:
return PROMPTS["fail_response"]
return QueryResult(content=PROMPTS["fail_response"])
if query_param.model_func:
use_model_func = query_param.model_func
@ -4147,41 +4138,31 @@ async def naive_query(
if cached_result is not None:
cached_response, _ = cached_result # Extract content, ignore timestamp
if not query_param.only_need_context and not query_param.only_need_prompt:
return cached_response
return QueryResult(content=cached_response)
tokenizer: Tokenizer = global_config["tokenizer"]
if not tokenizer:
if return_raw_data:
# Return empty raw data structure when tokenizer is missing
empty_raw_data = {
"status": "failure",
"message": "Tokenizer not found in global configuration.",
"data": {},
}
return empty_raw_data
else:
logger.error("Tokenizer not found in global configuration.")
return PROMPTS["fail_response"]
logger.error("Tokenizer not found in global configuration.")
return QueryResult(content=PROMPTS["fail_response"])
chunks = await _get_vector_context(query, chunks_vdb, query_param, None)
if chunks is None or len(chunks) == 0:
# If only raw data is requested, return it directly
if return_raw_data:
empty_raw_data = convert_to_user_format(
[], # naive mode has no entities
[], # naive mode has no relationships
[], # no chunks
[], # no references
"naive",
)
empty_raw_data["message"] = "No relevant document chunks found."
return empty_raw_data
else:
return PROMPTS["fail_response"]
# Build empty raw data structure for naive mode
empty_raw_data = convert_to_user_format(
[], # naive mode has no entities
[], # naive mode has no relationships
[], # no chunks
[], # no references
"naive",
)
empty_raw_data["message"] = "No relevant document chunks found."
return QueryResult(
content=PROMPTS["fail_response"],
raw_data=empty_raw_data
)
# Calculate dynamic token limit for chunks
# Get token limits from query_param (with fallback to global_config)
max_total_tokens = getattr(
query_param,
"max_total_tokens",
@ -4240,30 +4221,26 @@ async def naive_query(
logger.info(f"Final context: {len(processed_chunks_with_ref_ids)} chunks")
# If only raw data is requested, return it directly
if return_raw_data:
# Build raw data structure for naive mode using processed chunks with reference IDs
raw_data = convert_to_user_format(
[], # naive mode has no entities
[], # naive mode has no relationships
processed_chunks_with_ref_ids,
reference_list,
"naive",
)
# Build raw data structure for naive mode using processed chunks with reference IDs
raw_data = convert_to_user_format(
[], # naive mode has no entities
[], # naive mode has no relationships
processed_chunks_with_ref_ids,
reference_list,
"naive",
)
# Add complete metadata for naive mode
if "metadata" not in raw_data:
raw_data["metadata"] = {}
raw_data["metadata"]["keywords"] = {
"high_level": [], # naive mode has no keyword extraction
"low_level": [], # naive mode has no keyword extraction
}
raw_data["metadata"]["processing_info"] = {
"total_chunks_found": len(chunks),
"final_chunks_count": len(processed_chunks_with_ref_ids),
}
return raw_data
# Add complete metadata for naive mode
if "metadata" not in raw_data:
raw_data["metadata"] = {}
raw_data["metadata"]["keywords"] = {
"high_level": [], # naive mode has no keyword extraction
"low_level": [], # naive mode has no keyword extraction
}
raw_data["metadata"]["processing_info"] = {
"total_chunks_found": len(chunks),
"final_chunks_count": len(processed_chunks_with_ref_ids),
}
# Build text_units_context from processed chunks with reference IDs
text_units_context = []
@ -4284,8 +4261,7 @@ async def naive_query(
if ref["reference_id"]
)
if query_param.only_need_context and not query_param.only_need_prompt:
return f"""
context_content = f"""
---Document Chunks(DC)---
```json
@ -4297,6 +4273,13 @@ async def naive_query(
{reference_list_str}
"""
if query_param.only_need_context and not query_param.only_need_prompt:
return QueryResult(
content=context_content,
raw_data=raw_data
)
user_query = (
"\n\n".join([query, query_param.user_prompt])
if query_param.user_prompt
@ -4310,7 +4293,11 @@ async def naive_query(
)
if query_param.only_need_prompt:
return "\n\n".join([sys_prompt, "---User Query---", user_query])
prompt_content = "\n\n".join([sys_prompt, "---User Query---", user_query])
return QueryResult(
content=prompt_content,
raw_data=raw_data
)
len_of_prompts = len(tokenizer.encode(query + sys_prompt))
logger.debug(
@ -4325,43 +4312,56 @@ async def naive_query(
stream=query_param.stream,
)
if isinstance(response, str) and len(response) > len(sys_prompt):
response = (
response[len(sys_prompt) :]
.replace(sys_prompt, "")
.replace("user", "")
.replace("model", "")
.replace(query, "")
.replace("<system>", "")
.replace("</system>", "")
.strip()
)
# Return unified result based on actual response type
if isinstance(response, str):
# Non-streaming response (string)
if len(response) > len(sys_prompt):
response = (
response[len(sys_prompt) :]
.replace(sys_prompt, "")
.replace("user", "")
.replace("model", "")
.replace(query, "")
.replace("<system>", "")
.replace("</system>", "")
.strip()
)
if hashing_kv.global_config.get("enable_llm_cache"):
# Save to cache with query parameters
queryparam_dict = {
"mode": query_param.mode,
"response_type": query_param.response_type,
"top_k": query_param.top_k,
"chunk_top_k": query_param.chunk_top_k,
"max_entity_tokens": query_param.max_entity_tokens,
"max_relation_tokens": query_param.max_relation_tokens,
"max_total_tokens": query_param.max_total_tokens,
"hl_keywords": query_param.hl_keywords or [],
"ll_keywords": query_param.ll_keywords or [],
"user_prompt": query_param.user_prompt or "",
"enable_rerank": query_param.enable_rerank,
}
await save_to_cache(
hashing_kv,
CacheData(
args_hash=args_hash,
content=response,
prompt=query,
mode=query_param.mode,
cache_type="query",
queryparam=queryparam_dict,
),
)
# Cache response
if hashing_kv and hashing_kv.global_config.get("enable_llm_cache"):
queryparam_dict = {
"mode": query_param.mode,
"response_type": query_param.response_type,
"top_k": query_param.top_k,
"chunk_top_k": query_param.chunk_top_k,
"max_entity_tokens": query_param.max_entity_tokens,
"max_relation_tokens": query_param.max_relation_tokens,
"max_total_tokens": query_param.max_total_tokens,
"hl_keywords": query_param.hl_keywords or [],
"ll_keywords": query_param.ll_keywords or [],
"user_prompt": query_param.user_prompt or "",
"enable_rerank": query_param.enable_rerank,
}
await save_to_cache(
hashing_kv,
CacheData(
args_hash=args_hash,
content=response,
prompt=query,
mode=query_param.mode,
cache_type="query",
queryparam=queryparam_dict,
),
)
return response
return QueryResult(
content=response,
raw_data=raw_data
)
else:
# Streaming response (AsyncIterator)
return QueryResult(
response_iterator=response,
raw_data=raw_data,
is_streaming=True
)