Merge pull request #2147 from danielaskdd/return-reference-on-query
Feature: Add Reference List Support for All Query Endpoints
This commit is contained in:
commit
b4cc249dca
6 changed files with 1009 additions and 359 deletions
|
|
@ -1 +1 @@
|
||||||
__api_version__ = "0230"
|
__api_version__ = "0231"
|
||||||
|
|
|
||||||
|
|
@ -88,6 +88,11 @@ class QueryRequest(BaseModel):
|
||||||
description="Enable reranking for retrieved text chunks. If True but no rerank model is configured, a warning will be issued. Default is True.",
|
description="Enable reranking for retrieved text chunks. If True but no rerank model is configured, a warning will be issued. Default is True.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
include_references: Optional[bool] = Field(
|
||||||
|
default=True,
|
||||||
|
description="If True, includes reference list in responses. Affects /query and /query/stream endpoints. /query/data always includes references.",
|
||||||
|
)
|
||||||
|
|
||||||
@field_validator("query", mode="after")
|
@field_validator("query", mode="after")
|
||||||
@classmethod
|
@classmethod
|
||||||
def query_strip_after(cls, query: str) -> str:
|
def query_strip_after(cls, query: str) -> str:
|
||||||
|
|
@ -122,6 +127,10 @@ class QueryResponse(BaseModel):
|
||||||
response: str = Field(
|
response: str = Field(
|
||||||
description="The generated response",
|
description="The generated response",
|
||||||
)
|
)
|
||||||
|
references: Optional[List[Dict[str, str]]] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Reference list (only included when include_references=True, /query/data always includes references.)",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class QueryDataResponse(BaseModel):
|
class QueryDataResponse(BaseModel):
|
||||||
|
|
@ -149,6 +158,7 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
|
||||||
request (QueryRequest): The request object containing the query parameters.
|
request (QueryRequest): The request object containing the query parameters.
|
||||||
Returns:
|
Returns:
|
||||||
QueryResponse: A Pydantic model containing the result of the query processing.
|
QueryResponse: A Pydantic model containing the result of the query processing.
|
||||||
|
If include_references=True, also includes reference list.
|
||||||
If a string is returned (e.g., cache hit), it's directly returned.
|
If a string is returned (e.g., cache hit), it's directly returned.
|
||||||
Otherwise, an async generator may be used to build the response.
|
Otherwise, an async generator may be used to build the response.
|
||||||
|
|
||||||
|
|
@ -160,15 +170,26 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
|
||||||
param = request.to_query_params(False)
|
param = request.to_query_params(False)
|
||||||
response = await rag.aquery(request.query, param=param)
|
response = await rag.aquery(request.query, param=param)
|
||||||
|
|
||||||
# If response is a string (e.g. cache hit), return directly
|
# Get reference list if requested
|
||||||
if isinstance(response, str):
|
reference_list = None
|
||||||
return QueryResponse(response=response)
|
if request.include_references:
|
||||||
|
try:
|
||||||
|
# Use aquery_data to get reference list independently
|
||||||
|
data_result = await rag.aquery_data(request.query, param=param)
|
||||||
|
if isinstance(data_result, dict) and "data" in data_result:
|
||||||
|
reference_list = data_result["data"].get("references", [])
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"Failed to get reference list: {str(e)}")
|
||||||
|
reference_list = []
|
||||||
|
|
||||||
if isinstance(response, dict):
|
# Process response and return with optional references
|
||||||
|
if isinstance(response, str):
|
||||||
|
return QueryResponse(response=response, references=reference_list)
|
||||||
|
elif isinstance(response, dict):
|
||||||
result = json.dumps(response, indent=2)
|
result = json.dumps(response, indent=2)
|
||||||
return QueryResponse(response=result)
|
return QueryResponse(response=result, references=reference_list)
|
||||||
else:
|
else:
|
||||||
return QueryResponse(response=str(response))
|
return QueryResponse(response=str(response), references=reference_list)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
trace_exception(e)
|
trace_exception(e)
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
@ -178,12 +199,18 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
|
||||||
"""
|
"""
|
||||||
This endpoint performs a retrieval-augmented generation (RAG) query and streams the response.
|
This endpoint performs a retrieval-augmented generation (RAG) query and streams the response.
|
||||||
|
|
||||||
|
The streaming response includes:
|
||||||
|
1. Reference list (sent first as a single message, if include_references=True)
|
||||||
|
2. LLM response content (streamed as multiple chunks)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request (QueryRequest): The request object containing the query parameters.
|
request (QueryRequest): The request object containing the query parameters.
|
||||||
optional_api_key (Optional[str], optional): An optional API key for authentication. Defaults to None.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
StreamingResponse: A streaming response containing the RAG query results.
|
StreamingResponse: A streaming response containing:
|
||||||
|
- First message: {"references": [...]} - Complete reference list (if requested)
|
||||||
|
- Subsequent messages: {"response": "..."} - LLM response chunks
|
||||||
|
- Error messages: {"error": "..."} - If any errors occur
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
param = request.to_query_params(True)
|
param = request.to_query_params(True)
|
||||||
|
|
@ -192,6 +219,28 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
|
|
||||||
async def stream_generator():
|
async def stream_generator():
|
||||||
|
# Get reference list if requested (default is True for backward compatibility)
|
||||||
|
reference_list = []
|
||||||
|
if request.include_references:
|
||||||
|
try:
|
||||||
|
# Use aquery_data to get reference list independently
|
||||||
|
data_param = request.to_query_params(
|
||||||
|
False
|
||||||
|
) # Non-streaming for data
|
||||||
|
data_result = await rag.aquery_data(
|
||||||
|
request.query, param=data_param
|
||||||
|
)
|
||||||
|
if isinstance(data_result, dict) and "data" in data_result:
|
||||||
|
reference_list = data_result["data"].get("references", [])
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"Failed to get reference list: {str(e)}")
|
||||||
|
reference_list = []
|
||||||
|
|
||||||
|
# Send reference list first (if requested)
|
||||||
|
if request.include_references:
|
||||||
|
yield f"{json.dumps({'references': reference_list})}\n"
|
||||||
|
|
||||||
|
# Then stream the response content
|
||||||
if isinstance(response, str):
|
if isinstance(response, str):
|
||||||
# If it's a string, send it all at once
|
# If it's a string, send it all at once
|
||||||
yield f"{json.dumps({'response': response})}\n"
|
yield f"{json.dumps({'response': response})}\n"
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,10 @@ from typing import (
|
||||||
TypedDict,
|
TypedDict,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Callable,
|
Callable,
|
||||||
|
Optional,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
AsyncIterator,
|
||||||
)
|
)
|
||||||
from .utils import EmbeddingFunc
|
from .utils import EmbeddingFunc
|
||||||
from .types import KnowledgeGraph
|
from .types import KnowledgeGraph
|
||||||
|
|
@ -158,6 +162,12 @@ class QueryParam:
|
||||||
Default is True to enable reranking when rerank model is available.
|
Default is True to enable reranking when rerank model is available.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
include_references: bool = False
|
||||||
|
"""If True, includes reference list in the response for supported endpoints.
|
||||||
|
This parameter controls whether the API response includes a references field
|
||||||
|
containing citation information for the retrieved content.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class StorageNameSpace(ABC):
|
class StorageNameSpace(ABC):
|
||||||
|
|
@ -814,3 +824,68 @@ class DeletionResult:
|
||||||
message: str
|
message: str
|
||||||
status_code: int = 200
|
status_code: int = 200
|
||||||
file_path: str | None = None
|
file_path: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
# Unified Query Result Data Structures for Reference List Support
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class QueryResult:
|
||||||
|
"""
|
||||||
|
Unified query result data structure for all query modes.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
content: Text content for non-streaming responses
|
||||||
|
response_iterator: Streaming response iterator for streaming responses
|
||||||
|
raw_data: Complete structured data including references and metadata
|
||||||
|
is_streaming: Whether this is a streaming result
|
||||||
|
"""
|
||||||
|
|
||||||
|
content: Optional[str] = None
|
||||||
|
response_iterator: Optional[AsyncIterator[str]] = None
|
||||||
|
raw_data: Optional[Dict[str, Any]] = None
|
||||||
|
is_streaming: bool = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def reference_list(self) -> List[Dict[str, str]]:
|
||||||
|
"""
|
||||||
|
Convenient property to extract reference list from raw_data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Dict[str, str]]: Reference list in format:
|
||||||
|
[{"reference_id": "1", "file_path": "/path/to/file.pdf"}, ...]
|
||||||
|
"""
|
||||||
|
if self.raw_data:
|
||||||
|
return self.raw_data.get("data", {}).get("references", [])
|
||||||
|
return []
|
||||||
|
|
||||||
|
@property
|
||||||
|
def metadata(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Convenient property to extract metadata from raw_data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, Any]: Query metadata including query_mode, keywords, etc.
|
||||||
|
"""
|
||||||
|
if self.raw_data:
|
||||||
|
return self.raw_data.get("metadata", {})
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class QueryContextResult:
|
||||||
|
"""
|
||||||
|
Unified query context result data structure.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
context: LLM context string
|
||||||
|
raw_data: Complete structured data including reference_list
|
||||||
|
"""
|
||||||
|
|
||||||
|
context: str
|
||||||
|
raw_data: Dict[str, Any]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def reference_list(self) -> List[Dict[str, str]]:
|
||||||
|
"""Convenient property to extract reference list from raw_data."""
|
||||||
|
return self.raw_data.get("data", {}).get("references", [])
|
||||||
|
|
|
||||||
|
|
@ -71,6 +71,7 @@ from lightrag.base import (
|
||||||
StoragesStatus,
|
StoragesStatus,
|
||||||
DeletionResult,
|
DeletionResult,
|
||||||
OllamaServerInfos,
|
OllamaServerInfos,
|
||||||
|
QueryResult,
|
||||||
)
|
)
|
||||||
from lightrag.namespace import NameSpace
|
from lightrag.namespace import NameSpace
|
||||||
from lightrag.operate import (
|
from lightrag.operate import (
|
||||||
|
|
@ -2075,8 +2076,10 @@ class LightRAG:
|
||||||
# If a custom model is provided in param, temporarily update global config
|
# If a custom model is provided in param, temporarily update global config
|
||||||
global_config = asdict(self)
|
global_config = asdict(self)
|
||||||
|
|
||||||
|
query_result = None
|
||||||
|
|
||||||
if param.mode in ["local", "global", "hybrid", "mix"]:
|
if param.mode in ["local", "global", "hybrid", "mix"]:
|
||||||
response = await kg_query(
|
query_result = await kg_query(
|
||||||
query.strip(),
|
query.strip(),
|
||||||
self.chunk_entity_relation_graph,
|
self.chunk_entity_relation_graph,
|
||||||
self.entities_vdb,
|
self.entities_vdb,
|
||||||
|
|
@ -2089,7 +2092,7 @@ class LightRAG:
|
||||||
chunks_vdb=self.chunks_vdb,
|
chunks_vdb=self.chunks_vdb,
|
||||||
)
|
)
|
||||||
elif param.mode == "naive":
|
elif param.mode == "naive":
|
||||||
response = await naive_query(
|
query_result = await naive_query(
|
||||||
query.strip(),
|
query.strip(),
|
||||||
self.chunks_vdb,
|
self.chunks_vdb,
|
||||||
param,
|
param,
|
||||||
|
|
@ -2111,10 +2114,22 @@ class LightRAG:
|
||||||
enable_cot=True,
|
enable_cot=True,
|
||||||
stream=param.stream,
|
stream=param.stream,
|
||||||
)
|
)
|
||||||
|
# Create QueryResult for bypass mode
|
||||||
|
query_result = QueryResult(
|
||||||
|
content=response if not param.stream else None,
|
||||||
|
response_iterator=response if param.stream else None,
|
||||||
|
is_streaming=param.stream,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown mode {param.mode}")
|
raise ValueError(f"Unknown mode {param.mode}")
|
||||||
|
|
||||||
await self._query_done()
|
await self._query_done()
|
||||||
return response
|
|
||||||
|
# Return appropriate response based on streaming mode
|
||||||
|
if query_result.is_streaming:
|
||||||
|
return query_result.response_iterator
|
||||||
|
else:
|
||||||
|
return query_result.content
|
||||||
|
|
||||||
async def aquery_data(
|
async def aquery_data(
|
||||||
self,
|
self,
|
||||||
|
|
@ -2229,61 +2244,81 @@ class LightRAG:
|
||||||
"""
|
"""
|
||||||
global_config = asdict(self)
|
global_config = asdict(self)
|
||||||
|
|
||||||
if param.mode in ["local", "global", "hybrid", "mix"]:
|
# Create a copy of param to avoid modifying the original
|
||||||
logger.debug(f"[aquery_data] Using kg_query for mode: {param.mode}")
|
data_param = QueryParam(
|
||||||
final_data = await kg_query(
|
mode=param.mode,
|
||||||
|
only_need_context=True, # Skip LLM generation, only get context and data
|
||||||
|
only_need_prompt=False,
|
||||||
|
response_type=param.response_type,
|
||||||
|
stream=False, # Data retrieval doesn't need streaming
|
||||||
|
top_k=param.top_k,
|
||||||
|
chunk_top_k=param.chunk_top_k,
|
||||||
|
max_entity_tokens=param.max_entity_tokens,
|
||||||
|
max_relation_tokens=param.max_relation_tokens,
|
||||||
|
max_total_tokens=param.max_total_tokens,
|
||||||
|
hl_keywords=param.hl_keywords,
|
||||||
|
ll_keywords=param.ll_keywords,
|
||||||
|
conversation_history=param.conversation_history,
|
||||||
|
history_turns=param.history_turns,
|
||||||
|
model_func=param.model_func,
|
||||||
|
user_prompt=param.user_prompt,
|
||||||
|
enable_rerank=param.enable_rerank,
|
||||||
|
)
|
||||||
|
|
||||||
|
query_result = None
|
||||||
|
|
||||||
|
if data_param.mode in ["local", "global", "hybrid", "mix"]:
|
||||||
|
logger.debug(f"[aquery_data] Using kg_query for mode: {data_param.mode}")
|
||||||
|
query_result = await kg_query(
|
||||||
query.strip(),
|
query.strip(),
|
||||||
self.chunk_entity_relation_graph,
|
self.chunk_entity_relation_graph,
|
||||||
self.entities_vdb,
|
self.entities_vdb,
|
||||||
self.relationships_vdb,
|
self.relationships_vdb,
|
||||||
self.text_chunks,
|
self.text_chunks,
|
||||||
param,
|
data_param, # Use data_param with only_need_context=True
|
||||||
global_config,
|
global_config,
|
||||||
hashing_kv=self.llm_response_cache,
|
hashing_kv=self.llm_response_cache,
|
||||||
system_prompt=None,
|
system_prompt=None,
|
||||||
chunks_vdb=self.chunks_vdb,
|
chunks_vdb=self.chunks_vdb,
|
||||||
return_raw_data=True, # Get final processed data
|
|
||||||
)
|
)
|
||||||
elif param.mode == "naive":
|
elif data_param.mode == "naive":
|
||||||
logger.debug(f"[aquery_data] Using naive_query for mode: {param.mode}")
|
logger.debug(f"[aquery_data] Using naive_query for mode: {data_param.mode}")
|
||||||
final_data = await naive_query(
|
query_result = await naive_query(
|
||||||
query.strip(),
|
query.strip(),
|
||||||
self.chunks_vdb,
|
self.chunks_vdb,
|
||||||
param,
|
data_param, # Use data_param with only_need_context=True
|
||||||
global_config,
|
global_config,
|
||||||
hashing_kv=self.llm_response_cache,
|
hashing_kv=self.llm_response_cache,
|
||||||
system_prompt=None,
|
system_prompt=None,
|
||||||
return_raw_data=True, # Get final processed data
|
|
||||||
)
|
)
|
||||||
elif param.mode == "bypass":
|
elif data_param.mode == "bypass":
|
||||||
logger.debug("[aquery_data] Using bypass mode")
|
logger.debug("[aquery_data] Using bypass mode")
|
||||||
# bypass mode returns empty data using convert_to_user_format
|
# bypass mode returns empty data using convert_to_user_format
|
||||||
final_data = convert_to_user_format(
|
empty_raw_data = convert_to_user_format(
|
||||||
[], # no entities
|
[], # no entities
|
||||||
[], # no relationships
|
[], # no relationships
|
||||||
[], # no chunks
|
[], # no chunks
|
||||||
[], # no references
|
[], # no references
|
||||||
"bypass",
|
"bypass",
|
||||||
)
|
)
|
||||||
|
query_result = QueryResult(content="", raw_data=empty_raw_data)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown mode {param.mode}")
|
raise ValueError(f"Unknown mode {data_param.mode}")
|
||||||
|
|
||||||
|
# Extract raw_data from QueryResult
|
||||||
|
final_data = query_result.raw_data if query_result else {}
|
||||||
|
|
||||||
# Log final result counts - adapt to new data format from convert_to_user_format
|
# Log final result counts - adapt to new data format from convert_to_user_format
|
||||||
if isinstance(final_data, dict) and "data" in final_data:
|
if final_data and "data" in final_data:
|
||||||
# New format: data is nested under 'data' field
|
|
||||||
data_section = final_data["data"]
|
data_section = final_data["data"]
|
||||||
entities_count = len(data_section.get("entities", []))
|
entities_count = len(data_section.get("entities", []))
|
||||||
relationships_count = len(data_section.get("relationships", []))
|
relationships_count = len(data_section.get("relationships", []))
|
||||||
chunks_count = len(data_section.get("chunks", []))
|
chunks_count = len(data_section.get("chunks", []))
|
||||||
|
logger.debug(
|
||||||
|
f"[aquery_data] Final result: {entities_count} entities, {relationships_count} relationships, {chunks_count} chunks"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# Fallback for other formats
|
logger.warning("[aquery_data] No data section found in query result")
|
||||||
entities_count = len(final_data.get("entities", []))
|
|
||||||
relationships_count = len(final_data.get("relationships", []))
|
|
||||||
chunks_count = len(final_data.get("chunks", []))
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
f"[aquery_data] Final result: {entities_count} entities, {relationships_count} relationships, {chunks_count} chunks"
|
|
||||||
)
|
|
||||||
|
|
||||||
await self._query_done()
|
await self._query_done()
|
||||||
return final_data
|
return final_data
|
||||||
|
|
|
||||||
|
|
@ -39,6 +39,8 @@ from .base import (
|
||||||
BaseVectorStorage,
|
BaseVectorStorage,
|
||||||
TextChunkSchema,
|
TextChunkSchema,
|
||||||
QueryParam,
|
QueryParam,
|
||||||
|
QueryResult,
|
||||||
|
QueryContextResult,
|
||||||
)
|
)
|
||||||
from .prompt import PROMPTS
|
from .prompt import PROMPTS
|
||||||
from .constants import (
|
from .constants import (
|
||||||
|
|
@ -2277,16 +2279,38 @@ async def kg_query(
|
||||||
hashing_kv: BaseKVStorage | None = None,
|
hashing_kv: BaseKVStorage | None = None,
|
||||||
system_prompt: str | None = None,
|
system_prompt: str | None = None,
|
||||||
chunks_vdb: BaseVectorStorage = None,
|
chunks_vdb: BaseVectorStorage = None,
|
||||||
return_raw_data: bool = False,
|
) -> QueryResult:
|
||||||
) -> str | AsyncIterator[str] | dict[str, Any]:
|
"""
|
||||||
|
Execute knowledge graph query and return unified QueryResult object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Query string
|
||||||
|
knowledge_graph_inst: Knowledge graph storage instance
|
||||||
|
entities_vdb: Entity vector database
|
||||||
|
relationships_vdb: Relationship vector database
|
||||||
|
text_chunks_db: Text chunks storage
|
||||||
|
query_param: Query parameters
|
||||||
|
global_config: Global configuration
|
||||||
|
hashing_kv: Cache storage
|
||||||
|
system_prompt: System prompt
|
||||||
|
chunks_vdb: Document chunks vector database
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
QueryResult: Unified query result object containing:
|
||||||
|
- content: Non-streaming response text content
|
||||||
|
- response_iterator: Streaming response iterator
|
||||||
|
- raw_data: Complete structured data (including references and metadata)
|
||||||
|
- is_streaming: Whether this is a streaming result
|
||||||
|
|
||||||
|
Based on different query_param settings, different fields will be populated:
|
||||||
|
- only_need_context=True: content contains context string
|
||||||
|
- only_need_prompt=True: content contains complete prompt
|
||||||
|
- stream=True: response_iterator contains streaming response, raw_data contains complete data
|
||||||
|
- default: content contains LLM response text, raw_data contains complete data
|
||||||
|
"""
|
||||||
|
|
||||||
if not query:
|
if not query:
|
||||||
if return_raw_data:
|
return QueryResult(content=PROMPTS["fail_response"])
|
||||||
return {
|
|
||||||
"status": "failure",
|
|
||||||
"message": "Query string is empty.",
|
|
||||||
"data": {},
|
|
||||||
}
|
|
||||||
return PROMPTS["fail_response"]
|
|
||||||
|
|
||||||
if query_param.model_func:
|
if query_param.model_func:
|
||||||
use_model_func = query_param.model_func
|
use_model_func = query_param.model_func
|
||||||
|
|
@ -2315,12 +2339,11 @@ async def kg_query(
|
||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
cached_result is not None
|
cached_result is not None
|
||||||
and not return_raw_data
|
|
||||||
and not query_param.only_need_context
|
and not query_param.only_need_context
|
||||||
and not query_param.only_need_prompt
|
and not query_param.only_need_prompt
|
||||||
):
|
):
|
||||||
cached_response, _ = cached_result # Extract content, ignore timestamp
|
cached_response, _ = cached_result # Extract content, ignore timestamp
|
||||||
return cached_response
|
return QueryResult(content=cached_response)
|
||||||
|
|
||||||
hl_keywords, ll_keywords = await get_keywords_from_query(
|
hl_keywords, ll_keywords = await get_keywords_from_query(
|
||||||
query, query_param, global_config, hashing_kv
|
query, query_param, global_config, hashing_kv
|
||||||
|
|
@ -2339,53 +2362,13 @@ async def kg_query(
|
||||||
logger.warning(f"Forced low_level_keywords to origin query: {query}")
|
logger.warning(f"Forced low_level_keywords to origin query: {query}")
|
||||||
ll_keywords = [query]
|
ll_keywords = [query]
|
||||||
else:
|
else:
|
||||||
if return_raw_data:
|
return QueryResult(content=PROMPTS["fail_response"])
|
||||||
return {
|
|
||||||
"status": "failure",
|
|
||||||
"message": "Both high_level_keywords and low_level_keywords are empty",
|
|
||||||
"data": {},
|
|
||||||
}
|
|
||||||
return PROMPTS["fail_response"]
|
|
||||||
|
|
||||||
ll_keywords_str = ", ".join(ll_keywords) if ll_keywords else ""
|
ll_keywords_str = ", ".join(ll_keywords) if ll_keywords else ""
|
||||||
hl_keywords_str = ", ".join(hl_keywords) if hl_keywords else ""
|
hl_keywords_str = ", ".join(hl_keywords) if hl_keywords else ""
|
||||||
|
|
||||||
# If raw data is requested, get both context and raw data
|
# Build query context (unified interface)
|
||||||
if return_raw_data:
|
context_result = await _build_query_context(
|
||||||
context_result = await _build_query_context(
|
|
||||||
query,
|
|
||||||
ll_keywords_str,
|
|
||||||
hl_keywords_str,
|
|
||||||
knowledge_graph_inst,
|
|
||||||
entities_vdb,
|
|
||||||
relationships_vdb,
|
|
||||||
text_chunks_db,
|
|
||||||
query_param,
|
|
||||||
chunks_vdb,
|
|
||||||
return_raw_data=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(context_result, tuple):
|
|
||||||
context, raw_data = context_result
|
|
||||||
logger.debug(f"[kg_query] Context length: {len(context) if context else 0}")
|
|
||||||
logger.debug(
|
|
||||||
f"[kg_query] Raw data entities: {len(raw_data.get('entities', []))}, relationships: {len(raw_data.get('relationships', []))}, chunks: {len(raw_data.get('chunks', []))}"
|
|
||||||
)
|
|
||||||
return raw_data
|
|
||||||
else:
|
|
||||||
if not context_result:
|
|
||||||
return {
|
|
||||||
"status": "failure",
|
|
||||||
"message": "Query return empty data set.",
|
|
||||||
"data": {},
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"Fail to build raw data query result. Invalid return from _build_query_context"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Build context (normal flow)
|
|
||||||
context = await _build_query_context(
|
|
||||||
query,
|
query,
|
||||||
ll_keywords_str,
|
ll_keywords_str,
|
||||||
hl_keywords_str,
|
hl_keywords_str,
|
||||||
|
|
@ -2397,14 +2380,19 @@ async def kg_query(
|
||||||
chunks_vdb,
|
chunks_vdb,
|
||||||
)
|
)
|
||||||
|
|
||||||
if query_param.only_need_context and not query_param.only_need_prompt:
|
if context_result is None:
|
||||||
return context if context is not None else PROMPTS["fail_response"]
|
return QueryResult(content=PROMPTS["fail_response"])
|
||||||
if context is None:
|
|
||||||
return PROMPTS["fail_response"]
|
|
||||||
|
|
||||||
|
# Return different content based on query parameters
|
||||||
|
if query_param.only_need_context and not query_param.only_need_prompt:
|
||||||
|
return QueryResult(
|
||||||
|
content=context_result.context, raw_data=context_result.raw_data
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build system prompt
|
||||||
sys_prompt_temp = system_prompt if system_prompt else PROMPTS["rag_response"]
|
sys_prompt_temp = system_prompt if system_prompt else PROMPTS["rag_response"]
|
||||||
sys_prompt = sys_prompt_temp.format(
|
sys_prompt = sys_prompt_temp.format(
|
||||||
context_data=context,
|
context_data=context_result.context,
|
||||||
response_type=query_param.response_type,
|
response_type=query_param.response_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -2415,8 +2403,10 @@ async def kg_query(
|
||||||
)
|
)
|
||||||
|
|
||||||
if query_param.only_need_prompt:
|
if query_param.only_need_prompt:
|
||||||
return "\n\n".join([sys_prompt, "---User Query---", user_query])
|
prompt_content = "\n\n".join([sys_prompt, "---User Query---", user_query])
|
||||||
|
return QueryResult(content=prompt_content, raw_data=context_result.raw_data)
|
||||||
|
|
||||||
|
# Call LLM
|
||||||
tokenizer: Tokenizer = global_config["tokenizer"]
|
tokenizer: Tokenizer = global_config["tokenizer"]
|
||||||
len_of_prompts = len(tokenizer.encode(query + sys_prompt))
|
len_of_prompts = len(tokenizer.encode(query + sys_prompt))
|
||||||
logger.debug(
|
logger.debug(
|
||||||
|
|
@ -2430,45 +2420,56 @@ async def kg_query(
|
||||||
enable_cot=True,
|
enable_cot=True,
|
||||||
stream=query_param.stream,
|
stream=query_param.stream,
|
||||||
)
|
)
|
||||||
if isinstance(response, str) and len(response) > len(sys_prompt):
|
|
||||||
response = (
|
|
||||||
response.replace(sys_prompt, "")
|
|
||||||
.replace("user", "")
|
|
||||||
.replace("model", "")
|
|
||||||
.replace(query, "")
|
|
||||||
.replace("<system>", "")
|
|
||||||
.replace("</system>", "")
|
|
||||||
.strip()
|
|
||||||
)
|
|
||||||
|
|
||||||
if hashing_kv.global_config.get("enable_llm_cache"):
|
# Return unified result based on actual response type
|
||||||
# Save to cache with query parameters
|
if isinstance(response, str):
|
||||||
queryparam_dict = {
|
# Non-streaming response (string)
|
||||||
"mode": query_param.mode,
|
if len(response) > len(sys_prompt):
|
||||||
"response_type": query_param.response_type,
|
response = (
|
||||||
"top_k": query_param.top_k,
|
response.replace(sys_prompt, "")
|
||||||
"chunk_top_k": query_param.chunk_top_k,
|
.replace("user", "")
|
||||||
"max_entity_tokens": query_param.max_entity_tokens,
|
.replace("model", "")
|
||||||
"max_relation_tokens": query_param.max_relation_tokens,
|
.replace(query, "")
|
||||||
"max_total_tokens": query_param.max_total_tokens,
|
.replace("<system>", "")
|
||||||
"hl_keywords": query_param.hl_keywords or [],
|
.replace("</system>", "")
|
||||||
"ll_keywords": query_param.ll_keywords or [],
|
.strip()
|
||||||
"user_prompt": query_param.user_prompt or "",
|
)
|
||||||
"enable_rerank": query_param.enable_rerank,
|
|
||||||
}
|
|
||||||
await save_to_cache(
|
|
||||||
hashing_kv,
|
|
||||||
CacheData(
|
|
||||||
args_hash=args_hash,
|
|
||||||
content=response,
|
|
||||||
prompt=query,
|
|
||||||
mode=query_param.mode,
|
|
||||||
cache_type="query",
|
|
||||||
queryparam=queryparam_dict,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
return response
|
# Cache response
|
||||||
|
if hashing_kv and hashing_kv.global_config.get("enable_llm_cache"):
|
||||||
|
queryparam_dict = {
|
||||||
|
"mode": query_param.mode,
|
||||||
|
"response_type": query_param.response_type,
|
||||||
|
"top_k": query_param.top_k,
|
||||||
|
"chunk_top_k": query_param.chunk_top_k,
|
||||||
|
"max_entity_tokens": query_param.max_entity_tokens,
|
||||||
|
"max_relation_tokens": query_param.max_relation_tokens,
|
||||||
|
"max_total_tokens": query_param.max_total_tokens,
|
||||||
|
"hl_keywords": query_param.hl_keywords or [],
|
||||||
|
"ll_keywords": query_param.ll_keywords or [],
|
||||||
|
"user_prompt": query_param.user_prompt or "",
|
||||||
|
"enable_rerank": query_param.enable_rerank,
|
||||||
|
}
|
||||||
|
await save_to_cache(
|
||||||
|
hashing_kv,
|
||||||
|
CacheData(
|
||||||
|
args_hash=args_hash,
|
||||||
|
content=response,
|
||||||
|
prompt=query,
|
||||||
|
mode=query_param.mode,
|
||||||
|
cache_type="query",
|
||||||
|
queryparam=queryparam_dict,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return QueryResult(content=response, raw_data=context_result.raw_data)
|
||||||
|
else:
|
||||||
|
# Streaming response (AsyncIterator)
|
||||||
|
return QueryResult(
|
||||||
|
response_iterator=response,
|
||||||
|
raw_data=context_result.raw_data,
|
||||||
|
is_streaming=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def get_keywords_from_query(
|
async def get_keywords_from_query(
|
||||||
|
|
@ -3123,10 +3124,9 @@ async def _build_llm_context(
|
||||||
query_param: QueryParam,
|
query_param: QueryParam,
|
||||||
global_config: dict[str, str],
|
global_config: dict[str, str],
|
||||||
chunk_tracking: dict = None,
|
chunk_tracking: dict = None,
|
||||||
return_raw_data: bool = False,
|
|
||||||
entity_id_to_original: dict = None,
|
entity_id_to_original: dict = None,
|
||||||
relation_id_to_original: dict = None,
|
relation_id_to_original: dict = None,
|
||||||
) -> str | tuple[str, dict[str, Any]]:
|
) -> tuple[str, dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Build the final LLM context string with token processing.
|
Build the final LLM context string with token processing.
|
||||||
This includes dynamic token calculation and final chunk truncation.
|
This includes dynamic token calculation and final chunk truncation.
|
||||||
|
|
@ -3134,22 +3134,17 @@ async def _build_llm_context(
|
||||||
tokenizer = global_config.get("tokenizer")
|
tokenizer = global_config.get("tokenizer")
|
||||||
if not tokenizer:
|
if not tokenizer:
|
||||||
logger.error("Missing tokenizer, cannot build LLM context")
|
logger.error("Missing tokenizer, cannot build LLM context")
|
||||||
|
# Return empty raw data structure when no tokenizer
|
||||||
if return_raw_data:
|
empty_raw_data = convert_to_user_format(
|
||||||
# Return empty raw data structure when no entities/relations
|
[],
|
||||||
empty_raw_data = convert_to_user_format(
|
[],
|
||||||
[],
|
[],
|
||||||
[],
|
[],
|
||||||
[],
|
query_param.mode,
|
||||||
[],
|
)
|
||||||
query_param.mode,
|
empty_raw_data["status"] = "failure"
|
||||||
)
|
empty_raw_data["message"] = "Missing tokenizer, cannot build LLM context."
|
||||||
empty_raw_data["status"] = "failure"
|
return "", empty_raw_data
|
||||||
empty_raw_data["message"] = "Missing tokenizer, cannot build LLM context."
|
|
||||||
return None, empty_raw_data
|
|
||||||
else:
|
|
||||||
logger.error("Tokenizer not found in global configuration.")
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Get token limits
|
# Get token limits
|
||||||
max_total_tokens = getattr(
|
max_total_tokens = getattr(
|
||||||
|
|
@ -3268,20 +3263,17 @@ The reference documents list in Document Chunks(DC) is as follows (reference_id
|
||||||
|
|
||||||
# not necessary to use LLM to generate a response
|
# not necessary to use LLM to generate a response
|
||||||
if not entities_context and not relations_context:
|
if not entities_context and not relations_context:
|
||||||
if return_raw_data:
|
# Return empty raw data structure when no entities/relations
|
||||||
# Return empty raw data structure when no entities/relations
|
empty_raw_data = convert_to_user_format(
|
||||||
empty_raw_data = convert_to_user_format(
|
[],
|
||||||
[],
|
[],
|
||||||
[],
|
[],
|
||||||
[],
|
[],
|
||||||
[],
|
query_param.mode,
|
||||||
query_param.mode,
|
)
|
||||||
)
|
empty_raw_data["status"] = "failure"
|
||||||
empty_raw_data["status"] = "failure"
|
empty_raw_data["message"] = "Query returned empty dataset."
|
||||||
empty_raw_data["message"] = "Query returned empty dataset."
|
return "", empty_raw_data
|
||||||
return None, empty_raw_data
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# output chunks tracking infomations
|
# output chunks tracking infomations
|
||||||
# format: <source><frequency>/<order> (e.g., E5/2 R2/1 C1/1)
|
# format: <source><frequency>/<order> (e.g., E5/2 R2/1 C1/1)
|
||||||
|
|
@ -3342,26 +3334,23 @@ Document Chunks (DC) reference documents : (Each entry begins with [reference_id
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# If final data is requested, return both context and complete data structure
|
# Always return both context and complete data structure (unified approach)
|
||||||
if return_raw_data:
|
logger.debug(
|
||||||
logger.debug(
|
f"[_build_llm_context] Converting to user format: {len(entities_context)} entities, {len(relations_context)} relations, {len(truncated_chunks)} chunks"
|
||||||
f"[_build_llm_context] Converting to user format: {len(entities_context)} entities, {len(relations_context)} relations, {len(truncated_chunks)} chunks"
|
)
|
||||||
)
|
final_data = convert_to_user_format(
|
||||||
final_data = convert_to_user_format(
|
entities_context,
|
||||||
entities_context,
|
relations_context,
|
||||||
relations_context,
|
truncated_chunks,
|
||||||
truncated_chunks,
|
reference_list,
|
||||||
reference_list,
|
query_param.mode,
|
||||||
query_param.mode,
|
entity_id_to_original,
|
||||||
entity_id_to_original,
|
relation_id_to_original,
|
||||||
relation_id_to_original,
|
)
|
||||||
)
|
logger.debug(
|
||||||
logger.debug(
|
f"[_build_llm_context] Final data after conversion: {len(final_data.get('entities', []))} entities, {len(final_data.get('relationships', []))} relationships, {len(final_data.get('chunks', []))} chunks"
|
||||||
f"[_build_llm_context] Final data after conversion: {len(final_data.get('entities', []))} entities, {len(final_data.get('relationships', []))} relationships, {len(final_data.get('chunks', []))} chunks"
|
)
|
||||||
)
|
return result, final_data
|
||||||
return result, final_data
|
|
||||||
else:
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
# Now let's update the old _build_query_context to use the new architecture
|
# Now let's update the old _build_query_context to use the new architecture
|
||||||
|
|
@ -3375,16 +3364,17 @@ async def _build_query_context(
|
||||||
text_chunks_db: BaseKVStorage,
|
text_chunks_db: BaseKVStorage,
|
||||||
query_param: QueryParam,
|
query_param: QueryParam,
|
||||||
chunks_vdb: BaseVectorStorage = None,
|
chunks_vdb: BaseVectorStorage = None,
|
||||||
return_raw_data: bool = False,
|
) -> QueryContextResult | None:
|
||||||
) -> str | None | tuple[str, dict[str, Any]]:
|
|
||||||
"""
|
"""
|
||||||
Main query context building function using the new 4-stage architecture:
|
Main query context building function using the new 4-stage architecture:
|
||||||
1. Search -> 2. Truncate -> 3. Merge chunks -> 4. Build LLM context
|
1. Search -> 2. Truncate -> 3. Merge chunks -> 4. Build LLM context
|
||||||
|
|
||||||
|
Returns unified QueryContextResult containing both context and raw_data.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not query:
|
if not query:
|
||||||
logger.warning("Query is empty, skipping context building")
|
logger.warning("Query is empty, skipping context building")
|
||||||
return ""
|
return None
|
||||||
|
|
||||||
# Stage 1: Pure search
|
# Stage 1: Pure search
|
||||||
search_result = await _perform_kg_search(
|
search_result = await _perform_kg_search(
|
||||||
|
|
@ -3435,71 +3425,53 @@ async def _build_query_context(
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Stage 4: Build final LLM context with dynamic token processing
|
# Stage 4: Build final LLM context with dynamic token processing
|
||||||
|
# _build_llm_context now always returns tuple[str, dict]
|
||||||
|
context, raw_data = await _build_llm_context(
|
||||||
|
entities_context=truncation_result["entities_context"],
|
||||||
|
relations_context=truncation_result["relations_context"],
|
||||||
|
merged_chunks=merged_chunks,
|
||||||
|
query=query,
|
||||||
|
query_param=query_param,
|
||||||
|
global_config=text_chunks_db.global_config,
|
||||||
|
chunk_tracking=search_result["chunk_tracking"],
|
||||||
|
entity_id_to_original=truncation_result["entity_id_to_original"],
|
||||||
|
relation_id_to_original=truncation_result["relation_id_to_original"],
|
||||||
|
)
|
||||||
|
|
||||||
if return_raw_data:
|
# Convert keywords strings to lists and add complete metadata to raw_data
|
||||||
# Convert keywords strings to lists
|
hl_keywords_list = hl_keywords.split(", ") if hl_keywords else []
|
||||||
hl_keywords_list = hl_keywords.split(", ") if hl_keywords else []
|
ll_keywords_list = ll_keywords.split(", ") if ll_keywords else []
|
||||||
ll_keywords_list = ll_keywords.split(", ") if ll_keywords else []
|
|
||||||
|
|
||||||
# Get both context and final data - when return_raw_data=True, _build_llm_context always returns tuple
|
# Add complete metadata to raw_data (preserve existing metadata including query_mode)
|
||||||
context, raw_data = await _build_llm_context(
|
if "metadata" not in raw_data:
|
||||||
entities_context=truncation_result["entities_context"],
|
raw_data["metadata"] = {}
|
||||||
relations_context=truncation_result["relations_context"],
|
|
||||||
merged_chunks=merged_chunks,
|
|
||||||
query=query,
|
|
||||||
query_param=query_param,
|
|
||||||
global_config=text_chunks_db.global_config,
|
|
||||||
chunk_tracking=search_result["chunk_tracking"],
|
|
||||||
return_raw_data=True,
|
|
||||||
entity_id_to_original=truncation_result["entity_id_to_original"],
|
|
||||||
relation_id_to_original=truncation_result["relation_id_to_original"],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Convert keywords strings to lists and add complete metadata to raw_data
|
# Update keywords while preserving existing metadata
|
||||||
hl_keywords_list = hl_keywords.split(", ") if hl_keywords else []
|
raw_data["metadata"]["keywords"] = {
|
||||||
ll_keywords_list = ll_keywords.split(", ") if ll_keywords else []
|
"high_level": hl_keywords_list,
|
||||||
|
"low_level": ll_keywords_list,
|
||||||
|
}
|
||||||
|
raw_data["metadata"]["processing_info"] = {
|
||||||
|
"total_entities_found": len(search_result.get("final_entities", [])),
|
||||||
|
"total_relations_found": len(search_result.get("final_relations", [])),
|
||||||
|
"entities_after_truncation": len(
|
||||||
|
truncation_result.get("filtered_entities", [])
|
||||||
|
),
|
||||||
|
"relations_after_truncation": len(
|
||||||
|
truncation_result.get("filtered_relations", [])
|
||||||
|
),
|
||||||
|
"merged_chunks_count": len(merged_chunks),
|
||||||
|
"final_chunks_count": len(raw_data.get("data", {}).get("chunks", [])),
|
||||||
|
}
|
||||||
|
|
||||||
# Add complete metadata to raw_data (preserve existing metadata including query_mode)
|
logger.debug(
|
||||||
if "metadata" not in raw_data:
|
f"[_build_query_context] Context length: {len(context) if context else 0}"
|
||||||
raw_data["metadata"] = {}
|
)
|
||||||
|
logger.debug(
|
||||||
|
f"[_build_query_context] Raw data entities: {len(raw_data.get('data', {}).get('entities', []))}, relationships: {len(raw_data.get('data', {}).get('relationships', []))}, chunks: {len(raw_data.get('data', {}).get('chunks', []))}"
|
||||||
|
)
|
||||||
|
|
||||||
# Update keywords while preserving existing metadata
|
return QueryContextResult(context=context, raw_data=raw_data)
|
||||||
raw_data["metadata"]["keywords"] = {
|
|
||||||
"high_level": hl_keywords_list,
|
|
||||||
"low_level": ll_keywords_list,
|
|
||||||
}
|
|
||||||
raw_data["metadata"]["processing_info"] = {
|
|
||||||
"total_entities_found": len(search_result.get("final_entities", [])),
|
|
||||||
"total_relations_found": len(search_result.get("final_relations", [])),
|
|
||||||
"entities_after_truncation": len(
|
|
||||||
truncation_result.get("filtered_entities", [])
|
|
||||||
),
|
|
||||||
"relations_after_truncation": len(
|
|
||||||
truncation_result.get("filtered_relations", [])
|
|
||||||
),
|
|
||||||
"merged_chunks_count": len(merged_chunks),
|
|
||||||
"final_chunks_count": len(raw_data.get("chunks", [])),
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
f"[_build_query_context] Context length: {len(context) if context else 0}"
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
f"[_build_query_context] Raw data entities: {len(raw_data.get('entities', []))}, relationships: {len(raw_data.get('relationships', []))}, chunks: {len(raw_data.get('chunks', []))}"
|
|
||||||
)
|
|
||||||
return context, raw_data
|
|
||||||
else:
|
|
||||||
# Normal context building (existing logic)
|
|
||||||
context = await _build_llm_context(
|
|
||||||
entities_context=truncation_result["entities_context"],
|
|
||||||
relations_context=truncation_result["relations_context"],
|
|
||||||
merged_chunks=merged_chunks,
|
|
||||||
query=query,
|
|
||||||
query_param=query_param,
|
|
||||||
global_config=text_chunks_db.global_config,
|
|
||||||
chunk_tracking=search_result["chunk_tracking"],
|
|
||||||
)
|
|
||||||
return context
|
|
||||||
|
|
||||||
|
|
||||||
async def _get_node_data(
|
async def _get_node_data(
|
||||||
|
|
@ -4105,19 +4077,28 @@ async def naive_query(
|
||||||
global_config: dict[str, str],
|
global_config: dict[str, str],
|
||||||
hashing_kv: BaseKVStorage | None = None,
|
hashing_kv: BaseKVStorage | None = None,
|
||||||
system_prompt: str | None = None,
|
system_prompt: str | None = None,
|
||||||
return_raw_data: bool = False,
|
) -> QueryResult:
|
||||||
) -> str | AsyncIterator[str] | dict[str, Any]:
|
"""
|
||||||
|
Execute naive query and return unified QueryResult object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Query string
|
||||||
|
chunks_vdb: Document chunks vector database
|
||||||
|
query_param: Query parameters
|
||||||
|
global_config: Global configuration
|
||||||
|
hashing_kv: Cache storage
|
||||||
|
system_prompt: System prompt
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
QueryResult: Unified query result object containing:
|
||||||
|
- content: Non-streaming response text content
|
||||||
|
- response_iterator: Streaming response iterator
|
||||||
|
- raw_data: Complete structured data (including references and metadata)
|
||||||
|
- is_streaming: Whether this is a streaming result
|
||||||
|
"""
|
||||||
|
|
||||||
if not query:
|
if not query:
|
||||||
if return_raw_data:
|
return QueryResult(content=PROMPTS["fail_response"])
|
||||||
# Return empty raw data structure when query is empty
|
|
||||||
empty_raw_data = {
|
|
||||||
"status": "failure",
|
|
||||||
"message": "Query string is empty.",
|
|
||||||
"data": {},
|
|
||||||
}
|
|
||||||
return empty_raw_data
|
|
||||||
else:
|
|
||||||
return PROMPTS["fail_response"]
|
|
||||||
|
|
||||||
if query_param.model_func:
|
if query_param.model_func:
|
||||||
use_model_func = query_param.model_func
|
use_model_func = query_param.model_func
|
||||||
|
|
@ -4147,41 +4128,28 @@ async def naive_query(
|
||||||
if cached_result is not None:
|
if cached_result is not None:
|
||||||
cached_response, _ = cached_result # Extract content, ignore timestamp
|
cached_response, _ = cached_result # Extract content, ignore timestamp
|
||||||
if not query_param.only_need_context and not query_param.only_need_prompt:
|
if not query_param.only_need_context and not query_param.only_need_prompt:
|
||||||
return cached_response
|
return QueryResult(content=cached_response)
|
||||||
|
|
||||||
tokenizer: Tokenizer = global_config["tokenizer"]
|
tokenizer: Tokenizer = global_config["tokenizer"]
|
||||||
if not tokenizer:
|
if not tokenizer:
|
||||||
if return_raw_data:
|
logger.error("Tokenizer not found in global configuration.")
|
||||||
# Return empty raw data structure when tokenizer is missing
|
return QueryResult(content=PROMPTS["fail_response"])
|
||||||
empty_raw_data = {
|
|
||||||
"status": "failure",
|
|
||||||
"message": "Tokenizer not found in global configuration.",
|
|
||||||
"data": {},
|
|
||||||
}
|
|
||||||
return empty_raw_data
|
|
||||||
else:
|
|
||||||
logger.error("Tokenizer not found in global configuration.")
|
|
||||||
return PROMPTS["fail_response"]
|
|
||||||
|
|
||||||
chunks = await _get_vector_context(query, chunks_vdb, query_param, None)
|
chunks = await _get_vector_context(query, chunks_vdb, query_param, None)
|
||||||
|
|
||||||
if chunks is None or len(chunks) == 0:
|
if chunks is None or len(chunks) == 0:
|
||||||
# If only raw data is requested, return it directly
|
# Build empty raw data structure for naive mode
|
||||||
if return_raw_data:
|
empty_raw_data = convert_to_user_format(
|
||||||
empty_raw_data = convert_to_user_format(
|
[], # naive mode has no entities
|
||||||
[], # naive mode has no entities
|
[], # naive mode has no relationships
|
||||||
[], # naive mode has no relationships
|
[], # no chunks
|
||||||
[], # no chunks
|
[], # no references
|
||||||
[], # no references
|
"naive",
|
||||||
"naive",
|
)
|
||||||
)
|
empty_raw_data["message"] = "No relevant document chunks found."
|
||||||
empty_raw_data["message"] = "No relevant document chunks found."
|
return QueryResult(content=PROMPTS["fail_response"], raw_data=empty_raw_data)
|
||||||
return empty_raw_data
|
|
||||||
else:
|
|
||||||
return PROMPTS["fail_response"]
|
|
||||||
|
|
||||||
# Calculate dynamic token limit for chunks
|
# Calculate dynamic token limit for chunks
|
||||||
# Get token limits from query_param (with fallback to global_config)
|
|
||||||
max_total_tokens = getattr(
|
max_total_tokens = getattr(
|
||||||
query_param,
|
query_param,
|
||||||
"max_total_tokens",
|
"max_total_tokens",
|
||||||
|
|
@ -4240,30 +4208,26 @@ async def naive_query(
|
||||||
|
|
||||||
logger.info(f"Final context: {len(processed_chunks_with_ref_ids)} chunks")
|
logger.info(f"Final context: {len(processed_chunks_with_ref_ids)} chunks")
|
||||||
|
|
||||||
# If only raw data is requested, return it directly
|
# Build raw data structure for naive mode using processed chunks with reference IDs
|
||||||
if return_raw_data:
|
raw_data = convert_to_user_format(
|
||||||
# Build raw data structure for naive mode using processed chunks with reference IDs
|
[], # naive mode has no entities
|
||||||
raw_data = convert_to_user_format(
|
[], # naive mode has no relationships
|
||||||
[], # naive mode has no entities
|
processed_chunks_with_ref_ids,
|
||||||
[], # naive mode has no relationships
|
reference_list,
|
||||||
processed_chunks_with_ref_ids,
|
"naive",
|
||||||
reference_list,
|
)
|
||||||
"naive",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add complete metadata for naive mode
|
# Add complete metadata for naive mode
|
||||||
if "metadata" not in raw_data:
|
if "metadata" not in raw_data:
|
||||||
raw_data["metadata"] = {}
|
raw_data["metadata"] = {}
|
||||||
raw_data["metadata"]["keywords"] = {
|
raw_data["metadata"]["keywords"] = {
|
||||||
"high_level": [], # naive mode has no keyword extraction
|
"high_level": [], # naive mode has no keyword extraction
|
||||||
"low_level": [], # naive mode has no keyword extraction
|
"low_level": [], # naive mode has no keyword extraction
|
||||||
}
|
}
|
||||||
raw_data["metadata"]["processing_info"] = {
|
raw_data["metadata"]["processing_info"] = {
|
||||||
"total_chunks_found": len(chunks),
|
"total_chunks_found": len(chunks),
|
||||||
"final_chunks_count": len(processed_chunks_with_ref_ids),
|
"final_chunks_count": len(processed_chunks_with_ref_ids),
|
||||||
}
|
}
|
||||||
|
|
||||||
return raw_data
|
|
||||||
|
|
||||||
# Build text_units_context from processed chunks with reference IDs
|
# Build text_units_context from processed chunks with reference IDs
|
||||||
text_units_context = []
|
text_units_context = []
|
||||||
|
|
@ -4284,8 +4248,7 @@ async def naive_query(
|
||||||
if ref["reference_id"]
|
if ref["reference_id"]
|
||||||
)
|
)
|
||||||
|
|
||||||
if query_param.only_need_context and not query_param.only_need_prompt:
|
context_content = f"""
|
||||||
return f"""
|
|
||||||
---Document Chunks(DC)---
|
---Document Chunks(DC)---
|
||||||
|
|
||||||
```json
|
```json
|
||||||
|
|
@ -4297,6 +4260,10 @@ async def naive_query(
|
||||||
{reference_list_str}
|
{reference_list_str}
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if query_param.only_need_context and not query_param.only_need_prompt:
|
||||||
|
return QueryResult(content=context_content, raw_data=raw_data)
|
||||||
|
|
||||||
user_query = (
|
user_query = (
|
||||||
"\n\n".join([query, query_param.user_prompt])
|
"\n\n".join([query, query_param.user_prompt])
|
||||||
if query_param.user_prompt
|
if query_param.user_prompt
|
||||||
|
|
@ -4310,7 +4277,8 @@ async def naive_query(
|
||||||
)
|
)
|
||||||
|
|
||||||
if query_param.only_need_prompt:
|
if query_param.only_need_prompt:
|
||||||
return "\n\n".join([sys_prompt, "---User Query---", user_query])
|
prompt_content = "\n\n".join([sys_prompt, "---User Query---", user_query])
|
||||||
|
return QueryResult(content=prompt_content, raw_data=raw_data)
|
||||||
|
|
||||||
len_of_prompts = len(tokenizer.encode(query + sys_prompt))
|
len_of_prompts = len(tokenizer.encode(query + sys_prompt))
|
||||||
logger.debug(
|
logger.debug(
|
||||||
|
|
@ -4325,43 +4293,51 @@ async def naive_query(
|
||||||
stream=query_param.stream,
|
stream=query_param.stream,
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(response, str) and len(response) > len(sys_prompt):
|
# Return unified result based on actual response type
|
||||||
response = (
|
if isinstance(response, str):
|
||||||
response[len(sys_prompt) :]
|
# Non-streaming response (string)
|
||||||
.replace(sys_prompt, "")
|
if len(response) > len(sys_prompt):
|
||||||
.replace("user", "")
|
response = (
|
||||||
.replace("model", "")
|
response[len(sys_prompt) :]
|
||||||
.replace(query, "")
|
.replace(sys_prompt, "")
|
||||||
.replace("<system>", "")
|
.replace("user", "")
|
||||||
.replace("</system>", "")
|
.replace("model", "")
|
||||||
.strip()
|
.replace(query, "")
|
||||||
)
|
.replace("<system>", "")
|
||||||
|
.replace("</system>", "")
|
||||||
|
.strip()
|
||||||
|
)
|
||||||
|
|
||||||
if hashing_kv.global_config.get("enable_llm_cache"):
|
# Cache response
|
||||||
# Save to cache with query parameters
|
if hashing_kv and hashing_kv.global_config.get("enable_llm_cache"):
|
||||||
queryparam_dict = {
|
queryparam_dict = {
|
||||||
"mode": query_param.mode,
|
"mode": query_param.mode,
|
||||||
"response_type": query_param.response_type,
|
"response_type": query_param.response_type,
|
||||||
"top_k": query_param.top_k,
|
"top_k": query_param.top_k,
|
||||||
"chunk_top_k": query_param.chunk_top_k,
|
"chunk_top_k": query_param.chunk_top_k,
|
||||||
"max_entity_tokens": query_param.max_entity_tokens,
|
"max_entity_tokens": query_param.max_entity_tokens,
|
||||||
"max_relation_tokens": query_param.max_relation_tokens,
|
"max_relation_tokens": query_param.max_relation_tokens,
|
||||||
"max_total_tokens": query_param.max_total_tokens,
|
"max_total_tokens": query_param.max_total_tokens,
|
||||||
"hl_keywords": query_param.hl_keywords or [],
|
"hl_keywords": query_param.hl_keywords or [],
|
||||||
"ll_keywords": query_param.ll_keywords or [],
|
"ll_keywords": query_param.ll_keywords or [],
|
||||||
"user_prompt": query_param.user_prompt or "",
|
"user_prompt": query_param.user_prompt or "",
|
||||||
"enable_rerank": query_param.enable_rerank,
|
"enable_rerank": query_param.enable_rerank,
|
||||||
}
|
}
|
||||||
await save_to_cache(
|
await save_to_cache(
|
||||||
hashing_kv,
|
hashing_kv,
|
||||||
CacheData(
|
CacheData(
|
||||||
args_hash=args_hash,
|
args_hash=args_hash,
|
||||||
content=response,
|
content=response,
|
||||||
prompt=query,
|
prompt=query,
|
||||||
mode=query_param.mode,
|
mode=query_param.mode,
|
||||||
cache_type="query",
|
cache_type="query",
|
||||||
queryparam=queryparam_dict,
|
queryparam=queryparam_dict,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
return response
|
return QueryResult(content=response, raw_data=raw_data)
|
||||||
|
else:
|
||||||
|
# Streaming response (AsyncIterator)
|
||||||
|
return QueryResult(
|
||||||
|
response_iterator=response, raw_data=raw_data, is_streaming=True
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,8 @@ Updated to handle the new data format where:
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
import time
|
import time
|
||||||
from typing import Dict, Any
|
import json
|
||||||
|
from typing import Dict, Any, List, Optional
|
||||||
|
|
||||||
# API configuration
|
# API configuration
|
||||||
API_KEY = "your-secure-api-key-here-123"
|
API_KEY = "your-secure-api-key-here-123"
|
||||||
|
|
@ -21,6 +22,456 @@ BASE_URL = "http://localhost:9621"
|
||||||
AUTH_HEADERS = {"Content-Type": "application/json", "X-API-Key": API_KEY}
|
AUTH_HEADERS = {"Content-Type": "application/json", "X-API-Key": API_KEY}
|
||||||
|
|
||||||
|
|
||||||
|
def validate_references_format(references: List[Dict[str, Any]]) -> bool:
|
||||||
|
"""Validate the format of references list"""
|
||||||
|
if not isinstance(references, list):
|
||||||
|
print(f"❌ References should be a list, got {type(references)}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
for i, ref in enumerate(references):
|
||||||
|
if not isinstance(ref, dict):
|
||||||
|
print(f"❌ Reference {i} should be a dict, got {type(ref)}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
required_fields = ["reference_id", "file_path"]
|
||||||
|
for field in required_fields:
|
||||||
|
if field not in ref:
|
||||||
|
print(f"❌ Reference {i} missing required field: {field}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not isinstance(ref[field], str):
|
||||||
|
print(
|
||||||
|
f"❌ Reference {i} field '{field}' should be string, got {type(ref[field])}"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def parse_streaming_response(
|
||||||
|
response_text: str,
|
||||||
|
) -> tuple[Optional[List[Dict]], List[str], List[str]]:
|
||||||
|
"""Parse streaming response and extract references, response chunks, and errors"""
|
||||||
|
references = None
|
||||||
|
response_chunks = []
|
||||||
|
errors = []
|
||||||
|
|
||||||
|
lines = response_text.strip().split("\n")
|
||||||
|
|
||||||
|
for line in lines:
|
||||||
|
line = line.strip()
|
||||||
|
if not line or line.startswith("data: "):
|
||||||
|
if line.startswith("data: "):
|
||||||
|
line = line[6:] # Remove 'data: ' prefix
|
||||||
|
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = json.loads(line)
|
||||||
|
|
||||||
|
if "references" in data:
|
||||||
|
references = data["references"]
|
||||||
|
elif "response" in data:
|
||||||
|
response_chunks.append(data["response"])
|
||||||
|
elif "error" in data:
|
||||||
|
errors.append(data["error"])
|
||||||
|
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
# Skip non-JSON lines (like SSE comments)
|
||||||
|
continue
|
||||||
|
|
||||||
|
return references, response_chunks, errors
|
||||||
|
|
||||||
|
|
||||||
|
def test_query_endpoint_references():
|
||||||
|
"""Test /query endpoint references functionality"""
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("Testing /query endpoint references functionality")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
query_text = "who authored LightRAG"
|
||||||
|
endpoint = f"{BASE_URL}/query"
|
||||||
|
|
||||||
|
# Test 1: References enabled (default)
|
||||||
|
print("\n🧪 Test 1: References enabled (default)")
|
||||||
|
print("-" * 40)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.post(
|
||||||
|
endpoint,
|
||||||
|
json={"query": query_text, "mode": "mix", "include_references": True},
|
||||||
|
headers=AUTH_HEADERS,
|
||||||
|
timeout=30,
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
# Check response structure
|
||||||
|
if "response" not in data:
|
||||||
|
print("❌ Missing 'response' field")
|
||||||
|
return False
|
||||||
|
|
||||||
|
if "references" not in data:
|
||||||
|
print("❌ Missing 'references' field when include_references=True")
|
||||||
|
return False
|
||||||
|
|
||||||
|
references = data["references"]
|
||||||
|
if references is None:
|
||||||
|
print("❌ References should not be None when include_references=True")
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not validate_references_format(references):
|
||||||
|
return False
|
||||||
|
|
||||||
|
print(f"✅ References enabled: Found {len(references)} references")
|
||||||
|
print(f" Response length: {len(data['response'])} characters")
|
||||||
|
|
||||||
|
# Display reference list
|
||||||
|
if references:
|
||||||
|
print(" 📚 Reference List:")
|
||||||
|
for i, ref in enumerate(references, 1):
|
||||||
|
ref_id = ref.get("reference_id", "Unknown")
|
||||||
|
file_path = ref.get("file_path", "Unknown")
|
||||||
|
print(f" {i}. ID: {ref_id} | File: {file_path}")
|
||||||
|
|
||||||
|
else:
|
||||||
|
print(f"❌ Request failed: {response.status_code}")
|
||||||
|
print(f" Error: {response.text}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Test 1 failed: {str(e)}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Test 2: References disabled
|
||||||
|
print("\n🧪 Test 2: References disabled")
|
||||||
|
print("-" * 40)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.post(
|
||||||
|
endpoint,
|
||||||
|
json={"query": query_text, "mode": "mix", "include_references": False},
|
||||||
|
headers=AUTH_HEADERS,
|
||||||
|
timeout=30,
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
# Check response structure
|
||||||
|
if "response" not in data:
|
||||||
|
print("❌ Missing 'response' field")
|
||||||
|
return False
|
||||||
|
|
||||||
|
references = data.get("references")
|
||||||
|
if references is not None:
|
||||||
|
print("❌ References should be None when include_references=False")
|
||||||
|
return False
|
||||||
|
|
||||||
|
print("✅ References disabled: No references field present")
|
||||||
|
print(f" Response length: {len(data['response'])} characters")
|
||||||
|
|
||||||
|
else:
|
||||||
|
print(f"❌ Request failed: {response.status_code}")
|
||||||
|
print(f" Error: {response.text}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Test 2 failed: {str(e)}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
print("\n✅ /query endpoint references tests passed!")
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def test_query_stream_endpoint_references():
|
||||||
|
"""Test /query/stream endpoint references functionality"""
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("Testing /query/stream endpoint references functionality")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
query_text = "who authored LightRAG"
|
||||||
|
endpoint = f"{BASE_URL}/query/stream"
|
||||||
|
|
||||||
|
# Test 1: Streaming with references enabled
|
||||||
|
print("\n🧪 Test 1: Streaming with references enabled")
|
||||||
|
print("-" * 40)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.post(
|
||||||
|
endpoint,
|
||||||
|
json={"query": query_text, "mode": "mix", "include_references": True},
|
||||||
|
headers=AUTH_HEADERS,
|
||||||
|
timeout=30,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
# Collect streaming response
|
||||||
|
full_response = ""
|
||||||
|
for chunk in response.iter_content(chunk_size=1024, decode_unicode=True):
|
||||||
|
if chunk:
|
||||||
|
# Ensure chunk is string type
|
||||||
|
if isinstance(chunk, bytes):
|
||||||
|
chunk = chunk.decode("utf-8")
|
||||||
|
full_response += chunk
|
||||||
|
|
||||||
|
# Parse streaming response
|
||||||
|
references, response_chunks, errors = parse_streaming_response(
|
||||||
|
full_response
|
||||||
|
)
|
||||||
|
|
||||||
|
if errors:
|
||||||
|
print(f"❌ Errors in streaming response: {errors}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
if references is None:
|
||||||
|
print("❌ No references found in streaming response")
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not validate_references_format(references):
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not response_chunks:
|
||||||
|
print("❌ No response chunks found in streaming response")
|
||||||
|
return False
|
||||||
|
|
||||||
|
print(f"✅ Streaming with references: Found {len(references)} references")
|
||||||
|
print(f" Response chunks: {len(response_chunks)}")
|
||||||
|
print(
|
||||||
|
f" Total response length: {sum(len(chunk) for chunk in response_chunks)} characters"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Display reference list
|
||||||
|
if references:
|
||||||
|
print(" 📚 Reference List:")
|
||||||
|
for i, ref in enumerate(references, 1):
|
||||||
|
ref_id = ref.get("reference_id", "Unknown")
|
||||||
|
file_path = ref.get("file_path", "Unknown")
|
||||||
|
print(f" {i}. ID: {ref_id} | File: {file_path}")
|
||||||
|
|
||||||
|
else:
|
||||||
|
print(f"❌ Request failed: {response.status_code}")
|
||||||
|
print(f" Error: {response.text}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Test 1 failed: {str(e)}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Test 2: Streaming with references disabled
|
||||||
|
print("\n🧪 Test 2: Streaming with references disabled")
|
||||||
|
print("-" * 40)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.post(
|
||||||
|
endpoint,
|
||||||
|
json={"query": query_text, "mode": "mix", "include_references": False},
|
||||||
|
headers=AUTH_HEADERS,
|
||||||
|
timeout=30,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
# Collect streaming response
|
||||||
|
full_response = ""
|
||||||
|
for chunk in response.iter_content(chunk_size=1024, decode_unicode=True):
|
||||||
|
if chunk:
|
||||||
|
# Ensure chunk is string type
|
||||||
|
if isinstance(chunk, bytes):
|
||||||
|
chunk = chunk.decode("utf-8")
|
||||||
|
full_response += chunk
|
||||||
|
|
||||||
|
# Parse streaming response
|
||||||
|
references, response_chunks, errors = parse_streaming_response(
|
||||||
|
full_response
|
||||||
|
)
|
||||||
|
|
||||||
|
if errors:
|
||||||
|
print(f"❌ Errors in streaming response: {errors}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
if references is not None:
|
||||||
|
print("❌ References should be None when include_references=False")
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not response_chunks:
|
||||||
|
print("❌ No response chunks found in streaming response")
|
||||||
|
return False
|
||||||
|
|
||||||
|
print("✅ Streaming without references: No references present")
|
||||||
|
print(f" Response chunks: {len(response_chunks)}")
|
||||||
|
print(
|
||||||
|
f" Total response length: {sum(len(chunk) for chunk in response_chunks)} characters"
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
print(f"❌ Request failed: {response.status_code}")
|
||||||
|
print(f" Error: {response.text}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Test 2 failed: {str(e)}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
print("\n✅ /query/stream endpoint references tests passed!")
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def test_references_consistency():
|
||||||
|
"""Test references consistency across all endpoints"""
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("Testing references consistency across endpoints")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
query_text = "who authored LightRAG"
|
||||||
|
query_params = {
|
||||||
|
"query": query_text,
|
||||||
|
"mode": "mix",
|
||||||
|
"top_k": 10,
|
||||||
|
"chunk_top_k": 8,
|
||||||
|
"include_references": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
references_data = {}
|
||||||
|
|
||||||
|
# Test /query endpoint
|
||||||
|
print("\n🧪 Testing /query endpoint")
|
||||||
|
print("-" * 40)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.post(
|
||||||
|
f"{BASE_URL}/query", json=query_params, headers=AUTH_HEADERS, timeout=30
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
data = response.json()
|
||||||
|
references_data["query"] = data.get("references", [])
|
||||||
|
print(f"✅ /query: {len(references_data['query'])} references")
|
||||||
|
else:
|
||||||
|
print(f"❌ /query failed: {response.status_code}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ /query test failed: {str(e)}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Test /query/stream endpoint
|
||||||
|
print("\n🧪 Testing /query/stream endpoint")
|
||||||
|
print("-" * 40)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.post(
|
||||||
|
f"{BASE_URL}/query/stream",
|
||||||
|
json=query_params,
|
||||||
|
headers=AUTH_HEADERS,
|
||||||
|
timeout=30,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
full_response = ""
|
||||||
|
for chunk in response.iter_content(chunk_size=1024, decode_unicode=True):
|
||||||
|
if chunk:
|
||||||
|
# Ensure chunk is string type
|
||||||
|
if isinstance(chunk, bytes):
|
||||||
|
chunk = chunk.decode("utf-8")
|
||||||
|
full_response += chunk
|
||||||
|
|
||||||
|
references, _, errors = parse_streaming_response(full_response)
|
||||||
|
|
||||||
|
if errors:
|
||||||
|
print(f"❌ Errors: {errors}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
references_data["stream"] = references or []
|
||||||
|
print(f"✅ /query/stream: {len(references_data['stream'])} references")
|
||||||
|
else:
|
||||||
|
print(f"❌ /query/stream failed: {response.status_code}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ /query/stream test failed: {str(e)}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Test /query/data endpoint
|
||||||
|
print("\n🧪 Testing /query/data endpoint")
|
||||||
|
print("-" * 40)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.post(
|
||||||
|
f"{BASE_URL}/query/data",
|
||||||
|
json=query_params,
|
||||||
|
headers=AUTH_HEADERS,
|
||||||
|
timeout=30,
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
data = response.json()
|
||||||
|
query_data = data.get("data", {})
|
||||||
|
references_data["data"] = query_data.get("references", [])
|
||||||
|
print(f"✅ /query/data: {len(references_data['data'])} references")
|
||||||
|
else:
|
||||||
|
print(f"❌ /query/data failed: {response.status_code}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ /query/data test failed: {str(e)}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Compare references consistency
|
||||||
|
print("\n🔍 Comparing references consistency")
|
||||||
|
print("-" * 40)
|
||||||
|
|
||||||
|
# Convert to sets of (reference_id, file_path) tuples for comparison
|
||||||
|
def refs_to_set(refs):
|
||||||
|
return set(
|
||||||
|
(ref.get("reference_id", ""), ref.get("file_path", "")) for ref in refs
|
||||||
|
)
|
||||||
|
|
||||||
|
query_refs = refs_to_set(references_data["query"])
|
||||||
|
stream_refs = refs_to_set(references_data["stream"])
|
||||||
|
data_refs = refs_to_set(references_data["data"])
|
||||||
|
|
||||||
|
# Check consistency
|
||||||
|
consistency_passed = True
|
||||||
|
|
||||||
|
if query_refs != stream_refs:
|
||||||
|
print("❌ References mismatch between /query and /query/stream")
|
||||||
|
print(f" /query only: {query_refs - stream_refs}")
|
||||||
|
print(f" /query/stream only: {stream_refs - query_refs}")
|
||||||
|
consistency_passed = False
|
||||||
|
|
||||||
|
if query_refs != data_refs:
|
||||||
|
print("❌ References mismatch between /query and /query/data")
|
||||||
|
print(f" /query only: {query_refs - data_refs}")
|
||||||
|
print(f" /query/data only: {data_refs - query_refs}")
|
||||||
|
consistency_passed = False
|
||||||
|
|
||||||
|
if stream_refs != data_refs:
|
||||||
|
print("❌ References mismatch between /query/stream and /query/data")
|
||||||
|
print(f" /query/stream only: {stream_refs - data_refs}")
|
||||||
|
print(f" /query/data only: {data_refs - stream_refs}")
|
||||||
|
consistency_passed = False
|
||||||
|
|
||||||
|
if consistency_passed:
|
||||||
|
print("✅ All endpoints return consistent references")
|
||||||
|
print(f" Common references count: {len(query_refs)}")
|
||||||
|
|
||||||
|
# Display common reference list
|
||||||
|
if query_refs:
|
||||||
|
print(" 📚 Common Reference List:")
|
||||||
|
for i, (ref_id, file_path) in enumerate(sorted(query_refs), 1):
|
||||||
|
print(f" {i}. ID: {ref_id} | File: {file_path}")
|
||||||
|
|
||||||
|
return consistency_passed
|
||||||
|
|
||||||
|
|
||||||
def test_aquery_data_endpoint():
|
def test_aquery_data_endpoint():
|
||||||
"""Test the /query/data endpoint"""
|
"""Test the /query/data endpoint"""
|
||||||
|
|
||||||
|
|
@ -239,15 +690,79 @@ def compare_with_regular_query():
|
||||||
print(f" Regular query error: {str(e)}")
|
print(f" Regular query error: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
def run_all_reference_tests():
|
||||||
|
"""Run all reference-related tests"""
|
||||||
|
|
||||||
|
print("\n" + "🚀" * 20)
|
||||||
|
print("LightRAG References Test Suite")
|
||||||
|
print("🚀" * 20)
|
||||||
|
|
||||||
|
all_tests_passed = True
|
||||||
|
|
||||||
|
# Test 1: /query endpoint references
|
||||||
|
try:
|
||||||
|
if not test_query_endpoint_references():
|
||||||
|
all_tests_passed = False
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ /query endpoint test failed with exception: {str(e)}")
|
||||||
|
all_tests_passed = False
|
||||||
|
|
||||||
|
# Test 2: /query/stream endpoint references
|
||||||
|
try:
|
||||||
|
if not test_query_stream_endpoint_references():
|
||||||
|
all_tests_passed = False
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ /query/stream endpoint test failed with exception: {str(e)}")
|
||||||
|
all_tests_passed = False
|
||||||
|
|
||||||
|
# Test 3: References consistency across endpoints
|
||||||
|
try:
|
||||||
|
if not test_references_consistency():
|
||||||
|
all_tests_passed = False
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ References consistency test failed with exception: {str(e)}")
|
||||||
|
all_tests_passed = False
|
||||||
|
|
||||||
|
# Final summary
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("TEST SUITE SUMMARY")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
if all_tests_passed:
|
||||||
|
print("🎉 ALL TESTS PASSED!")
|
||||||
|
print("✅ /query endpoint references functionality works correctly")
|
||||||
|
print("✅ /query/stream endpoint references functionality works correctly")
|
||||||
|
print("✅ References are consistent across all endpoints")
|
||||||
|
print("\n🔧 System is ready for production use with reference support!")
|
||||||
|
else:
|
||||||
|
print("❌ SOME TESTS FAILED!")
|
||||||
|
print("Please check the error messages above and fix the issues.")
|
||||||
|
print("\n🔧 System needs attention before production deployment.")
|
||||||
|
|
||||||
|
return all_tests_passed
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Run main test
|
import sys
|
||||||
test_aquery_data_endpoint()
|
|
||||||
|
|
||||||
# Run comparison test
|
if len(sys.argv) > 1 and sys.argv[1] == "--references-only":
|
||||||
compare_with_regular_query()
|
# Run only the new reference tests
|
||||||
|
success = run_all_reference_tests()
|
||||||
|
sys.exit(0 if success else 1)
|
||||||
|
else:
|
||||||
|
# Run original tests plus new reference tests
|
||||||
|
print("Running original aquery_data endpoint test...")
|
||||||
|
test_aquery_data_endpoint()
|
||||||
|
|
||||||
print("\n💡 Usage tips:")
|
print("\nRunning comparison test...")
|
||||||
print("1. Ensure LightRAG API service is running")
|
compare_with_regular_query()
|
||||||
print("2. Adjust base_url and authentication information as needed")
|
|
||||||
print("3. Modify query parameters to test different retrieval strategies")
|
print("\nRunning new reference tests...")
|
||||||
print("4. Data query results can be used for further analysis and processing")
|
run_all_reference_tests()
|
||||||
|
|
||||||
|
print("\n💡 Usage tips:")
|
||||||
|
print("1. Ensure LightRAG API service is running")
|
||||||
|
print("2. Adjust base_url and authentication information as needed")
|
||||||
|
print("3. Modify query parameters to test different retrieval strategies")
|
||||||
|
print("4. Data query results can be used for further analysis and processing")
|
||||||
|
print("5. Run with --references-only flag to test only reference functionality")
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue