LightRAG/tests/test_aquery_data_endpoint.py
clssck 69358d830d test(lightrag,examples,api): comprehensive ruff formatting and type hints
Format entire codebase with ruff and add type hints across all modules:
- Apply ruff formatting to all Python files (121 files, 17K insertions)
- Add type hints to function signatures throughout lightrag core and API
- Update test suite with improved type annotations and docstrings
- Add pyrightconfig.json for static type checking configuration
- Create prompt_optimized.py and test_extraction_prompt_ab.py test files
- Update ruff.toml and .gitignore for improved linting configuration
- Standardize code style across examples, reproduce scripts, and utilities
2025-12-05 15:17:06 +01:00

753 lines
26 KiB
Python

#!/usr/bin/env python3
"""
Test script: Demonstrates usage of aquery_data FastAPI endpoint
Query content: Who is the author of LightRAG
Updated to handle the new data format where:
- Response includes status, message, data, and metadata fields at top level
- Actual query results (entities, relationships, chunks, references) are nested under 'data' field
- Includes backward compatibility with legacy format
"""
import json
import time
from typing import Any
import pytest
import requests
# API configuration
API_KEY = 'your-secure-api-key-here-123'
BASE_URL = 'http://localhost:9621'
# Unified authentication headers
AUTH_HEADERS = {'Content-Type': 'application/json', 'X-API-Key': API_KEY}
def validate_references_format(references: list[dict[str, Any]]) -> bool:
"""Validate the format of references list"""
if not isinstance(references, list):
print(f'❌ References should be a list, got {type(references)}')
return False
for i, ref in enumerate(references):
if not isinstance(ref, dict):
print(f'❌ Reference {i} should be a dict, got {type(ref)}')
return False
required_fields = ['reference_id', 'file_path']
for field in required_fields:
if field not in ref:
print(f'❌ Reference {i} missing required field: {field}')
return False
if not isinstance(ref[field], str):
print(f"❌ Reference {i} field '{field}' should be string, got {type(ref[field])}")
return False
return True
def parse_streaming_response(
response_text: str,
) -> tuple[list[dict] | None, list[str], list[str]]:
"""Parse streaming response and extract references, response chunks, and errors"""
references = None
response_chunks = []
errors = []
lines = response_text.strip().split('\n')
for line in lines:
line = line.strip()
if line.startswith('data: '):
line = line[6:] # Remove 'data: ' prefix
if not line:
continue
try:
data = json.loads(line)
if 'references' in data:
references = data['references']
if 'response' in data:
response_chunks.append(data['response'])
if 'error' in data:
errors.append(data['error'])
except json.JSONDecodeError:
# Skip non-JSON lines (like SSE comments)
continue
return references, response_chunks, errors
@pytest.mark.integration
@pytest.mark.requires_api
def test_query_endpoint_references():
"""Test /query endpoint references functionality"""
print('\n' + '=' * 60)
print('Testing /query endpoint references functionality')
print('=' * 60)
query_text = 'who authored LightRAG'
endpoint = f'{BASE_URL}/query'
# Test 1: References enabled (default)
print('\n🧪 Test 1: References enabled (default)')
print('-' * 40)
try:
response = requests.post(
endpoint,
json={'query': query_text, 'mode': 'mix', 'include_references': True},
headers=AUTH_HEADERS,
timeout=30,
)
if response.status_code == 200:
data = response.json()
# Check response structure
if 'response' not in data:
print("❌ Missing 'response' field")
return False
if 'references' not in data:
print("❌ Missing 'references' field when include_references=True")
return False
references = data['references']
if references is None:
print('❌ References should not be None when include_references=True')
return False
if not validate_references_format(references):
return False
print(f'✅ References enabled: Found {len(references)} references')
print(f' Response length: {len(data["response"])} characters')
# Display reference list
if references:
print(' 📚 Reference List:')
for i, ref in enumerate(references, 1):
ref_id = ref.get('reference_id', 'Unknown')
file_path = ref.get('file_path', 'Unknown')
print(f' {i}. ID: {ref_id} | File: {file_path}')
else:
print(f'❌ Request failed: {response.status_code}')
print(f' Error: {response.text}')
return False
except Exception as e:
print(f'❌ Test 1 failed: {e!s}')
return False
# Test 2: References disabled
print('\n🧪 Test 2: References disabled')
print('-' * 40)
try:
response = requests.post(
endpoint,
json={'query': query_text, 'mode': 'mix', 'include_references': False},
headers=AUTH_HEADERS,
timeout=30,
)
if response.status_code == 200:
data = response.json()
# Check response structure
if 'response' not in data:
print("❌ Missing 'response' field")
return False
references = data.get('references')
if references is not None:
print('❌ References should be None when include_references=False')
return False
print('✅ References disabled: No references field present')
print(f' Response length: {len(data["response"])} characters')
else:
print(f'❌ Request failed: {response.status_code}')
print(f' Error: {response.text}')
return False
except Exception as e:
print(f'❌ Test 2 failed: {e!s}')
return False
print('\n✅ /query endpoint references tests passed!')
return True
@pytest.mark.integration
@pytest.mark.requires_api
def test_query_stream_endpoint_references():
"""Test /query/stream endpoint references functionality"""
print('\n' + '=' * 60)
print('Testing /query/stream endpoint references functionality')
print('=' * 60)
query_text = 'who authored LightRAG'
endpoint = f'{BASE_URL}/query/stream'
# Test 1: Streaming with references enabled
print('\n🧪 Test 1: Streaming with references enabled')
print('-' * 40)
try:
response = requests.post(
endpoint,
json={'query': query_text, 'mode': 'mix', 'include_references': True},
headers=AUTH_HEADERS,
timeout=30,
stream=True,
)
if response.status_code == 200:
# Collect streaming response
full_response = ''
for chunk in response.iter_content(chunk_size=1024, decode_unicode=True):
if chunk:
# Ensure chunk is string type
if isinstance(chunk, bytes):
chunk = chunk.decode('utf-8')
full_response += chunk
# Parse streaming response
references, response_chunks, errors = parse_streaming_response(full_response)
if errors:
print(f'❌ Errors in streaming response: {errors}')
return False
if references is None:
print('❌ No references found in streaming response')
return False
if not validate_references_format(references):
return False
if not response_chunks:
print('❌ No response chunks found in streaming response')
return False
print(f'✅ Streaming with references: Found {len(references)} references')
print(f' Response chunks: {len(response_chunks)}')
print(f' Total response length: {sum(len(chunk) for chunk in response_chunks)} characters')
# Display reference list
if references:
print(' 📚 Reference List:')
for i, ref in enumerate(references, 1):
ref_id = ref.get('reference_id', 'Unknown')
file_path = ref.get('file_path', 'Unknown')
print(f' {i}. ID: {ref_id} | File: {file_path}')
else:
print(f'❌ Request failed: {response.status_code}')
print(f' Error: {response.text}')
return False
except Exception as e:
print(f'❌ Test 1 failed: {e!s}')
return False
# Test 2: Streaming with references disabled
print('\n🧪 Test 2: Streaming with references disabled')
print('-' * 40)
try:
response = requests.post(
endpoint,
json={'query': query_text, 'mode': 'mix', 'include_references': False},
headers=AUTH_HEADERS,
timeout=30,
stream=True,
)
if response.status_code == 200:
# Collect streaming response
full_response = ''
for chunk in response.iter_content(chunk_size=1024, decode_unicode=True):
if chunk:
# Ensure chunk is string type
if isinstance(chunk, bytes):
chunk = chunk.decode('utf-8')
full_response += chunk
# Parse streaming response
references, response_chunks, errors = parse_streaming_response(full_response)
if errors:
print(f'❌ Errors in streaming response: {errors}')
return False
if references is not None:
print('❌ References should be None when include_references=False')
return False
if not response_chunks:
print('❌ No response chunks found in streaming response')
return False
print('✅ Streaming without references: No references present')
print(f' Response chunks: {len(response_chunks)}')
print(f' Total response length: {sum(len(chunk) for chunk in response_chunks)} characters')
else:
print(f'❌ Request failed: {response.status_code}')
print(f' Error: {response.text}')
return False
except Exception as e:
print(f'❌ Test 2 failed: {e!s}')
return False
print('\n✅ /query/stream endpoint references tests passed!')
return True
@pytest.mark.integration
@pytest.mark.requires_api
def test_references_consistency():
"""Test references consistency across all endpoints"""
print('\n' + '=' * 60)
print('Testing references consistency across endpoints')
print('=' * 60)
query_text = 'who authored LightRAG'
query_params = {
'query': query_text,
'mode': 'mix',
'top_k': 10,
'chunk_top_k': 8,
'include_references': True,
}
references_data = {}
# Test /query endpoint
print('\n🧪 Testing /query endpoint')
print('-' * 40)
try:
response = requests.post(f'{BASE_URL}/query', json=query_params, headers=AUTH_HEADERS, timeout=30)
if response.status_code == 200:
data = response.json()
references_data['query'] = data.get('references', [])
print(f'✅ /query: {len(references_data["query"])} references')
else:
print(f'❌ /query failed: {response.status_code}')
return False
except Exception as e:
print(f'❌ /query test failed: {e!s}')
return False
# Test /query/stream endpoint
print('\n🧪 Testing /query/stream endpoint')
print('-' * 40)
try:
response = requests.post(
f'{BASE_URL}/query/stream',
json=query_params,
headers=AUTH_HEADERS,
timeout=30,
stream=True,
)
if response.status_code == 200:
full_response = ''
for chunk in response.iter_content(chunk_size=1024, decode_unicode=True):
if chunk:
# Ensure chunk is string type
if isinstance(chunk, bytes):
chunk = chunk.decode('utf-8')
full_response += chunk
references, _, errors = parse_streaming_response(full_response)
if errors:
print(f'❌ Errors: {errors}')
return False
references_data['stream'] = references or []
print(f'✅ /query/stream: {len(references_data["stream"])} references')
else:
print(f'❌ /query/stream failed: {response.status_code}')
return False
except Exception as e:
print(f'❌ /query/stream test failed: {e!s}')
return False
# Test /query/data endpoint
print('\n🧪 Testing /query/data endpoint')
print('-' * 40)
try:
response = requests.post(
f'{BASE_URL}/query/data',
json=query_params,
headers=AUTH_HEADERS,
timeout=30,
)
if response.status_code == 200:
data = response.json()
query_data = data.get('data', {})
references_data['data'] = query_data.get('references', [])
print(f'✅ /query/data: {len(references_data["data"])} references')
else:
print(f'❌ /query/data failed: {response.status_code}')
return False
except Exception as e:
print(f'❌ /query/data test failed: {e!s}')
return False
# Compare references consistency
print('\n🔍 Comparing references consistency')
print('-' * 40)
# Convert to sets of (reference_id, file_path) tuples for comparison
def refs_to_set(refs):
return {(ref.get('reference_id', ''), ref.get('file_path', '')) for ref in refs}
query_refs = refs_to_set(references_data['query'])
stream_refs = refs_to_set(references_data['stream'])
data_refs = refs_to_set(references_data['data'])
# Check consistency
consistency_passed = True
if query_refs != stream_refs:
print('❌ References mismatch between /query and /query/stream')
print(f' /query only: {query_refs - stream_refs}')
print(f' /query/stream only: {stream_refs - query_refs}')
consistency_passed = False
if query_refs != data_refs:
print('❌ References mismatch between /query and /query/data')
print(f' /query only: {query_refs - data_refs}')
print(f' /query/data only: {data_refs - query_refs}')
consistency_passed = False
if stream_refs != data_refs:
print('❌ References mismatch between /query/stream and /query/data')
print(f' /query/stream only: {stream_refs - data_refs}')
print(f' /query/data only: {data_refs - stream_refs}')
consistency_passed = False
if consistency_passed:
print('✅ All endpoints return consistent references')
print(f' Common references count: {len(query_refs)}')
# Display common reference list
if query_refs:
print(' 📚 Common Reference List:')
for i, (ref_id, file_path) in enumerate(sorted(query_refs), 1):
print(f' {i}. ID: {ref_id} | File: {file_path}')
return consistency_passed
@pytest.mark.integration
@pytest.mark.requires_api
def test_aquery_data_endpoint():
"""Test the /query/data endpoint"""
# Use unified configuration
endpoint = f'{BASE_URL}/query/data'
# Query request
query_request = {
'query': 'who authored LighRAG',
'mode': 'mix', # Use mixed mode to get the most comprehensive results
'top_k': 20,
'chunk_top_k': 15,
'max_entity_tokens': 4000,
'max_relation_tokens': 4000,
'max_total_tokens': 16000,
'enable_rerank': True,
}
print('=' * 60)
print('LightRAG aquery_data endpoint test')
print(' Returns structured data including entities, relationships and text chunks')
print(' Can be used for custom processing and analysis')
print('=' * 60)
print(f'Query content: {query_request["query"]}')
print(f'Query mode: {query_request["mode"]}')
print(f'API endpoint: {endpoint}')
print('-' * 60)
try:
# Send request
print('Sending request...')
start_time = time.time()
response = requests.post(endpoint, json=query_request, headers=AUTH_HEADERS, timeout=30)
end_time = time.time()
response_time = end_time - start_time
print(f'Response time: {response_time:.2f} seconds')
print(f'HTTP status code: {response.status_code}')
if response.status_code == 200:
data = response.json()
print_query_results(data)
else:
print(f'Request failed: {response.status_code}')
print(f'Error message: {response.text}')
except requests.exceptions.ConnectionError:
print('❌ Connection failed: Please ensure LightRAG API service is running')
print(' Start command: python -m lightrag.api.lightrag_server')
except requests.exceptions.Timeout:
print('❌ Request timeout: Query processing took too long')
except Exception as e:
print(f'❌ Error occurred: {e!s}')
def print_query_results(data: dict[str, Any]):
"""Format and print query results"""
# Check for new data format with status and message
status = data.get('status', 'unknown')
message = data.get('message', '')
print(f'\n📋 Query Status: {status}')
if message:
print(f'📋 Message: {message}')
# Handle new nested data format
query_data = data.get('data', {})
# Fallback to old format if new format is not present
if not query_data and any(key in data for key in ['entities', 'relationships', 'chunks']):
print(' (Using legacy data format)')
query_data = data
entities = query_data.get('entities', [])
relationships = query_data.get('relationships', [])
chunks = query_data.get('chunks', [])
references = query_data.get('references', [])
print('\n📊 Query result statistics:')
print(f' Entity count: {len(entities)}')
print(f' Relationship count: {len(relationships)}')
print(f' Text chunk count: {len(chunks)}')
print(f' Reference count: {len(references)}')
# Print metadata (now at top level in new format)
metadata = data.get('metadata', {})
if metadata:
print('\n🔍 Query metadata:')
print(f' Query mode: {metadata.get("query_mode", "unknown")}')
keywords = metadata.get('keywords', {})
if keywords:
high_level = keywords.get('high_level', [])
low_level = keywords.get('low_level', [])
if high_level:
print(f' High-level keywords: {", ".join(high_level)}')
if low_level:
print(f' Low-level keywords: {", ".join(low_level)}')
processing_info = metadata.get('processing_info', {})
if processing_info:
print(' Processing info:')
for key, value in processing_info.items():
print(f' {key}: {value}')
# Print entity information
if entities:
print('\n👥 Retrieved entities (first 5):')
for i, entity in enumerate(entities[:5]):
entity_name = entity.get('entity_name', 'Unknown')
entity_type = entity.get('entity_type', 'Unknown')
description = entity.get('description', 'No description')
file_path = entity.get('file_path', 'Unknown source')
reference_id = entity.get('reference_id', 'No reference')
print(f' {i + 1}. {entity_name} ({entity_type})')
print(f' Description: {description[:100]}{"..." if len(description) > 100 else ""}')
print(f' Source: {file_path}')
print(f' Reference ID: {reference_id}')
print()
# Print relationship information
if relationships:
print('🔗 Retrieved relationships (first 5):')
for i, rel in enumerate(relationships[:5]):
src = rel.get('src_id', 'Unknown')
tgt = rel.get('tgt_id', 'Unknown')
description = rel.get('description', 'No description')
keywords = rel.get('keywords', 'No keywords')
file_path = rel.get('file_path', 'Unknown source')
reference_id = rel.get('reference_id', 'No reference')
print(f' {i + 1}. {src}{tgt}')
print(f' Keywords: {keywords}')
print(f' Description: {description[:100]}{"..." if len(description) > 100 else ""}')
print(f' Source: {file_path}')
print(f' Reference ID: {reference_id}')
print()
# Print text chunk information
if chunks:
print('📄 Retrieved text chunks (first 3):')
for i, chunk in enumerate(chunks[:3]):
content = chunk.get('content', 'No content')
file_path = chunk.get('file_path', 'Unknown source')
chunk_id = chunk.get('chunk_id', 'Unknown ID')
reference_id = chunk.get('reference_id', 'No reference')
print(f' {i + 1}. Text chunk ID: {chunk_id}')
print(f' Source: {file_path}')
print(f' Reference ID: {reference_id}')
print(f' Content: {content[:200]}{"..." if len(content) > 200 else ""}')
print()
# Print references information (new in updated format)
if references:
print('📚 References:')
for i, ref in enumerate(references):
reference_id = ref.get('reference_id', 'Unknown ID')
file_path = ref.get('file_path', 'Unknown source')
print(f' {i + 1}. Reference ID: {reference_id}')
print(f' File Path: {file_path}')
print()
print('=' * 60)
@pytest.mark.integration
@pytest.mark.requires_api
def compare_with_regular_query():
"""Compare results between regular query and data query"""
query_text = 'LightRAG的作者是谁'
print('\n🔄 Comparison test: Regular query vs Data query')
print('-' * 60)
# Regular query
try:
print('1. Regular query (/query):')
regular_response = requests.post(
f'{BASE_URL}/query',
json={'query': query_text, 'mode': 'mix'},
headers=AUTH_HEADERS,
timeout=30,
)
if regular_response.status_code == 200:
regular_data = regular_response.json()
response_text = regular_data.get('response', 'No response')
print(f' Generated answer: {response_text[:300]}{"..." if len(response_text) > 300 else ""}')
else:
print(f' Regular query failed: {regular_response.status_code}')
if regular_response.status_code == 403:
print(' Authentication failed - Please check API Key configuration')
elif regular_response.status_code == 401:
print(' Unauthorized - Please check authentication information')
print(f' Error details: {regular_response.text}')
except Exception as e:
print(f' Regular query error: {e!s}')
@pytest.mark.integration
@pytest.mark.requires_api
def run_all_reference_tests():
"""Run all reference-related tests"""
print('\n' + '🚀' * 20)
print('LightRAG References Test Suite')
print('🚀' * 20)
all_tests_passed = True
# Test 1: /query endpoint references
try:
if not test_query_endpoint_references():
all_tests_passed = False
except Exception as e:
print(f'❌ /query endpoint test failed with exception: {e!s}')
all_tests_passed = False
# Test 2: /query/stream endpoint references
try:
if not test_query_stream_endpoint_references():
all_tests_passed = False
except Exception as e:
print(f'❌ /query/stream endpoint test failed with exception: {e!s}')
all_tests_passed = False
# Test 3: References consistency across endpoints
try:
if not test_references_consistency():
all_tests_passed = False
except Exception as e:
print(f'❌ References consistency test failed with exception: {e!s}')
all_tests_passed = False
# Final summary
print('\n' + '=' * 60)
print('TEST SUITE SUMMARY')
print('=' * 60)
if all_tests_passed:
print('🎉 ALL TESTS PASSED!')
print('✅ /query endpoint references functionality works correctly')
print('✅ /query/stream endpoint references functionality works correctly')
print('✅ References are consistent across all endpoints')
print('\n🔧 System is ready for production use with reference support!')
else:
print('❌ SOME TESTS FAILED!')
print('Please check the error messages above and fix the issues.')
print('\n🔧 System needs attention before production deployment.')
return all_tests_passed
if __name__ == '__main__':
import sys
if len(sys.argv) > 1 and sys.argv[1] == '--references-only':
# Run only the new reference tests
success = run_all_reference_tests()
sys.exit(0 if success else 1)
else:
# Run original tests plus new reference tests
print('Running original aquery_data endpoint test...')
test_aquery_data_endpoint()
print('\nRunning comparison test...')
compare_with_regular_query()
print('\nRunning new reference tests...')
run_all_reference_tests()
print('\n💡 Usage tips:')
print('1. Ensure LightRAG API service is running')
print('2. Adjust base_url and authentication information as needed')
print('3. Modify query parameters to test different retrieval strategies')
print('4. Data query results can be used for further analysis and processing')
print('5. Run with --references-only flag to test only reference functionality')