Add aquery_data endpoint for structured retrieval without LLM generation
- Add QueryDataResponse model - Implement /query/data endpoint - Add aquery_data method to LightRAG - Return entities, relationships, chunks
This commit is contained in:
parent
99137446d0
commit
b1c8206346
5 changed files with 341 additions and 53 deletions
|
|
@ -134,6 +134,21 @@ class QueryResponse(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class QueryDataResponse(BaseModel):
|
||||
entities: List[Dict[str, Any]] = Field(
|
||||
description="Retrieved entities from knowledge graph"
|
||||
)
|
||||
relationships: List[Dict[str, Any]] = Field(
|
||||
description="Retrieved relationships from knowledge graph"
|
||||
)
|
||||
chunks: List[Dict[str, Any]] = Field(
|
||||
description="Retrieved text chunks from documents"
|
||||
)
|
||||
metadata: Dict[str, Any] = Field(
|
||||
description="Query metadata including mode, keywords, and processing information"
|
||||
)
|
||||
|
||||
|
||||
def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
|
||||
combined_auth = get_combined_auth_dependency(api_key)
|
||||
|
||||
|
|
@ -221,4 +236,66 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
|
|||
trace_exception(e)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.post(
|
||||
"/query/data", response_model=QueryDataResponse, dependencies=[Depends(combined_auth)]
|
||||
)
|
||||
async def query_data(request: QueryRequest):
|
||||
"""
|
||||
Retrieve structured data without LLM generation.
|
||||
|
||||
This endpoint returns raw retrieval results including entities, relationships,
|
||||
and text chunks that would be used for RAG, but without generating a final response.
|
||||
All parameters are compatible with the regular /query endpoint.
|
||||
|
||||
Parameters:
|
||||
request (QueryRequest): The request object containing the query parameters.
|
||||
|
||||
Returns:
|
||||
QueryDataResponse: A Pydantic model containing structured data with entities,
|
||||
relationships, chunks, and metadata.
|
||||
|
||||
Raises:
|
||||
HTTPException: Raised when an error occurs during the request handling process,
|
||||
with status code 500 and detail containing the exception message.
|
||||
"""
|
||||
try:
|
||||
param = request.to_query_params(False) # No streaming for data endpoint
|
||||
response = await rag.aquery_data(request.query, param=param)
|
||||
|
||||
# The aquery_data method returns a dict with entities, relationships, chunks, and metadata
|
||||
if isinstance(response, dict):
|
||||
# Ensure all required fields exist and are lists/dicts
|
||||
entities = response.get("entities", [])
|
||||
relationships = response.get("relationships", [])
|
||||
chunks = response.get("chunks", [])
|
||||
metadata = response.get("metadata", {})
|
||||
|
||||
# Validate data types
|
||||
if not isinstance(entities, list):
|
||||
entities = []
|
||||
if not isinstance(relationships, list):
|
||||
relationships = []
|
||||
if not isinstance(chunks, list):
|
||||
chunks = []
|
||||
if not isinstance(metadata, dict):
|
||||
metadata = {}
|
||||
|
||||
return QueryDataResponse(
|
||||
entities=entities,
|
||||
relationships=relationships,
|
||||
chunks=chunks,
|
||||
metadata=metadata
|
||||
)
|
||||
else:
|
||||
# Fallback for unexpected response format
|
||||
return QueryDataResponse(
|
||||
entities=[],
|
||||
relationships=[],
|
||||
chunks=[],
|
||||
metadata={"error": "Unexpected response format", "raw_response": str(response)}
|
||||
)
|
||||
except Exception as e:
|
||||
trace_exception(e)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
return router
|
||||
|
|
|
|||
|
|
@ -2136,6 +2136,7 @@ class LightRAG:
|
|||
global_config = asdict(self)
|
||||
|
||||
if param.mode in ["local", "global", "hybrid", "mix"]:
|
||||
logger.debug(f"[aquery_data] Using kg_query for mode: {param.mode}")
|
||||
final_data = await kg_query(
|
||||
query.strip(),
|
||||
self.chunk_entity_relation_graph,
|
||||
|
|
@ -2150,6 +2151,7 @@ class LightRAG:
|
|||
return_raw_data=True, # Get final processed data
|
||||
)
|
||||
elif param.mode == "naive":
|
||||
logger.debug(f"[aquery_data] Using naive_query for mode: {param.mode}")
|
||||
final_data = await naive_query(
|
||||
query.strip(),
|
||||
self.chunks_vdb,
|
||||
|
|
@ -2160,6 +2162,7 @@ class LightRAG:
|
|||
return_raw_data=True, # Get final processed data
|
||||
)
|
||||
elif param.mode == "bypass":
|
||||
logger.debug("[aquery_data] Using bypass mode")
|
||||
# bypass mode returns empty data
|
||||
final_data = {
|
||||
"entities": [],
|
||||
|
|
@ -2173,6 +2176,12 @@ class LightRAG:
|
|||
else:
|
||||
raise ValueError(f"Unknown mode {param.mode}")
|
||||
|
||||
# Log final result counts
|
||||
entities_count = len(final_data.get("entities", []))
|
||||
relationships_count = len(final_data.get("relationships", []))
|
||||
chunks_count = len(final_data.get("chunks", []))
|
||||
logger.debug(f"[aquery_data] Final result: {entities_count} entities, {relationships_count} relationships, {chunks_count} chunks")
|
||||
|
||||
await self._query_done()
|
||||
return final_data
|
||||
|
||||
|
|
|
|||
|
|
@ -2344,7 +2344,9 @@ async def kg_query(
|
|||
)
|
||||
|
||||
if isinstance(context_result, tuple):
|
||||
_, raw_data = context_result
|
||||
context, raw_data = context_result
|
||||
logger.debug(f"[kg_query] Context length: {len(context) if context else 0}")
|
||||
logger.debug(f"[kg_query] Raw data entities: {len(raw_data.get('entities', []))}, relationships: {len(raw_data.get('relationships', []))}, chunks: {len(raw_data.get('chunks', []))}")
|
||||
return raw_data
|
||||
else:
|
||||
raise RuntimeError(
|
||||
|
|
@ -3071,7 +3073,9 @@ async def _build_llm_context(
|
|||
query_param: QueryParam,
|
||||
global_config: dict[str, str],
|
||||
chunk_tracking: dict = None,
|
||||
return_final_data: bool = False,
|
||||
return_raw_data: bool = False,
|
||||
hl_keywords: list[str] = None,
|
||||
ll_keywords: list[str] = None,
|
||||
) -> str | tuple[str, dict[str, Any]]:
|
||||
"""
|
||||
Build the final LLM context string with token processing.
|
||||
|
|
@ -3230,7 +3234,16 @@ async def _build_llm_context(
|
|||
|
||||
# not necessary to use LLM to generate a response
|
||||
if not entities_context and not relations_context:
|
||||
return None
|
||||
if return_raw_data:
|
||||
# Return empty raw data structure when no entities/relations
|
||||
empty_raw_data = _convert_to_user_format(
|
||||
[], [], [], query_param.mode,
|
||||
hl_keywords=hl_keywords,
|
||||
ll_keywords=ll_keywords,
|
||||
)
|
||||
return None, empty_raw_data
|
||||
else:
|
||||
return None
|
||||
|
||||
# output chunks tracking infomations
|
||||
# format: <source><frequency>/<order> (e.g., E5/2 R2/1 C1/1)
|
||||
|
|
@ -3281,10 +3294,17 @@ async def _build_llm_context(
|
|||
"""
|
||||
|
||||
# If final data is requested, return both context and complete data structure
|
||||
if return_final_data:
|
||||
if return_raw_data:
|
||||
logger.debug(f"[_build_llm_context] Converting to user format: {len(entities_context)} entities, {len(relations_context)} relations, {len(truncated_chunks)} chunks")
|
||||
final_data = _convert_to_user_format(
|
||||
entities_context, relations_context, truncated_chunks, query_param.mode
|
||||
entities_context,
|
||||
relations_context,
|
||||
truncated_chunks,
|
||||
query_param.mode,
|
||||
hl_keywords=hl_keywords,
|
||||
ll_keywords=ll_keywords,
|
||||
)
|
||||
logger.debug(f"[_build_llm_context] Final data after conversion: {len(final_data.get('entities', []))} entities, {len(final_data.get('relationships', []))} relationships, {len(final_data.get('chunks', []))} chunks")
|
||||
return result, final_data
|
||||
else:
|
||||
return result
|
||||
|
|
@ -3361,9 +3381,14 @@ async def _build_query_context(
|
|||
return None
|
||||
|
||||
# Stage 4: Build final LLM context with dynamic token processing
|
||||
|
||||
if return_raw_data:
|
||||
# Get both context and final data
|
||||
context_result = await _build_llm_context(
|
||||
# Convert keywords strings to lists
|
||||
hl_keywords_list = hl_keywords.split(", ") if hl_keywords else []
|
||||
ll_keywords_list = ll_keywords.split(", ") if ll_keywords else []
|
||||
|
||||
# Get both context and final data - when return_raw_data=True, _build_llm_context always returns tuple
|
||||
context, raw_data = await _build_llm_context(
|
||||
entities_context=truncation_result["entities_context"],
|
||||
relations_context=truncation_result["relations_context"],
|
||||
merged_chunks=merged_chunks,
|
||||
|
|
@ -3371,45 +3396,13 @@ async def _build_query_context(
|
|||
query_param=query_param,
|
||||
global_config=text_chunks_db.global_config,
|
||||
chunk_tracking=search_result["chunk_tracking"],
|
||||
return_final_data=True,
|
||||
return_raw_data=True,
|
||||
hl_keywords=hl_keywords_list,
|
||||
ll_keywords=ll_keywords_list,
|
||||
)
|
||||
|
||||
if isinstance(context_result, tuple):
|
||||
context, final_chunks = context_result
|
||||
else:
|
||||
# Handle case where no final chunks are returned
|
||||
context = context_result
|
||||
final_chunks = merged_chunks
|
||||
|
||||
# Build raw data structure with the same data that goes to LLM
|
||||
raw_data = {
|
||||
"entities": truncation_result[
|
||||
"filtered_entities"
|
||||
], # Use filtered entities (same as LLM)
|
||||
"relationships": truncation_result[
|
||||
"filtered_relations"
|
||||
], # Use filtered relations (same as LLM)
|
||||
"chunks": final_chunks, # Use final processed chunks (same as LLM)
|
||||
"metadata": {
|
||||
"query_mode": query_param.mode,
|
||||
"keywords": {
|
||||
"high_level": hl_keywords.split(", ") if hl_keywords else [],
|
||||
"low_level": ll_keywords.split(", ") if ll_keywords else [],
|
||||
},
|
||||
"processing_info": {
|
||||
"total_entities_found": len(search_result["final_entities"]),
|
||||
"total_relations_found": len(search_result["final_relations"]),
|
||||
"entities_after_truncation": len(
|
||||
truncation_result["filtered_entities"]
|
||||
),
|
||||
"relations_after_truncation": len(
|
||||
truncation_result["filtered_relations"]
|
||||
),
|
||||
"merged_chunks_count": len(merged_chunks),
|
||||
"final_chunks_count": len(final_chunks),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
logger.debug(f"[_build_query_context] Context length: {len(context) if context else 0}")
|
||||
logger.debug(f"[_build_query_context] Raw data entities: {len(raw_data.get('entities', []))}, relationships: {len(raw_data.get('relationships', []))}, chunks: {len(raw_data.get('chunks', []))}")
|
||||
return context, raw_data
|
||||
else:
|
||||
# Normal context building (existing logic)
|
||||
|
|
|
|||
|
|
@ -2758,14 +2758,15 @@ def _convert_to_user_format(
|
|||
|
||||
# Convert chunks format
|
||||
formatted_chunks = []
|
||||
for chunk in final_chunks:
|
||||
formatted_chunks.append(
|
||||
{
|
||||
"content": chunk.get("content", ""),
|
||||
"file_path": chunk.get("file_path", "unknown_source"),
|
||||
"chunk_id": chunk.get("chunk_id", ""),
|
||||
}
|
||||
)
|
||||
for i, chunk in enumerate(final_chunks):
|
||||
chunk_data = {
|
||||
"content": chunk.get("content", ""),
|
||||
"file_path": chunk.get("file_path", "unknown_source"),
|
||||
"chunk_id": chunk.get("chunk_id", ""),
|
||||
}
|
||||
formatted_chunks.append(chunk_data)
|
||||
|
||||
logger.debug(f"[_convert_to_user_format] Formatted {len(formatted_chunks)}/{len(final_chunks)} chunks")
|
||||
|
||||
# Build metadata with processing info
|
||||
metadata = {
|
||||
|
|
|
|||
208
tests/test_aquery_data_endpoint.py
Normal file
208
tests/test_aquery_data_endpoint.py
Normal file
|
|
@ -0,0 +1,208 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script: Demonstrates usage of aquery_data FastAPI endpoint
|
||||
Query content: Who is the author of LightRAG
|
||||
"""
|
||||
|
||||
import requests
|
||||
import time
|
||||
from typing import Dict, Any
|
||||
|
||||
# API configuration
|
||||
API_KEY = "your-secure-api-key-here-123"
|
||||
BASE_URL = "http://localhost:9621"
|
||||
|
||||
# Unified authentication headers
|
||||
AUTH_HEADERS = {
|
||||
"Content-Type": "application/json",
|
||||
"X-API-Key": API_KEY
|
||||
}
|
||||
|
||||
|
||||
def test_aquery_data_endpoint():
|
||||
"""Test the /query/data endpoint"""
|
||||
|
||||
# Use unified configuration
|
||||
endpoint = f"{BASE_URL}/query/data"
|
||||
|
||||
# Query request
|
||||
query_request = {
|
||||
"query": "who authored LighRAG",
|
||||
"mode": "mix", # Use mixed mode to get the most comprehensive results
|
||||
"top_k": 20,
|
||||
"chunk_top_k": 15,
|
||||
"max_entity_tokens": 4000,
|
||||
"max_relation_tokens": 4000,
|
||||
"max_total_tokens": 16000,
|
||||
"enable_rerank": True
|
||||
}
|
||||
|
||||
print("=" * 60)
|
||||
print("LightRAG aquery_data endpoint test")
|
||||
print(" Returns structured data including entities, relationships and text chunks")
|
||||
print(" Can be used for custom processing and analysis")
|
||||
print("=" * 60)
|
||||
print(f"Query content: {query_request['query']}")
|
||||
print(f"Query mode: {query_request['mode']}")
|
||||
print(f"API endpoint: {endpoint}")
|
||||
print("-" * 60)
|
||||
|
||||
try:
|
||||
# Send request
|
||||
print("Sending request...")
|
||||
start_time = time.time()
|
||||
|
||||
response = requests.post(
|
||||
endpoint,
|
||||
json=query_request,
|
||||
headers=AUTH_HEADERS,
|
||||
timeout=30
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
response_time = end_time - start_time
|
||||
|
||||
print(f"Response time: {response_time:.2f} seconds")
|
||||
print(f"HTTP status code: {response.status_code}")
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
print_query_results(data)
|
||||
else:
|
||||
print(f"Request failed: {response.status_code}")
|
||||
print(f"Error message: {response.text}")
|
||||
|
||||
except requests.exceptions.ConnectionError:
|
||||
print("❌ Connection failed: Please ensure LightRAG API service is running")
|
||||
print(" Start command: python -m lightrag.api.lightrag_server")
|
||||
except requests.exceptions.Timeout:
|
||||
print("❌ Request timeout: Query processing took too long")
|
||||
except Exception as e:
|
||||
print(f"❌ Error occurred: {str(e)}")
|
||||
|
||||
|
||||
def print_query_results(data: Dict[str, Any]):
|
||||
"""Format and print query results"""
|
||||
|
||||
entities = data.get("entities", [])
|
||||
relationships = data.get("relationships", [])
|
||||
chunks = data.get("chunks", [])
|
||||
metadata = data.get("metadata", {})
|
||||
|
||||
print("\n📊 Query result statistics:")
|
||||
print(f" Entity count: {len(entities)}")
|
||||
print(f" Relationship count: {len(relationships)}")
|
||||
print(f" Text chunk count: {len(chunks)}")
|
||||
|
||||
# Print metadata
|
||||
if metadata:
|
||||
print("\n🔍 Query metadata:")
|
||||
print(f" Query mode: {metadata.get('query_mode', 'unknown')}")
|
||||
|
||||
keywords = metadata.get('keywords', {})
|
||||
if keywords:
|
||||
high_level = keywords.get('high_level', [])
|
||||
low_level = keywords.get('low_level', [])
|
||||
if high_level:
|
||||
print(f" High-level keywords: {', '.join(high_level)}")
|
||||
if low_level:
|
||||
print(f" Low-level keywords: {', '.join(low_level)}")
|
||||
|
||||
processing_info = metadata.get('processing_info', {})
|
||||
if processing_info:
|
||||
print(" Processing info:")
|
||||
for key, value in processing_info.items():
|
||||
print(f" {key}: {value}")
|
||||
|
||||
# Print entity information
|
||||
if entities:
|
||||
print("\n👥 Retrieved entities (first 5):")
|
||||
for i, entity in enumerate(entities[:5]):
|
||||
entity_name = entity.get('entity_name', 'Unknown')
|
||||
entity_type = entity.get('entity_type', 'Unknown')
|
||||
description = entity.get('description', 'No description')
|
||||
file_path = entity.get('file_path', 'Unknown source')
|
||||
|
||||
print(f" {i+1}. {entity_name} ({entity_type})")
|
||||
print(f" Description: {description[:100]}{'...' if len(description) > 100 else ''}")
|
||||
print(f" Source: {file_path}")
|
||||
print()
|
||||
|
||||
# Print relationship information
|
||||
if relationships:
|
||||
print("🔗 Retrieved relationships (first 5):")
|
||||
for i, rel in enumerate(relationships[:5]):
|
||||
src = rel.get('src_id', 'Unknown')
|
||||
tgt = rel.get('tgt_id', 'Unknown')
|
||||
description = rel.get('description', 'No description')
|
||||
keywords = rel.get('keywords', 'No keywords')
|
||||
file_path = rel.get('file_path', 'Unknown source')
|
||||
|
||||
print(f" {i+1}. {src} → {tgt}")
|
||||
print(f" Keywords: {keywords}")
|
||||
print(f" Description: {description[:100]}{'...' if len(description) > 100 else ''}")
|
||||
print(f" Source: {file_path}")
|
||||
print()
|
||||
|
||||
# Print text chunk information
|
||||
if chunks:
|
||||
print("📄 Retrieved text chunks (first 3):")
|
||||
for i, chunk in enumerate(chunks[:3]):
|
||||
content = chunk.get('content', 'No content')
|
||||
file_path = chunk.get('file_path', 'Unknown source')
|
||||
chunk_id = chunk.get('chunk_id', 'Unknown ID')
|
||||
|
||||
print(f" {i+1}. Text chunk ID: {chunk_id}")
|
||||
print(f" Source: {file_path}")
|
||||
print(f" Content: {content[:200]}{'...' if len(content) > 200 else ''}")
|
||||
print()
|
||||
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
def compare_with_regular_query():
|
||||
"""Compare results between regular query and data query"""
|
||||
|
||||
query_text = "LightRAG的作者是谁"
|
||||
|
||||
print("\n🔄 Comparison test: Regular query vs Data query")
|
||||
print("-" * 60)
|
||||
|
||||
# Regular query
|
||||
try:
|
||||
print("1. Regular query (/query):")
|
||||
regular_response = requests.post(
|
||||
f"{BASE_URL}/query",
|
||||
json={"query": query_text, "mode": "mix"},
|
||||
headers=AUTH_HEADERS,
|
||||
timeout=30
|
||||
)
|
||||
|
||||
if regular_response.status_code == 200:
|
||||
regular_data = regular_response.json()
|
||||
response_text = regular_data.get('response', 'No response')
|
||||
print(f" Generated answer: {response_text[:300]}{'...' if len(response_text) > 300 else ''}")
|
||||
else:
|
||||
print(f" Regular query failed: {regular_response.status_code}")
|
||||
if regular_response.status_code == 403:
|
||||
print(" Authentication failed - Please check API Key configuration")
|
||||
elif regular_response.status_code == 401:
|
||||
print(" Unauthorized - Please check authentication information")
|
||||
print(f" Error details: {regular_response.text}")
|
||||
|
||||
except Exception as e:
|
||||
print(f" Regular query error: {str(e)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run main test
|
||||
test_aquery_data_endpoint()
|
||||
|
||||
# Run comparison test
|
||||
compare_with_regular_query()
|
||||
|
||||
print("\n💡 Usage tips:")
|
||||
print("1. Ensure LightRAG API service is running")
|
||||
print("2. Adjust base_url and authentication information as needed")
|
||||
print("3. Modify query parameters to test different retrieval strategies")
|
||||
print("4. Data query results can be used for further analysis and processing")
|
||||
Loading…
Add table
Reference in a new issue