Add reference list support to query API endpoints with unified result handling

• Add include_references param to QueryRequest
• Extend QueryResponse with references field
• Create unified QueryResult data structures
• Refactor kg_query and naive_query functions
• Update streaming to send references first
This commit is contained in:
yangdx 2025-09-25 16:21:42 +08:00
parent 64c38864e5
commit b08b8a6a6a
4 changed files with 504 additions and 349 deletions

View file

@ -88,6 +88,11 @@ class QueryRequest(BaseModel):
description="Enable reranking for retrieved text chunks. If True but no rerank model is configured, a warning will be issued. Default is True.", description="Enable reranking for retrieved text chunks. If True but no rerank model is configured, a warning will be issued. Default is True.",
) )
include_references: Optional[bool] = Field(
default=True,
description="If True, includes reference list in responses. Affects /query and /query/stream endpoints. /query/data always includes references.",
)
@field_validator("query", mode="after") @field_validator("query", mode="after")
@classmethod @classmethod
def query_strip_after(cls, query: str) -> str: def query_strip_after(cls, query: str) -> str:
@ -122,6 +127,10 @@ class QueryResponse(BaseModel):
response: str = Field( response: str = Field(
description="The generated response", description="The generated response",
) )
references: Optional[List[Dict[str, str]]] = Field(
default=None,
description="Reference list (only included when include_references=True, /query/data always includes references.)",
)
class QueryDataResponse(BaseModel): class QueryDataResponse(BaseModel):
@ -149,6 +158,7 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
request (QueryRequest): The request object containing the query parameters. request (QueryRequest): The request object containing the query parameters.
Returns: Returns:
QueryResponse: A Pydantic model containing the result of the query processing. QueryResponse: A Pydantic model containing the result of the query processing.
If include_references=True, also includes reference list.
If a string is returned (e.g., cache hit), it's directly returned. If a string is returned (e.g., cache hit), it's directly returned.
Otherwise, an async generator may be used to build the response. Otherwise, an async generator may be used to build the response.
@ -160,15 +170,26 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
param = request.to_query_params(False) param = request.to_query_params(False)
response = await rag.aquery(request.query, param=param) response = await rag.aquery(request.query, param=param)
# If response is a string (e.g. cache hit), return directly # Get reference list if requested
if isinstance(response, str): reference_list = None
return QueryResponse(response=response) if request.include_references:
try:
# Use aquery_data to get reference list independently
data_result = await rag.aquery_data(request.query, param=param)
if isinstance(data_result, dict) and "data" in data_result:
reference_list = data_result["data"].get("references", [])
except Exception as e:
logging.warning(f"Failed to get reference list: {str(e)}")
reference_list = []
if isinstance(response, dict): # Process response and return with optional references
if isinstance(response, str):
return QueryResponse(response=response, references=reference_list)
elif isinstance(response, dict):
result = json.dumps(response, indent=2) result = json.dumps(response, indent=2)
return QueryResponse(response=result) return QueryResponse(response=result, references=reference_list)
else: else:
return QueryResponse(response=str(response)) return QueryResponse(response=str(response), references=reference_list)
except Exception as e: except Exception as e:
trace_exception(e) trace_exception(e)
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@ -177,13 +198,19 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
async def query_text_stream(request: QueryRequest): async def query_text_stream(request: QueryRequest):
""" """
This endpoint performs a retrieval-augmented generation (RAG) query and streams the response. This endpoint performs a retrieval-augmented generation (RAG) query and streams the response.
The streaming response includes:
1. Reference list (sent first as a single message, if include_references=True)
2. LLM response content (streamed as multiple chunks)
Args: Args:
request (QueryRequest): The request object containing the query parameters. request (QueryRequest): The request object containing the query parameters.
optional_api_key (Optional[str], optional): An optional API key for authentication. Defaults to None.
Returns: Returns:
StreamingResponse: A streaming response containing the RAG query results. StreamingResponse: A streaming response containing:
- First message: {"references": [...]} - Complete reference list (if requested)
- Subsequent messages: {"response": "..."} - LLM response chunks
- Error messages: {"error": "..."} - If any errors occur
""" """
try: try:
param = request.to_query_params(True) param = request.to_query_params(True)
@ -192,6 +219,24 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
async def stream_generator(): async def stream_generator():
# Get reference list if requested (default is True for backward compatibility)
reference_list = []
if request.include_references:
try:
# Use aquery_data to get reference list independently
data_param = request.to_query_params(False) # Non-streaming for data
data_result = await rag.aquery_data(request.query, param=data_param)
if isinstance(data_result, dict) and "data" in data_result:
reference_list = data_result["data"].get("references", [])
except Exception as e:
logging.warning(f"Failed to get reference list: {str(e)}")
reference_list = []
# Send reference list first (if requested)
if request.include_references:
yield f"{json.dumps({'references': reference_list})}\n"
# Then stream the response content
if isinstance(response, str): if isinstance(response, str):
# If it's a string, send it all at once # If it's a string, send it all at once
yield f"{json.dumps({'response': response})}\n" yield f"{json.dumps({'response': response})}\n"

View file

@ -11,6 +11,10 @@ from typing import (
TypedDict, TypedDict,
TypeVar, TypeVar,
Callable, Callable,
Optional,
Dict,
List,
AsyncIterator,
) )
from .utils import EmbeddingFunc from .utils import EmbeddingFunc
from .types import KnowledgeGraph from .types import KnowledgeGraph
@ -158,6 +162,12 @@ class QueryParam:
Default is True to enable reranking when rerank model is available. Default is True to enable reranking when rerank model is available.
""" """
include_references: bool = False
"""If True, includes reference list in the response for supported endpoints.
This parameter controls whether the API response includes a references field
containing citation information for the retrieved content.
"""
@dataclass @dataclass
class StorageNameSpace(ABC): class StorageNameSpace(ABC):
@ -814,3 +824,65 @@ class DeletionResult:
message: str message: str
status_code: int = 200 status_code: int = 200
file_path: str | None = None file_path: str | None = None
# Unified Query Result Data Structures for Reference List Support
@dataclass
class QueryResult:
"""
Unified query result data structure for all query modes.
Attributes:
content: Text content for non-streaming responses
response_iterator: Streaming response iterator for streaming responses
raw_data: Complete structured data including references and metadata
is_streaming: Whether this is a streaming result
"""
content: Optional[str] = None
response_iterator: Optional[AsyncIterator[str]] = None
raw_data: Optional[Dict[str, Any]] = None
is_streaming: bool = False
@property
def reference_list(self) -> List[Dict[str, str]]:
"""
Convenient property to extract reference list from raw_data.
Returns:
List[Dict[str, str]]: Reference list in format:
[{"reference_id": "1", "file_path": "/path/to/file.pdf"}, ...]
"""
if self.raw_data:
return self.raw_data.get("data", {}).get("references", [])
return []
@property
def metadata(self) -> Dict[str, Any]:
"""
Convenient property to extract metadata from raw_data.
Returns:
Dict[str, Any]: Query metadata including query_mode, keywords, etc.
"""
if self.raw_data:
return self.raw_data.get("metadata", {})
return {}
@dataclass
class QueryContextResult:
"""
Unified query context result data structure.
Attributes:
context: LLM context string
raw_data: Complete structured data including reference_list
"""
context: str
raw_data: Dict[str, Any]
@property
def reference_list(self) -> List[Dict[str, str]]:
"""Convenient property to extract reference list from raw_data."""
return self.raw_data.get("data", {}).get("references", [])

View file

@ -71,6 +71,7 @@ from lightrag.base import (
StoragesStatus, StoragesStatus,
DeletionResult, DeletionResult,
OllamaServerInfos, OllamaServerInfos,
QueryResult,
) )
from lightrag.namespace import NameSpace from lightrag.namespace import NameSpace
from lightrag.operate import ( from lightrag.operate import (
@ -2075,8 +2076,10 @@ class LightRAG:
# If a custom model is provided in param, temporarily update global config # If a custom model is provided in param, temporarily update global config
global_config = asdict(self) global_config = asdict(self)
query_result = None
if param.mode in ["local", "global", "hybrid", "mix"]: if param.mode in ["local", "global", "hybrid", "mix"]:
response = await kg_query( query_result = await kg_query(
query.strip(), query.strip(),
self.chunk_entity_relation_graph, self.chunk_entity_relation_graph,
self.entities_vdb, self.entities_vdb,
@ -2089,7 +2092,7 @@ class LightRAG:
chunks_vdb=self.chunks_vdb, chunks_vdb=self.chunks_vdb,
) )
elif param.mode == "naive": elif param.mode == "naive":
response = await naive_query( query_result = await naive_query(
query.strip(), query.strip(),
self.chunks_vdb, self.chunks_vdb,
param, param,
@ -2111,10 +2114,22 @@ class LightRAG:
enable_cot=True, enable_cot=True,
stream=param.stream, stream=param.stream,
) )
# Create QueryResult for bypass mode
query_result = QueryResult(
content=response if not param.stream else None,
response_iterator=response if param.stream else None,
is_streaming=param.stream
)
else: else:
raise ValueError(f"Unknown mode {param.mode}") raise ValueError(f"Unknown mode {param.mode}")
await self._query_done() await self._query_done()
return response
# Return appropriate response based on streaming mode
if query_result.is_streaming:
return query_result.response_iterator
else:
return query_result.content
async def aquery_data( async def aquery_data(
self, self,
@ -2229,61 +2244,84 @@ class LightRAG:
""" """
global_config = asdict(self) global_config = asdict(self)
if param.mode in ["local", "global", "hybrid", "mix"]: # Create a copy of param to avoid modifying the original
logger.debug(f"[aquery_data] Using kg_query for mode: {param.mode}") data_param = QueryParam(
final_data = await kg_query( mode=param.mode,
only_need_context=True, # Skip LLM generation, only get context and data
only_need_prompt=False,
response_type=param.response_type,
stream=False, # Data retrieval doesn't need streaming
top_k=param.top_k,
chunk_top_k=param.chunk_top_k,
max_entity_tokens=param.max_entity_tokens,
max_relation_tokens=param.max_relation_tokens,
max_total_tokens=param.max_total_tokens,
hl_keywords=param.hl_keywords,
ll_keywords=param.ll_keywords,
conversation_history=param.conversation_history,
history_turns=param.history_turns,
model_func=param.model_func,
user_prompt=param.user_prompt,
enable_rerank=param.enable_rerank,
)
query_result = None
if data_param.mode in ["local", "global", "hybrid", "mix"]:
logger.debug(f"[aquery_data] Using kg_query for mode: {data_param.mode}")
query_result = await kg_query(
query.strip(), query.strip(),
self.chunk_entity_relation_graph, self.chunk_entity_relation_graph,
self.entities_vdb, self.entities_vdb,
self.relationships_vdb, self.relationships_vdb,
self.text_chunks, self.text_chunks,
param, data_param, # Use data_param with only_need_context=True
global_config, global_config,
hashing_kv=self.llm_response_cache, hashing_kv=self.llm_response_cache,
system_prompt=None, system_prompt=None,
chunks_vdb=self.chunks_vdb, chunks_vdb=self.chunks_vdb,
return_raw_data=True, # Get final processed data
) )
elif param.mode == "naive": elif data_param.mode == "naive":
logger.debug(f"[aquery_data] Using naive_query for mode: {param.mode}") logger.debug(f"[aquery_data] Using naive_query for mode: {data_param.mode}")
final_data = await naive_query( query_result = await naive_query(
query.strip(), query.strip(),
self.chunks_vdb, self.chunks_vdb,
param, data_param, # Use data_param with only_need_context=True
global_config, global_config,
hashing_kv=self.llm_response_cache, hashing_kv=self.llm_response_cache,
system_prompt=None, system_prompt=None,
return_raw_data=True, # Get final processed data
) )
elif param.mode == "bypass": elif data_param.mode == "bypass":
logger.debug("[aquery_data] Using bypass mode") logger.debug("[aquery_data] Using bypass mode")
# bypass mode returns empty data using convert_to_user_format # bypass mode returns empty data using convert_to_user_format
final_data = convert_to_user_format( empty_raw_data = convert_to_user_format(
[], # no entities [], # no entities
[], # no relationships [], # no relationships
[], # no chunks [], # no chunks
[], # no references [], # no references
"bypass", "bypass",
) )
query_result = QueryResult(
content="",
raw_data=empty_raw_data
)
else: else:
raise ValueError(f"Unknown mode {param.mode}") raise ValueError(f"Unknown mode {data_param.mode}")
# Extract raw_data from QueryResult
final_data = query_result.raw_data if query_result else {}
# Log final result counts - adapt to new data format from convert_to_user_format # Log final result counts - adapt to new data format from convert_to_user_format
if isinstance(final_data, dict) and "data" in final_data: if final_data and "data" in final_data:
# New format: data is nested under 'data' field
data_section = final_data["data"] data_section = final_data["data"]
entities_count = len(data_section.get("entities", [])) entities_count = len(data_section.get("entities", []))
relationships_count = len(data_section.get("relationships", [])) relationships_count = len(data_section.get("relationships", []))
chunks_count = len(data_section.get("chunks", [])) chunks_count = len(data_section.get("chunks", []))
logger.debug(
f"[aquery_data] Final result: {entities_count} entities, {relationships_count} relationships, {chunks_count} chunks"
)
else: else:
# Fallback for other formats logger.warning("[aquery_data] No data section found in query result")
entities_count = len(final_data.get("entities", []))
relationships_count = len(final_data.get("relationships", []))
chunks_count = len(final_data.get("chunks", []))
logger.debug(
f"[aquery_data] Final result: {entities_count} entities, {relationships_count} relationships, {chunks_count} chunks"
)
await self._query_done() await self._query_done()
return final_data return final_data

View file

@ -39,6 +39,8 @@ from .base import (
BaseVectorStorage, BaseVectorStorage,
TextChunkSchema, TextChunkSchema,
QueryParam, QueryParam,
QueryResult,
QueryContextResult,
) )
from .prompt import PROMPTS from .prompt import PROMPTS
from .constants import ( from .constants import (
@ -2277,16 +2279,38 @@ async def kg_query(
hashing_kv: BaseKVStorage | None = None, hashing_kv: BaseKVStorage | None = None,
system_prompt: str | None = None, system_prompt: str | None = None,
chunks_vdb: BaseVectorStorage = None, chunks_vdb: BaseVectorStorage = None,
return_raw_data: bool = False, ) -> QueryResult:
) -> str | AsyncIterator[str] | dict[str, Any]: """
Execute knowledge graph query and return unified QueryResult object.
Args:
query: Query string
knowledge_graph_inst: Knowledge graph storage instance
entities_vdb: Entity vector database
relationships_vdb: Relationship vector database
text_chunks_db: Text chunks storage
query_param: Query parameters
global_config: Global configuration
hashing_kv: Cache storage
system_prompt: System prompt
chunks_vdb: Document chunks vector database
Returns:
QueryResult: Unified query result object containing:
- content: Non-streaming response text content
- response_iterator: Streaming response iterator
- raw_data: Complete structured data (including references and metadata)
- is_streaming: Whether this is a streaming result
Based on different query_param settings, different fields will be populated:
- only_need_context=True: content contains context string
- only_need_prompt=True: content contains complete prompt
- stream=True: response_iterator contains streaming response, raw_data contains complete data
- default: content contains LLM response text, raw_data contains complete data
"""
if not query: if not query:
if return_raw_data: return QueryResult(content=PROMPTS["fail_response"])
return {
"status": "failure",
"message": "Query string is empty.",
"data": {},
}
return PROMPTS["fail_response"]
if query_param.model_func: if query_param.model_func:
use_model_func = query_param.model_func use_model_func = query_param.model_func
@ -2315,12 +2339,11 @@ async def kg_query(
) )
if ( if (
cached_result is not None cached_result is not None
and not return_raw_data
and not query_param.only_need_context and not query_param.only_need_context
and not query_param.only_need_prompt and not query_param.only_need_prompt
): ):
cached_response, _ = cached_result # Extract content, ignore timestamp cached_response, _ = cached_result # Extract content, ignore timestamp
return cached_response return QueryResult(content=cached_response)
hl_keywords, ll_keywords = await get_keywords_from_query( hl_keywords, ll_keywords = await get_keywords_from_query(
query, query_param, global_config, hashing_kv query, query_param, global_config, hashing_kv
@ -2339,53 +2362,13 @@ async def kg_query(
logger.warning(f"Forced low_level_keywords to origin query: {query}") logger.warning(f"Forced low_level_keywords to origin query: {query}")
ll_keywords = [query] ll_keywords = [query]
else: else:
if return_raw_data: return QueryResult(content=PROMPTS["fail_response"])
return {
"status": "failure",
"message": "Both high_level_keywords and low_level_keywords are empty",
"data": {},
}
return PROMPTS["fail_response"]
ll_keywords_str = ", ".join(ll_keywords) if ll_keywords else "" ll_keywords_str = ", ".join(ll_keywords) if ll_keywords else ""
hl_keywords_str = ", ".join(hl_keywords) if hl_keywords else "" hl_keywords_str = ", ".join(hl_keywords) if hl_keywords else ""
# If raw data is requested, get both context and raw data # Build query context (unified interface)
if return_raw_data: context_result = await _build_query_context(
context_result = await _build_query_context(
query,
ll_keywords_str,
hl_keywords_str,
knowledge_graph_inst,
entities_vdb,
relationships_vdb,
text_chunks_db,
query_param,
chunks_vdb,
return_raw_data=True,
)
if isinstance(context_result, tuple):
context, raw_data = context_result
logger.debug(f"[kg_query] Context length: {len(context) if context else 0}")
logger.debug(
f"[kg_query] Raw data entities: {len(raw_data.get('entities', []))}, relationships: {len(raw_data.get('relationships', []))}, chunks: {len(raw_data.get('chunks', []))}"
)
return raw_data
else:
if not context_result:
return {
"status": "failure",
"message": "Query return empty data set.",
"data": {},
}
else:
raise ValueError(
"Fail to build raw data query result. Invalid return from _build_query_context"
)
# Build context (normal flow)
context = await _build_query_context(
query, query,
ll_keywords_str, ll_keywords_str,
hl_keywords_str, hl_keywords_str,
@ -2397,14 +2380,20 @@ async def kg_query(
chunks_vdb, chunks_vdb,
) )
if query_param.only_need_context and not query_param.only_need_prompt: if context_result is None:
return context if context is not None else PROMPTS["fail_response"] return QueryResult(content=PROMPTS["fail_response"])
if context is None:
return PROMPTS["fail_response"]
# Return different content based on query parameters
if query_param.only_need_context and not query_param.only_need_prompt:
return QueryResult(
content=context_result.context,
raw_data=context_result.raw_data
)
# Build system prompt
sys_prompt_temp = system_prompt if system_prompt else PROMPTS["rag_response"] sys_prompt_temp = system_prompt if system_prompt else PROMPTS["rag_response"]
sys_prompt = sys_prompt_temp.format( sys_prompt = sys_prompt_temp.format(
context_data=context, context_data=context_result.context,
response_type=query_param.response_type, response_type=query_param.response_type,
) )
@ -2415,8 +2404,13 @@ async def kg_query(
) )
if query_param.only_need_prompt: if query_param.only_need_prompt:
return "\n\n".join([sys_prompt, "---User Query---", user_query]) prompt_content = "\n\n".join([sys_prompt, "---User Query---", user_query])
return QueryResult(
content=prompt_content,
raw_data=context_result.raw_data
)
# Call LLM
tokenizer: Tokenizer = global_config["tokenizer"] tokenizer: Tokenizer = global_config["tokenizer"]
len_of_prompts = len(tokenizer.encode(query + sys_prompt)) len_of_prompts = len(tokenizer.encode(query + sys_prompt))
logger.debug( logger.debug(
@ -2430,45 +2424,59 @@ async def kg_query(
enable_cot=True, enable_cot=True,
stream=query_param.stream, stream=query_param.stream,
) )
if isinstance(response, str) and len(response) > len(sys_prompt):
response = (
response.replace(sys_prompt, "")
.replace("user", "")
.replace("model", "")
.replace(query, "")
.replace("<system>", "")
.replace("</system>", "")
.strip()
)
if hashing_kv.global_config.get("enable_llm_cache"): # Return unified result based on actual response type
# Save to cache with query parameters if isinstance(response, str):
queryparam_dict = { # Non-streaming response (string)
"mode": query_param.mode, if len(response) > len(sys_prompt):
"response_type": query_param.response_type, response = (
"top_k": query_param.top_k, response.replace(sys_prompt, "")
"chunk_top_k": query_param.chunk_top_k, .replace("user", "")
"max_entity_tokens": query_param.max_entity_tokens, .replace("model", "")
"max_relation_tokens": query_param.max_relation_tokens, .replace(query, "")
"max_total_tokens": query_param.max_total_tokens, .replace("<system>", "")
"hl_keywords": query_param.hl_keywords or [], .replace("</system>", "")
"ll_keywords": query_param.ll_keywords or [], .strip()
"user_prompt": query_param.user_prompt or "", )
"enable_rerank": query_param.enable_rerank,
}
await save_to_cache(
hashing_kv,
CacheData(
args_hash=args_hash,
content=response,
prompt=query,
mode=query_param.mode,
cache_type="query",
queryparam=queryparam_dict,
),
)
return response # Cache response
if hashing_kv and hashing_kv.global_config.get("enable_llm_cache"):
queryparam_dict = {
"mode": query_param.mode,
"response_type": query_param.response_type,
"top_k": query_param.top_k,
"chunk_top_k": query_param.chunk_top_k,
"max_entity_tokens": query_param.max_entity_tokens,
"max_relation_tokens": query_param.max_relation_tokens,
"max_total_tokens": query_param.max_total_tokens,
"hl_keywords": query_param.hl_keywords or [],
"ll_keywords": query_param.ll_keywords or [],
"user_prompt": query_param.user_prompt or "",
"enable_rerank": query_param.enable_rerank,
}
await save_to_cache(
hashing_kv,
CacheData(
args_hash=args_hash,
content=response,
prompt=query,
mode=query_param.mode,
cache_type="query",
queryparam=queryparam_dict,
),
)
return QueryResult(
content=response,
raw_data=context_result.raw_data
)
else:
# Streaming response (AsyncIterator)
return QueryResult(
response_iterator=response,
raw_data=context_result.raw_data,
is_streaming=True
)
async def get_keywords_from_query( async def get_keywords_from_query(
@ -3123,10 +3131,9 @@ async def _build_llm_context(
query_param: QueryParam, query_param: QueryParam,
global_config: dict[str, str], global_config: dict[str, str],
chunk_tracking: dict = None, chunk_tracking: dict = None,
return_raw_data: bool = False,
entity_id_to_original: dict = None, entity_id_to_original: dict = None,
relation_id_to_original: dict = None, relation_id_to_original: dict = None,
) -> str | tuple[str, dict[str, Any]]: ) -> tuple[str, dict[str, Any]]:
""" """
Build the final LLM context string with token processing. Build the final LLM context string with token processing.
This includes dynamic token calculation and final chunk truncation. This includes dynamic token calculation and final chunk truncation.
@ -3134,22 +3141,17 @@ async def _build_llm_context(
tokenizer = global_config.get("tokenizer") tokenizer = global_config.get("tokenizer")
if not tokenizer: if not tokenizer:
logger.error("Missing tokenizer, cannot build LLM context") logger.error("Missing tokenizer, cannot build LLM context")
# Return empty raw data structure when no tokenizer
if return_raw_data: empty_raw_data = convert_to_user_format(
# Return empty raw data structure when no entities/relations [],
empty_raw_data = convert_to_user_format( [],
[], [],
[], [],
[], query_param.mode,
[], )
query_param.mode, empty_raw_data["status"] = "failure"
) empty_raw_data["message"] = "Missing tokenizer, cannot build LLM context."
empty_raw_data["status"] = "failure" return "", empty_raw_data
empty_raw_data["message"] = "Missing tokenizer, cannot build LLM context."
return None, empty_raw_data
else:
logger.error("Tokenizer not found in global configuration.")
return None
# Get token limits # Get token limits
max_total_tokens = getattr( max_total_tokens = getattr(
@ -3268,20 +3270,17 @@ The reference documents list in Document Chunks(DC) is as follows (reference_id
# not necessary to use LLM to generate a response # not necessary to use LLM to generate a response
if not entities_context and not relations_context: if not entities_context and not relations_context:
if return_raw_data: # Return empty raw data structure when no entities/relations
# Return empty raw data structure when no entities/relations empty_raw_data = convert_to_user_format(
empty_raw_data = convert_to_user_format( [],
[], [],
[], [],
[], [],
[], query_param.mode,
query_param.mode, )
) empty_raw_data["status"] = "failure"
empty_raw_data["status"] = "failure" empty_raw_data["message"] = "Query returned empty dataset."
empty_raw_data["message"] = "Query returned empty dataset." return "", empty_raw_data
return None, empty_raw_data
else:
return None
# output chunks tracking infomations # output chunks tracking infomations
# format: <source><frequency>/<order> (e.g., E5/2 R2/1 C1/1) # format: <source><frequency>/<order> (e.g., E5/2 R2/1 C1/1)
@ -3342,26 +3341,23 @@ Document Chunks (DC) reference documents : (Each entry begins with [reference_id
""" """
# If final data is requested, return both context and complete data structure # Always return both context and complete data structure (unified approach)
if return_raw_data: logger.debug(
logger.debug( f"[_build_llm_context] Converting to user format: {len(entities_context)} entities, {len(relations_context)} relations, {len(truncated_chunks)} chunks"
f"[_build_llm_context] Converting to user format: {len(entities_context)} entities, {len(relations_context)} relations, {len(truncated_chunks)} chunks" )
) final_data = convert_to_user_format(
final_data = convert_to_user_format( entities_context,
entities_context, relations_context,
relations_context, truncated_chunks,
truncated_chunks, reference_list,
reference_list, query_param.mode,
query_param.mode, entity_id_to_original,
entity_id_to_original, relation_id_to_original,
relation_id_to_original, )
) logger.debug(
logger.debug( f"[_build_llm_context] Final data after conversion: {len(final_data.get('entities', []))} entities, {len(final_data.get('relationships', []))} relationships, {len(final_data.get('chunks', []))} chunks"
f"[_build_llm_context] Final data after conversion: {len(final_data.get('entities', []))} entities, {len(final_data.get('relationships', []))} relationships, {len(final_data.get('chunks', []))} chunks" )
) return result, final_data
return result, final_data
else:
return result
# Now let's update the old _build_query_context to use the new architecture # Now let's update the old _build_query_context to use the new architecture
@ -3375,16 +3371,17 @@ async def _build_query_context(
text_chunks_db: BaseKVStorage, text_chunks_db: BaseKVStorage,
query_param: QueryParam, query_param: QueryParam,
chunks_vdb: BaseVectorStorage = None, chunks_vdb: BaseVectorStorage = None,
return_raw_data: bool = False, ) -> QueryContextResult | None:
) -> str | None | tuple[str, dict[str, Any]]:
""" """
Main query context building function using the new 4-stage architecture: Main query context building function using the new 4-stage architecture:
1. Search -> 2. Truncate -> 3. Merge chunks -> 4. Build LLM context 1. Search -> 2. Truncate -> 3. Merge chunks -> 4. Build LLM context
Returns unified QueryContextResult containing both context and raw_data.
""" """
if not query: if not query:
logger.warning("Query is empty, skipping context building") logger.warning("Query is empty, skipping context building")
return "" return None
# Stage 1: Pure search # Stage 1: Pure search
search_result = await _perform_kg_search( search_result = await _perform_kg_search(
@ -3435,71 +3432,56 @@ async def _build_query_context(
return None return None
# Stage 4: Build final LLM context with dynamic token processing # Stage 4: Build final LLM context with dynamic token processing
# _build_llm_context now always returns tuple[str, dict]
context, raw_data = await _build_llm_context(
entities_context=truncation_result["entities_context"],
relations_context=truncation_result["relations_context"],
merged_chunks=merged_chunks,
query=query,
query_param=query_param,
global_config=text_chunks_db.global_config,
chunk_tracking=search_result["chunk_tracking"],
entity_id_to_original=truncation_result["entity_id_to_original"],
relation_id_to_original=truncation_result["relation_id_to_original"],
)
if return_raw_data: # Convert keywords strings to lists and add complete metadata to raw_data
# Convert keywords strings to lists hl_keywords_list = hl_keywords.split(", ") if hl_keywords else []
hl_keywords_list = hl_keywords.split(", ") if hl_keywords else [] ll_keywords_list = ll_keywords.split(", ") if ll_keywords else []
ll_keywords_list = ll_keywords.split(", ") if ll_keywords else []
# Get both context and final data - when return_raw_data=True, _build_llm_context always returns tuple # Add complete metadata to raw_data (preserve existing metadata including query_mode)
context, raw_data = await _build_llm_context( if "metadata" not in raw_data:
entities_context=truncation_result["entities_context"], raw_data["metadata"] = {}
relations_context=truncation_result["relations_context"],
merged_chunks=merged_chunks,
query=query,
query_param=query_param,
global_config=text_chunks_db.global_config,
chunk_tracking=search_result["chunk_tracking"],
return_raw_data=True,
entity_id_to_original=truncation_result["entity_id_to_original"],
relation_id_to_original=truncation_result["relation_id_to_original"],
)
# Convert keywords strings to lists and add complete metadata to raw_data # Update keywords while preserving existing metadata
hl_keywords_list = hl_keywords.split(", ") if hl_keywords else [] raw_data["metadata"]["keywords"] = {
ll_keywords_list = ll_keywords.split(", ") if ll_keywords else [] "high_level": hl_keywords_list,
"low_level": ll_keywords_list,
}
raw_data["metadata"]["processing_info"] = {
"total_entities_found": len(search_result.get("final_entities", [])),
"total_relations_found": len(search_result.get("final_relations", [])),
"entities_after_truncation": len(
truncation_result.get("filtered_entities", [])
),
"relations_after_truncation": len(
truncation_result.get("filtered_relations", [])
),
"merged_chunks_count": len(merged_chunks),
"final_chunks_count": len(raw_data.get("data", {}).get("chunks", [])),
}
# Add complete metadata to raw_data (preserve existing metadata including query_mode) logger.debug(
if "metadata" not in raw_data: f"[_build_query_context] Context length: {len(context) if context else 0}"
raw_data["metadata"] = {} )
logger.debug(
# Update keywords while preserving existing metadata f"[_build_query_context] Raw data entities: {len(raw_data.get('data', {}).get('entities', []))}, relationships: {len(raw_data.get('data', {}).get('relationships', []))}, chunks: {len(raw_data.get('data', {}).get('chunks', []))}"
raw_data["metadata"]["keywords"] = { )
"high_level": hl_keywords_list,
"low_level": ll_keywords_list, return QueryContextResult(
} context=context,
raw_data["metadata"]["processing_info"] = { raw_data=raw_data
"total_entities_found": len(search_result.get("final_entities", [])), )
"total_relations_found": len(search_result.get("final_relations", [])),
"entities_after_truncation": len(
truncation_result.get("filtered_entities", [])
),
"relations_after_truncation": len(
truncation_result.get("filtered_relations", [])
),
"merged_chunks_count": len(merged_chunks),
"final_chunks_count": len(raw_data.get("chunks", [])),
}
logger.debug(
f"[_build_query_context] Context length: {len(context) if context else 0}"
)
logger.debug(
f"[_build_query_context] Raw data entities: {len(raw_data.get('entities', []))}, relationships: {len(raw_data.get('relationships', []))}, chunks: {len(raw_data.get('chunks', []))}"
)
return context, raw_data
else:
# Normal context building (existing logic)
context = await _build_llm_context(
entities_context=truncation_result["entities_context"],
relations_context=truncation_result["relations_context"],
merged_chunks=merged_chunks,
query=query,
query_param=query_param,
global_config=text_chunks_db.global_config,
chunk_tracking=search_result["chunk_tracking"],
)
return context
async def _get_node_data( async def _get_node_data(
@ -4105,19 +4087,28 @@ async def naive_query(
global_config: dict[str, str], global_config: dict[str, str],
hashing_kv: BaseKVStorage | None = None, hashing_kv: BaseKVStorage | None = None,
system_prompt: str | None = None, system_prompt: str | None = None,
return_raw_data: bool = False, ) -> QueryResult:
) -> str | AsyncIterator[str] | dict[str, Any]: """
Execute naive query and return unified QueryResult object.
Args:
query: Query string
chunks_vdb: Document chunks vector database
query_param: Query parameters
global_config: Global configuration
hashing_kv: Cache storage
system_prompt: System prompt
Returns:
QueryResult: Unified query result object containing:
- content: Non-streaming response text content
- response_iterator: Streaming response iterator
- raw_data: Complete structured data (including references and metadata)
- is_streaming: Whether this is a streaming result
"""
if not query: if not query:
if return_raw_data: return QueryResult(content=PROMPTS["fail_response"])
# Return empty raw data structure when query is empty
empty_raw_data = {
"status": "failure",
"message": "Query string is empty.",
"data": {},
}
return empty_raw_data
else:
return PROMPTS["fail_response"]
if query_param.model_func: if query_param.model_func:
use_model_func = query_param.model_func use_model_func = query_param.model_func
@ -4147,41 +4138,31 @@ async def naive_query(
if cached_result is not None: if cached_result is not None:
cached_response, _ = cached_result # Extract content, ignore timestamp cached_response, _ = cached_result # Extract content, ignore timestamp
if not query_param.only_need_context and not query_param.only_need_prompt: if not query_param.only_need_context and not query_param.only_need_prompt:
return cached_response return QueryResult(content=cached_response)
tokenizer: Tokenizer = global_config["tokenizer"] tokenizer: Tokenizer = global_config["tokenizer"]
if not tokenizer: if not tokenizer:
if return_raw_data: logger.error("Tokenizer not found in global configuration.")
# Return empty raw data structure when tokenizer is missing return QueryResult(content=PROMPTS["fail_response"])
empty_raw_data = {
"status": "failure",
"message": "Tokenizer not found in global configuration.",
"data": {},
}
return empty_raw_data
else:
logger.error("Tokenizer not found in global configuration.")
return PROMPTS["fail_response"]
chunks = await _get_vector_context(query, chunks_vdb, query_param, None) chunks = await _get_vector_context(query, chunks_vdb, query_param, None)
if chunks is None or len(chunks) == 0: if chunks is None or len(chunks) == 0:
# If only raw data is requested, return it directly # Build empty raw data structure for naive mode
if return_raw_data: empty_raw_data = convert_to_user_format(
empty_raw_data = convert_to_user_format( [], # naive mode has no entities
[], # naive mode has no entities [], # naive mode has no relationships
[], # naive mode has no relationships [], # no chunks
[], # no chunks [], # no references
[], # no references "naive",
"naive", )
) empty_raw_data["message"] = "No relevant document chunks found."
empty_raw_data["message"] = "No relevant document chunks found." return QueryResult(
return empty_raw_data content=PROMPTS["fail_response"],
else: raw_data=empty_raw_data
return PROMPTS["fail_response"] )
# Calculate dynamic token limit for chunks # Calculate dynamic token limit for chunks
# Get token limits from query_param (with fallback to global_config)
max_total_tokens = getattr( max_total_tokens = getattr(
query_param, query_param,
"max_total_tokens", "max_total_tokens",
@ -4240,30 +4221,26 @@ async def naive_query(
logger.info(f"Final context: {len(processed_chunks_with_ref_ids)} chunks") logger.info(f"Final context: {len(processed_chunks_with_ref_ids)} chunks")
# If only raw data is requested, return it directly # Build raw data structure for naive mode using processed chunks with reference IDs
if return_raw_data: raw_data = convert_to_user_format(
# Build raw data structure for naive mode using processed chunks with reference IDs [], # naive mode has no entities
raw_data = convert_to_user_format( [], # naive mode has no relationships
[], # naive mode has no entities processed_chunks_with_ref_ids,
[], # naive mode has no relationships reference_list,
processed_chunks_with_ref_ids, "naive",
reference_list, )
"naive",
)
# Add complete metadata for naive mode # Add complete metadata for naive mode
if "metadata" not in raw_data: if "metadata" not in raw_data:
raw_data["metadata"] = {} raw_data["metadata"] = {}
raw_data["metadata"]["keywords"] = { raw_data["metadata"]["keywords"] = {
"high_level": [], # naive mode has no keyword extraction "high_level": [], # naive mode has no keyword extraction
"low_level": [], # naive mode has no keyword extraction "low_level": [], # naive mode has no keyword extraction
} }
raw_data["metadata"]["processing_info"] = { raw_data["metadata"]["processing_info"] = {
"total_chunks_found": len(chunks), "total_chunks_found": len(chunks),
"final_chunks_count": len(processed_chunks_with_ref_ids), "final_chunks_count": len(processed_chunks_with_ref_ids),
} }
return raw_data
# Build text_units_context from processed chunks with reference IDs # Build text_units_context from processed chunks with reference IDs
text_units_context = [] text_units_context = []
@ -4284,8 +4261,7 @@ async def naive_query(
if ref["reference_id"] if ref["reference_id"]
) )
if query_param.only_need_context and not query_param.only_need_prompt: context_content = f"""
return f"""
---Document Chunks(DC)--- ---Document Chunks(DC)---
```json ```json
@ -4297,6 +4273,13 @@ async def naive_query(
{reference_list_str} {reference_list_str}
""" """
if query_param.only_need_context and not query_param.only_need_prompt:
return QueryResult(
content=context_content,
raw_data=raw_data
)
user_query = ( user_query = (
"\n\n".join([query, query_param.user_prompt]) "\n\n".join([query, query_param.user_prompt])
if query_param.user_prompt if query_param.user_prompt
@ -4310,7 +4293,11 @@ async def naive_query(
) )
if query_param.only_need_prompt: if query_param.only_need_prompt:
return "\n\n".join([sys_prompt, "---User Query---", user_query]) prompt_content = "\n\n".join([sys_prompt, "---User Query---", user_query])
return QueryResult(
content=prompt_content,
raw_data=raw_data
)
len_of_prompts = len(tokenizer.encode(query + sys_prompt)) len_of_prompts = len(tokenizer.encode(query + sys_prompt))
logger.debug( logger.debug(
@ -4325,43 +4312,56 @@ async def naive_query(
stream=query_param.stream, stream=query_param.stream,
) )
if isinstance(response, str) and len(response) > len(sys_prompt): # Return unified result based on actual response type
response = ( if isinstance(response, str):
response[len(sys_prompt) :] # Non-streaming response (string)
.replace(sys_prompt, "") if len(response) > len(sys_prompt):
.replace("user", "") response = (
.replace("model", "") response[len(sys_prompt) :]
.replace(query, "") .replace(sys_prompt, "")
.replace("<system>", "") .replace("user", "")
.replace("</system>", "") .replace("model", "")
.strip() .replace(query, "")
) .replace("<system>", "")
.replace("</system>", "")
.strip()
)
if hashing_kv.global_config.get("enable_llm_cache"): # Cache response
# Save to cache with query parameters if hashing_kv and hashing_kv.global_config.get("enable_llm_cache"):
queryparam_dict = { queryparam_dict = {
"mode": query_param.mode, "mode": query_param.mode,
"response_type": query_param.response_type, "response_type": query_param.response_type,
"top_k": query_param.top_k, "top_k": query_param.top_k,
"chunk_top_k": query_param.chunk_top_k, "chunk_top_k": query_param.chunk_top_k,
"max_entity_tokens": query_param.max_entity_tokens, "max_entity_tokens": query_param.max_entity_tokens,
"max_relation_tokens": query_param.max_relation_tokens, "max_relation_tokens": query_param.max_relation_tokens,
"max_total_tokens": query_param.max_total_tokens, "max_total_tokens": query_param.max_total_tokens,
"hl_keywords": query_param.hl_keywords or [], "hl_keywords": query_param.hl_keywords or [],
"ll_keywords": query_param.ll_keywords or [], "ll_keywords": query_param.ll_keywords or [],
"user_prompt": query_param.user_prompt or "", "user_prompt": query_param.user_prompt or "",
"enable_rerank": query_param.enable_rerank, "enable_rerank": query_param.enable_rerank,
} }
await save_to_cache( await save_to_cache(
hashing_kv, hashing_kv,
CacheData( CacheData(
args_hash=args_hash, args_hash=args_hash,
content=response, content=response,
prompt=query, prompt=query,
mode=query_param.mode, mode=query_param.mode,
cache_type="query", cache_type="query",
queryparam=queryparam_dict, queryparam=queryparam_dict,
), ),
) )
return response return QueryResult(
content=response,
raw_data=raw_data
)
else:
# Streaming response (AsyncIterator)
return QueryResult(
response_iterator=response,
raw_data=raw_data,
is_streaming=True
)