Add aquery_data endpoint for structured retrieval without LLM generation

- Add QueryDataResponse model
- Implement /query/data endpoint
- Add aquery_data method to LightRAG
- Return entities, relationships, chunks
This commit is contained in:
yangdx 2025-09-15 02:15:14 +08:00
parent 99137446d0
commit b1c8206346
5 changed files with 341 additions and 53 deletions

View file

@ -134,6 +134,21 @@ class QueryResponse(BaseModel):
)
class QueryDataResponse(BaseModel):
entities: List[Dict[str, Any]] = Field(
description="Retrieved entities from knowledge graph"
)
relationships: List[Dict[str, Any]] = Field(
description="Retrieved relationships from knowledge graph"
)
chunks: List[Dict[str, Any]] = Field(
description="Retrieved text chunks from documents"
)
metadata: Dict[str, Any] = Field(
description="Query metadata including mode, keywords, and processing information"
)
def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
combined_auth = get_combined_auth_dependency(api_key)
@ -221,4 +236,66 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
trace_exception(e)
raise HTTPException(status_code=500, detail=str(e))
@router.post(
"/query/data", response_model=QueryDataResponse, dependencies=[Depends(combined_auth)]
)
async def query_data(request: QueryRequest):
"""
Retrieve structured data without LLM generation.
This endpoint returns raw retrieval results including entities, relationships,
and text chunks that would be used for RAG, but without generating a final response.
All parameters are compatible with the regular /query endpoint.
Parameters:
request (QueryRequest): The request object containing the query parameters.
Returns:
QueryDataResponse: A Pydantic model containing structured data with entities,
relationships, chunks, and metadata.
Raises:
HTTPException: Raised when an error occurs during the request handling process,
with status code 500 and detail containing the exception message.
"""
try:
param = request.to_query_params(False) # No streaming for data endpoint
response = await rag.aquery_data(request.query, param=param)
# The aquery_data method returns a dict with entities, relationships, chunks, and metadata
if isinstance(response, dict):
# Ensure all required fields exist and are lists/dicts
entities = response.get("entities", [])
relationships = response.get("relationships", [])
chunks = response.get("chunks", [])
metadata = response.get("metadata", {})
# Validate data types
if not isinstance(entities, list):
entities = []
if not isinstance(relationships, list):
relationships = []
if not isinstance(chunks, list):
chunks = []
if not isinstance(metadata, dict):
metadata = {}
return QueryDataResponse(
entities=entities,
relationships=relationships,
chunks=chunks,
metadata=metadata
)
else:
# Fallback for unexpected response format
return QueryDataResponse(
entities=[],
relationships=[],
chunks=[],
metadata={"error": "Unexpected response format", "raw_response": str(response)}
)
except Exception as e:
trace_exception(e)
raise HTTPException(status_code=500, detail=str(e))
return router

View file

@ -2136,6 +2136,7 @@ class LightRAG:
global_config = asdict(self)
if param.mode in ["local", "global", "hybrid", "mix"]:
logger.debug(f"[aquery_data] Using kg_query for mode: {param.mode}")
final_data = await kg_query(
query.strip(),
self.chunk_entity_relation_graph,
@ -2150,6 +2151,7 @@ class LightRAG:
return_raw_data=True, # Get final processed data
)
elif param.mode == "naive":
logger.debug(f"[aquery_data] Using naive_query for mode: {param.mode}")
final_data = await naive_query(
query.strip(),
self.chunks_vdb,
@ -2160,6 +2162,7 @@ class LightRAG:
return_raw_data=True, # Get final processed data
)
elif param.mode == "bypass":
logger.debug("[aquery_data] Using bypass mode")
# bypass mode returns empty data
final_data = {
"entities": [],
@ -2173,6 +2176,12 @@ class LightRAG:
else:
raise ValueError(f"Unknown mode {param.mode}")
# Log final result counts
entities_count = len(final_data.get("entities", []))
relationships_count = len(final_data.get("relationships", []))
chunks_count = len(final_data.get("chunks", []))
logger.debug(f"[aquery_data] Final result: {entities_count} entities, {relationships_count} relationships, {chunks_count} chunks")
await self._query_done()
return final_data

View file

@ -2344,7 +2344,9 @@ async def kg_query(
)
if isinstance(context_result, tuple):
_, raw_data = context_result
context, raw_data = context_result
logger.debug(f"[kg_query] Context length: {len(context) if context else 0}")
logger.debug(f"[kg_query] Raw data entities: {len(raw_data.get('entities', []))}, relationships: {len(raw_data.get('relationships', []))}, chunks: {len(raw_data.get('chunks', []))}")
return raw_data
else:
raise RuntimeError(
@ -3071,7 +3073,9 @@ async def _build_llm_context(
query_param: QueryParam,
global_config: dict[str, str],
chunk_tracking: dict = None,
return_final_data: bool = False,
return_raw_data: bool = False,
hl_keywords: list[str] = None,
ll_keywords: list[str] = None,
) -> str | tuple[str, dict[str, Any]]:
"""
Build the final LLM context string with token processing.
@ -3230,7 +3234,16 @@ async def _build_llm_context(
# not necessary to use LLM to generate a response
if not entities_context and not relations_context:
return None
if return_raw_data:
# Return empty raw data structure when no entities/relations
empty_raw_data = _convert_to_user_format(
[], [], [], query_param.mode,
hl_keywords=hl_keywords,
ll_keywords=ll_keywords,
)
return None, empty_raw_data
else:
return None
# output chunks tracking infomations
# format: <source><frequency>/<order> (e.g., E5/2 R2/1 C1/1)
@ -3281,10 +3294,17 @@ async def _build_llm_context(
"""
# If final data is requested, return both context and complete data structure
if return_final_data:
if return_raw_data:
logger.debug(f"[_build_llm_context] Converting to user format: {len(entities_context)} entities, {len(relations_context)} relations, {len(truncated_chunks)} chunks")
final_data = _convert_to_user_format(
entities_context, relations_context, truncated_chunks, query_param.mode
entities_context,
relations_context,
truncated_chunks,
query_param.mode,
hl_keywords=hl_keywords,
ll_keywords=ll_keywords,
)
logger.debug(f"[_build_llm_context] Final data after conversion: {len(final_data.get('entities', []))} entities, {len(final_data.get('relationships', []))} relationships, {len(final_data.get('chunks', []))} chunks")
return result, final_data
else:
return result
@ -3361,9 +3381,14 @@ async def _build_query_context(
return None
# Stage 4: Build final LLM context with dynamic token processing
if return_raw_data:
# Get both context and final data
context_result = await _build_llm_context(
# Convert keywords strings to lists
hl_keywords_list = hl_keywords.split(", ") if hl_keywords else []
ll_keywords_list = ll_keywords.split(", ") if ll_keywords else []
# Get both context and final data - when return_raw_data=True, _build_llm_context always returns tuple
context, raw_data = await _build_llm_context(
entities_context=truncation_result["entities_context"],
relations_context=truncation_result["relations_context"],
merged_chunks=merged_chunks,
@ -3371,45 +3396,13 @@ async def _build_query_context(
query_param=query_param,
global_config=text_chunks_db.global_config,
chunk_tracking=search_result["chunk_tracking"],
return_final_data=True,
return_raw_data=True,
hl_keywords=hl_keywords_list,
ll_keywords=ll_keywords_list,
)
if isinstance(context_result, tuple):
context, final_chunks = context_result
else:
# Handle case where no final chunks are returned
context = context_result
final_chunks = merged_chunks
# Build raw data structure with the same data that goes to LLM
raw_data = {
"entities": truncation_result[
"filtered_entities"
], # Use filtered entities (same as LLM)
"relationships": truncation_result[
"filtered_relations"
], # Use filtered relations (same as LLM)
"chunks": final_chunks, # Use final processed chunks (same as LLM)
"metadata": {
"query_mode": query_param.mode,
"keywords": {
"high_level": hl_keywords.split(", ") if hl_keywords else [],
"low_level": ll_keywords.split(", ") if ll_keywords else [],
},
"processing_info": {
"total_entities_found": len(search_result["final_entities"]),
"total_relations_found": len(search_result["final_relations"]),
"entities_after_truncation": len(
truncation_result["filtered_entities"]
),
"relations_after_truncation": len(
truncation_result["filtered_relations"]
),
"merged_chunks_count": len(merged_chunks),
"final_chunks_count": len(final_chunks),
},
},
}
logger.debug(f"[_build_query_context] Context length: {len(context) if context else 0}")
logger.debug(f"[_build_query_context] Raw data entities: {len(raw_data.get('entities', []))}, relationships: {len(raw_data.get('relationships', []))}, chunks: {len(raw_data.get('chunks', []))}")
return context, raw_data
else:
# Normal context building (existing logic)

View file

@ -2758,14 +2758,15 @@ def _convert_to_user_format(
# Convert chunks format
formatted_chunks = []
for chunk in final_chunks:
formatted_chunks.append(
{
"content": chunk.get("content", ""),
"file_path": chunk.get("file_path", "unknown_source"),
"chunk_id": chunk.get("chunk_id", ""),
}
)
for i, chunk in enumerate(final_chunks):
chunk_data = {
"content": chunk.get("content", ""),
"file_path": chunk.get("file_path", "unknown_source"),
"chunk_id": chunk.get("chunk_id", ""),
}
formatted_chunks.append(chunk_data)
logger.debug(f"[_convert_to_user_format] Formatted {len(formatted_chunks)}/{len(final_chunks)} chunks")
# Build metadata with processing info
metadata = {

View file

@ -0,0 +1,208 @@
#!/usr/bin/env python3
"""
Test script: Demonstrates usage of aquery_data FastAPI endpoint
Query content: Who is the author of LightRAG
"""
import requests
import time
from typing import Dict, Any
# API configuration
API_KEY = "your-secure-api-key-here-123"
BASE_URL = "http://localhost:9621"
# Unified authentication headers
AUTH_HEADERS = {
"Content-Type": "application/json",
"X-API-Key": API_KEY
}
def test_aquery_data_endpoint():
"""Test the /query/data endpoint"""
# Use unified configuration
endpoint = f"{BASE_URL}/query/data"
# Query request
query_request = {
"query": "who authored LighRAG",
"mode": "mix", # Use mixed mode to get the most comprehensive results
"top_k": 20,
"chunk_top_k": 15,
"max_entity_tokens": 4000,
"max_relation_tokens": 4000,
"max_total_tokens": 16000,
"enable_rerank": True
}
print("=" * 60)
print("LightRAG aquery_data endpoint test")
print(" Returns structured data including entities, relationships and text chunks")
print(" Can be used for custom processing and analysis")
print("=" * 60)
print(f"Query content: {query_request['query']}")
print(f"Query mode: {query_request['mode']}")
print(f"API endpoint: {endpoint}")
print("-" * 60)
try:
# Send request
print("Sending request...")
start_time = time.time()
response = requests.post(
endpoint,
json=query_request,
headers=AUTH_HEADERS,
timeout=30
)
end_time = time.time()
response_time = end_time - start_time
print(f"Response time: {response_time:.2f} seconds")
print(f"HTTP status code: {response.status_code}")
if response.status_code == 200:
data = response.json()
print_query_results(data)
else:
print(f"Request failed: {response.status_code}")
print(f"Error message: {response.text}")
except requests.exceptions.ConnectionError:
print("❌ Connection failed: Please ensure LightRAG API service is running")
print(" Start command: python -m lightrag.api.lightrag_server")
except requests.exceptions.Timeout:
print("❌ Request timeout: Query processing took too long")
except Exception as e:
print(f"❌ Error occurred: {str(e)}")
def print_query_results(data: Dict[str, Any]):
"""Format and print query results"""
entities = data.get("entities", [])
relationships = data.get("relationships", [])
chunks = data.get("chunks", [])
metadata = data.get("metadata", {})
print("\n📊 Query result statistics:")
print(f" Entity count: {len(entities)}")
print(f" Relationship count: {len(relationships)}")
print(f" Text chunk count: {len(chunks)}")
# Print metadata
if metadata:
print("\n🔍 Query metadata:")
print(f" Query mode: {metadata.get('query_mode', 'unknown')}")
keywords = metadata.get('keywords', {})
if keywords:
high_level = keywords.get('high_level', [])
low_level = keywords.get('low_level', [])
if high_level:
print(f" High-level keywords: {', '.join(high_level)}")
if low_level:
print(f" Low-level keywords: {', '.join(low_level)}")
processing_info = metadata.get('processing_info', {})
if processing_info:
print(" Processing info:")
for key, value in processing_info.items():
print(f" {key}: {value}")
# Print entity information
if entities:
print("\n👥 Retrieved entities (first 5):")
for i, entity in enumerate(entities[:5]):
entity_name = entity.get('entity_name', 'Unknown')
entity_type = entity.get('entity_type', 'Unknown')
description = entity.get('description', 'No description')
file_path = entity.get('file_path', 'Unknown source')
print(f" {i+1}. {entity_name} ({entity_type})")
print(f" Description: {description[:100]}{'...' if len(description) > 100 else ''}")
print(f" Source: {file_path}")
print()
# Print relationship information
if relationships:
print("🔗 Retrieved relationships (first 5):")
for i, rel in enumerate(relationships[:5]):
src = rel.get('src_id', 'Unknown')
tgt = rel.get('tgt_id', 'Unknown')
description = rel.get('description', 'No description')
keywords = rel.get('keywords', 'No keywords')
file_path = rel.get('file_path', 'Unknown source')
print(f" {i+1}. {src}{tgt}")
print(f" Keywords: {keywords}")
print(f" Description: {description[:100]}{'...' if len(description) > 100 else ''}")
print(f" Source: {file_path}")
print()
# Print text chunk information
if chunks:
print("📄 Retrieved text chunks (first 3):")
for i, chunk in enumerate(chunks[:3]):
content = chunk.get('content', 'No content')
file_path = chunk.get('file_path', 'Unknown source')
chunk_id = chunk.get('chunk_id', 'Unknown ID')
print(f" {i+1}. Text chunk ID: {chunk_id}")
print(f" Source: {file_path}")
print(f" Content: {content[:200]}{'...' if len(content) > 200 else ''}")
print()
print("=" * 60)
def compare_with_regular_query():
"""Compare results between regular query and data query"""
query_text = "LightRAG的作者是谁"
print("\n🔄 Comparison test: Regular query vs Data query")
print("-" * 60)
# Regular query
try:
print("1. Regular query (/query):")
regular_response = requests.post(
f"{BASE_URL}/query",
json={"query": query_text, "mode": "mix"},
headers=AUTH_HEADERS,
timeout=30
)
if regular_response.status_code == 200:
regular_data = regular_response.json()
response_text = regular_data.get('response', 'No response')
print(f" Generated answer: {response_text[:300]}{'...' if len(response_text) > 300 else ''}")
else:
print(f" Regular query failed: {regular_response.status_code}")
if regular_response.status_code == 403:
print(" Authentication failed - Please check API Key configuration")
elif regular_response.status_code == 401:
print(" Unauthorized - Please check authentication information")
print(f" Error details: {regular_response.text}")
except Exception as e:
print(f" Regular query error: {str(e)}")
if __name__ == "__main__":
# Run main test
test_aquery_data_endpoint()
# Run comparison test
compare_with_regular_query()
print("\n💡 Usage tips:")
print("1. Ensure LightRAG API service is running")
print("2. Adjust base_url and authentication information as needed")
print("3. Modify query parameters to test different retrieval strategies")
print("4. Data query results can be used for further analysis and processing")