Merge pull request #2147 from danielaskdd/return-reference-on-query

Feature: Add Reference List Support for All Query Endpoints
This commit is contained in:
Daniel.y 2025-09-25 16:58:32 +08:00 committed by GitHub
commit b4cc249dca
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 1009 additions and 359 deletions

View file

@ -1 +1 @@
__api_version__ = "0230"
__api_version__ = "0231"

View file

@ -88,6 +88,11 @@ class QueryRequest(BaseModel):
description="Enable reranking for retrieved text chunks. If True but no rerank model is configured, a warning will be issued. Default is True.",
)
include_references: Optional[bool] = Field(
default=True,
description="If True, includes reference list in responses. Affects /query and /query/stream endpoints. /query/data always includes references.",
)
@field_validator("query", mode="after")
@classmethod
def query_strip_after(cls, query: str) -> str:
@ -122,6 +127,10 @@ class QueryResponse(BaseModel):
response: str = Field(
description="The generated response",
)
references: Optional[List[Dict[str, str]]] = Field(
default=None,
description="Reference list (only included when include_references=True, /query/data always includes references.)",
)
class QueryDataResponse(BaseModel):
@ -149,6 +158,7 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
request (QueryRequest): The request object containing the query parameters.
Returns:
QueryResponse: A Pydantic model containing the result of the query processing.
If include_references=True, also includes reference list.
If a string is returned (e.g., cache hit), it's directly returned.
Otherwise, an async generator may be used to build the response.
@ -160,15 +170,26 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
param = request.to_query_params(False)
response = await rag.aquery(request.query, param=param)
# If response is a string (e.g. cache hit), return directly
if isinstance(response, str):
return QueryResponse(response=response)
# Get reference list if requested
reference_list = None
if request.include_references:
try:
# Use aquery_data to get reference list independently
data_result = await rag.aquery_data(request.query, param=param)
if isinstance(data_result, dict) and "data" in data_result:
reference_list = data_result["data"].get("references", [])
except Exception as e:
logging.warning(f"Failed to get reference list: {str(e)}")
reference_list = []
if isinstance(response, dict):
# Process response and return with optional references
if isinstance(response, str):
return QueryResponse(response=response, references=reference_list)
elif isinstance(response, dict):
result = json.dumps(response, indent=2)
return QueryResponse(response=result)
return QueryResponse(response=result, references=reference_list)
else:
return QueryResponse(response=str(response))
return QueryResponse(response=str(response), references=reference_list)
except Exception as e:
trace_exception(e)
raise HTTPException(status_code=500, detail=str(e))
@ -178,12 +199,18 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
"""
This endpoint performs a retrieval-augmented generation (RAG) query and streams the response.
The streaming response includes:
1. Reference list (sent first as a single message, if include_references=True)
2. LLM response content (streamed as multiple chunks)
Args:
request (QueryRequest): The request object containing the query parameters.
optional_api_key (Optional[str], optional): An optional API key for authentication. Defaults to None.
Returns:
StreamingResponse: A streaming response containing the RAG query results.
StreamingResponse: A streaming response containing:
- First message: {"references": [...]} - Complete reference list (if requested)
- Subsequent messages: {"response": "..."} - LLM response chunks
- Error messages: {"error": "..."} - If any errors occur
"""
try:
param = request.to_query_params(True)
@ -192,6 +219,28 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
from fastapi.responses import StreamingResponse
async def stream_generator():
# Get reference list if requested (default is True for backward compatibility)
reference_list = []
if request.include_references:
try:
# Use aquery_data to get reference list independently
data_param = request.to_query_params(
False
) # Non-streaming for data
data_result = await rag.aquery_data(
request.query, param=data_param
)
if isinstance(data_result, dict) and "data" in data_result:
reference_list = data_result["data"].get("references", [])
except Exception as e:
logging.warning(f"Failed to get reference list: {str(e)}")
reference_list = []
# Send reference list first (if requested)
if request.include_references:
yield f"{json.dumps({'references': reference_list})}\n"
# Then stream the response content
if isinstance(response, str):
# If it's a string, send it all at once
yield f"{json.dumps({'response': response})}\n"

View file

@ -11,6 +11,10 @@ from typing import (
TypedDict,
TypeVar,
Callable,
Optional,
Dict,
List,
AsyncIterator,
)
from .utils import EmbeddingFunc
from .types import KnowledgeGraph
@ -158,6 +162,12 @@ class QueryParam:
Default is True to enable reranking when rerank model is available.
"""
include_references: bool = False
"""If True, includes reference list in the response for supported endpoints.
This parameter controls whether the API response includes a references field
containing citation information for the retrieved content.
"""
@dataclass
class StorageNameSpace(ABC):
@ -814,3 +824,68 @@ class DeletionResult:
message: str
status_code: int = 200
file_path: str | None = None
# Unified Query Result Data Structures for Reference List Support
@dataclass
class QueryResult:
"""
Unified query result data structure for all query modes.
Attributes:
content: Text content for non-streaming responses
response_iterator: Streaming response iterator for streaming responses
raw_data: Complete structured data including references and metadata
is_streaming: Whether this is a streaming result
"""
content: Optional[str] = None
response_iterator: Optional[AsyncIterator[str]] = None
raw_data: Optional[Dict[str, Any]] = None
is_streaming: bool = False
@property
def reference_list(self) -> List[Dict[str, str]]:
"""
Convenient property to extract reference list from raw_data.
Returns:
List[Dict[str, str]]: Reference list in format:
[{"reference_id": "1", "file_path": "/path/to/file.pdf"}, ...]
"""
if self.raw_data:
return self.raw_data.get("data", {}).get("references", [])
return []
@property
def metadata(self) -> Dict[str, Any]:
"""
Convenient property to extract metadata from raw_data.
Returns:
Dict[str, Any]: Query metadata including query_mode, keywords, etc.
"""
if self.raw_data:
return self.raw_data.get("metadata", {})
return {}
@dataclass
class QueryContextResult:
"""
Unified query context result data structure.
Attributes:
context: LLM context string
raw_data: Complete structured data including reference_list
"""
context: str
raw_data: Dict[str, Any]
@property
def reference_list(self) -> List[Dict[str, str]]:
"""Convenient property to extract reference list from raw_data."""
return self.raw_data.get("data", {}).get("references", [])

View file

@ -71,6 +71,7 @@ from lightrag.base import (
StoragesStatus,
DeletionResult,
OllamaServerInfos,
QueryResult,
)
from lightrag.namespace import NameSpace
from lightrag.operate import (
@ -2075,8 +2076,10 @@ class LightRAG:
# If a custom model is provided in param, temporarily update global config
global_config = asdict(self)
query_result = None
if param.mode in ["local", "global", "hybrid", "mix"]:
response = await kg_query(
query_result = await kg_query(
query.strip(),
self.chunk_entity_relation_graph,
self.entities_vdb,
@ -2089,7 +2092,7 @@ class LightRAG:
chunks_vdb=self.chunks_vdb,
)
elif param.mode == "naive":
response = await naive_query(
query_result = await naive_query(
query.strip(),
self.chunks_vdb,
param,
@ -2111,10 +2114,22 @@ class LightRAG:
enable_cot=True,
stream=param.stream,
)
# Create QueryResult for bypass mode
query_result = QueryResult(
content=response if not param.stream else None,
response_iterator=response if param.stream else None,
is_streaming=param.stream,
)
else:
raise ValueError(f"Unknown mode {param.mode}")
await self._query_done()
return response
# Return appropriate response based on streaming mode
if query_result.is_streaming:
return query_result.response_iterator
else:
return query_result.content
async def aquery_data(
self,
@ -2229,61 +2244,81 @@ class LightRAG:
"""
global_config = asdict(self)
if param.mode in ["local", "global", "hybrid", "mix"]:
logger.debug(f"[aquery_data] Using kg_query for mode: {param.mode}")
final_data = await kg_query(
# Create a copy of param to avoid modifying the original
data_param = QueryParam(
mode=param.mode,
only_need_context=True, # Skip LLM generation, only get context and data
only_need_prompt=False,
response_type=param.response_type,
stream=False, # Data retrieval doesn't need streaming
top_k=param.top_k,
chunk_top_k=param.chunk_top_k,
max_entity_tokens=param.max_entity_tokens,
max_relation_tokens=param.max_relation_tokens,
max_total_tokens=param.max_total_tokens,
hl_keywords=param.hl_keywords,
ll_keywords=param.ll_keywords,
conversation_history=param.conversation_history,
history_turns=param.history_turns,
model_func=param.model_func,
user_prompt=param.user_prompt,
enable_rerank=param.enable_rerank,
)
query_result = None
if data_param.mode in ["local", "global", "hybrid", "mix"]:
logger.debug(f"[aquery_data] Using kg_query for mode: {data_param.mode}")
query_result = await kg_query(
query.strip(),
self.chunk_entity_relation_graph,
self.entities_vdb,
self.relationships_vdb,
self.text_chunks,
param,
data_param, # Use data_param with only_need_context=True
global_config,
hashing_kv=self.llm_response_cache,
system_prompt=None,
chunks_vdb=self.chunks_vdb,
return_raw_data=True, # Get final processed data
)
elif param.mode == "naive":
logger.debug(f"[aquery_data] Using naive_query for mode: {param.mode}")
final_data = await naive_query(
elif data_param.mode == "naive":
logger.debug(f"[aquery_data] Using naive_query for mode: {data_param.mode}")
query_result = await naive_query(
query.strip(),
self.chunks_vdb,
param,
data_param, # Use data_param with only_need_context=True
global_config,
hashing_kv=self.llm_response_cache,
system_prompt=None,
return_raw_data=True, # Get final processed data
)
elif param.mode == "bypass":
elif data_param.mode == "bypass":
logger.debug("[aquery_data] Using bypass mode")
# bypass mode returns empty data using convert_to_user_format
final_data = convert_to_user_format(
empty_raw_data = convert_to_user_format(
[], # no entities
[], # no relationships
[], # no chunks
[], # no references
"bypass",
)
query_result = QueryResult(content="", raw_data=empty_raw_data)
else:
raise ValueError(f"Unknown mode {param.mode}")
raise ValueError(f"Unknown mode {data_param.mode}")
# Extract raw_data from QueryResult
final_data = query_result.raw_data if query_result else {}
# Log final result counts - adapt to new data format from convert_to_user_format
if isinstance(final_data, dict) and "data" in final_data:
# New format: data is nested under 'data' field
if final_data and "data" in final_data:
data_section = final_data["data"]
entities_count = len(data_section.get("entities", []))
relationships_count = len(data_section.get("relationships", []))
chunks_count = len(data_section.get("chunks", []))
logger.debug(
f"[aquery_data] Final result: {entities_count} entities, {relationships_count} relationships, {chunks_count} chunks"
)
else:
# Fallback for other formats
entities_count = len(final_data.get("entities", []))
relationships_count = len(final_data.get("relationships", []))
chunks_count = len(final_data.get("chunks", []))
logger.debug(
f"[aquery_data] Final result: {entities_count} entities, {relationships_count} relationships, {chunks_count} chunks"
)
logger.warning("[aquery_data] No data section found in query result")
await self._query_done()
return final_data

View file

@ -39,6 +39,8 @@ from .base import (
BaseVectorStorage,
TextChunkSchema,
QueryParam,
QueryResult,
QueryContextResult,
)
from .prompt import PROMPTS
from .constants import (
@ -2277,16 +2279,38 @@ async def kg_query(
hashing_kv: BaseKVStorage | None = None,
system_prompt: str | None = None,
chunks_vdb: BaseVectorStorage = None,
return_raw_data: bool = False,
) -> str | AsyncIterator[str] | dict[str, Any]:
) -> QueryResult:
"""
Execute knowledge graph query and return unified QueryResult object.
Args:
query: Query string
knowledge_graph_inst: Knowledge graph storage instance
entities_vdb: Entity vector database
relationships_vdb: Relationship vector database
text_chunks_db: Text chunks storage
query_param: Query parameters
global_config: Global configuration
hashing_kv: Cache storage
system_prompt: System prompt
chunks_vdb: Document chunks vector database
Returns:
QueryResult: Unified query result object containing:
- content: Non-streaming response text content
- response_iterator: Streaming response iterator
- raw_data: Complete structured data (including references and metadata)
- is_streaming: Whether this is a streaming result
Based on different query_param settings, different fields will be populated:
- only_need_context=True: content contains context string
- only_need_prompt=True: content contains complete prompt
- stream=True: response_iterator contains streaming response, raw_data contains complete data
- default: content contains LLM response text, raw_data contains complete data
"""
if not query:
if return_raw_data:
return {
"status": "failure",
"message": "Query string is empty.",
"data": {},
}
return PROMPTS["fail_response"]
return QueryResult(content=PROMPTS["fail_response"])
if query_param.model_func:
use_model_func = query_param.model_func
@ -2315,12 +2339,11 @@ async def kg_query(
)
if (
cached_result is not None
and not return_raw_data
and not query_param.only_need_context
and not query_param.only_need_prompt
):
cached_response, _ = cached_result # Extract content, ignore timestamp
return cached_response
return QueryResult(content=cached_response)
hl_keywords, ll_keywords = await get_keywords_from_query(
query, query_param, global_config, hashing_kv
@ -2339,53 +2362,13 @@ async def kg_query(
logger.warning(f"Forced low_level_keywords to origin query: {query}")
ll_keywords = [query]
else:
if return_raw_data:
return {
"status": "failure",
"message": "Both high_level_keywords and low_level_keywords are empty",
"data": {},
}
return PROMPTS["fail_response"]
return QueryResult(content=PROMPTS["fail_response"])
ll_keywords_str = ", ".join(ll_keywords) if ll_keywords else ""
hl_keywords_str = ", ".join(hl_keywords) if hl_keywords else ""
# If raw data is requested, get both context and raw data
if return_raw_data:
context_result = await _build_query_context(
query,
ll_keywords_str,
hl_keywords_str,
knowledge_graph_inst,
entities_vdb,
relationships_vdb,
text_chunks_db,
query_param,
chunks_vdb,
return_raw_data=True,
)
if isinstance(context_result, tuple):
context, raw_data = context_result
logger.debug(f"[kg_query] Context length: {len(context) if context else 0}")
logger.debug(
f"[kg_query] Raw data entities: {len(raw_data.get('entities', []))}, relationships: {len(raw_data.get('relationships', []))}, chunks: {len(raw_data.get('chunks', []))}"
)
return raw_data
else:
if not context_result:
return {
"status": "failure",
"message": "Query return empty data set.",
"data": {},
}
else:
raise ValueError(
"Fail to build raw data query result. Invalid return from _build_query_context"
)
# Build context (normal flow)
context = await _build_query_context(
# Build query context (unified interface)
context_result = await _build_query_context(
query,
ll_keywords_str,
hl_keywords_str,
@ -2397,14 +2380,19 @@ async def kg_query(
chunks_vdb,
)
if query_param.only_need_context and not query_param.only_need_prompt:
return context if context is not None else PROMPTS["fail_response"]
if context is None:
return PROMPTS["fail_response"]
if context_result is None:
return QueryResult(content=PROMPTS["fail_response"])
# Return different content based on query parameters
if query_param.only_need_context and not query_param.only_need_prompt:
return QueryResult(
content=context_result.context, raw_data=context_result.raw_data
)
# Build system prompt
sys_prompt_temp = system_prompt if system_prompt else PROMPTS["rag_response"]
sys_prompt = sys_prompt_temp.format(
context_data=context,
context_data=context_result.context,
response_type=query_param.response_type,
)
@ -2415,8 +2403,10 @@ async def kg_query(
)
if query_param.only_need_prompt:
return "\n\n".join([sys_prompt, "---User Query---", user_query])
prompt_content = "\n\n".join([sys_prompt, "---User Query---", user_query])
return QueryResult(content=prompt_content, raw_data=context_result.raw_data)
# Call LLM
tokenizer: Tokenizer = global_config["tokenizer"]
len_of_prompts = len(tokenizer.encode(query + sys_prompt))
logger.debug(
@ -2430,45 +2420,56 @@ async def kg_query(
enable_cot=True,
stream=query_param.stream,
)
if isinstance(response, str) and len(response) > len(sys_prompt):
response = (
response.replace(sys_prompt, "")
.replace("user", "")
.replace("model", "")
.replace(query, "")
.replace("<system>", "")
.replace("</system>", "")
.strip()
)
if hashing_kv.global_config.get("enable_llm_cache"):
# Save to cache with query parameters
queryparam_dict = {
"mode": query_param.mode,
"response_type": query_param.response_type,
"top_k": query_param.top_k,
"chunk_top_k": query_param.chunk_top_k,
"max_entity_tokens": query_param.max_entity_tokens,
"max_relation_tokens": query_param.max_relation_tokens,
"max_total_tokens": query_param.max_total_tokens,
"hl_keywords": query_param.hl_keywords or [],
"ll_keywords": query_param.ll_keywords or [],
"user_prompt": query_param.user_prompt or "",
"enable_rerank": query_param.enable_rerank,
}
await save_to_cache(
hashing_kv,
CacheData(
args_hash=args_hash,
content=response,
prompt=query,
mode=query_param.mode,
cache_type="query",
queryparam=queryparam_dict,
),
)
# Return unified result based on actual response type
if isinstance(response, str):
# Non-streaming response (string)
if len(response) > len(sys_prompt):
response = (
response.replace(sys_prompt, "")
.replace("user", "")
.replace("model", "")
.replace(query, "")
.replace("<system>", "")
.replace("</system>", "")
.strip()
)
return response
# Cache response
if hashing_kv and hashing_kv.global_config.get("enable_llm_cache"):
queryparam_dict = {
"mode": query_param.mode,
"response_type": query_param.response_type,
"top_k": query_param.top_k,
"chunk_top_k": query_param.chunk_top_k,
"max_entity_tokens": query_param.max_entity_tokens,
"max_relation_tokens": query_param.max_relation_tokens,
"max_total_tokens": query_param.max_total_tokens,
"hl_keywords": query_param.hl_keywords or [],
"ll_keywords": query_param.ll_keywords or [],
"user_prompt": query_param.user_prompt or "",
"enable_rerank": query_param.enable_rerank,
}
await save_to_cache(
hashing_kv,
CacheData(
args_hash=args_hash,
content=response,
prompt=query,
mode=query_param.mode,
cache_type="query",
queryparam=queryparam_dict,
),
)
return QueryResult(content=response, raw_data=context_result.raw_data)
else:
# Streaming response (AsyncIterator)
return QueryResult(
response_iterator=response,
raw_data=context_result.raw_data,
is_streaming=True,
)
async def get_keywords_from_query(
@ -3123,10 +3124,9 @@ async def _build_llm_context(
query_param: QueryParam,
global_config: dict[str, str],
chunk_tracking: dict = None,
return_raw_data: bool = False,
entity_id_to_original: dict = None,
relation_id_to_original: dict = None,
) -> str | tuple[str, dict[str, Any]]:
) -> tuple[str, dict[str, Any]]:
"""
Build the final LLM context string with token processing.
This includes dynamic token calculation and final chunk truncation.
@ -3134,22 +3134,17 @@ async def _build_llm_context(
tokenizer = global_config.get("tokenizer")
if not tokenizer:
logger.error("Missing tokenizer, cannot build LLM context")
if return_raw_data:
# Return empty raw data structure when no entities/relations
empty_raw_data = convert_to_user_format(
[],
[],
[],
[],
query_param.mode,
)
empty_raw_data["status"] = "failure"
empty_raw_data["message"] = "Missing tokenizer, cannot build LLM context."
return None, empty_raw_data
else:
logger.error("Tokenizer not found in global configuration.")
return None
# Return empty raw data structure when no tokenizer
empty_raw_data = convert_to_user_format(
[],
[],
[],
[],
query_param.mode,
)
empty_raw_data["status"] = "failure"
empty_raw_data["message"] = "Missing tokenizer, cannot build LLM context."
return "", empty_raw_data
# Get token limits
max_total_tokens = getattr(
@ -3268,20 +3263,17 @@ The reference documents list in Document Chunks(DC) is as follows (reference_id
# not necessary to use LLM to generate a response
if not entities_context and not relations_context:
if return_raw_data:
# Return empty raw data structure when no entities/relations
empty_raw_data = convert_to_user_format(
[],
[],
[],
[],
query_param.mode,
)
empty_raw_data["status"] = "failure"
empty_raw_data["message"] = "Query returned empty dataset."
return None, empty_raw_data
else:
return None
# Return empty raw data structure when no entities/relations
empty_raw_data = convert_to_user_format(
[],
[],
[],
[],
query_param.mode,
)
empty_raw_data["status"] = "failure"
empty_raw_data["message"] = "Query returned empty dataset."
return "", empty_raw_data
# output chunks tracking infomations
# format: <source><frequency>/<order> (e.g., E5/2 R2/1 C1/1)
@ -3342,26 +3334,23 @@ Document Chunks (DC) reference documents : (Each entry begins with [reference_id
"""
# If final data is requested, return both context and complete data structure
if return_raw_data:
logger.debug(
f"[_build_llm_context] Converting to user format: {len(entities_context)} entities, {len(relations_context)} relations, {len(truncated_chunks)} chunks"
)
final_data = convert_to_user_format(
entities_context,
relations_context,
truncated_chunks,
reference_list,
query_param.mode,
entity_id_to_original,
relation_id_to_original,
)
logger.debug(
f"[_build_llm_context] Final data after conversion: {len(final_data.get('entities', []))} entities, {len(final_data.get('relationships', []))} relationships, {len(final_data.get('chunks', []))} chunks"
)
return result, final_data
else:
return result
# Always return both context and complete data structure (unified approach)
logger.debug(
f"[_build_llm_context] Converting to user format: {len(entities_context)} entities, {len(relations_context)} relations, {len(truncated_chunks)} chunks"
)
final_data = convert_to_user_format(
entities_context,
relations_context,
truncated_chunks,
reference_list,
query_param.mode,
entity_id_to_original,
relation_id_to_original,
)
logger.debug(
f"[_build_llm_context] Final data after conversion: {len(final_data.get('entities', []))} entities, {len(final_data.get('relationships', []))} relationships, {len(final_data.get('chunks', []))} chunks"
)
return result, final_data
# Now let's update the old _build_query_context to use the new architecture
@ -3375,16 +3364,17 @@ async def _build_query_context(
text_chunks_db: BaseKVStorage,
query_param: QueryParam,
chunks_vdb: BaseVectorStorage = None,
return_raw_data: bool = False,
) -> str | None | tuple[str, dict[str, Any]]:
) -> QueryContextResult | None:
"""
Main query context building function using the new 4-stage architecture:
1. Search -> 2. Truncate -> 3. Merge chunks -> 4. Build LLM context
Returns unified QueryContextResult containing both context and raw_data.
"""
if not query:
logger.warning("Query is empty, skipping context building")
return ""
return None
# Stage 1: Pure search
search_result = await _perform_kg_search(
@ -3435,71 +3425,53 @@ async def _build_query_context(
return None
# Stage 4: Build final LLM context with dynamic token processing
# _build_llm_context now always returns tuple[str, dict]
context, raw_data = await _build_llm_context(
entities_context=truncation_result["entities_context"],
relations_context=truncation_result["relations_context"],
merged_chunks=merged_chunks,
query=query,
query_param=query_param,
global_config=text_chunks_db.global_config,
chunk_tracking=search_result["chunk_tracking"],
entity_id_to_original=truncation_result["entity_id_to_original"],
relation_id_to_original=truncation_result["relation_id_to_original"],
)
if return_raw_data:
# Convert keywords strings to lists
hl_keywords_list = hl_keywords.split(", ") if hl_keywords else []
ll_keywords_list = ll_keywords.split(", ") if ll_keywords else []
# Convert keywords strings to lists and add complete metadata to raw_data
hl_keywords_list = hl_keywords.split(", ") if hl_keywords else []
ll_keywords_list = ll_keywords.split(", ") if ll_keywords else []
# Get both context and final data - when return_raw_data=True, _build_llm_context always returns tuple
context, raw_data = await _build_llm_context(
entities_context=truncation_result["entities_context"],
relations_context=truncation_result["relations_context"],
merged_chunks=merged_chunks,
query=query,
query_param=query_param,
global_config=text_chunks_db.global_config,
chunk_tracking=search_result["chunk_tracking"],
return_raw_data=True,
entity_id_to_original=truncation_result["entity_id_to_original"],
relation_id_to_original=truncation_result["relation_id_to_original"],
)
# Add complete metadata to raw_data (preserve existing metadata including query_mode)
if "metadata" not in raw_data:
raw_data["metadata"] = {}
# Convert keywords strings to lists and add complete metadata to raw_data
hl_keywords_list = hl_keywords.split(", ") if hl_keywords else []
ll_keywords_list = ll_keywords.split(", ") if ll_keywords else []
# Update keywords while preserving existing metadata
raw_data["metadata"]["keywords"] = {
"high_level": hl_keywords_list,
"low_level": ll_keywords_list,
}
raw_data["metadata"]["processing_info"] = {
"total_entities_found": len(search_result.get("final_entities", [])),
"total_relations_found": len(search_result.get("final_relations", [])),
"entities_after_truncation": len(
truncation_result.get("filtered_entities", [])
),
"relations_after_truncation": len(
truncation_result.get("filtered_relations", [])
),
"merged_chunks_count": len(merged_chunks),
"final_chunks_count": len(raw_data.get("data", {}).get("chunks", [])),
}
# Add complete metadata to raw_data (preserve existing metadata including query_mode)
if "metadata" not in raw_data:
raw_data["metadata"] = {}
logger.debug(
f"[_build_query_context] Context length: {len(context) if context else 0}"
)
logger.debug(
f"[_build_query_context] Raw data entities: {len(raw_data.get('data', {}).get('entities', []))}, relationships: {len(raw_data.get('data', {}).get('relationships', []))}, chunks: {len(raw_data.get('data', {}).get('chunks', []))}"
)
# Update keywords while preserving existing metadata
raw_data["metadata"]["keywords"] = {
"high_level": hl_keywords_list,
"low_level": ll_keywords_list,
}
raw_data["metadata"]["processing_info"] = {
"total_entities_found": len(search_result.get("final_entities", [])),
"total_relations_found": len(search_result.get("final_relations", [])),
"entities_after_truncation": len(
truncation_result.get("filtered_entities", [])
),
"relations_after_truncation": len(
truncation_result.get("filtered_relations", [])
),
"merged_chunks_count": len(merged_chunks),
"final_chunks_count": len(raw_data.get("chunks", [])),
}
logger.debug(
f"[_build_query_context] Context length: {len(context) if context else 0}"
)
logger.debug(
f"[_build_query_context] Raw data entities: {len(raw_data.get('entities', []))}, relationships: {len(raw_data.get('relationships', []))}, chunks: {len(raw_data.get('chunks', []))}"
)
return context, raw_data
else:
# Normal context building (existing logic)
context = await _build_llm_context(
entities_context=truncation_result["entities_context"],
relations_context=truncation_result["relations_context"],
merged_chunks=merged_chunks,
query=query,
query_param=query_param,
global_config=text_chunks_db.global_config,
chunk_tracking=search_result["chunk_tracking"],
)
return context
return QueryContextResult(context=context, raw_data=raw_data)
async def _get_node_data(
@ -4105,19 +4077,28 @@ async def naive_query(
global_config: dict[str, str],
hashing_kv: BaseKVStorage | None = None,
system_prompt: str | None = None,
return_raw_data: bool = False,
) -> str | AsyncIterator[str] | dict[str, Any]:
) -> QueryResult:
"""
Execute naive query and return unified QueryResult object.
Args:
query: Query string
chunks_vdb: Document chunks vector database
query_param: Query parameters
global_config: Global configuration
hashing_kv: Cache storage
system_prompt: System prompt
Returns:
QueryResult: Unified query result object containing:
- content: Non-streaming response text content
- response_iterator: Streaming response iterator
- raw_data: Complete structured data (including references and metadata)
- is_streaming: Whether this is a streaming result
"""
if not query:
if return_raw_data:
# Return empty raw data structure when query is empty
empty_raw_data = {
"status": "failure",
"message": "Query string is empty.",
"data": {},
}
return empty_raw_data
else:
return PROMPTS["fail_response"]
return QueryResult(content=PROMPTS["fail_response"])
if query_param.model_func:
use_model_func = query_param.model_func
@ -4147,41 +4128,28 @@ async def naive_query(
if cached_result is not None:
cached_response, _ = cached_result # Extract content, ignore timestamp
if not query_param.only_need_context and not query_param.only_need_prompt:
return cached_response
return QueryResult(content=cached_response)
tokenizer: Tokenizer = global_config["tokenizer"]
if not tokenizer:
if return_raw_data:
# Return empty raw data structure when tokenizer is missing
empty_raw_data = {
"status": "failure",
"message": "Tokenizer not found in global configuration.",
"data": {},
}
return empty_raw_data
else:
logger.error("Tokenizer not found in global configuration.")
return PROMPTS["fail_response"]
logger.error("Tokenizer not found in global configuration.")
return QueryResult(content=PROMPTS["fail_response"])
chunks = await _get_vector_context(query, chunks_vdb, query_param, None)
if chunks is None or len(chunks) == 0:
# If only raw data is requested, return it directly
if return_raw_data:
empty_raw_data = convert_to_user_format(
[], # naive mode has no entities
[], # naive mode has no relationships
[], # no chunks
[], # no references
"naive",
)
empty_raw_data["message"] = "No relevant document chunks found."
return empty_raw_data
else:
return PROMPTS["fail_response"]
# Build empty raw data structure for naive mode
empty_raw_data = convert_to_user_format(
[], # naive mode has no entities
[], # naive mode has no relationships
[], # no chunks
[], # no references
"naive",
)
empty_raw_data["message"] = "No relevant document chunks found."
return QueryResult(content=PROMPTS["fail_response"], raw_data=empty_raw_data)
# Calculate dynamic token limit for chunks
# Get token limits from query_param (with fallback to global_config)
max_total_tokens = getattr(
query_param,
"max_total_tokens",
@ -4240,30 +4208,26 @@ async def naive_query(
logger.info(f"Final context: {len(processed_chunks_with_ref_ids)} chunks")
# If only raw data is requested, return it directly
if return_raw_data:
# Build raw data structure for naive mode using processed chunks with reference IDs
raw_data = convert_to_user_format(
[], # naive mode has no entities
[], # naive mode has no relationships
processed_chunks_with_ref_ids,
reference_list,
"naive",
)
# Build raw data structure for naive mode using processed chunks with reference IDs
raw_data = convert_to_user_format(
[], # naive mode has no entities
[], # naive mode has no relationships
processed_chunks_with_ref_ids,
reference_list,
"naive",
)
# Add complete metadata for naive mode
if "metadata" not in raw_data:
raw_data["metadata"] = {}
raw_data["metadata"]["keywords"] = {
"high_level": [], # naive mode has no keyword extraction
"low_level": [], # naive mode has no keyword extraction
}
raw_data["metadata"]["processing_info"] = {
"total_chunks_found": len(chunks),
"final_chunks_count": len(processed_chunks_with_ref_ids),
}
return raw_data
# Add complete metadata for naive mode
if "metadata" not in raw_data:
raw_data["metadata"] = {}
raw_data["metadata"]["keywords"] = {
"high_level": [], # naive mode has no keyword extraction
"low_level": [], # naive mode has no keyword extraction
}
raw_data["metadata"]["processing_info"] = {
"total_chunks_found": len(chunks),
"final_chunks_count": len(processed_chunks_with_ref_ids),
}
# Build text_units_context from processed chunks with reference IDs
text_units_context = []
@ -4284,8 +4248,7 @@ async def naive_query(
if ref["reference_id"]
)
if query_param.only_need_context and not query_param.only_need_prompt:
return f"""
context_content = f"""
---Document Chunks(DC)---
```json
@ -4297,6 +4260,10 @@ async def naive_query(
{reference_list_str}
"""
if query_param.only_need_context and not query_param.only_need_prompt:
return QueryResult(content=context_content, raw_data=raw_data)
user_query = (
"\n\n".join([query, query_param.user_prompt])
if query_param.user_prompt
@ -4310,7 +4277,8 @@ async def naive_query(
)
if query_param.only_need_prompt:
return "\n\n".join([sys_prompt, "---User Query---", user_query])
prompt_content = "\n\n".join([sys_prompt, "---User Query---", user_query])
return QueryResult(content=prompt_content, raw_data=raw_data)
len_of_prompts = len(tokenizer.encode(query + sys_prompt))
logger.debug(
@ -4325,43 +4293,51 @@ async def naive_query(
stream=query_param.stream,
)
if isinstance(response, str) and len(response) > len(sys_prompt):
response = (
response[len(sys_prompt) :]
.replace(sys_prompt, "")
.replace("user", "")
.replace("model", "")
.replace(query, "")
.replace("<system>", "")
.replace("</system>", "")
.strip()
)
# Return unified result based on actual response type
if isinstance(response, str):
# Non-streaming response (string)
if len(response) > len(sys_prompt):
response = (
response[len(sys_prompt) :]
.replace(sys_prompt, "")
.replace("user", "")
.replace("model", "")
.replace(query, "")
.replace("<system>", "")
.replace("</system>", "")
.strip()
)
if hashing_kv.global_config.get("enable_llm_cache"):
# Save to cache with query parameters
queryparam_dict = {
"mode": query_param.mode,
"response_type": query_param.response_type,
"top_k": query_param.top_k,
"chunk_top_k": query_param.chunk_top_k,
"max_entity_tokens": query_param.max_entity_tokens,
"max_relation_tokens": query_param.max_relation_tokens,
"max_total_tokens": query_param.max_total_tokens,
"hl_keywords": query_param.hl_keywords or [],
"ll_keywords": query_param.ll_keywords or [],
"user_prompt": query_param.user_prompt or "",
"enable_rerank": query_param.enable_rerank,
}
await save_to_cache(
hashing_kv,
CacheData(
args_hash=args_hash,
content=response,
prompt=query,
mode=query_param.mode,
cache_type="query",
queryparam=queryparam_dict,
),
)
# Cache response
if hashing_kv and hashing_kv.global_config.get("enable_llm_cache"):
queryparam_dict = {
"mode": query_param.mode,
"response_type": query_param.response_type,
"top_k": query_param.top_k,
"chunk_top_k": query_param.chunk_top_k,
"max_entity_tokens": query_param.max_entity_tokens,
"max_relation_tokens": query_param.max_relation_tokens,
"max_total_tokens": query_param.max_total_tokens,
"hl_keywords": query_param.hl_keywords or [],
"ll_keywords": query_param.ll_keywords or [],
"user_prompt": query_param.user_prompt or "",
"enable_rerank": query_param.enable_rerank,
}
await save_to_cache(
hashing_kv,
CacheData(
args_hash=args_hash,
content=response,
prompt=query,
mode=query_param.mode,
cache_type="query",
queryparam=queryparam_dict,
),
)
return response
return QueryResult(content=response, raw_data=raw_data)
else:
# Streaming response (AsyncIterator)
return QueryResult(
response_iterator=response, raw_data=raw_data, is_streaming=True
)

View file

@ -11,7 +11,8 @@ Updated to handle the new data format where:
import requests
import time
from typing import Dict, Any
import json
from typing import Dict, Any, List, Optional
# API configuration
API_KEY = "your-secure-api-key-here-123"
@ -21,6 +22,456 @@ BASE_URL = "http://localhost:9621"
AUTH_HEADERS = {"Content-Type": "application/json", "X-API-Key": API_KEY}
def validate_references_format(references: List[Dict[str, Any]]) -> bool:
"""Validate the format of references list"""
if not isinstance(references, list):
print(f"❌ References should be a list, got {type(references)}")
return False
for i, ref in enumerate(references):
if not isinstance(ref, dict):
print(f"❌ Reference {i} should be a dict, got {type(ref)}")
return False
required_fields = ["reference_id", "file_path"]
for field in required_fields:
if field not in ref:
print(f"❌ Reference {i} missing required field: {field}")
return False
if not isinstance(ref[field], str):
print(
f"❌ Reference {i} field '{field}' should be string, got {type(ref[field])}"
)
return False
return True
def parse_streaming_response(
response_text: str,
) -> tuple[Optional[List[Dict]], List[str], List[str]]:
"""Parse streaming response and extract references, response chunks, and errors"""
references = None
response_chunks = []
errors = []
lines = response_text.strip().split("\n")
for line in lines:
line = line.strip()
if not line or line.startswith("data: "):
if line.startswith("data: "):
line = line[6:] # Remove 'data: ' prefix
if not line:
continue
try:
data = json.loads(line)
if "references" in data:
references = data["references"]
elif "response" in data:
response_chunks.append(data["response"])
elif "error" in data:
errors.append(data["error"])
except json.JSONDecodeError:
# Skip non-JSON lines (like SSE comments)
continue
return references, response_chunks, errors
def test_query_endpoint_references():
"""Test /query endpoint references functionality"""
print("\n" + "=" * 60)
print("Testing /query endpoint references functionality")
print("=" * 60)
query_text = "who authored LightRAG"
endpoint = f"{BASE_URL}/query"
# Test 1: References enabled (default)
print("\n🧪 Test 1: References enabled (default)")
print("-" * 40)
try:
response = requests.post(
endpoint,
json={"query": query_text, "mode": "mix", "include_references": True},
headers=AUTH_HEADERS,
timeout=30,
)
if response.status_code == 200:
data = response.json()
# Check response structure
if "response" not in data:
print("❌ Missing 'response' field")
return False
if "references" not in data:
print("❌ Missing 'references' field when include_references=True")
return False
references = data["references"]
if references is None:
print("❌ References should not be None when include_references=True")
return False
if not validate_references_format(references):
return False
print(f"✅ References enabled: Found {len(references)} references")
print(f" Response length: {len(data['response'])} characters")
# Display reference list
if references:
print(" 📚 Reference List:")
for i, ref in enumerate(references, 1):
ref_id = ref.get("reference_id", "Unknown")
file_path = ref.get("file_path", "Unknown")
print(f" {i}. ID: {ref_id} | File: {file_path}")
else:
print(f"❌ Request failed: {response.status_code}")
print(f" Error: {response.text}")
return False
except Exception as e:
print(f"❌ Test 1 failed: {str(e)}")
return False
# Test 2: References disabled
print("\n🧪 Test 2: References disabled")
print("-" * 40)
try:
response = requests.post(
endpoint,
json={"query": query_text, "mode": "mix", "include_references": False},
headers=AUTH_HEADERS,
timeout=30,
)
if response.status_code == 200:
data = response.json()
# Check response structure
if "response" not in data:
print("❌ Missing 'response' field")
return False
references = data.get("references")
if references is not None:
print("❌ References should be None when include_references=False")
return False
print("✅ References disabled: No references field present")
print(f" Response length: {len(data['response'])} characters")
else:
print(f"❌ Request failed: {response.status_code}")
print(f" Error: {response.text}")
return False
except Exception as e:
print(f"❌ Test 2 failed: {str(e)}")
return False
print("\n✅ /query endpoint references tests passed!")
return True
def test_query_stream_endpoint_references():
"""Test /query/stream endpoint references functionality"""
print("\n" + "=" * 60)
print("Testing /query/stream endpoint references functionality")
print("=" * 60)
query_text = "who authored LightRAG"
endpoint = f"{BASE_URL}/query/stream"
# Test 1: Streaming with references enabled
print("\n🧪 Test 1: Streaming with references enabled")
print("-" * 40)
try:
response = requests.post(
endpoint,
json={"query": query_text, "mode": "mix", "include_references": True},
headers=AUTH_HEADERS,
timeout=30,
stream=True,
)
if response.status_code == 200:
# Collect streaming response
full_response = ""
for chunk in response.iter_content(chunk_size=1024, decode_unicode=True):
if chunk:
# Ensure chunk is string type
if isinstance(chunk, bytes):
chunk = chunk.decode("utf-8")
full_response += chunk
# Parse streaming response
references, response_chunks, errors = parse_streaming_response(
full_response
)
if errors:
print(f"❌ Errors in streaming response: {errors}")
return False
if references is None:
print("❌ No references found in streaming response")
return False
if not validate_references_format(references):
return False
if not response_chunks:
print("❌ No response chunks found in streaming response")
return False
print(f"✅ Streaming with references: Found {len(references)} references")
print(f" Response chunks: {len(response_chunks)}")
print(
f" Total response length: {sum(len(chunk) for chunk in response_chunks)} characters"
)
# Display reference list
if references:
print(" 📚 Reference List:")
for i, ref in enumerate(references, 1):
ref_id = ref.get("reference_id", "Unknown")
file_path = ref.get("file_path", "Unknown")
print(f" {i}. ID: {ref_id} | File: {file_path}")
else:
print(f"❌ Request failed: {response.status_code}")
print(f" Error: {response.text}")
return False
except Exception as e:
print(f"❌ Test 1 failed: {str(e)}")
return False
# Test 2: Streaming with references disabled
print("\n🧪 Test 2: Streaming with references disabled")
print("-" * 40)
try:
response = requests.post(
endpoint,
json={"query": query_text, "mode": "mix", "include_references": False},
headers=AUTH_HEADERS,
timeout=30,
stream=True,
)
if response.status_code == 200:
# Collect streaming response
full_response = ""
for chunk in response.iter_content(chunk_size=1024, decode_unicode=True):
if chunk:
# Ensure chunk is string type
if isinstance(chunk, bytes):
chunk = chunk.decode("utf-8")
full_response += chunk
# Parse streaming response
references, response_chunks, errors = parse_streaming_response(
full_response
)
if errors:
print(f"❌ Errors in streaming response: {errors}")
return False
if references is not None:
print("❌ References should be None when include_references=False")
return False
if not response_chunks:
print("❌ No response chunks found in streaming response")
return False
print("✅ Streaming without references: No references present")
print(f" Response chunks: {len(response_chunks)}")
print(
f" Total response length: {sum(len(chunk) for chunk in response_chunks)} characters"
)
else:
print(f"❌ Request failed: {response.status_code}")
print(f" Error: {response.text}")
return False
except Exception as e:
print(f"❌ Test 2 failed: {str(e)}")
return False
print("\n✅ /query/stream endpoint references tests passed!")
return True
def test_references_consistency():
"""Test references consistency across all endpoints"""
print("\n" + "=" * 60)
print("Testing references consistency across endpoints")
print("=" * 60)
query_text = "who authored LightRAG"
query_params = {
"query": query_text,
"mode": "mix",
"top_k": 10,
"chunk_top_k": 8,
"include_references": True,
}
references_data = {}
# Test /query endpoint
print("\n🧪 Testing /query endpoint")
print("-" * 40)
try:
response = requests.post(
f"{BASE_URL}/query", json=query_params, headers=AUTH_HEADERS, timeout=30
)
if response.status_code == 200:
data = response.json()
references_data["query"] = data.get("references", [])
print(f"✅ /query: {len(references_data['query'])} references")
else:
print(f"❌ /query failed: {response.status_code}")
return False
except Exception as e:
print(f"❌ /query test failed: {str(e)}")
return False
# Test /query/stream endpoint
print("\n🧪 Testing /query/stream endpoint")
print("-" * 40)
try:
response = requests.post(
f"{BASE_URL}/query/stream",
json=query_params,
headers=AUTH_HEADERS,
timeout=30,
stream=True,
)
if response.status_code == 200:
full_response = ""
for chunk in response.iter_content(chunk_size=1024, decode_unicode=True):
if chunk:
# Ensure chunk is string type
if isinstance(chunk, bytes):
chunk = chunk.decode("utf-8")
full_response += chunk
references, _, errors = parse_streaming_response(full_response)
if errors:
print(f"❌ Errors: {errors}")
return False
references_data["stream"] = references or []
print(f"✅ /query/stream: {len(references_data['stream'])} references")
else:
print(f"❌ /query/stream failed: {response.status_code}")
return False
except Exception as e:
print(f"❌ /query/stream test failed: {str(e)}")
return False
# Test /query/data endpoint
print("\n🧪 Testing /query/data endpoint")
print("-" * 40)
try:
response = requests.post(
f"{BASE_URL}/query/data",
json=query_params,
headers=AUTH_HEADERS,
timeout=30,
)
if response.status_code == 200:
data = response.json()
query_data = data.get("data", {})
references_data["data"] = query_data.get("references", [])
print(f"✅ /query/data: {len(references_data['data'])} references")
else:
print(f"❌ /query/data failed: {response.status_code}")
return False
except Exception as e:
print(f"❌ /query/data test failed: {str(e)}")
return False
# Compare references consistency
print("\n🔍 Comparing references consistency")
print("-" * 40)
# Convert to sets of (reference_id, file_path) tuples for comparison
def refs_to_set(refs):
return set(
(ref.get("reference_id", ""), ref.get("file_path", "")) for ref in refs
)
query_refs = refs_to_set(references_data["query"])
stream_refs = refs_to_set(references_data["stream"])
data_refs = refs_to_set(references_data["data"])
# Check consistency
consistency_passed = True
if query_refs != stream_refs:
print("❌ References mismatch between /query and /query/stream")
print(f" /query only: {query_refs - stream_refs}")
print(f" /query/stream only: {stream_refs - query_refs}")
consistency_passed = False
if query_refs != data_refs:
print("❌ References mismatch between /query and /query/data")
print(f" /query only: {query_refs - data_refs}")
print(f" /query/data only: {data_refs - query_refs}")
consistency_passed = False
if stream_refs != data_refs:
print("❌ References mismatch between /query/stream and /query/data")
print(f" /query/stream only: {stream_refs - data_refs}")
print(f" /query/data only: {data_refs - stream_refs}")
consistency_passed = False
if consistency_passed:
print("✅ All endpoints return consistent references")
print(f" Common references count: {len(query_refs)}")
# Display common reference list
if query_refs:
print(" 📚 Common Reference List:")
for i, (ref_id, file_path) in enumerate(sorted(query_refs), 1):
print(f" {i}. ID: {ref_id} | File: {file_path}")
return consistency_passed
def test_aquery_data_endpoint():
"""Test the /query/data endpoint"""
@ -239,15 +690,79 @@ def compare_with_regular_query():
print(f" Regular query error: {str(e)}")
def run_all_reference_tests():
"""Run all reference-related tests"""
print("\n" + "🚀" * 20)
print("LightRAG References Test Suite")
print("🚀" * 20)
all_tests_passed = True
# Test 1: /query endpoint references
try:
if not test_query_endpoint_references():
all_tests_passed = False
except Exception as e:
print(f"❌ /query endpoint test failed with exception: {str(e)}")
all_tests_passed = False
# Test 2: /query/stream endpoint references
try:
if not test_query_stream_endpoint_references():
all_tests_passed = False
except Exception as e:
print(f"❌ /query/stream endpoint test failed with exception: {str(e)}")
all_tests_passed = False
# Test 3: References consistency across endpoints
try:
if not test_references_consistency():
all_tests_passed = False
except Exception as e:
print(f"❌ References consistency test failed with exception: {str(e)}")
all_tests_passed = False
# Final summary
print("\n" + "=" * 60)
print("TEST SUITE SUMMARY")
print("=" * 60)
if all_tests_passed:
print("🎉 ALL TESTS PASSED!")
print("✅ /query endpoint references functionality works correctly")
print("✅ /query/stream endpoint references functionality works correctly")
print("✅ References are consistent across all endpoints")
print("\n🔧 System is ready for production use with reference support!")
else:
print("❌ SOME TESTS FAILED!")
print("Please check the error messages above and fix the issues.")
print("\n🔧 System needs attention before production deployment.")
return all_tests_passed
if __name__ == "__main__":
# Run main test
test_aquery_data_endpoint()
import sys
# Run comparison test
compare_with_regular_query()
if len(sys.argv) > 1 and sys.argv[1] == "--references-only":
# Run only the new reference tests
success = run_all_reference_tests()
sys.exit(0 if success else 1)
else:
# Run original tests plus new reference tests
print("Running original aquery_data endpoint test...")
test_aquery_data_endpoint()
print("\n💡 Usage tips:")
print("1. Ensure LightRAG API service is running")
print("2. Adjust base_url and authentication information as needed")
print("3. Modify query parameters to test different retrieval strategies")
print("4. Data query results can be used for further analysis and processing")
print("\nRunning comparison test...")
compare_with_regular_query()
print("\nRunning new reference tests...")
run_all_reference_tests()
print("\n💡 Usage tips:")
print("1. Ensure LightRAG API service is running")
print("2. Adjust base_url and authentication information as needed")
print("3. Modify query parameters to test different retrieval strategies")
print("4. Data query results can be used for further analysis and processing")
print("5. Run with --references-only flag to test only reference functionality")