Merge pull request #2147 from danielaskdd/return-reference-on-query

Feature: Add Reference List Support for All Query Endpoints
This commit is contained in:
Daniel.y 2025-09-25 16:58:32 +08:00 committed by GitHub
commit b4cc249dca
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 1009 additions and 359 deletions

View file

@ -1 +1 @@
__api_version__ = "0230" __api_version__ = "0231"

View file

@ -88,6 +88,11 @@ class QueryRequest(BaseModel):
description="Enable reranking for retrieved text chunks. If True but no rerank model is configured, a warning will be issued. Default is True.", description="Enable reranking for retrieved text chunks. If True but no rerank model is configured, a warning will be issued. Default is True.",
) )
include_references: Optional[bool] = Field(
default=True,
description="If True, includes reference list in responses. Affects /query and /query/stream endpoints. /query/data always includes references.",
)
@field_validator("query", mode="after") @field_validator("query", mode="after")
@classmethod @classmethod
def query_strip_after(cls, query: str) -> str: def query_strip_after(cls, query: str) -> str:
@ -122,6 +127,10 @@ class QueryResponse(BaseModel):
response: str = Field( response: str = Field(
description="The generated response", description="The generated response",
) )
references: Optional[List[Dict[str, str]]] = Field(
default=None,
description="Reference list (only included when include_references=True, /query/data always includes references.)",
)
class QueryDataResponse(BaseModel): class QueryDataResponse(BaseModel):
@ -149,6 +158,7 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
request (QueryRequest): The request object containing the query parameters. request (QueryRequest): The request object containing the query parameters.
Returns: Returns:
QueryResponse: A Pydantic model containing the result of the query processing. QueryResponse: A Pydantic model containing the result of the query processing.
If include_references=True, also includes reference list.
If a string is returned (e.g., cache hit), it's directly returned. If a string is returned (e.g., cache hit), it's directly returned.
Otherwise, an async generator may be used to build the response. Otherwise, an async generator may be used to build the response.
@ -160,15 +170,26 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
param = request.to_query_params(False) param = request.to_query_params(False)
response = await rag.aquery(request.query, param=param) response = await rag.aquery(request.query, param=param)
# If response is a string (e.g. cache hit), return directly # Get reference list if requested
if isinstance(response, str): reference_list = None
return QueryResponse(response=response) if request.include_references:
try:
# Use aquery_data to get reference list independently
data_result = await rag.aquery_data(request.query, param=param)
if isinstance(data_result, dict) and "data" in data_result:
reference_list = data_result["data"].get("references", [])
except Exception as e:
logging.warning(f"Failed to get reference list: {str(e)}")
reference_list = []
if isinstance(response, dict): # Process response and return with optional references
if isinstance(response, str):
return QueryResponse(response=response, references=reference_list)
elif isinstance(response, dict):
result = json.dumps(response, indent=2) result = json.dumps(response, indent=2)
return QueryResponse(response=result) return QueryResponse(response=result, references=reference_list)
else: else:
return QueryResponse(response=str(response)) return QueryResponse(response=str(response), references=reference_list)
except Exception as e: except Exception as e:
trace_exception(e) trace_exception(e)
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@ -178,12 +199,18 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
""" """
This endpoint performs a retrieval-augmented generation (RAG) query and streams the response. This endpoint performs a retrieval-augmented generation (RAG) query and streams the response.
The streaming response includes:
1. Reference list (sent first as a single message, if include_references=True)
2. LLM response content (streamed as multiple chunks)
Args: Args:
request (QueryRequest): The request object containing the query parameters. request (QueryRequest): The request object containing the query parameters.
optional_api_key (Optional[str], optional): An optional API key for authentication. Defaults to None.
Returns: Returns:
StreamingResponse: A streaming response containing the RAG query results. StreamingResponse: A streaming response containing:
- First message: {"references": [...]} - Complete reference list (if requested)
- Subsequent messages: {"response": "..."} - LLM response chunks
- Error messages: {"error": "..."} - If any errors occur
""" """
try: try:
param = request.to_query_params(True) param = request.to_query_params(True)
@ -192,6 +219,28 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
async def stream_generator(): async def stream_generator():
# Get reference list if requested (default is True for backward compatibility)
reference_list = []
if request.include_references:
try:
# Use aquery_data to get reference list independently
data_param = request.to_query_params(
False
) # Non-streaming for data
data_result = await rag.aquery_data(
request.query, param=data_param
)
if isinstance(data_result, dict) and "data" in data_result:
reference_list = data_result["data"].get("references", [])
except Exception as e:
logging.warning(f"Failed to get reference list: {str(e)}")
reference_list = []
# Send reference list first (if requested)
if request.include_references:
yield f"{json.dumps({'references': reference_list})}\n"
# Then stream the response content
if isinstance(response, str): if isinstance(response, str):
# If it's a string, send it all at once # If it's a string, send it all at once
yield f"{json.dumps({'response': response})}\n" yield f"{json.dumps({'response': response})}\n"

View file

@ -11,6 +11,10 @@ from typing import (
TypedDict, TypedDict,
TypeVar, TypeVar,
Callable, Callable,
Optional,
Dict,
List,
AsyncIterator,
) )
from .utils import EmbeddingFunc from .utils import EmbeddingFunc
from .types import KnowledgeGraph from .types import KnowledgeGraph
@ -158,6 +162,12 @@ class QueryParam:
Default is True to enable reranking when rerank model is available. Default is True to enable reranking when rerank model is available.
""" """
include_references: bool = False
"""If True, includes reference list in the response for supported endpoints.
This parameter controls whether the API response includes a references field
containing citation information for the retrieved content.
"""
@dataclass @dataclass
class StorageNameSpace(ABC): class StorageNameSpace(ABC):
@ -814,3 +824,68 @@ class DeletionResult:
message: str message: str
status_code: int = 200 status_code: int = 200
file_path: str | None = None file_path: str | None = None
# Unified Query Result Data Structures for Reference List Support
@dataclass
class QueryResult:
"""
Unified query result data structure for all query modes.
Attributes:
content: Text content for non-streaming responses
response_iterator: Streaming response iterator for streaming responses
raw_data: Complete structured data including references and metadata
is_streaming: Whether this is a streaming result
"""
content: Optional[str] = None
response_iterator: Optional[AsyncIterator[str]] = None
raw_data: Optional[Dict[str, Any]] = None
is_streaming: bool = False
@property
def reference_list(self) -> List[Dict[str, str]]:
"""
Convenient property to extract reference list from raw_data.
Returns:
List[Dict[str, str]]: Reference list in format:
[{"reference_id": "1", "file_path": "/path/to/file.pdf"}, ...]
"""
if self.raw_data:
return self.raw_data.get("data", {}).get("references", [])
return []
@property
def metadata(self) -> Dict[str, Any]:
"""
Convenient property to extract metadata from raw_data.
Returns:
Dict[str, Any]: Query metadata including query_mode, keywords, etc.
"""
if self.raw_data:
return self.raw_data.get("metadata", {})
return {}
@dataclass
class QueryContextResult:
"""
Unified query context result data structure.
Attributes:
context: LLM context string
raw_data: Complete structured data including reference_list
"""
context: str
raw_data: Dict[str, Any]
@property
def reference_list(self) -> List[Dict[str, str]]:
"""Convenient property to extract reference list from raw_data."""
return self.raw_data.get("data", {}).get("references", [])

View file

@ -71,6 +71,7 @@ from lightrag.base import (
StoragesStatus, StoragesStatus,
DeletionResult, DeletionResult,
OllamaServerInfos, OllamaServerInfos,
QueryResult,
) )
from lightrag.namespace import NameSpace from lightrag.namespace import NameSpace
from lightrag.operate import ( from lightrag.operate import (
@ -2075,8 +2076,10 @@ class LightRAG:
# If a custom model is provided in param, temporarily update global config # If a custom model is provided in param, temporarily update global config
global_config = asdict(self) global_config = asdict(self)
query_result = None
if param.mode in ["local", "global", "hybrid", "mix"]: if param.mode in ["local", "global", "hybrid", "mix"]:
response = await kg_query( query_result = await kg_query(
query.strip(), query.strip(),
self.chunk_entity_relation_graph, self.chunk_entity_relation_graph,
self.entities_vdb, self.entities_vdb,
@ -2089,7 +2092,7 @@ class LightRAG:
chunks_vdb=self.chunks_vdb, chunks_vdb=self.chunks_vdb,
) )
elif param.mode == "naive": elif param.mode == "naive":
response = await naive_query( query_result = await naive_query(
query.strip(), query.strip(),
self.chunks_vdb, self.chunks_vdb,
param, param,
@ -2111,10 +2114,22 @@ class LightRAG:
enable_cot=True, enable_cot=True,
stream=param.stream, stream=param.stream,
) )
# Create QueryResult for bypass mode
query_result = QueryResult(
content=response if not param.stream else None,
response_iterator=response if param.stream else None,
is_streaming=param.stream,
)
else: else:
raise ValueError(f"Unknown mode {param.mode}") raise ValueError(f"Unknown mode {param.mode}")
await self._query_done() await self._query_done()
return response
# Return appropriate response based on streaming mode
if query_result.is_streaming:
return query_result.response_iterator
else:
return query_result.content
async def aquery_data( async def aquery_data(
self, self,
@ -2229,61 +2244,81 @@ class LightRAG:
""" """
global_config = asdict(self) global_config = asdict(self)
if param.mode in ["local", "global", "hybrid", "mix"]: # Create a copy of param to avoid modifying the original
logger.debug(f"[aquery_data] Using kg_query for mode: {param.mode}") data_param = QueryParam(
final_data = await kg_query( mode=param.mode,
only_need_context=True, # Skip LLM generation, only get context and data
only_need_prompt=False,
response_type=param.response_type,
stream=False, # Data retrieval doesn't need streaming
top_k=param.top_k,
chunk_top_k=param.chunk_top_k,
max_entity_tokens=param.max_entity_tokens,
max_relation_tokens=param.max_relation_tokens,
max_total_tokens=param.max_total_tokens,
hl_keywords=param.hl_keywords,
ll_keywords=param.ll_keywords,
conversation_history=param.conversation_history,
history_turns=param.history_turns,
model_func=param.model_func,
user_prompt=param.user_prompt,
enable_rerank=param.enable_rerank,
)
query_result = None
if data_param.mode in ["local", "global", "hybrid", "mix"]:
logger.debug(f"[aquery_data] Using kg_query for mode: {data_param.mode}")
query_result = await kg_query(
query.strip(), query.strip(),
self.chunk_entity_relation_graph, self.chunk_entity_relation_graph,
self.entities_vdb, self.entities_vdb,
self.relationships_vdb, self.relationships_vdb,
self.text_chunks, self.text_chunks,
param, data_param, # Use data_param with only_need_context=True
global_config, global_config,
hashing_kv=self.llm_response_cache, hashing_kv=self.llm_response_cache,
system_prompt=None, system_prompt=None,
chunks_vdb=self.chunks_vdb, chunks_vdb=self.chunks_vdb,
return_raw_data=True, # Get final processed data
) )
elif param.mode == "naive": elif data_param.mode == "naive":
logger.debug(f"[aquery_data] Using naive_query for mode: {param.mode}") logger.debug(f"[aquery_data] Using naive_query for mode: {data_param.mode}")
final_data = await naive_query( query_result = await naive_query(
query.strip(), query.strip(),
self.chunks_vdb, self.chunks_vdb,
param, data_param, # Use data_param with only_need_context=True
global_config, global_config,
hashing_kv=self.llm_response_cache, hashing_kv=self.llm_response_cache,
system_prompt=None, system_prompt=None,
return_raw_data=True, # Get final processed data
) )
elif param.mode == "bypass": elif data_param.mode == "bypass":
logger.debug("[aquery_data] Using bypass mode") logger.debug("[aquery_data] Using bypass mode")
# bypass mode returns empty data using convert_to_user_format # bypass mode returns empty data using convert_to_user_format
final_data = convert_to_user_format( empty_raw_data = convert_to_user_format(
[], # no entities [], # no entities
[], # no relationships [], # no relationships
[], # no chunks [], # no chunks
[], # no references [], # no references
"bypass", "bypass",
) )
query_result = QueryResult(content="", raw_data=empty_raw_data)
else: else:
raise ValueError(f"Unknown mode {param.mode}") raise ValueError(f"Unknown mode {data_param.mode}")
# Extract raw_data from QueryResult
final_data = query_result.raw_data if query_result else {}
# Log final result counts - adapt to new data format from convert_to_user_format # Log final result counts - adapt to new data format from convert_to_user_format
if isinstance(final_data, dict) and "data" in final_data: if final_data and "data" in final_data:
# New format: data is nested under 'data' field
data_section = final_data["data"] data_section = final_data["data"]
entities_count = len(data_section.get("entities", [])) entities_count = len(data_section.get("entities", []))
relationships_count = len(data_section.get("relationships", [])) relationships_count = len(data_section.get("relationships", []))
chunks_count = len(data_section.get("chunks", [])) chunks_count = len(data_section.get("chunks", []))
logger.debug(
f"[aquery_data] Final result: {entities_count} entities, {relationships_count} relationships, {chunks_count} chunks"
)
else: else:
# Fallback for other formats logger.warning("[aquery_data] No data section found in query result")
entities_count = len(final_data.get("entities", []))
relationships_count = len(final_data.get("relationships", []))
chunks_count = len(final_data.get("chunks", []))
logger.debug(
f"[aquery_data] Final result: {entities_count} entities, {relationships_count} relationships, {chunks_count} chunks"
)
await self._query_done() await self._query_done()
return final_data return final_data

View file

@ -39,6 +39,8 @@ from .base import (
BaseVectorStorage, BaseVectorStorage,
TextChunkSchema, TextChunkSchema,
QueryParam, QueryParam,
QueryResult,
QueryContextResult,
) )
from .prompt import PROMPTS from .prompt import PROMPTS
from .constants import ( from .constants import (
@ -2277,16 +2279,38 @@ async def kg_query(
hashing_kv: BaseKVStorage | None = None, hashing_kv: BaseKVStorage | None = None,
system_prompt: str | None = None, system_prompt: str | None = None,
chunks_vdb: BaseVectorStorage = None, chunks_vdb: BaseVectorStorage = None,
return_raw_data: bool = False, ) -> QueryResult:
) -> str | AsyncIterator[str] | dict[str, Any]: """
Execute knowledge graph query and return unified QueryResult object.
Args:
query: Query string
knowledge_graph_inst: Knowledge graph storage instance
entities_vdb: Entity vector database
relationships_vdb: Relationship vector database
text_chunks_db: Text chunks storage
query_param: Query parameters
global_config: Global configuration
hashing_kv: Cache storage
system_prompt: System prompt
chunks_vdb: Document chunks vector database
Returns:
QueryResult: Unified query result object containing:
- content: Non-streaming response text content
- response_iterator: Streaming response iterator
- raw_data: Complete structured data (including references and metadata)
- is_streaming: Whether this is a streaming result
Based on different query_param settings, different fields will be populated:
- only_need_context=True: content contains context string
- only_need_prompt=True: content contains complete prompt
- stream=True: response_iterator contains streaming response, raw_data contains complete data
- default: content contains LLM response text, raw_data contains complete data
"""
if not query: if not query:
if return_raw_data: return QueryResult(content=PROMPTS["fail_response"])
return {
"status": "failure",
"message": "Query string is empty.",
"data": {},
}
return PROMPTS["fail_response"]
if query_param.model_func: if query_param.model_func:
use_model_func = query_param.model_func use_model_func = query_param.model_func
@ -2315,12 +2339,11 @@ async def kg_query(
) )
if ( if (
cached_result is not None cached_result is not None
and not return_raw_data
and not query_param.only_need_context and not query_param.only_need_context
and not query_param.only_need_prompt and not query_param.only_need_prompt
): ):
cached_response, _ = cached_result # Extract content, ignore timestamp cached_response, _ = cached_result # Extract content, ignore timestamp
return cached_response return QueryResult(content=cached_response)
hl_keywords, ll_keywords = await get_keywords_from_query( hl_keywords, ll_keywords = await get_keywords_from_query(
query, query_param, global_config, hashing_kv query, query_param, global_config, hashing_kv
@ -2339,53 +2362,13 @@ async def kg_query(
logger.warning(f"Forced low_level_keywords to origin query: {query}") logger.warning(f"Forced low_level_keywords to origin query: {query}")
ll_keywords = [query] ll_keywords = [query]
else: else:
if return_raw_data: return QueryResult(content=PROMPTS["fail_response"])
return {
"status": "failure",
"message": "Both high_level_keywords and low_level_keywords are empty",
"data": {},
}
return PROMPTS["fail_response"]
ll_keywords_str = ", ".join(ll_keywords) if ll_keywords else "" ll_keywords_str = ", ".join(ll_keywords) if ll_keywords else ""
hl_keywords_str = ", ".join(hl_keywords) if hl_keywords else "" hl_keywords_str = ", ".join(hl_keywords) if hl_keywords else ""
# If raw data is requested, get both context and raw data # Build query context (unified interface)
if return_raw_data: context_result = await _build_query_context(
context_result = await _build_query_context(
query,
ll_keywords_str,
hl_keywords_str,
knowledge_graph_inst,
entities_vdb,
relationships_vdb,
text_chunks_db,
query_param,
chunks_vdb,
return_raw_data=True,
)
if isinstance(context_result, tuple):
context, raw_data = context_result
logger.debug(f"[kg_query] Context length: {len(context) if context else 0}")
logger.debug(
f"[kg_query] Raw data entities: {len(raw_data.get('entities', []))}, relationships: {len(raw_data.get('relationships', []))}, chunks: {len(raw_data.get('chunks', []))}"
)
return raw_data
else:
if not context_result:
return {
"status": "failure",
"message": "Query return empty data set.",
"data": {},
}
else:
raise ValueError(
"Fail to build raw data query result. Invalid return from _build_query_context"
)
# Build context (normal flow)
context = await _build_query_context(
query, query,
ll_keywords_str, ll_keywords_str,
hl_keywords_str, hl_keywords_str,
@ -2397,14 +2380,19 @@ async def kg_query(
chunks_vdb, chunks_vdb,
) )
if query_param.only_need_context and not query_param.only_need_prompt: if context_result is None:
return context if context is not None else PROMPTS["fail_response"] return QueryResult(content=PROMPTS["fail_response"])
if context is None:
return PROMPTS["fail_response"]
# Return different content based on query parameters
if query_param.only_need_context and not query_param.only_need_prompt:
return QueryResult(
content=context_result.context, raw_data=context_result.raw_data
)
# Build system prompt
sys_prompt_temp = system_prompt if system_prompt else PROMPTS["rag_response"] sys_prompt_temp = system_prompt if system_prompt else PROMPTS["rag_response"]
sys_prompt = sys_prompt_temp.format( sys_prompt = sys_prompt_temp.format(
context_data=context, context_data=context_result.context,
response_type=query_param.response_type, response_type=query_param.response_type,
) )
@ -2415,8 +2403,10 @@ async def kg_query(
) )
if query_param.only_need_prompt: if query_param.only_need_prompt:
return "\n\n".join([sys_prompt, "---User Query---", user_query]) prompt_content = "\n\n".join([sys_prompt, "---User Query---", user_query])
return QueryResult(content=prompt_content, raw_data=context_result.raw_data)
# Call LLM
tokenizer: Tokenizer = global_config["tokenizer"] tokenizer: Tokenizer = global_config["tokenizer"]
len_of_prompts = len(tokenizer.encode(query + sys_prompt)) len_of_prompts = len(tokenizer.encode(query + sys_prompt))
logger.debug( logger.debug(
@ -2430,45 +2420,56 @@ async def kg_query(
enable_cot=True, enable_cot=True,
stream=query_param.stream, stream=query_param.stream,
) )
if isinstance(response, str) and len(response) > len(sys_prompt):
response = (
response.replace(sys_prompt, "")
.replace("user", "")
.replace("model", "")
.replace(query, "")
.replace("<system>", "")
.replace("</system>", "")
.strip()
)
if hashing_kv.global_config.get("enable_llm_cache"): # Return unified result based on actual response type
# Save to cache with query parameters if isinstance(response, str):
queryparam_dict = { # Non-streaming response (string)
"mode": query_param.mode, if len(response) > len(sys_prompt):
"response_type": query_param.response_type, response = (
"top_k": query_param.top_k, response.replace(sys_prompt, "")
"chunk_top_k": query_param.chunk_top_k, .replace("user", "")
"max_entity_tokens": query_param.max_entity_tokens, .replace("model", "")
"max_relation_tokens": query_param.max_relation_tokens, .replace(query, "")
"max_total_tokens": query_param.max_total_tokens, .replace("<system>", "")
"hl_keywords": query_param.hl_keywords or [], .replace("</system>", "")
"ll_keywords": query_param.ll_keywords or [], .strip()
"user_prompt": query_param.user_prompt or "", )
"enable_rerank": query_param.enable_rerank,
}
await save_to_cache(
hashing_kv,
CacheData(
args_hash=args_hash,
content=response,
prompt=query,
mode=query_param.mode,
cache_type="query",
queryparam=queryparam_dict,
),
)
return response # Cache response
if hashing_kv and hashing_kv.global_config.get("enable_llm_cache"):
queryparam_dict = {
"mode": query_param.mode,
"response_type": query_param.response_type,
"top_k": query_param.top_k,
"chunk_top_k": query_param.chunk_top_k,
"max_entity_tokens": query_param.max_entity_tokens,
"max_relation_tokens": query_param.max_relation_tokens,
"max_total_tokens": query_param.max_total_tokens,
"hl_keywords": query_param.hl_keywords or [],
"ll_keywords": query_param.ll_keywords or [],
"user_prompt": query_param.user_prompt or "",
"enable_rerank": query_param.enable_rerank,
}
await save_to_cache(
hashing_kv,
CacheData(
args_hash=args_hash,
content=response,
prompt=query,
mode=query_param.mode,
cache_type="query",
queryparam=queryparam_dict,
),
)
return QueryResult(content=response, raw_data=context_result.raw_data)
else:
# Streaming response (AsyncIterator)
return QueryResult(
response_iterator=response,
raw_data=context_result.raw_data,
is_streaming=True,
)
async def get_keywords_from_query( async def get_keywords_from_query(
@ -3123,10 +3124,9 @@ async def _build_llm_context(
query_param: QueryParam, query_param: QueryParam,
global_config: dict[str, str], global_config: dict[str, str],
chunk_tracking: dict = None, chunk_tracking: dict = None,
return_raw_data: bool = False,
entity_id_to_original: dict = None, entity_id_to_original: dict = None,
relation_id_to_original: dict = None, relation_id_to_original: dict = None,
) -> str | tuple[str, dict[str, Any]]: ) -> tuple[str, dict[str, Any]]:
""" """
Build the final LLM context string with token processing. Build the final LLM context string with token processing.
This includes dynamic token calculation and final chunk truncation. This includes dynamic token calculation and final chunk truncation.
@ -3134,22 +3134,17 @@ async def _build_llm_context(
tokenizer = global_config.get("tokenizer") tokenizer = global_config.get("tokenizer")
if not tokenizer: if not tokenizer:
logger.error("Missing tokenizer, cannot build LLM context") logger.error("Missing tokenizer, cannot build LLM context")
# Return empty raw data structure when no tokenizer
if return_raw_data: empty_raw_data = convert_to_user_format(
# Return empty raw data structure when no entities/relations [],
empty_raw_data = convert_to_user_format( [],
[], [],
[], [],
[], query_param.mode,
[], )
query_param.mode, empty_raw_data["status"] = "failure"
) empty_raw_data["message"] = "Missing tokenizer, cannot build LLM context."
empty_raw_data["status"] = "failure" return "", empty_raw_data
empty_raw_data["message"] = "Missing tokenizer, cannot build LLM context."
return None, empty_raw_data
else:
logger.error("Tokenizer not found in global configuration.")
return None
# Get token limits # Get token limits
max_total_tokens = getattr( max_total_tokens = getattr(
@ -3268,20 +3263,17 @@ The reference documents list in Document Chunks(DC) is as follows (reference_id
# not necessary to use LLM to generate a response # not necessary to use LLM to generate a response
if not entities_context and not relations_context: if not entities_context and not relations_context:
if return_raw_data: # Return empty raw data structure when no entities/relations
# Return empty raw data structure when no entities/relations empty_raw_data = convert_to_user_format(
empty_raw_data = convert_to_user_format( [],
[], [],
[], [],
[], [],
[], query_param.mode,
query_param.mode, )
) empty_raw_data["status"] = "failure"
empty_raw_data["status"] = "failure" empty_raw_data["message"] = "Query returned empty dataset."
empty_raw_data["message"] = "Query returned empty dataset." return "", empty_raw_data
return None, empty_raw_data
else:
return None
# output chunks tracking infomations # output chunks tracking infomations
# format: <source><frequency>/<order> (e.g., E5/2 R2/1 C1/1) # format: <source><frequency>/<order> (e.g., E5/2 R2/1 C1/1)
@ -3342,26 +3334,23 @@ Document Chunks (DC) reference documents : (Each entry begins with [reference_id
""" """
# If final data is requested, return both context and complete data structure # Always return both context and complete data structure (unified approach)
if return_raw_data: logger.debug(
logger.debug( f"[_build_llm_context] Converting to user format: {len(entities_context)} entities, {len(relations_context)} relations, {len(truncated_chunks)} chunks"
f"[_build_llm_context] Converting to user format: {len(entities_context)} entities, {len(relations_context)} relations, {len(truncated_chunks)} chunks" )
) final_data = convert_to_user_format(
final_data = convert_to_user_format( entities_context,
entities_context, relations_context,
relations_context, truncated_chunks,
truncated_chunks, reference_list,
reference_list, query_param.mode,
query_param.mode, entity_id_to_original,
entity_id_to_original, relation_id_to_original,
relation_id_to_original, )
) logger.debug(
logger.debug( f"[_build_llm_context] Final data after conversion: {len(final_data.get('entities', []))} entities, {len(final_data.get('relationships', []))} relationships, {len(final_data.get('chunks', []))} chunks"
f"[_build_llm_context] Final data after conversion: {len(final_data.get('entities', []))} entities, {len(final_data.get('relationships', []))} relationships, {len(final_data.get('chunks', []))} chunks" )
) return result, final_data
return result, final_data
else:
return result
# Now let's update the old _build_query_context to use the new architecture # Now let's update the old _build_query_context to use the new architecture
@ -3375,16 +3364,17 @@ async def _build_query_context(
text_chunks_db: BaseKVStorage, text_chunks_db: BaseKVStorage,
query_param: QueryParam, query_param: QueryParam,
chunks_vdb: BaseVectorStorage = None, chunks_vdb: BaseVectorStorage = None,
return_raw_data: bool = False, ) -> QueryContextResult | None:
) -> str | None | tuple[str, dict[str, Any]]:
""" """
Main query context building function using the new 4-stage architecture: Main query context building function using the new 4-stage architecture:
1. Search -> 2. Truncate -> 3. Merge chunks -> 4. Build LLM context 1. Search -> 2. Truncate -> 3. Merge chunks -> 4. Build LLM context
Returns unified QueryContextResult containing both context and raw_data.
""" """
if not query: if not query:
logger.warning("Query is empty, skipping context building") logger.warning("Query is empty, skipping context building")
return "" return None
# Stage 1: Pure search # Stage 1: Pure search
search_result = await _perform_kg_search( search_result = await _perform_kg_search(
@ -3435,71 +3425,53 @@ async def _build_query_context(
return None return None
# Stage 4: Build final LLM context with dynamic token processing # Stage 4: Build final LLM context with dynamic token processing
# _build_llm_context now always returns tuple[str, dict]
context, raw_data = await _build_llm_context(
entities_context=truncation_result["entities_context"],
relations_context=truncation_result["relations_context"],
merged_chunks=merged_chunks,
query=query,
query_param=query_param,
global_config=text_chunks_db.global_config,
chunk_tracking=search_result["chunk_tracking"],
entity_id_to_original=truncation_result["entity_id_to_original"],
relation_id_to_original=truncation_result["relation_id_to_original"],
)
if return_raw_data: # Convert keywords strings to lists and add complete metadata to raw_data
# Convert keywords strings to lists hl_keywords_list = hl_keywords.split(", ") if hl_keywords else []
hl_keywords_list = hl_keywords.split(", ") if hl_keywords else [] ll_keywords_list = ll_keywords.split(", ") if ll_keywords else []
ll_keywords_list = ll_keywords.split(", ") if ll_keywords else []
# Get both context and final data - when return_raw_data=True, _build_llm_context always returns tuple # Add complete metadata to raw_data (preserve existing metadata including query_mode)
context, raw_data = await _build_llm_context( if "metadata" not in raw_data:
entities_context=truncation_result["entities_context"], raw_data["metadata"] = {}
relations_context=truncation_result["relations_context"],
merged_chunks=merged_chunks,
query=query,
query_param=query_param,
global_config=text_chunks_db.global_config,
chunk_tracking=search_result["chunk_tracking"],
return_raw_data=True,
entity_id_to_original=truncation_result["entity_id_to_original"],
relation_id_to_original=truncation_result["relation_id_to_original"],
)
# Convert keywords strings to lists and add complete metadata to raw_data # Update keywords while preserving existing metadata
hl_keywords_list = hl_keywords.split(", ") if hl_keywords else [] raw_data["metadata"]["keywords"] = {
ll_keywords_list = ll_keywords.split(", ") if ll_keywords else [] "high_level": hl_keywords_list,
"low_level": ll_keywords_list,
}
raw_data["metadata"]["processing_info"] = {
"total_entities_found": len(search_result.get("final_entities", [])),
"total_relations_found": len(search_result.get("final_relations", [])),
"entities_after_truncation": len(
truncation_result.get("filtered_entities", [])
),
"relations_after_truncation": len(
truncation_result.get("filtered_relations", [])
),
"merged_chunks_count": len(merged_chunks),
"final_chunks_count": len(raw_data.get("data", {}).get("chunks", [])),
}
# Add complete metadata to raw_data (preserve existing metadata including query_mode) logger.debug(
if "metadata" not in raw_data: f"[_build_query_context] Context length: {len(context) if context else 0}"
raw_data["metadata"] = {} )
logger.debug(
f"[_build_query_context] Raw data entities: {len(raw_data.get('data', {}).get('entities', []))}, relationships: {len(raw_data.get('data', {}).get('relationships', []))}, chunks: {len(raw_data.get('data', {}).get('chunks', []))}"
)
# Update keywords while preserving existing metadata return QueryContextResult(context=context, raw_data=raw_data)
raw_data["metadata"]["keywords"] = {
"high_level": hl_keywords_list,
"low_level": ll_keywords_list,
}
raw_data["metadata"]["processing_info"] = {
"total_entities_found": len(search_result.get("final_entities", [])),
"total_relations_found": len(search_result.get("final_relations", [])),
"entities_after_truncation": len(
truncation_result.get("filtered_entities", [])
),
"relations_after_truncation": len(
truncation_result.get("filtered_relations", [])
),
"merged_chunks_count": len(merged_chunks),
"final_chunks_count": len(raw_data.get("chunks", [])),
}
logger.debug(
f"[_build_query_context] Context length: {len(context) if context else 0}"
)
logger.debug(
f"[_build_query_context] Raw data entities: {len(raw_data.get('entities', []))}, relationships: {len(raw_data.get('relationships', []))}, chunks: {len(raw_data.get('chunks', []))}"
)
return context, raw_data
else:
# Normal context building (existing logic)
context = await _build_llm_context(
entities_context=truncation_result["entities_context"],
relations_context=truncation_result["relations_context"],
merged_chunks=merged_chunks,
query=query,
query_param=query_param,
global_config=text_chunks_db.global_config,
chunk_tracking=search_result["chunk_tracking"],
)
return context
async def _get_node_data( async def _get_node_data(
@ -4105,19 +4077,28 @@ async def naive_query(
global_config: dict[str, str], global_config: dict[str, str],
hashing_kv: BaseKVStorage | None = None, hashing_kv: BaseKVStorage | None = None,
system_prompt: str | None = None, system_prompt: str | None = None,
return_raw_data: bool = False, ) -> QueryResult:
) -> str | AsyncIterator[str] | dict[str, Any]: """
Execute naive query and return unified QueryResult object.
Args:
query: Query string
chunks_vdb: Document chunks vector database
query_param: Query parameters
global_config: Global configuration
hashing_kv: Cache storage
system_prompt: System prompt
Returns:
QueryResult: Unified query result object containing:
- content: Non-streaming response text content
- response_iterator: Streaming response iterator
- raw_data: Complete structured data (including references and metadata)
- is_streaming: Whether this is a streaming result
"""
if not query: if not query:
if return_raw_data: return QueryResult(content=PROMPTS["fail_response"])
# Return empty raw data structure when query is empty
empty_raw_data = {
"status": "failure",
"message": "Query string is empty.",
"data": {},
}
return empty_raw_data
else:
return PROMPTS["fail_response"]
if query_param.model_func: if query_param.model_func:
use_model_func = query_param.model_func use_model_func = query_param.model_func
@ -4147,41 +4128,28 @@ async def naive_query(
if cached_result is not None: if cached_result is not None:
cached_response, _ = cached_result # Extract content, ignore timestamp cached_response, _ = cached_result # Extract content, ignore timestamp
if not query_param.only_need_context and not query_param.only_need_prompt: if not query_param.only_need_context and not query_param.only_need_prompt:
return cached_response return QueryResult(content=cached_response)
tokenizer: Tokenizer = global_config["tokenizer"] tokenizer: Tokenizer = global_config["tokenizer"]
if not tokenizer: if not tokenizer:
if return_raw_data: logger.error("Tokenizer not found in global configuration.")
# Return empty raw data structure when tokenizer is missing return QueryResult(content=PROMPTS["fail_response"])
empty_raw_data = {
"status": "failure",
"message": "Tokenizer not found in global configuration.",
"data": {},
}
return empty_raw_data
else:
logger.error("Tokenizer not found in global configuration.")
return PROMPTS["fail_response"]
chunks = await _get_vector_context(query, chunks_vdb, query_param, None) chunks = await _get_vector_context(query, chunks_vdb, query_param, None)
if chunks is None or len(chunks) == 0: if chunks is None or len(chunks) == 0:
# If only raw data is requested, return it directly # Build empty raw data structure for naive mode
if return_raw_data: empty_raw_data = convert_to_user_format(
empty_raw_data = convert_to_user_format( [], # naive mode has no entities
[], # naive mode has no entities [], # naive mode has no relationships
[], # naive mode has no relationships [], # no chunks
[], # no chunks [], # no references
[], # no references "naive",
"naive", )
) empty_raw_data["message"] = "No relevant document chunks found."
empty_raw_data["message"] = "No relevant document chunks found." return QueryResult(content=PROMPTS["fail_response"], raw_data=empty_raw_data)
return empty_raw_data
else:
return PROMPTS["fail_response"]
# Calculate dynamic token limit for chunks # Calculate dynamic token limit for chunks
# Get token limits from query_param (with fallback to global_config)
max_total_tokens = getattr( max_total_tokens = getattr(
query_param, query_param,
"max_total_tokens", "max_total_tokens",
@ -4240,30 +4208,26 @@ async def naive_query(
logger.info(f"Final context: {len(processed_chunks_with_ref_ids)} chunks") logger.info(f"Final context: {len(processed_chunks_with_ref_ids)} chunks")
# If only raw data is requested, return it directly # Build raw data structure for naive mode using processed chunks with reference IDs
if return_raw_data: raw_data = convert_to_user_format(
# Build raw data structure for naive mode using processed chunks with reference IDs [], # naive mode has no entities
raw_data = convert_to_user_format( [], # naive mode has no relationships
[], # naive mode has no entities processed_chunks_with_ref_ids,
[], # naive mode has no relationships reference_list,
processed_chunks_with_ref_ids, "naive",
reference_list, )
"naive",
)
# Add complete metadata for naive mode # Add complete metadata for naive mode
if "metadata" not in raw_data: if "metadata" not in raw_data:
raw_data["metadata"] = {} raw_data["metadata"] = {}
raw_data["metadata"]["keywords"] = { raw_data["metadata"]["keywords"] = {
"high_level": [], # naive mode has no keyword extraction "high_level": [], # naive mode has no keyword extraction
"low_level": [], # naive mode has no keyword extraction "low_level": [], # naive mode has no keyword extraction
} }
raw_data["metadata"]["processing_info"] = { raw_data["metadata"]["processing_info"] = {
"total_chunks_found": len(chunks), "total_chunks_found": len(chunks),
"final_chunks_count": len(processed_chunks_with_ref_ids), "final_chunks_count": len(processed_chunks_with_ref_ids),
} }
return raw_data
# Build text_units_context from processed chunks with reference IDs # Build text_units_context from processed chunks with reference IDs
text_units_context = [] text_units_context = []
@ -4284,8 +4248,7 @@ async def naive_query(
if ref["reference_id"] if ref["reference_id"]
) )
if query_param.only_need_context and not query_param.only_need_prompt: context_content = f"""
return f"""
---Document Chunks(DC)--- ---Document Chunks(DC)---
```json ```json
@ -4297,6 +4260,10 @@ async def naive_query(
{reference_list_str} {reference_list_str}
""" """
if query_param.only_need_context and not query_param.only_need_prompt:
return QueryResult(content=context_content, raw_data=raw_data)
user_query = ( user_query = (
"\n\n".join([query, query_param.user_prompt]) "\n\n".join([query, query_param.user_prompt])
if query_param.user_prompt if query_param.user_prompt
@ -4310,7 +4277,8 @@ async def naive_query(
) )
if query_param.only_need_prompt: if query_param.only_need_prompt:
return "\n\n".join([sys_prompt, "---User Query---", user_query]) prompt_content = "\n\n".join([sys_prompt, "---User Query---", user_query])
return QueryResult(content=prompt_content, raw_data=raw_data)
len_of_prompts = len(tokenizer.encode(query + sys_prompt)) len_of_prompts = len(tokenizer.encode(query + sys_prompt))
logger.debug( logger.debug(
@ -4325,43 +4293,51 @@ async def naive_query(
stream=query_param.stream, stream=query_param.stream,
) )
if isinstance(response, str) and len(response) > len(sys_prompt): # Return unified result based on actual response type
response = ( if isinstance(response, str):
response[len(sys_prompt) :] # Non-streaming response (string)
.replace(sys_prompt, "") if len(response) > len(sys_prompt):
.replace("user", "") response = (
.replace("model", "") response[len(sys_prompt) :]
.replace(query, "") .replace(sys_prompt, "")
.replace("<system>", "") .replace("user", "")
.replace("</system>", "") .replace("model", "")
.strip() .replace(query, "")
) .replace("<system>", "")
.replace("</system>", "")
.strip()
)
if hashing_kv.global_config.get("enable_llm_cache"): # Cache response
# Save to cache with query parameters if hashing_kv and hashing_kv.global_config.get("enable_llm_cache"):
queryparam_dict = { queryparam_dict = {
"mode": query_param.mode, "mode": query_param.mode,
"response_type": query_param.response_type, "response_type": query_param.response_type,
"top_k": query_param.top_k, "top_k": query_param.top_k,
"chunk_top_k": query_param.chunk_top_k, "chunk_top_k": query_param.chunk_top_k,
"max_entity_tokens": query_param.max_entity_tokens, "max_entity_tokens": query_param.max_entity_tokens,
"max_relation_tokens": query_param.max_relation_tokens, "max_relation_tokens": query_param.max_relation_tokens,
"max_total_tokens": query_param.max_total_tokens, "max_total_tokens": query_param.max_total_tokens,
"hl_keywords": query_param.hl_keywords or [], "hl_keywords": query_param.hl_keywords or [],
"ll_keywords": query_param.ll_keywords or [], "ll_keywords": query_param.ll_keywords or [],
"user_prompt": query_param.user_prompt or "", "user_prompt": query_param.user_prompt or "",
"enable_rerank": query_param.enable_rerank, "enable_rerank": query_param.enable_rerank,
} }
await save_to_cache( await save_to_cache(
hashing_kv, hashing_kv,
CacheData( CacheData(
args_hash=args_hash, args_hash=args_hash,
content=response, content=response,
prompt=query, prompt=query,
mode=query_param.mode, mode=query_param.mode,
cache_type="query", cache_type="query",
queryparam=queryparam_dict, queryparam=queryparam_dict,
), ),
) )
return response return QueryResult(content=response, raw_data=raw_data)
else:
# Streaming response (AsyncIterator)
return QueryResult(
response_iterator=response, raw_data=raw_data, is_streaming=True
)

View file

@ -11,7 +11,8 @@ Updated to handle the new data format where:
import requests import requests
import time import time
from typing import Dict, Any import json
from typing import Dict, Any, List, Optional
# API configuration # API configuration
API_KEY = "your-secure-api-key-here-123" API_KEY = "your-secure-api-key-here-123"
@ -21,6 +22,456 @@ BASE_URL = "http://localhost:9621"
AUTH_HEADERS = {"Content-Type": "application/json", "X-API-Key": API_KEY} AUTH_HEADERS = {"Content-Type": "application/json", "X-API-Key": API_KEY}
def validate_references_format(references: List[Dict[str, Any]]) -> bool:
"""Validate the format of references list"""
if not isinstance(references, list):
print(f"❌ References should be a list, got {type(references)}")
return False
for i, ref in enumerate(references):
if not isinstance(ref, dict):
print(f"❌ Reference {i} should be a dict, got {type(ref)}")
return False
required_fields = ["reference_id", "file_path"]
for field in required_fields:
if field not in ref:
print(f"❌ Reference {i} missing required field: {field}")
return False
if not isinstance(ref[field], str):
print(
f"❌ Reference {i} field '{field}' should be string, got {type(ref[field])}"
)
return False
return True
def parse_streaming_response(
response_text: str,
) -> tuple[Optional[List[Dict]], List[str], List[str]]:
"""Parse streaming response and extract references, response chunks, and errors"""
references = None
response_chunks = []
errors = []
lines = response_text.strip().split("\n")
for line in lines:
line = line.strip()
if not line or line.startswith("data: "):
if line.startswith("data: "):
line = line[6:] # Remove 'data: ' prefix
if not line:
continue
try:
data = json.loads(line)
if "references" in data:
references = data["references"]
elif "response" in data:
response_chunks.append(data["response"])
elif "error" in data:
errors.append(data["error"])
except json.JSONDecodeError:
# Skip non-JSON lines (like SSE comments)
continue
return references, response_chunks, errors
def test_query_endpoint_references():
"""Test /query endpoint references functionality"""
print("\n" + "=" * 60)
print("Testing /query endpoint references functionality")
print("=" * 60)
query_text = "who authored LightRAG"
endpoint = f"{BASE_URL}/query"
# Test 1: References enabled (default)
print("\n🧪 Test 1: References enabled (default)")
print("-" * 40)
try:
response = requests.post(
endpoint,
json={"query": query_text, "mode": "mix", "include_references": True},
headers=AUTH_HEADERS,
timeout=30,
)
if response.status_code == 200:
data = response.json()
# Check response structure
if "response" not in data:
print("❌ Missing 'response' field")
return False
if "references" not in data:
print("❌ Missing 'references' field when include_references=True")
return False
references = data["references"]
if references is None:
print("❌ References should not be None when include_references=True")
return False
if not validate_references_format(references):
return False
print(f"✅ References enabled: Found {len(references)} references")
print(f" Response length: {len(data['response'])} characters")
# Display reference list
if references:
print(" 📚 Reference List:")
for i, ref in enumerate(references, 1):
ref_id = ref.get("reference_id", "Unknown")
file_path = ref.get("file_path", "Unknown")
print(f" {i}. ID: {ref_id} | File: {file_path}")
else:
print(f"❌ Request failed: {response.status_code}")
print(f" Error: {response.text}")
return False
except Exception as e:
print(f"❌ Test 1 failed: {str(e)}")
return False
# Test 2: References disabled
print("\n🧪 Test 2: References disabled")
print("-" * 40)
try:
response = requests.post(
endpoint,
json={"query": query_text, "mode": "mix", "include_references": False},
headers=AUTH_HEADERS,
timeout=30,
)
if response.status_code == 200:
data = response.json()
# Check response structure
if "response" not in data:
print("❌ Missing 'response' field")
return False
references = data.get("references")
if references is not None:
print("❌ References should be None when include_references=False")
return False
print("✅ References disabled: No references field present")
print(f" Response length: {len(data['response'])} characters")
else:
print(f"❌ Request failed: {response.status_code}")
print(f" Error: {response.text}")
return False
except Exception as e:
print(f"❌ Test 2 failed: {str(e)}")
return False
print("\n✅ /query endpoint references tests passed!")
return True
def test_query_stream_endpoint_references():
"""Test /query/stream endpoint references functionality"""
print("\n" + "=" * 60)
print("Testing /query/stream endpoint references functionality")
print("=" * 60)
query_text = "who authored LightRAG"
endpoint = f"{BASE_URL}/query/stream"
# Test 1: Streaming with references enabled
print("\n🧪 Test 1: Streaming with references enabled")
print("-" * 40)
try:
response = requests.post(
endpoint,
json={"query": query_text, "mode": "mix", "include_references": True},
headers=AUTH_HEADERS,
timeout=30,
stream=True,
)
if response.status_code == 200:
# Collect streaming response
full_response = ""
for chunk in response.iter_content(chunk_size=1024, decode_unicode=True):
if chunk:
# Ensure chunk is string type
if isinstance(chunk, bytes):
chunk = chunk.decode("utf-8")
full_response += chunk
# Parse streaming response
references, response_chunks, errors = parse_streaming_response(
full_response
)
if errors:
print(f"❌ Errors in streaming response: {errors}")
return False
if references is None:
print("❌ No references found in streaming response")
return False
if not validate_references_format(references):
return False
if not response_chunks:
print("❌ No response chunks found in streaming response")
return False
print(f"✅ Streaming with references: Found {len(references)} references")
print(f" Response chunks: {len(response_chunks)}")
print(
f" Total response length: {sum(len(chunk) for chunk in response_chunks)} characters"
)
# Display reference list
if references:
print(" 📚 Reference List:")
for i, ref in enumerate(references, 1):
ref_id = ref.get("reference_id", "Unknown")
file_path = ref.get("file_path", "Unknown")
print(f" {i}. ID: {ref_id} | File: {file_path}")
else:
print(f"❌ Request failed: {response.status_code}")
print(f" Error: {response.text}")
return False
except Exception as e:
print(f"❌ Test 1 failed: {str(e)}")
return False
# Test 2: Streaming with references disabled
print("\n🧪 Test 2: Streaming with references disabled")
print("-" * 40)
try:
response = requests.post(
endpoint,
json={"query": query_text, "mode": "mix", "include_references": False},
headers=AUTH_HEADERS,
timeout=30,
stream=True,
)
if response.status_code == 200:
# Collect streaming response
full_response = ""
for chunk in response.iter_content(chunk_size=1024, decode_unicode=True):
if chunk:
# Ensure chunk is string type
if isinstance(chunk, bytes):
chunk = chunk.decode("utf-8")
full_response += chunk
# Parse streaming response
references, response_chunks, errors = parse_streaming_response(
full_response
)
if errors:
print(f"❌ Errors in streaming response: {errors}")
return False
if references is not None:
print("❌ References should be None when include_references=False")
return False
if not response_chunks:
print("❌ No response chunks found in streaming response")
return False
print("✅ Streaming without references: No references present")
print(f" Response chunks: {len(response_chunks)}")
print(
f" Total response length: {sum(len(chunk) for chunk in response_chunks)} characters"
)
else:
print(f"❌ Request failed: {response.status_code}")
print(f" Error: {response.text}")
return False
except Exception as e:
print(f"❌ Test 2 failed: {str(e)}")
return False
print("\n✅ /query/stream endpoint references tests passed!")
return True
def test_references_consistency():
"""Test references consistency across all endpoints"""
print("\n" + "=" * 60)
print("Testing references consistency across endpoints")
print("=" * 60)
query_text = "who authored LightRAG"
query_params = {
"query": query_text,
"mode": "mix",
"top_k": 10,
"chunk_top_k": 8,
"include_references": True,
}
references_data = {}
# Test /query endpoint
print("\n🧪 Testing /query endpoint")
print("-" * 40)
try:
response = requests.post(
f"{BASE_URL}/query", json=query_params, headers=AUTH_HEADERS, timeout=30
)
if response.status_code == 200:
data = response.json()
references_data["query"] = data.get("references", [])
print(f"✅ /query: {len(references_data['query'])} references")
else:
print(f"❌ /query failed: {response.status_code}")
return False
except Exception as e:
print(f"❌ /query test failed: {str(e)}")
return False
# Test /query/stream endpoint
print("\n🧪 Testing /query/stream endpoint")
print("-" * 40)
try:
response = requests.post(
f"{BASE_URL}/query/stream",
json=query_params,
headers=AUTH_HEADERS,
timeout=30,
stream=True,
)
if response.status_code == 200:
full_response = ""
for chunk in response.iter_content(chunk_size=1024, decode_unicode=True):
if chunk:
# Ensure chunk is string type
if isinstance(chunk, bytes):
chunk = chunk.decode("utf-8")
full_response += chunk
references, _, errors = parse_streaming_response(full_response)
if errors:
print(f"❌ Errors: {errors}")
return False
references_data["stream"] = references or []
print(f"✅ /query/stream: {len(references_data['stream'])} references")
else:
print(f"❌ /query/stream failed: {response.status_code}")
return False
except Exception as e:
print(f"❌ /query/stream test failed: {str(e)}")
return False
# Test /query/data endpoint
print("\n🧪 Testing /query/data endpoint")
print("-" * 40)
try:
response = requests.post(
f"{BASE_URL}/query/data",
json=query_params,
headers=AUTH_HEADERS,
timeout=30,
)
if response.status_code == 200:
data = response.json()
query_data = data.get("data", {})
references_data["data"] = query_data.get("references", [])
print(f"✅ /query/data: {len(references_data['data'])} references")
else:
print(f"❌ /query/data failed: {response.status_code}")
return False
except Exception as e:
print(f"❌ /query/data test failed: {str(e)}")
return False
# Compare references consistency
print("\n🔍 Comparing references consistency")
print("-" * 40)
# Convert to sets of (reference_id, file_path) tuples for comparison
def refs_to_set(refs):
return set(
(ref.get("reference_id", ""), ref.get("file_path", "")) for ref in refs
)
query_refs = refs_to_set(references_data["query"])
stream_refs = refs_to_set(references_data["stream"])
data_refs = refs_to_set(references_data["data"])
# Check consistency
consistency_passed = True
if query_refs != stream_refs:
print("❌ References mismatch between /query and /query/stream")
print(f" /query only: {query_refs - stream_refs}")
print(f" /query/stream only: {stream_refs - query_refs}")
consistency_passed = False
if query_refs != data_refs:
print("❌ References mismatch between /query and /query/data")
print(f" /query only: {query_refs - data_refs}")
print(f" /query/data only: {data_refs - query_refs}")
consistency_passed = False
if stream_refs != data_refs:
print("❌ References mismatch between /query/stream and /query/data")
print(f" /query/stream only: {stream_refs - data_refs}")
print(f" /query/data only: {data_refs - stream_refs}")
consistency_passed = False
if consistency_passed:
print("✅ All endpoints return consistent references")
print(f" Common references count: {len(query_refs)}")
# Display common reference list
if query_refs:
print(" 📚 Common Reference List:")
for i, (ref_id, file_path) in enumerate(sorted(query_refs), 1):
print(f" {i}. ID: {ref_id} | File: {file_path}")
return consistency_passed
def test_aquery_data_endpoint(): def test_aquery_data_endpoint():
"""Test the /query/data endpoint""" """Test the /query/data endpoint"""
@ -239,15 +690,79 @@ def compare_with_regular_query():
print(f" Regular query error: {str(e)}") print(f" Regular query error: {str(e)}")
def run_all_reference_tests():
"""Run all reference-related tests"""
print("\n" + "🚀" * 20)
print("LightRAG References Test Suite")
print("🚀" * 20)
all_tests_passed = True
# Test 1: /query endpoint references
try:
if not test_query_endpoint_references():
all_tests_passed = False
except Exception as e:
print(f"❌ /query endpoint test failed with exception: {str(e)}")
all_tests_passed = False
# Test 2: /query/stream endpoint references
try:
if not test_query_stream_endpoint_references():
all_tests_passed = False
except Exception as e:
print(f"❌ /query/stream endpoint test failed with exception: {str(e)}")
all_tests_passed = False
# Test 3: References consistency across endpoints
try:
if not test_references_consistency():
all_tests_passed = False
except Exception as e:
print(f"❌ References consistency test failed with exception: {str(e)}")
all_tests_passed = False
# Final summary
print("\n" + "=" * 60)
print("TEST SUITE SUMMARY")
print("=" * 60)
if all_tests_passed:
print("🎉 ALL TESTS PASSED!")
print("✅ /query endpoint references functionality works correctly")
print("✅ /query/stream endpoint references functionality works correctly")
print("✅ References are consistent across all endpoints")
print("\n🔧 System is ready for production use with reference support!")
else:
print("❌ SOME TESTS FAILED!")
print("Please check the error messages above and fix the issues.")
print("\n🔧 System needs attention before production deployment.")
return all_tests_passed
if __name__ == "__main__": if __name__ == "__main__":
# Run main test import sys
test_aquery_data_endpoint()
# Run comparison test if len(sys.argv) > 1 and sys.argv[1] == "--references-only":
compare_with_regular_query() # Run only the new reference tests
success = run_all_reference_tests()
sys.exit(0 if success else 1)
else:
# Run original tests plus new reference tests
print("Running original aquery_data endpoint test...")
test_aquery_data_endpoint()
print("\n💡 Usage tips:") print("\nRunning comparison test...")
print("1. Ensure LightRAG API service is running") compare_with_regular_query()
print("2. Adjust base_url and authentication information as needed")
print("3. Modify query parameters to test different retrieval strategies") print("\nRunning new reference tests...")
print("4. Data query results can be used for further analysis and processing") run_all_reference_tests()
print("\n💡 Usage tips:")
print("1. Ensure LightRAG API service is running")
print("2. Adjust base_url and authentication information as needed")
print("3. Modify query parameters to test different retrieval strategies")
print("4. Data query results can be used for further analysis and processing")
print("5. Run with --references-only flag to test only reference functionality")