Fix linting
This commit is contained in:
parent
b1c8206346
commit
c0d5abba6b
4 changed files with 79 additions and 66 deletions
|
|
@ -237,7 +237,9 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
"/query/data", response_model=QueryDataResponse, dependencies=[Depends(combined_auth)]
|
"/query/data",
|
||||||
|
response_model=QueryDataResponse,
|
||||||
|
dependencies=[Depends(combined_auth)],
|
||||||
)
|
)
|
||||||
async def query_data(request: QueryRequest):
|
async def query_data(request: QueryRequest):
|
||||||
"""
|
"""
|
||||||
|
|
@ -269,7 +271,7 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
|
||||||
relationships = response.get("relationships", [])
|
relationships = response.get("relationships", [])
|
||||||
chunks = response.get("chunks", [])
|
chunks = response.get("chunks", [])
|
||||||
metadata = response.get("metadata", {})
|
metadata = response.get("metadata", {})
|
||||||
|
|
||||||
# Validate data types
|
# Validate data types
|
||||||
if not isinstance(entities, list):
|
if not isinstance(entities, list):
|
||||||
entities = []
|
entities = []
|
||||||
|
|
@ -279,12 +281,12 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
|
||||||
chunks = []
|
chunks = []
|
||||||
if not isinstance(metadata, dict):
|
if not isinstance(metadata, dict):
|
||||||
metadata = {}
|
metadata = {}
|
||||||
|
|
||||||
return QueryDataResponse(
|
return QueryDataResponse(
|
||||||
entities=entities,
|
entities=entities,
|
||||||
relationships=relationships,
|
relationships=relationships,
|
||||||
chunks=chunks,
|
chunks=chunks,
|
||||||
metadata=metadata
|
metadata=metadata,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Fallback for unexpected response format
|
# Fallback for unexpected response format
|
||||||
|
|
@ -292,7 +294,10 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
|
||||||
entities=[],
|
entities=[],
|
||||||
relationships=[],
|
relationships=[],
|
||||||
chunks=[],
|
chunks=[],
|
||||||
metadata={"error": "Unexpected response format", "raw_response": str(response)}
|
metadata={
|
||||||
|
"error": "Unexpected response format",
|
||||||
|
"raw_response": str(response),
|
||||||
|
},
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
trace_exception(e)
|
trace_exception(e)
|
||||||
|
|
|
||||||
|
|
@ -2180,7 +2180,9 @@ class LightRAG:
|
||||||
entities_count = len(final_data.get("entities", []))
|
entities_count = len(final_data.get("entities", []))
|
||||||
relationships_count = len(final_data.get("relationships", []))
|
relationships_count = len(final_data.get("relationships", []))
|
||||||
chunks_count = len(final_data.get("chunks", []))
|
chunks_count = len(final_data.get("chunks", []))
|
||||||
logger.debug(f"[aquery_data] Final result: {entities_count} entities, {relationships_count} relationships, {chunks_count} chunks")
|
logger.debug(
|
||||||
|
f"[aquery_data] Final result: {entities_count} entities, {relationships_count} relationships, {chunks_count} chunks"
|
||||||
|
)
|
||||||
|
|
||||||
await self._query_done()
|
await self._query_done()
|
||||||
return final_data
|
return final_data
|
||||||
|
|
|
||||||
|
|
@ -2346,7 +2346,9 @@ async def kg_query(
|
||||||
if isinstance(context_result, tuple):
|
if isinstance(context_result, tuple):
|
||||||
context, raw_data = context_result
|
context, raw_data = context_result
|
||||||
logger.debug(f"[kg_query] Context length: {len(context) if context else 0}")
|
logger.debug(f"[kg_query] Context length: {len(context) if context else 0}")
|
||||||
logger.debug(f"[kg_query] Raw data entities: {len(raw_data.get('entities', []))}, relationships: {len(raw_data.get('relationships', []))}, chunks: {len(raw_data.get('chunks', []))}")
|
logger.debug(
|
||||||
|
f"[kg_query] Raw data entities: {len(raw_data.get('entities', []))}, relationships: {len(raw_data.get('relationships', []))}, chunks: {len(raw_data.get('chunks', []))}"
|
||||||
|
)
|
||||||
return raw_data
|
return raw_data
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
|
|
|
||||||
|
|
@ -13,18 +13,15 @@ API_KEY = "your-secure-api-key-here-123"
|
||||||
BASE_URL = "http://localhost:9621"
|
BASE_URL = "http://localhost:9621"
|
||||||
|
|
||||||
# Unified authentication headers
|
# Unified authentication headers
|
||||||
AUTH_HEADERS = {
|
AUTH_HEADERS = {"Content-Type": "application/json", "X-API-Key": API_KEY}
|
||||||
"Content-Type": "application/json",
|
|
||||||
"X-API-Key": API_KEY
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def test_aquery_data_endpoint():
|
def test_aquery_data_endpoint():
|
||||||
"""Test the /query/data endpoint"""
|
"""Test the /query/data endpoint"""
|
||||||
|
|
||||||
# Use unified configuration
|
# Use unified configuration
|
||||||
endpoint = f"{BASE_URL}/query/data"
|
endpoint = f"{BASE_URL}/query/data"
|
||||||
|
|
||||||
# Query request
|
# Query request
|
||||||
query_request = {
|
query_request = {
|
||||||
"query": "who authored LighRAG",
|
"query": "who authored LighRAG",
|
||||||
|
|
@ -34,12 +31,14 @@ def test_aquery_data_endpoint():
|
||||||
"max_entity_tokens": 4000,
|
"max_entity_tokens": 4000,
|
||||||
"max_relation_tokens": 4000,
|
"max_relation_tokens": 4000,
|
||||||
"max_total_tokens": 16000,
|
"max_total_tokens": 16000,
|
||||||
"enable_rerank": True
|
"enable_rerank": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
print("LightRAG aquery_data endpoint test")
|
print("LightRAG aquery_data endpoint test")
|
||||||
print(" Returns structured data including entities, relationships and text chunks")
|
print(
|
||||||
|
" Returns structured data including entities, relationships and text chunks"
|
||||||
|
)
|
||||||
print(" Can be used for custom processing and analysis")
|
print(" Can be used for custom processing and analysis")
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
print(f"Query content: {query_request['query']}")
|
print(f"Query content: {query_request['query']}")
|
||||||
|
|
@ -51,27 +50,24 @@ def test_aquery_data_endpoint():
|
||||||
# Send request
|
# Send request
|
||||||
print("Sending request...")
|
print("Sending request...")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
endpoint,
|
endpoint, json=query_request, headers=AUTH_HEADERS, timeout=30
|
||||||
json=query_request,
|
|
||||||
headers=AUTH_HEADERS,
|
|
||||||
timeout=30
|
|
||||||
)
|
)
|
||||||
|
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
response_time = end_time - start_time
|
response_time = end_time - start_time
|
||||||
|
|
||||||
print(f"Response time: {response_time:.2f} seconds")
|
print(f"Response time: {response_time:.2f} seconds")
|
||||||
print(f"HTTP status code: {response.status_code}")
|
print(f"HTTP status code: {response.status_code}")
|
||||||
|
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
data = response.json()
|
data = response.json()
|
||||||
print_query_results(data)
|
print_query_results(data)
|
||||||
else:
|
else:
|
||||||
print(f"Request failed: {response.status_code}")
|
print(f"Request failed: {response.status_code}")
|
||||||
print(f"Error message: {response.text}")
|
print(f"Error message: {response.text}")
|
||||||
|
|
||||||
except requests.exceptions.ConnectionError:
|
except requests.exceptions.ConnectionError:
|
||||||
print("❌ Connection failed: Please ensure LightRAG API service is running")
|
print("❌ Connection failed: Please ensure LightRAG API service is running")
|
||||||
print(" Start command: python -m lightrag.api.lightrag_server")
|
print(" Start command: python -m lightrag.api.lightrag_server")
|
||||||
|
|
@ -83,91 +79,97 @@ def test_aquery_data_endpoint():
|
||||||
|
|
||||||
def print_query_results(data: Dict[str, Any]):
|
def print_query_results(data: Dict[str, Any]):
|
||||||
"""Format and print query results"""
|
"""Format and print query results"""
|
||||||
|
|
||||||
entities = data.get("entities", [])
|
entities = data.get("entities", [])
|
||||||
relationships = data.get("relationships", [])
|
relationships = data.get("relationships", [])
|
||||||
chunks = data.get("chunks", [])
|
chunks = data.get("chunks", [])
|
||||||
metadata = data.get("metadata", {})
|
metadata = data.get("metadata", {})
|
||||||
|
|
||||||
print("\n📊 Query result statistics:")
|
print("\n📊 Query result statistics:")
|
||||||
print(f" Entity count: {len(entities)}")
|
print(f" Entity count: {len(entities)}")
|
||||||
print(f" Relationship count: {len(relationships)}")
|
print(f" Relationship count: {len(relationships)}")
|
||||||
print(f" Text chunk count: {len(chunks)}")
|
print(f" Text chunk count: {len(chunks)}")
|
||||||
|
|
||||||
# Print metadata
|
# Print metadata
|
||||||
if metadata:
|
if metadata:
|
||||||
print("\n🔍 Query metadata:")
|
print("\n🔍 Query metadata:")
|
||||||
print(f" Query mode: {metadata.get('query_mode', 'unknown')}")
|
print(f" Query mode: {metadata.get('query_mode', 'unknown')}")
|
||||||
|
|
||||||
keywords = metadata.get('keywords', {})
|
keywords = metadata.get("keywords", {})
|
||||||
if keywords:
|
if keywords:
|
||||||
high_level = keywords.get('high_level', [])
|
high_level = keywords.get("high_level", [])
|
||||||
low_level = keywords.get('low_level', [])
|
low_level = keywords.get("low_level", [])
|
||||||
if high_level:
|
if high_level:
|
||||||
print(f" High-level keywords: {', '.join(high_level)}")
|
print(f" High-level keywords: {', '.join(high_level)}")
|
||||||
if low_level:
|
if low_level:
|
||||||
print(f" Low-level keywords: {', '.join(low_level)}")
|
print(f" Low-level keywords: {', '.join(low_level)}")
|
||||||
|
|
||||||
processing_info = metadata.get('processing_info', {})
|
processing_info = metadata.get("processing_info", {})
|
||||||
if processing_info:
|
if processing_info:
|
||||||
print(" Processing info:")
|
print(" Processing info:")
|
||||||
for key, value in processing_info.items():
|
for key, value in processing_info.items():
|
||||||
print(f" {key}: {value}")
|
print(f" {key}: {value}")
|
||||||
|
|
||||||
# Print entity information
|
# Print entity information
|
||||||
if entities:
|
if entities:
|
||||||
print("\n👥 Retrieved entities (first 5):")
|
print("\n👥 Retrieved entities (first 5):")
|
||||||
for i, entity in enumerate(entities[:5]):
|
for i, entity in enumerate(entities[:5]):
|
||||||
entity_name = entity.get('entity_name', 'Unknown')
|
entity_name = entity.get("entity_name", "Unknown")
|
||||||
entity_type = entity.get('entity_type', 'Unknown')
|
entity_type = entity.get("entity_type", "Unknown")
|
||||||
description = entity.get('description', 'No description')
|
description = entity.get("description", "No description")
|
||||||
file_path = entity.get('file_path', 'Unknown source')
|
file_path = entity.get("file_path", "Unknown source")
|
||||||
|
|
||||||
print(f" {i+1}. {entity_name} ({entity_type})")
|
print(f" {i+1}. {entity_name} ({entity_type})")
|
||||||
print(f" Description: {description[:100]}{'...' if len(description) > 100 else ''}")
|
print(
|
||||||
|
f" Description: {description[:100]}{'...' if len(description) > 100 else ''}"
|
||||||
|
)
|
||||||
print(f" Source: {file_path}")
|
print(f" Source: {file_path}")
|
||||||
print()
|
print()
|
||||||
|
|
||||||
# Print relationship information
|
# Print relationship information
|
||||||
if relationships:
|
if relationships:
|
||||||
print("🔗 Retrieved relationships (first 5):")
|
print("🔗 Retrieved relationships (first 5):")
|
||||||
for i, rel in enumerate(relationships[:5]):
|
for i, rel in enumerate(relationships[:5]):
|
||||||
src = rel.get('src_id', 'Unknown')
|
src = rel.get("src_id", "Unknown")
|
||||||
tgt = rel.get('tgt_id', 'Unknown')
|
tgt = rel.get("tgt_id", "Unknown")
|
||||||
description = rel.get('description', 'No description')
|
description = rel.get("description", "No description")
|
||||||
keywords = rel.get('keywords', 'No keywords')
|
keywords = rel.get("keywords", "No keywords")
|
||||||
file_path = rel.get('file_path', 'Unknown source')
|
file_path = rel.get("file_path", "Unknown source")
|
||||||
|
|
||||||
print(f" {i+1}. {src} → {tgt}")
|
print(f" {i+1}. {src} → {tgt}")
|
||||||
print(f" Keywords: {keywords}")
|
print(f" Keywords: {keywords}")
|
||||||
print(f" Description: {description[:100]}{'...' if len(description) > 100 else ''}")
|
print(
|
||||||
|
f" Description: {description[:100]}{'...' if len(description) > 100 else ''}"
|
||||||
|
)
|
||||||
print(f" Source: {file_path}")
|
print(f" Source: {file_path}")
|
||||||
print()
|
print()
|
||||||
|
|
||||||
# Print text chunk information
|
# Print text chunk information
|
||||||
if chunks:
|
if chunks:
|
||||||
print("📄 Retrieved text chunks (first 3):")
|
print("📄 Retrieved text chunks (first 3):")
|
||||||
for i, chunk in enumerate(chunks[:3]):
|
for i, chunk in enumerate(chunks[:3]):
|
||||||
content = chunk.get('content', 'No content')
|
content = chunk.get("content", "No content")
|
||||||
file_path = chunk.get('file_path', 'Unknown source')
|
file_path = chunk.get("file_path", "Unknown source")
|
||||||
chunk_id = chunk.get('chunk_id', 'Unknown ID')
|
chunk_id = chunk.get("chunk_id", "Unknown ID")
|
||||||
|
|
||||||
print(f" {i+1}. Text chunk ID: {chunk_id}")
|
print(f" {i+1}. Text chunk ID: {chunk_id}")
|
||||||
print(f" Source: {file_path}")
|
print(f" Source: {file_path}")
|
||||||
print(f" Content: {content[:200]}{'...' if len(content) > 200 else ''}")
|
print(
|
||||||
|
f" Content: {content[:200]}{'...' if len(content) > 200 else ''}"
|
||||||
|
)
|
||||||
print()
|
print()
|
||||||
|
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
|
|
||||||
|
|
||||||
def compare_with_regular_query():
|
def compare_with_regular_query():
|
||||||
"""Compare results between regular query and data query"""
|
"""Compare results between regular query and data query"""
|
||||||
|
|
||||||
query_text = "LightRAG的作者是谁"
|
query_text = "LightRAG的作者是谁"
|
||||||
|
|
||||||
print("\n🔄 Comparison test: Regular query vs Data query")
|
print("\n🔄 Comparison test: Regular query vs Data query")
|
||||||
print("-" * 60)
|
print("-" * 60)
|
||||||
|
|
||||||
# Regular query
|
# Regular query
|
||||||
try:
|
try:
|
||||||
print("1. Regular query (/query):")
|
print("1. Regular query (/query):")
|
||||||
|
|
@ -175,13 +177,15 @@ def compare_with_regular_query():
|
||||||
f"{BASE_URL}/query",
|
f"{BASE_URL}/query",
|
||||||
json={"query": query_text, "mode": "mix"},
|
json={"query": query_text, "mode": "mix"},
|
||||||
headers=AUTH_HEADERS,
|
headers=AUTH_HEADERS,
|
||||||
timeout=30
|
timeout=30,
|
||||||
)
|
)
|
||||||
|
|
||||||
if regular_response.status_code == 200:
|
if regular_response.status_code == 200:
|
||||||
regular_data = regular_response.json()
|
regular_data = regular_response.json()
|
||||||
response_text = regular_data.get('response', 'No response')
|
response_text = regular_data.get("response", "No response")
|
||||||
print(f" Generated answer: {response_text[:300]}{'...' if len(response_text) > 300 else ''}")
|
print(
|
||||||
|
f" Generated answer: {response_text[:300]}{'...' if len(response_text) > 300 else ''}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
print(f" Regular query failed: {regular_response.status_code}")
|
print(f" Regular query failed: {regular_response.status_code}")
|
||||||
if regular_response.status_code == 403:
|
if regular_response.status_code == 403:
|
||||||
|
|
@ -189,7 +193,7 @@ def compare_with_regular_query():
|
||||||
elif regular_response.status_code == 401:
|
elif regular_response.status_code == 401:
|
||||||
print(" Unauthorized - Please check authentication information")
|
print(" Unauthorized - Please check authentication information")
|
||||||
print(f" Error details: {regular_response.text}")
|
print(f" Error details: {regular_response.text}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f" Regular query error: {str(e)}")
|
print(f" Regular query error: {str(e)}")
|
||||||
|
|
||||||
|
|
@ -197,10 +201,10 @@ def compare_with_regular_query():
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Run main test
|
# Run main test
|
||||||
test_aquery_data_endpoint()
|
test_aquery_data_endpoint()
|
||||||
|
|
||||||
# Run comparison test
|
# Run comparison test
|
||||||
compare_with_regular_query()
|
compare_with_regular_query()
|
||||||
|
|
||||||
print("\n💡 Usage tips:")
|
print("\n💡 Usage tips:")
|
||||||
print("1. Ensure LightRAG API service is running")
|
print("1. Ensure LightRAG API service is running")
|
||||||
print("2. Adjust base_url and authentication information as needed")
|
print("2. Adjust base_url and authentication information as needed")
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue