diff --git a/lightrag/api/__init__.py b/lightrag/api/__init__.py
index 3aec4bf9..671a7326 100644
--- a/lightrag/api/__init__.py
+++ b/lightrag/api/__init__.py
@@ -1 +1 @@
-__api_version__ = "0230"
+__api_version__ = "0231"
diff --git a/lightrag/api/routers/query_routes.py b/lightrag/api/routers/query_routes.py
index c6e050db..83df2823 100644
--- a/lightrag/api/routers/query_routes.py
+++ b/lightrag/api/routers/query_routes.py
@@ -88,6 +88,11 @@ class QueryRequest(BaseModel):
description="Enable reranking for retrieved text chunks. If True but no rerank model is configured, a warning will be issued. Default is True.",
)
+ include_references: Optional[bool] = Field(
+ default=True,
+ description="If True, includes reference list in responses. Affects /query and /query/stream endpoints. /query/data always includes references.",
+ )
+
@field_validator("query", mode="after")
@classmethod
def query_strip_after(cls, query: str) -> str:
@@ -122,6 +127,10 @@ class QueryResponse(BaseModel):
response: str = Field(
description="The generated response",
)
+ references: Optional[List[Dict[str, str]]] = Field(
+ default=None,
+ description="Reference list (only included when include_references=True, /query/data always includes references.)",
+ )
class QueryDataResponse(BaseModel):
@@ -149,6 +158,7 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
request (QueryRequest): The request object containing the query parameters.
Returns:
QueryResponse: A Pydantic model containing the result of the query processing.
+ If include_references=True, also includes reference list.
If a string is returned (e.g., cache hit), it's directly returned.
Otherwise, an async generator may be used to build the response.
@@ -160,15 +170,26 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
param = request.to_query_params(False)
response = await rag.aquery(request.query, param=param)
- # If response is a string (e.g. cache hit), return directly
- if isinstance(response, str):
- return QueryResponse(response=response)
+ # Get reference list if requested
+ reference_list = None
+ if request.include_references:
+ try:
+ # Use aquery_data to get reference list independently
+ data_result = await rag.aquery_data(request.query, param=param)
+ if isinstance(data_result, dict) and "data" in data_result:
+ reference_list = data_result["data"].get("references", [])
+ except Exception as e:
+ logging.warning(f"Failed to get reference list: {str(e)}")
+ reference_list = []
- if isinstance(response, dict):
+ # Process response and return with optional references
+ if isinstance(response, str):
+ return QueryResponse(response=response, references=reference_list)
+ elif isinstance(response, dict):
result = json.dumps(response, indent=2)
- return QueryResponse(response=result)
+ return QueryResponse(response=result, references=reference_list)
else:
- return QueryResponse(response=str(response))
+ return QueryResponse(response=str(response), references=reference_list)
except Exception as e:
trace_exception(e)
raise HTTPException(status_code=500, detail=str(e))
@@ -178,12 +199,18 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
"""
This endpoint performs a retrieval-augmented generation (RAG) query and streams the response.
+ The streaming response includes:
+ 1. Reference list (sent first as a single message, if include_references=True)
+ 2. LLM response content (streamed as multiple chunks)
+
Args:
request (QueryRequest): The request object containing the query parameters.
- optional_api_key (Optional[str], optional): An optional API key for authentication. Defaults to None.
Returns:
- StreamingResponse: A streaming response containing the RAG query results.
+ StreamingResponse: A streaming response containing:
+ - First message: {"references": [...]} - Complete reference list (if requested)
+ - Subsequent messages: {"response": "..."} - LLM response chunks
+ - Error messages: {"error": "..."} - If any errors occur
"""
try:
param = request.to_query_params(True)
@@ -192,6 +219,28 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
from fastapi.responses import StreamingResponse
async def stream_generator():
+ # Get reference list if requested (default is True for backward compatibility)
+ reference_list = []
+ if request.include_references:
+ try:
+ # Use aquery_data to get reference list independently
+ data_param = request.to_query_params(
+ False
+ ) # Non-streaming for data
+ data_result = await rag.aquery_data(
+ request.query, param=data_param
+ )
+ if isinstance(data_result, dict) and "data" in data_result:
+ reference_list = data_result["data"].get("references", [])
+ except Exception as e:
+ logging.warning(f"Failed to get reference list: {str(e)}")
+ reference_list = []
+
+ # Send reference list first (if requested)
+ if request.include_references:
+ yield f"{json.dumps({'references': reference_list})}\n"
+
+ # Then stream the response content
if isinstance(response, str):
# If it's a string, send it all at once
yield f"{json.dumps({'response': response})}\n"
diff --git a/lightrag/base.py b/lightrag/base.py
index 0ee64949..a6420069 100644
--- a/lightrag/base.py
+++ b/lightrag/base.py
@@ -11,6 +11,10 @@ from typing import (
TypedDict,
TypeVar,
Callable,
+ Optional,
+ Dict,
+ List,
+ AsyncIterator,
)
from .utils import EmbeddingFunc
from .types import KnowledgeGraph
@@ -158,6 +162,12 @@ class QueryParam:
Default is True to enable reranking when rerank model is available.
"""
+ include_references: bool = False
+ """If True, includes reference list in the response for supported endpoints.
+ This parameter controls whether the API response includes a references field
+ containing citation information for the retrieved content.
+ """
+
@dataclass
class StorageNameSpace(ABC):
@@ -814,3 +824,68 @@ class DeletionResult:
message: str
status_code: int = 200
file_path: str | None = None
+
+
+# Unified Query Result Data Structures for Reference List Support
+
+
+@dataclass
+class QueryResult:
+ """
+ Unified query result data structure for all query modes.
+
+ Attributes:
+ content: Text content for non-streaming responses
+ response_iterator: Streaming response iterator for streaming responses
+ raw_data: Complete structured data including references and metadata
+ is_streaming: Whether this is a streaming result
+ """
+
+ content: Optional[str] = None
+ response_iterator: Optional[AsyncIterator[str]] = None
+ raw_data: Optional[Dict[str, Any]] = None
+ is_streaming: bool = False
+
+ @property
+ def reference_list(self) -> List[Dict[str, str]]:
+ """
+ Convenient property to extract reference list from raw_data.
+
+ Returns:
+ List[Dict[str, str]]: Reference list in format:
+ [{"reference_id": "1", "file_path": "/path/to/file.pdf"}, ...]
+ """
+ if self.raw_data:
+ return self.raw_data.get("data", {}).get("references", [])
+ return []
+
+ @property
+ def metadata(self) -> Dict[str, Any]:
+ """
+ Convenient property to extract metadata from raw_data.
+
+ Returns:
+ Dict[str, Any]: Query metadata including query_mode, keywords, etc.
+ """
+ if self.raw_data:
+ return self.raw_data.get("metadata", {})
+ return {}
+
+
+@dataclass
+class QueryContextResult:
+ """
+ Unified query context result data structure.
+
+ Attributes:
+ context: LLM context string
+ raw_data: Complete structured data including reference_list
+ """
+
+ context: str
+ raw_data: Dict[str, Any]
+
+ @property
+ def reference_list(self) -> List[Dict[str, str]]:
+ """Convenient property to extract reference list from raw_data."""
+ return self.raw_data.get("data", {}).get("references", [])
diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py
index 32dc89c9..afc0bc5f 100644
--- a/lightrag/lightrag.py
+++ b/lightrag/lightrag.py
@@ -71,6 +71,7 @@ from lightrag.base import (
StoragesStatus,
DeletionResult,
OllamaServerInfos,
+ QueryResult,
)
from lightrag.namespace import NameSpace
from lightrag.operate import (
@@ -2075,8 +2076,10 @@ class LightRAG:
# If a custom model is provided in param, temporarily update global config
global_config = asdict(self)
+ query_result = None
+
if param.mode in ["local", "global", "hybrid", "mix"]:
- response = await kg_query(
+ query_result = await kg_query(
query.strip(),
self.chunk_entity_relation_graph,
self.entities_vdb,
@@ -2089,7 +2092,7 @@ class LightRAG:
chunks_vdb=self.chunks_vdb,
)
elif param.mode == "naive":
- response = await naive_query(
+ query_result = await naive_query(
query.strip(),
self.chunks_vdb,
param,
@@ -2111,10 +2114,22 @@ class LightRAG:
enable_cot=True,
stream=param.stream,
)
+ # Create QueryResult for bypass mode
+ query_result = QueryResult(
+ content=response if not param.stream else None,
+ response_iterator=response if param.stream else None,
+ is_streaming=param.stream,
+ )
else:
raise ValueError(f"Unknown mode {param.mode}")
+
await self._query_done()
- return response
+
+ # Return appropriate response based on streaming mode
+ if query_result.is_streaming:
+ return query_result.response_iterator
+ else:
+ return query_result.content
async def aquery_data(
self,
@@ -2229,61 +2244,81 @@ class LightRAG:
"""
global_config = asdict(self)
- if param.mode in ["local", "global", "hybrid", "mix"]:
- logger.debug(f"[aquery_data] Using kg_query for mode: {param.mode}")
- final_data = await kg_query(
+ # Create a copy of param to avoid modifying the original
+ data_param = QueryParam(
+ mode=param.mode,
+ only_need_context=True, # Skip LLM generation, only get context and data
+ only_need_prompt=False,
+ response_type=param.response_type,
+ stream=False, # Data retrieval doesn't need streaming
+ top_k=param.top_k,
+ chunk_top_k=param.chunk_top_k,
+ max_entity_tokens=param.max_entity_tokens,
+ max_relation_tokens=param.max_relation_tokens,
+ max_total_tokens=param.max_total_tokens,
+ hl_keywords=param.hl_keywords,
+ ll_keywords=param.ll_keywords,
+ conversation_history=param.conversation_history,
+ history_turns=param.history_turns,
+ model_func=param.model_func,
+ user_prompt=param.user_prompt,
+ enable_rerank=param.enable_rerank,
+ )
+
+ query_result = None
+
+ if data_param.mode in ["local", "global", "hybrid", "mix"]:
+ logger.debug(f"[aquery_data] Using kg_query for mode: {data_param.mode}")
+ query_result = await kg_query(
query.strip(),
self.chunk_entity_relation_graph,
self.entities_vdb,
self.relationships_vdb,
self.text_chunks,
- param,
+ data_param, # Use data_param with only_need_context=True
global_config,
hashing_kv=self.llm_response_cache,
system_prompt=None,
chunks_vdb=self.chunks_vdb,
- return_raw_data=True, # Get final processed data
)
- elif param.mode == "naive":
- logger.debug(f"[aquery_data] Using naive_query for mode: {param.mode}")
- final_data = await naive_query(
+ elif data_param.mode == "naive":
+ logger.debug(f"[aquery_data] Using naive_query for mode: {data_param.mode}")
+ query_result = await naive_query(
query.strip(),
self.chunks_vdb,
- param,
+ data_param, # Use data_param with only_need_context=True
global_config,
hashing_kv=self.llm_response_cache,
system_prompt=None,
- return_raw_data=True, # Get final processed data
)
- elif param.mode == "bypass":
+ elif data_param.mode == "bypass":
logger.debug("[aquery_data] Using bypass mode")
# bypass mode returns empty data using convert_to_user_format
- final_data = convert_to_user_format(
+ empty_raw_data = convert_to_user_format(
[], # no entities
[], # no relationships
[], # no chunks
[], # no references
"bypass",
)
+ query_result = QueryResult(content="", raw_data=empty_raw_data)
else:
- raise ValueError(f"Unknown mode {param.mode}")
+ raise ValueError(f"Unknown mode {data_param.mode}")
+
+ # Extract raw_data from QueryResult
+ final_data = query_result.raw_data if query_result else {}
# Log final result counts - adapt to new data format from convert_to_user_format
- if isinstance(final_data, dict) and "data" in final_data:
- # New format: data is nested under 'data' field
+ if final_data and "data" in final_data:
data_section = final_data["data"]
entities_count = len(data_section.get("entities", []))
relationships_count = len(data_section.get("relationships", []))
chunks_count = len(data_section.get("chunks", []))
+ logger.debug(
+ f"[aquery_data] Final result: {entities_count} entities, {relationships_count} relationships, {chunks_count} chunks"
+ )
else:
- # Fallback for other formats
- entities_count = len(final_data.get("entities", []))
- relationships_count = len(final_data.get("relationships", []))
- chunks_count = len(final_data.get("chunks", []))
-
- logger.debug(
- f"[aquery_data] Final result: {entities_count} entities, {relationships_count} relationships, {chunks_count} chunks"
- )
+ logger.warning("[aquery_data] No data section found in query result")
await self._query_done()
return final_data
diff --git a/lightrag/operate.py b/lightrag/operate.py
index e65d3893..685e86a8 100644
--- a/lightrag/operate.py
+++ b/lightrag/operate.py
@@ -39,6 +39,8 @@ from .base import (
BaseVectorStorage,
TextChunkSchema,
QueryParam,
+ QueryResult,
+ QueryContextResult,
)
from .prompt import PROMPTS
from .constants import (
@@ -2277,16 +2279,38 @@ async def kg_query(
hashing_kv: BaseKVStorage | None = None,
system_prompt: str | None = None,
chunks_vdb: BaseVectorStorage = None,
- return_raw_data: bool = False,
-) -> str | AsyncIterator[str] | dict[str, Any]:
+) -> QueryResult:
+ """
+ Execute knowledge graph query and return unified QueryResult object.
+
+ Args:
+ query: Query string
+ knowledge_graph_inst: Knowledge graph storage instance
+ entities_vdb: Entity vector database
+ relationships_vdb: Relationship vector database
+ text_chunks_db: Text chunks storage
+ query_param: Query parameters
+ global_config: Global configuration
+ hashing_kv: Cache storage
+ system_prompt: System prompt
+ chunks_vdb: Document chunks vector database
+
+ Returns:
+ QueryResult: Unified query result object containing:
+ - content: Non-streaming response text content
+ - response_iterator: Streaming response iterator
+ - raw_data: Complete structured data (including references and metadata)
+ - is_streaming: Whether this is a streaming result
+
+ Based on different query_param settings, different fields will be populated:
+ - only_need_context=True: content contains context string
+ - only_need_prompt=True: content contains complete prompt
+ - stream=True: response_iterator contains streaming response, raw_data contains complete data
+ - default: content contains LLM response text, raw_data contains complete data
+ """
+
if not query:
- if return_raw_data:
- return {
- "status": "failure",
- "message": "Query string is empty.",
- "data": {},
- }
- return PROMPTS["fail_response"]
+ return QueryResult(content=PROMPTS["fail_response"])
if query_param.model_func:
use_model_func = query_param.model_func
@@ -2315,12 +2339,11 @@ async def kg_query(
)
if (
cached_result is not None
- and not return_raw_data
and not query_param.only_need_context
and not query_param.only_need_prompt
):
cached_response, _ = cached_result # Extract content, ignore timestamp
- return cached_response
+ return QueryResult(content=cached_response)
hl_keywords, ll_keywords = await get_keywords_from_query(
query, query_param, global_config, hashing_kv
@@ -2339,53 +2362,13 @@ async def kg_query(
logger.warning(f"Forced low_level_keywords to origin query: {query}")
ll_keywords = [query]
else:
- if return_raw_data:
- return {
- "status": "failure",
- "message": "Both high_level_keywords and low_level_keywords are empty",
- "data": {},
- }
- return PROMPTS["fail_response"]
+ return QueryResult(content=PROMPTS["fail_response"])
ll_keywords_str = ", ".join(ll_keywords) if ll_keywords else ""
hl_keywords_str = ", ".join(hl_keywords) if hl_keywords else ""
- # If raw data is requested, get both context and raw data
- if return_raw_data:
- context_result = await _build_query_context(
- query,
- ll_keywords_str,
- hl_keywords_str,
- knowledge_graph_inst,
- entities_vdb,
- relationships_vdb,
- text_chunks_db,
- query_param,
- chunks_vdb,
- return_raw_data=True,
- )
-
- if isinstance(context_result, tuple):
- context, raw_data = context_result
- logger.debug(f"[kg_query] Context length: {len(context) if context else 0}")
- logger.debug(
- f"[kg_query] Raw data entities: {len(raw_data.get('entities', []))}, relationships: {len(raw_data.get('relationships', []))}, chunks: {len(raw_data.get('chunks', []))}"
- )
- return raw_data
- else:
- if not context_result:
- return {
- "status": "failure",
- "message": "Query return empty data set.",
- "data": {},
- }
- else:
- raise ValueError(
- "Fail to build raw data query result. Invalid return from _build_query_context"
- )
-
- # Build context (normal flow)
- context = await _build_query_context(
+ # Build query context (unified interface)
+ context_result = await _build_query_context(
query,
ll_keywords_str,
hl_keywords_str,
@@ -2397,14 +2380,19 @@ async def kg_query(
chunks_vdb,
)
- if query_param.only_need_context and not query_param.only_need_prompt:
- return context if context is not None else PROMPTS["fail_response"]
- if context is None:
- return PROMPTS["fail_response"]
+ if context_result is None:
+ return QueryResult(content=PROMPTS["fail_response"])
+ # Return different content based on query parameters
+ if query_param.only_need_context and not query_param.only_need_prompt:
+ return QueryResult(
+ content=context_result.context, raw_data=context_result.raw_data
+ )
+
+ # Build system prompt
sys_prompt_temp = system_prompt if system_prompt else PROMPTS["rag_response"]
sys_prompt = sys_prompt_temp.format(
- context_data=context,
+ context_data=context_result.context,
response_type=query_param.response_type,
)
@@ -2415,8 +2403,10 @@ async def kg_query(
)
if query_param.only_need_prompt:
- return "\n\n".join([sys_prompt, "---User Query---", user_query])
+ prompt_content = "\n\n".join([sys_prompt, "---User Query---", user_query])
+ return QueryResult(content=prompt_content, raw_data=context_result.raw_data)
+ # Call LLM
tokenizer: Tokenizer = global_config["tokenizer"]
len_of_prompts = len(tokenizer.encode(query + sys_prompt))
logger.debug(
@@ -2430,45 +2420,56 @@ async def kg_query(
enable_cot=True,
stream=query_param.stream,
)
- if isinstance(response, str) and len(response) > len(sys_prompt):
- response = (
- response.replace(sys_prompt, "")
- .replace("user", "")
- .replace("model", "")
- .replace(query, "")
- .replace("", "")
- .replace("", "")
- .strip()
- )
- if hashing_kv.global_config.get("enable_llm_cache"):
- # Save to cache with query parameters
- queryparam_dict = {
- "mode": query_param.mode,
- "response_type": query_param.response_type,
- "top_k": query_param.top_k,
- "chunk_top_k": query_param.chunk_top_k,
- "max_entity_tokens": query_param.max_entity_tokens,
- "max_relation_tokens": query_param.max_relation_tokens,
- "max_total_tokens": query_param.max_total_tokens,
- "hl_keywords": query_param.hl_keywords or [],
- "ll_keywords": query_param.ll_keywords or [],
- "user_prompt": query_param.user_prompt or "",
- "enable_rerank": query_param.enable_rerank,
- }
- await save_to_cache(
- hashing_kv,
- CacheData(
- args_hash=args_hash,
- content=response,
- prompt=query,
- mode=query_param.mode,
- cache_type="query",
- queryparam=queryparam_dict,
- ),
- )
+ # Return unified result based on actual response type
+ if isinstance(response, str):
+ # Non-streaming response (string)
+ if len(response) > len(sys_prompt):
+ response = (
+ response.replace(sys_prompt, "")
+ .replace("user", "")
+ .replace("model", "")
+ .replace(query, "")
+ .replace("", "")
+ .replace("", "")
+ .strip()
+ )
- return response
+ # Cache response
+ if hashing_kv and hashing_kv.global_config.get("enable_llm_cache"):
+ queryparam_dict = {
+ "mode": query_param.mode,
+ "response_type": query_param.response_type,
+ "top_k": query_param.top_k,
+ "chunk_top_k": query_param.chunk_top_k,
+ "max_entity_tokens": query_param.max_entity_tokens,
+ "max_relation_tokens": query_param.max_relation_tokens,
+ "max_total_tokens": query_param.max_total_tokens,
+ "hl_keywords": query_param.hl_keywords or [],
+ "ll_keywords": query_param.ll_keywords or [],
+ "user_prompt": query_param.user_prompt or "",
+ "enable_rerank": query_param.enable_rerank,
+ }
+ await save_to_cache(
+ hashing_kv,
+ CacheData(
+ args_hash=args_hash,
+ content=response,
+ prompt=query,
+ mode=query_param.mode,
+ cache_type="query",
+ queryparam=queryparam_dict,
+ ),
+ )
+
+ return QueryResult(content=response, raw_data=context_result.raw_data)
+ else:
+ # Streaming response (AsyncIterator)
+ return QueryResult(
+ response_iterator=response,
+ raw_data=context_result.raw_data,
+ is_streaming=True,
+ )
async def get_keywords_from_query(
@@ -3123,10 +3124,9 @@ async def _build_llm_context(
query_param: QueryParam,
global_config: dict[str, str],
chunk_tracking: dict = None,
- return_raw_data: bool = False,
entity_id_to_original: dict = None,
relation_id_to_original: dict = None,
-) -> str | tuple[str, dict[str, Any]]:
+) -> tuple[str, dict[str, Any]]:
"""
Build the final LLM context string with token processing.
This includes dynamic token calculation and final chunk truncation.
@@ -3134,22 +3134,17 @@ async def _build_llm_context(
tokenizer = global_config.get("tokenizer")
if not tokenizer:
logger.error("Missing tokenizer, cannot build LLM context")
-
- if return_raw_data:
- # Return empty raw data structure when no entities/relations
- empty_raw_data = convert_to_user_format(
- [],
- [],
- [],
- [],
- query_param.mode,
- )
- empty_raw_data["status"] = "failure"
- empty_raw_data["message"] = "Missing tokenizer, cannot build LLM context."
- return None, empty_raw_data
- else:
- logger.error("Tokenizer not found in global configuration.")
- return None
+ # Return empty raw data structure when no tokenizer
+ empty_raw_data = convert_to_user_format(
+ [],
+ [],
+ [],
+ [],
+ query_param.mode,
+ )
+ empty_raw_data["status"] = "failure"
+ empty_raw_data["message"] = "Missing tokenizer, cannot build LLM context."
+ return "", empty_raw_data
# Get token limits
max_total_tokens = getattr(
@@ -3268,20 +3263,17 @@ The reference documents list in Document Chunks(DC) is as follows (reference_id
# not necessary to use LLM to generate a response
if not entities_context and not relations_context:
- if return_raw_data:
- # Return empty raw data structure when no entities/relations
- empty_raw_data = convert_to_user_format(
- [],
- [],
- [],
- [],
- query_param.mode,
- )
- empty_raw_data["status"] = "failure"
- empty_raw_data["message"] = "Query returned empty dataset."
- return None, empty_raw_data
- else:
- return None
+ # Return empty raw data structure when no entities/relations
+ empty_raw_data = convert_to_user_format(
+ [],
+ [],
+ [],
+ [],
+ query_param.mode,
+ )
+ empty_raw_data["status"] = "failure"
+ empty_raw_data["message"] = "Query returned empty dataset."
+ return "", empty_raw_data
# output chunks tracking infomations
# format: / (e.g., E5/2 R2/1 C1/1)
@@ -3342,26 +3334,23 @@ Document Chunks (DC) reference documents : (Each entry begins with [reference_id
"""
- # If final data is requested, return both context and complete data structure
- if return_raw_data:
- logger.debug(
- f"[_build_llm_context] Converting to user format: {len(entities_context)} entities, {len(relations_context)} relations, {len(truncated_chunks)} chunks"
- )
- final_data = convert_to_user_format(
- entities_context,
- relations_context,
- truncated_chunks,
- reference_list,
- query_param.mode,
- entity_id_to_original,
- relation_id_to_original,
- )
- logger.debug(
- f"[_build_llm_context] Final data after conversion: {len(final_data.get('entities', []))} entities, {len(final_data.get('relationships', []))} relationships, {len(final_data.get('chunks', []))} chunks"
- )
- return result, final_data
- else:
- return result
+ # Always return both context and complete data structure (unified approach)
+ logger.debug(
+ f"[_build_llm_context] Converting to user format: {len(entities_context)} entities, {len(relations_context)} relations, {len(truncated_chunks)} chunks"
+ )
+ final_data = convert_to_user_format(
+ entities_context,
+ relations_context,
+ truncated_chunks,
+ reference_list,
+ query_param.mode,
+ entity_id_to_original,
+ relation_id_to_original,
+ )
+ logger.debug(
+ f"[_build_llm_context] Final data after conversion: {len(final_data.get('entities', []))} entities, {len(final_data.get('relationships', []))} relationships, {len(final_data.get('chunks', []))} chunks"
+ )
+ return result, final_data
# Now let's update the old _build_query_context to use the new architecture
@@ -3375,16 +3364,17 @@ async def _build_query_context(
text_chunks_db: BaseKVStorage,
query_param: QueryParam,
chunks_vdb: BaseVectorStorage = None,
- return_raw_data: bool = False,
-) -> str | None | tuple[str, dict[str, Any]]:
+) -> QueryContextResult | None:
"""
Main query context building function using the new 4-stage architecture:
1. Search -> 2. Truncate -> 3. Merge chunks -> 4. Build LLM context
+
+ Returns unified QueryContextResult containing both context and raw_data.
"""
if not query:
logger.warning("Query is empty, skipping context building")
- return ""
+ return None
# Stage 1: Pure search
search_result = await _perform_kg_search(
@@ -3435,71 +3425,53 @@ async def _build_query_context(
return None
# Stage 4: Build final LLM context with dynamic token processing
+ # _build_llm_context now always returns tuple[str, dict]
+ context, raw_data = await _build_llm_context(
+ entities_context=truncation_result["entities_context"],
+ relations_context=truncation_result["relations_context"],
+ merged_chunks=merged_chunks,
+ query=query,
+ query_param=query_param,
+ global_config=text_chunks_db.global_config,
+ chunk_tracking=search_result["chunk_tracking"],
+ entity_id_to_original=truncation_result["entity_id_to_original"],
+ relation_id_to_original=truncation_result["relation_id_to_original"],
+ )
- if return_raw_data:
- # Convert keywords strings to lists
- hl_keywords_list = hl_keywords.split(", ") if hl_keywords else []
- ll_keywords_list = ll_keywords.split(", ") if ll_keywords else []
+ # Convert keywords strings to lists and add complete metadata to raw_data
+ hl_keywords_list = hl_keywords.split(", ") if hl_keywords else []
+ ll_keywords_list = ll_keywords.split(", ") if ll_keywords else []
- # Get both context and final data - when return_raw_data=True, _build_llm_context always returns tuple
- context, raw_data = await _build_llm_context(
- entities_context=truncation_result["entities_context"],
- relations_context=truncation_result["relations_context"],
- merged_chunks=merged_chunks,
- query=query,
- query_param=query_param,
- global_config=text_chunks_db.global_config,
- chunk_tracking=search_result["chunk_tracking"],
- return_raw_data=True,
- entity_id_to_original=truncation_result["entity_id_to_original"],
- relation_id_to_original=truncation_result["relation_id_to_original"],
- )
+ # Add complete metadata to raw_data (preserve existing metadata including query_mode)
+ if "metadata" not in raw_data:
+ raw_data["metadata"] = {}
- # Convert keywords strings to lists and add complete metadata to raw_data
- hl_keywords_list = hl_keywords.split(", ") if hl_keywords else []
- ll_keywords_list = ll_keywords.split(", ") if ll_keywords else []
+ # Update keywords while preserving existing metadata
+ raw_data["metadata"]["keywords"] = {
+ "high_level": hl_keywords_list,
+ "low_level": ll_keywords_list,
+ }
+ raw_data["metadata"]["processing_info"] = {
+ "total_entities_found": len(search_result.get("final_entities", [])),
+ "total_relations_found": len(search_result.get("final_relations", [])),
+ "entities_after_truncation": len(
+ truncation_result.get("filtered_entities", [])
+ ),
+ "relations_after_truncation": len(
+ truncation_result.get("filtered_relations", [])
+ ),
+ "merged_chunks_count": len(merged_chunks),
+ "final_chunks_count": len(raw_data.get("data", {}).get("chunks", [])),
+ }
- # Add complete metadata to raw_data (preserve existing metadata including query_mode)
- if "metadata" not in raw_data:
- raw_data["metadata"] = {}
+ logger.debug(
+ f"[_build_query_context] Context length: {len(context) if context else 0}"
+ )
+ logger.debug(
+ f"[_build_query_context] Raw data entities: {len(raw_data.get('data', {}).get('entities', []))}, relationships: {len(raw_data.get('data', {}).get('relationships', []))}, chunks: {len(raw_data.get('data', {}).get('chunks', []))}"
+ )
- # Update keywords while preserving existing metadata
- raw_data["metadata"]["keywords"] = {
- "high_level": hl_keywords_list,
- "low_level": ll_keywords_list,
- }
- raw_data["metadata"]["processing_info"] = {
- "total_entities_found": len(search_result.get("final_entities", [])),
- "total_relations_found": len(search_result.get("final_relations", [])),
- "entities_after_truncation": len(
- truncation_result.get("filtered_entities", [])
- ),
- "relations_after_truncation": len(
- truncation_result.get("filtered_relations", [])
- ),
- "merged_chunks_count": len(merged_chunks),
- "final_chunks_count": len(raw_data.get("chunks", [])),
- }
-
- logger.debug(
- f"[_build_query_context] Context length: {len(context) if context else 0}"
- )
- logger.debug(
- f"[_build_query_context] Raw data entities: {len(raw_data.get('entities', []))}, relationships: {len(raw_data.get('relationships', []))}, chunks: {len(raw_data.get('chunks', []))}"
- )
- return context, raw_data
- else:
- # Normal context building (existing logic)
- context = await _build_llm_context(
- entities_context=truncation_result["entities_context"],
- relations_context=truncation_result["relations_context"],
- merged_chunks=merged_chunks,
- query=query,
- query_param=query_param,
- global_config=text_chunks_db.global_config,
- chunk_tracking=search_result["chunk_tracking"],
- )
- return context
+ return QueryContextResult(context=context, raw_data=raw_data)
async def _get_node_data(
@@ -4105,19 +4077,28 @@ async def naive_query(
global_config: dict[str, str],
hashing_kv: BaseKVStorage | None = None,
system_prompt: str | None = None,
- return_raw_data: bool = False,
-) -> str | AsyncIterator[str] | dict[str, Any]:
+) -> QueryResult:
+ """
+ Execute naive query and return unified QueryResult object.
+
+ Args:
+ query: Query string
+ chunks_vdb: Document chunks vector database
+ query_param: Query parameters
+ global_config: Global configuration
+ hashing_kv: Cache storage
+ system_prompt: System prompt
+
+ Returns:
+ QueryResult: Unified query result object containing:
+ - content: Non-streaming response text content
+ - response_iterator: Streaming response iterator
+ - raw_data: Complete structured data (including references and metadata)
+ - is_streaming: Whether this is a streaming result
+ """
+
if not query:
- if return_raw_data:
- # Return empty raw data structure when query is empty
- empty_raw_data = {
- "status": "failure",
- "message": "Query string is empty.",
- "data": {},
- }
- return empty_raw_data
- else:
- return PROMPTS["fail_response"]
+ return QueryResult(content=PROMPTS["fail_response"])
if query_param.model_func:
use_model_func = query_param.model_func
@@ -4147,41 +4128,28 @@ async def naive_query(
if cached_result is not None:
cached_response, _ = cached_result # Extract content, ignore timestamp
if not query_param.only_need_context and not query_param.only_need_prompt:
- return cached_response
+ return QueryResult(content=cached_response)
tokenizer: Tokenizer = global_config["tokenizer"]
if not tokenizer:
- if return_raw_data:
- # Return empty raw data structure when tokenizer is missing
- empty_raw_data = {
- "status": "failure",
- "message": "Tokenizer not found in global configuration.",
- "data": {},
- }
- return empty_raw_data
- else:
- logger.error("Tokenizer not found in global configuration.")
- return PROMPTS["fail_response"]
+ logger.error("Tokenizer not found in global configuration.")
+ return QueryResult(content=PROMPTS["fail_response"])
chunks = await _get_vector_context(query, chunks_vdb, query_param, None)
if chunks is None or len(chunks) == 0:
- # If only raw data is requested, return it directly
- if return_raw_data:
- empty_raw_data = convert_to_user_format(
- [], # naive mode has no entities
- [], # naive mode has no relationships
- [], # no chunks
- [], # no references
- "naive",
- )
- empty_raw_data["message"] = "No relevant document chunks found."
- return empty_raw_data
- else:
- return PROMPTS["fail_response"]
+ # Build empty raw data structure for naive mode
+ empty_raw_data = convert_to_user_format(
+ [], # naive mode has no entities
+ [], # naive mode has no relationships
+ [], # no chunks
+ [], # no references
+ "naive",
+ )
+ empty_raw_data["message"] = "No relevant document chunks found."
+ return QueryResult(content=PROMPTS["fail_response"], raw_data=empty_raw_data)
# Calculate dynamic token limit for chunks
- # Get token limits from query_param (with fallback to global_config)
max_total_tokens = getattr(
query_param,
"max_total_tokens",
@@ -4240,30 +4208,26 @@ async def naive_query(
logger.info(f"Final context: {len(processed_chunks_with_ref_ids)} chunks")
- # If only raw data is requested, return it directly
- if return_raw_data:
- # Build raw data structure for naive mode using processed chunks with reference IDs
- raw_data = convert_to_user_format(
- [], # naive mode has no entities
- [], # naive mode has no relationships
- processed_chunks_with_ref_ids,
- reference_list,
- "naive",
- )
+ # Build raw data structure for naive mode using processed chunks with reference IDs
+ raw_data = convert_to_user_format(
+ [], # naive mode has no entities
+ [], # naive mode has no relationships
+ processed_chunks_with_ref_ids,
+ reference_list,
+ "naive",
+ )
- # Add complete metadata for naive mode
- if "metadata" not in raw_data:
- raw_data["metadata"] = {}
- raw_data["metadata"]["keywords"] = {
- "high_level": [], # naive mode has no keyword extraction
- "low_level": [], # naive mode has no keyword extraction
- }
- raw_data["metadata"]["processing_info"] = {
- "total_chunks_found": len(chunks),
- "final_chunks_count": len(processed_chunks_with_ref_ids),
- }
-
- return raw_data
+ # Add complete metadata for naive mode
+ if "metadata" not in raw_data:
+ raw_data["metadata"] = {}
+ raw_data["metadata"]["keywords"] = {
+ "high_level": [], # naive mode has no keyword extraction
+ "low_level": [], # naive mode has no keyword extraction
+ }
+ raw_data["metadata"]["processing_info"] = {
+ "total_chunks_found": len(chunks),
+ "final_chunks_count": len(processed_chunks_with_ref_ids),
+ }
# Build text_units_context from processed chunks with reference IDs
text_units_context = []
@@ -4284,8 +4248,7 @@ async def naive_query(
if ref["reference_id"]
)
- if query_param.only_need_context and not query_param.only_need_prompt:
- return f"""
+ context_content = f"""
---Document Chunks(DC)---
```json
@@ -4297,6 +4260,10 @@ async def naive_query(
{reference_list_str}
"""
+
+ if query_param.only_need_context and not query_param.only_need_prompt:
+ return QueryResult(content=context_content, raw_data=raw_data)
+
user_query = (
"\n\n".join([query, query_param.user_prompt])
if query_param.user_prompt
@@ -4310,7 +4277,8 @@ async def naive_query(
)
if query_param.only_need_prompt:
- return "\n\n".join([sys_prompt, "---User Query---", user_query])
+ prompt_content = "\n\n".join([sys_prompt, "---User Query---", user_query])
+ return QueryResult(content=prompt_content, raw_data=raw_data)
len_of_prompts = len(tokenizer.encode(query + sys_prompt))
logger.debug(
@@ -4325,43 +4293,51 @@ async def naive_query(
stream=query_param.stream,
)
- if isinstance(response, str) and len(response) > len(sys_prompt):
- response = (
- response[len(sys_prompt) :]
- .replace(sys_prompt, "")
- .replace("user", "")
- .replace("model", "")
- .replace(query, "")
- .replace("", "")
- .replace("", "")
- .strip()
- )
+ # Return unified result based on actual response type
+ if isinstance(response, str):
+ # Non-streaming response (string)
+ if len(response) > len(sys_prompt):
+ response = (
+ response[len(sys_prompt) :]
+ .replace(sys_prompt, "")
+ .replace("user", "")
+ .replace("model", "")
+ .replace(query, "")
+ .replace("", "")
+ .replace("", "")
+ .strip()
+ )
- if hashing_kv.global_config.get("enable_llm_cache"):
- # Save to cache with query parameters
- queryparam_dict = {
- "mode": query_param.mode,
- "response_type": query_param.response_type,
- "top_k": query_param.top_k,
- "chunk_top_k": query_param.chunk_top_k,
- "max_entity_tokens": query_param.max_entity_tokens,
- "max_relation_tokens": query_param.max_relation_tokens,
- "max_total_tokens": query_param.max_total_tokens,
- "hl_keywords": query_param.hl_keywords or [],
- "ll_keywords": query_param.ll_keywords or [],
- "user_prompt": query_param.user_prompt or "",
- "enable_rerank": query_param.enable_rerank,
- }
- await save_to_cache(
- hashing_kv,
- CacheData(
- args_hash=args_hash,
- content=response,
- prompt=query,
- mode=query_param.mode,
- cache_type="query",
- queryparam=queryparam_dict,
- ),
- )
+ # Cache response
+ if hashing_kv and hashing_kv.global_config.get("enable_llm_cache"):
+ queryparam_dict = {
+ "mode": query_param.mode,
+ "response_type": query_param.response_type,
+ "top_k": query_param.top_k,
+ "chunk_top_k": query_param.chunk_top_k,
+ "max_entity_tokens": query_param.max_entity_tokens,
+ "max_relation_tokens": query_param.max_relation_tokens,
+ "max_total_tokens": query_param.max_total_tokens,
+ "hl_keywords": query_param.hl_keywords or [],
+ "ll_keywords": query_param.ll_keywords or [],
+ "user_prompt": query_param.user_prompt or "",
+ "enable_rerank": query_param.enable_rerank,
+ }
+ await save_to_cache(
+ hashing_kv,
+ CacheData(
+ args_hash=args_hash,
+ content=response,
+ prompt=query,
+ mode=query_param.mode,
+ cache_type="query",
+ queryparam=queryparam_dict,
+ ),
+ )
- return response
+ return QueryResult(content=response, raw_data=raw_data)
+ else:
+ # Streaming response (AsyncIterator)
+ return QueryResult(
+ response_iterator=response, raw_data=raw_data, is_streaming=True
+ )
diff --git a/tests/test_aquery_data_endpoint.py b/tests/test_aquery_data_endpoint.py
index 5c629f5e..3e89be6f 100644
--- a/tests/test_aquery_data_endpoint.py
+++ b/tests/test_aquery_data_endpoint.py
@@ -11,7 +11,8 @@ Updated to handle the new data format where:
import requests
import time
-from typing import Dict, Any
+import json
+from typing import Dict, Any, List, Optional
# API configuration
API_KEY = "your-secure-api-key-here-123"
@@ -21,6 +22,456 @@ BASE_URL = "http://localhost:9621"
AUTH_HEADERS = {"Content-Type": "application/json", "X-API-Key": API_KEY}
+def validate_references_format(references: List[Dict[str, Any]]) -> bool:
+ """Validate the format of references list"""
+ if not isinstance(references, list):
+ print(f"โ References should be a list, got {type(references)}")
+ return False
+
+ for i, ref in enumerate(references):
+ if not isinstance(ref, dict):
+ print(f"โ Reference {i} should be a dict, got {type(ref)}")
+ return False
+
+ required_fields = ["reference_id", "file_path"]
+ for field in required_fields:
+ if field not in ref:
+ print(f"โ Reference {i} missing required field: {field}")
+ return False
+
+ if not isinstance(ref[field], str):
+ print(
+ f"โ Reference {i} field '{field}' should be string, got {type(ref[field])}"
+ )
+ return False
+
+ return True
+
+
+def parse_streaming_response(
+ response_text: str,
+) -> tuple[Optional[List[Dict]], List[str], List[str]]:
+ """Parse streaming response and extract references, response chunks, and errors"""
+ references = None
+ response_chunks = []
+ errors = []
+
+ lines = response_text.strip().split("\n")
+
+ for line in lines:
+ line = line.strip()
+ if not line or line.startswith("data: "):
+ if line.startswith("data: "):
+ line = line[6:] # Remove 'data: ' prefix
+
+ if not line:
+ continue
+
+ try:
+ data = json.loads(line)
+
+ if "references" in data:
+ references = data["references"]
+ elif "response" in data:
+ response_chunks.append(data["response"])
+ elif "error" in data:
+ errors.append(data["error"])
+
+ except json.JSONDecodeError:
+ # Skip non-JSON lines (like SSE comments)
+ continue
+
+ return references, response_chunks, errors
+
+
+def test_query_endpoint_references():
+ """Test /query endpoint references functionality"""
+
+ print("\n" + "=" * 60)
+ print("Testing /query endpoint references functionality")
+ print("=" * 60)
+
+ query_text = "who authored LightRAG"
+ endpoint = f"{BASE_URL}/query"
+
+ # Test 1: References enabled (default)
+ print("\n๐งช Test 1: References enabled (default)")
+ print("-" * 40)
+
+ try:
+ response = requests.post(
+ endpoint,
+ json={"query": query_text, "mode": "mix", "include_references": True},
+ headers=AUTH_HEADERS,
+ timeout=30,
+ )
+
+ if response.status_code == 200:
+ data = response.json()
+
+ # Check response structure
+ if "response" not in data:
+ print("โ Missing 'response' field")
+ return False
+
+ if "references" not in data:
+ print("โ Missing 'references' field when include_references=True")
+ return False
+
+ references = data["references"]
+ if references is None:
+ print("โ References should not be None when include_references=True")
+ return False
+
+ if not validate_references_format(references):
+ return False
+
+ print(f"โ
References enabled: Found {len(references)} references")
+ print(f" Response length: {len(data['response'])} characters")
+
+ # Display reference list
+ if references:
+ print(" ๐ Reference List:")
+ for i, ref in enumerate(references, 1):
+ ref_id = ref.get("reference_id", "Unknown")
+ file_path = ref.get("file_path", "Unknown")
+ print(f" {i}. ID: {ref_id} | File: {file_path}")
+
+ else:
+ print(f"โ Request failed: {response.status_code}")
+ print(f" Error: {response.text}")
+ return False
+
+ except Exception as e:
+ print(f"โ Test 1 failed: {str(e)}")
+ return False
+
+ # Test 2: References disabled
+ print("\n๐งช Test 2: References disabled")
+ print("-" * 40)
+
+ try:
+ response = requests.post(
+ endpoint,
+ json={"query": query_text, "mode": "mix", "include_references": False},
+ headers=AUTH_HEADERS,
+ timeout=30,
+ )
+
+ if response.status_code == 200:
+ data = response.json()
+
+ # Check response structure
+ if "response" not in data:
+ print("โ Missing 'response' field")
+ return False
+
+ references = data.get("references")
+ if references is not None:
+ print("โ References should be None when include_references=False")
+ return False
+
+ print("โ
References disabled: No references field present")
+ print(f" Response length: {len(data['response'])} characters")
+
+ else:
+ print(f"โ Request failed: {response.status_code}")
+ print(f" Error: {response.text}")
+ return False
+
+ except Exception as e:
+ print(f"โ Test 2 failed: {str(e)}")
+ return False
+
+ print("\nโ
/query endpoint references tests passed!")
+ return True
+
+
+def test_query_stream_endpoint_references():
+ """Test /query/stream endpoint references functionality"""
+
+ print("\n" + "=" * 60)
+ print("Testing /query/stream endpoint references functionality")
+ print("=" * 60)
+
+ query_text = "who authored LightRAG"
+ endpoint = f"{BASE_URL}/query/stream"
+
+ # Test 1: Streaming with references enabled
+ print("\n๐งช Test 1: Streaming with references enabled")
+ print("-" * 40)
+
+ try:
+ response = requests.post(
+ endpoint,
+ json={"query": query_text, "mode": "mix", "include_references": True},
+ headers=AUTH_HEADERS,
+ timeout=30,
+ stream=True,
+ )
+
+ if response.status_code == 200:
+ # Collect streaming response
+ full_response = ""
+ for chunk in response.iter_content(chunk_size=1024, decode_unicode=True):
+ if chunk:
+ # Ensure chunk is string type
+ if isinstance(chunk, bytes):
+ chunk = chunk.decode("utf-8")
+ full_response += chunk
+
+ # Parse streaming response
+ references, response_chunks, errors = parse_streaming_response(
+ full_response
+ )
+
+ if errors:
+ print(f"โ Errors in streaming response: {errors}")
+ return False
+
+ if references is None:
+ print("โ No references found in streaming response")
+ return False
+
+ if not validate_references_format(references):
+ return False
+
+ if not response_chunks:
+ print("โ No response chunks found in streaming response")
+ return False
+
+ print(f"โ
Streaming with references: Found {len(references)} references")
+ print(f" Response chunks: {len(response_chunks)}")
+ print(
+ f" Total response length: {sum(len(chunk) for chunk in response_chunks)} characters"
+ )
+
+ # Display reference list
+ if references:
+ print(" ๐ Reference List:")
+ for i, ref in enumerate(references, 1):
+ ref_id = ref.get("reference_id", "Unknown")
+ file_path = ref.get("file_path", "Unknown")
+ print(f" {i}. ID: {ref_id} | File: {file_path}")
+
+ else:
+ print(f"โ Request failed: {response.status_code}")
+ print(f" Error: {response.text}")
+ return False
+
+ except Exception as e:
+ print(f"โ Test 1 failed: {str(e)}")
+ return False
+
+ # Test 2: Streaming with references disabled
+ print("\n๐งช Test 2: Streaming with references disabled")
+ print("-" * 40)
+
+ try:
+ response = requests.post(
+ endpoint,
+ json={"query": query_text, "mode": "mix", "include_references": False},
+ headers=AUTH_HEADERS,
+ timeout=30,
+ stream=True,
+ )
+
+ if response.status_code == 200:
+ # Collect streaming response
+ full_response = ""
+ for chunk in response.iter_content(chunk_size=1024, decode_unicode=True):
+ if chunk:
+ # Ensure chunk is string type
+ if isinstance(chunk, bytes):
+ chunk = chunk.decode("utf-8")
+ full_response += chunk
+
+ # Parse streaming response
+ references, response_chunks, errors = parse_streaming_response(
+ full_response
+ )
+
+ if errors:
+ print(f"โ Errors in streaming response: {errors}")
+ return False
+
+ if references is not None:
+ print("โ References should be None when include_references=False")
+ return False
+
+ if not response_chunks:
+ print("โ No response chunks found in streaming response")
+ return False
+
+ print("โ
Streaming without references: No references present")
+ print(f" Response chunks: {len(response_chunks)}")
+ print(
+ f" Total response length: {sum(len(chunk) for chunk in response_chunks)} characters"
+ )
+
+ else:
+ print(f"โ Request failed: {response.status_code}")
+ print(f" Error: {response.text}")
+ return False
+
+ except Exception as e:
+ print(f"โ Test 2 failed: {str(e)}")
+ return False
+
+ print("\nโ
/query/stream endpoint references tests passed!")
+ return True
+
+
+def test_references_consistency():
+ """Test references consistency across all endpoints"""
+
+ print("\n" + "=" * 60)
+ print("Testing references consistency across endpoints")
+ print("=" * 60)
+
+ query_text = "who authored LightRAG"
+ query_params = {
+ "query": query_text,
+ "mode": "mix",
+ "top_k": 10,
+ "chunk_top_k": 8,
+ "include_references": True,
+ }
+
+ references_data = {}
+
+ # Test /query endpoint
+ print("\n๐งช Testing /query endpoint")
+ print("-" * 40)
+
+ try:
+ response = requests.post(
+ f"{BASE_URL}/query", json=query_params, headers=AUTH_HEADERS, timeout=30
+ )
+
+ if response.status_code == 200:
+ data = response.json()
+ references_data["query"] = data.get("references", [])
+ print(f"โ
/query: {len(references_data['query'])} references")
+ else:
+ print(f"โ /query failed: {response.status_code}")
+ return False
+
+ except Exception as e:
+ print(f"โ /query test failed: {str(e)}")
+ return False
+
+ # Test /query/stream endpoint
+ print("\n๐งช Testing /query/stream endpoint")
+ print("-" * 40)
+
+ try:
+ response = requests.post(
+ f"{BASE_URL}/query/stream",
+ json=query_params,
+ headers=AUTH_HEADERS,
+ timeout=30,
+ stream=True,
+ )
+
+ if response.status_code == 200:
+ full_response = ""
+ for chunk in response.iter_content(chunk_size=1024, decode_unicode=True):
+ if chunk:
+ # Ensure chunk is string type
+ if isinstance(chunk, bytes):
+ chunk = chunk.decode("utf-8")
+ full_response += chunk
+
+ references, _, errors = parse_streaming_response(full_response)
+
+ if errors:
+ print(f"โ Errors: {errors}")
+ return False
+
+ references_data["stream"] = references or []
+ print(f"โ
/query/stream: {len(references_data['stream'])} references")
+ else:
+ print(f"โ /query/stream failed: {response.status_code}")
+ return False
+
+ except Exception as e:
+ print(f"โ /query/stream test failed: {str(e)}")
+ return False
+
+ # Test /query/data endpoint
+ print("\n๐งช Testing /query/data endpoint")
+ print("-" * 40)
+
+ try:
+ response = requests.post(
+ f"{BASE_URL}/query/data",
+ json=query_params,
+ headers=AUTH_HEADERS,
+ timeout=30,
+ )
+
+ if response.status_code == 200:
+ data = response.json()
+ query_data = data.get("data", {})
+ references_data["data"] = query_data.get("references", [])
+ print(f"โ
/query/data: {len(references_data['data'])} references")
+ else:
+ print(f"โ /query/data failed: {response.status_code}")
+ return False
+
+ except Exception as e:
+ print(f"โ /query/data test failed: {str(e)}")
+ return False
+
+ # Compare references consistency
+ print("\n๐ Comparing references consistency")
+ print("-" * 40)
+
+ # Convert to sets of (reference_id, file_path) tuples for comparison
+ def refs_to_set(refs):
+ return set(
+ (ref.get("reference_id", ""), ref.get("file_path", "")) for ref in refs
+ )
+
+ query_refs = refs_to_set(references_data["query"])
+ stream_refs = refs_to_set(references_data["stream"])
+ data_refs = refs_to_set(references_data["data"])
+
+ # Check consistency
+ consistency_passed = True
+
+ if query_refs != stream_refs:
+ print("โ References mismatch between /query and /query/stream")
+ print(f" /query only: {query_refs - stream_refs}")
+ print(f" /query/stream only: {stream_refs - query_refs}")
+ consistency_passed = False
+
+ if query_refs != data_refs:
+ print("โ References mismatch between /query and /query/data")
+ print(f" /query only: {query_refs - data_refs}")
+ print(f" /query/data only: {data_refs - query_refs}")
+ consistency_passed = False
+
+ if stream_refs != data_refs:
+ print("โ References mismatch between /query/stream and /query/data")
+ print(f" /query/stream only: {stream_refs - data_refs}")
+ print(f" /query/data only: {data_refs - stream_refs}")
+ consistency_passed = False
+
+ if consistency_passed:
+ print("โ
All endpoints return consistent references")
+ print(f" Common references count: {len(query_refs)}")
+
+ # Display common reference list
+ if query_refs:
+ print(" ๐ Common Reference List:")
+ for i, (ref_id, file_path) in enumerate(sorted(query_refs), 1):
+ print(f" {i}. ID: {ref_id} | File: {file_path}")
+
+ return consistency_passed
+
+
def test_aquery_data_endpoint():
"""Test the /query/data endpoint"""
@@ -239,15 +690,79 @@ def compare_with_regular_query():
print(f" Regular query error: {str(e)}")
+def run_all_reference_tests():
+ """Run all reference-related tests"""
+
+ print("\n" + "๐" * 20)
+ print("LightRAG References Test Suite")
+ print("๐" * 20)
+
+ all_tests_passed = True
+
+ # Test 1: /query endpoint references
+ try:
+ if not test_query_endpoint_references():
+ all_tests_passed = False
+ except Exception as e:
+ print(f"โ /query endpoint test failed with exception: {str(e)}")
+ all_tests_passed = False
+
+ # Test 2: /query/stream endpoint references
+ try:
+ if not test_query_stream_endpoint_references():
+ all_tests_passed = False
+ except Exception as e:
+ print(f"โ /query/stream endpoint test failed with exception: {str(e)}")
+ all_tests_passed = False
+
+ # Test 3: References consistency across endpoints
+ try:
+ if not test_references_consistency():
+ all_tests_passed = False
+ except Exception as e:
+ print(f"โ References consistency test failed with exception: {str(e)}")
+ all_tests_passed = False
+
+ # Final summary
+ print("\n" + "=" * 60)
+ print("TEST SUITE SUMMARY")
+ print("=" * 60)
+
+ if all_tests_passed:
+ print("๐ ALL TESTS PASSED!")
+ print("โ
/query endpoint references functionality works correctly")
+ print("โ
/query/stream endpoint references functionality works correctly")
+ print("โ
References are consistent across all endpoints")
+ print("\n๐ง System is ready for production use with reference support!")
+ else:
+ print("โ SOME TESTS FAILED!")
+ print("Please check the error messages above and fix the issues.")
+ print("\n๐ง System needs attention before production deployment.")
+
+ return all_tests_passed
+
+
if __name__ == "__main__":
- # Run main test
- test_aquery_data_endpoint()
+ import sys
- # Run comparison test
- compare_with_regular_query()
+ if len(sys.argv) > 1 and sys.argv[1] == "--references-only":
+ # Run only the new reference tests
+ success = run_all_reference_tests()
+ sys.exit(0 if success else 1)
+ else:
+ # Run original tests plus new reference tests
+ print("Running original aquery_data endpoint test...")
+ test_aquery_data_endpoint()
- print("\n๐ก Usage tips:")
- print("1. Ensure LightRAG API service is running")
- print("2. Adjust base_url and authentication information as needed")
- print("3. Modify query parameters to test different retrieval strategies")
- print("4. Data query results can be used for further analysis and processing")
+ print("\nRunning comparison test...")
+ compare_with_regular_query()
+
+ print("\nRunning new reference tests...")
+ run_all_reference_tests()
+
+ print("\n๐ก Usage tips:")
+ print("1. Ensure LightRAG API service is running")
+ print("2. Adjust base_url and authentication information as needed")
+ print("3. Modify query parameters to test different retrieval strategies")
+ print("4. Data query results can be used for further analysis and processing")
+ print("5. Run with --references-only flag to test only reference functionality")