#!/usr/bin/env python3 """ Test script: Demonstrates usage of aquery_data FastAPI endpoint Query content: Who is the author of LightRAG Updated to handle the new data format where: - Response includes status, message, data, and metadata fields at top level - Actual query results (entities, relationships, chunks, references) are nested under 'data' field - Includes backward compatibility with legacy format """ import json import time from typing import Any import pytest import requests # API configuration API_KEY = 'your-secure-api-key-here-123' BASE_URL = 'http://localhost:9621' # Unified authentication headers AUTH_HEADERS = {'Content-Type': 'application/json', 'X-API-Key': API_KEY} def validate_references_format(references: list[dict[str, Any]]) -> bool: """Validate the format of references list""" if not isinstance(references, list): print(f'❌ References should be a list, got {type(references)}') return False for i, ref in enumerate(references): if not isinstance(ref, dict): print(f'❌ Reference {i} should be a dict, got {type(ref)}') return False required_fields = ['reference_id', 'file_path'] for field in required_fields: if field not in ref: print(f'❌ Reference {i} missing required field: {field}') return False if not isinstance(ref[field], str): print(f"❌ Reference {i} field '{field}' should be string, got {type(ref[field])}") return False return True def parse_streaming_response( response_text: str, ) -> tuple[list[dict] | None, list[str], list[str]]: """Parse streaming response and extract references, response chunks, and errors""" references = None response_chunks = [] errors = [] lines = response_text.strip().split('\n') for line in lines: line = line.strip() if line.startswith('data: '): line = line[6:] # Remove 'data: ' prefix if not line: continue try: data = json.loads(line) if 'references' in data: references = data['references'] if 'response' in data: response_chunks.append(data['response']) if 'error' in data: errors.append(data['error']) except json.JSONDecodeError: # Skip non-JSON lines (like SSE comments) continue return references, response_chunks, errors @pytest.mark.integration @pytest.mark.requires_api def test_query_endpoint_references(): """Test /query endpoint references functionality""" print('\n' + '=' * 60) print('Testing /query endpoint references functionality') print('=' * 60) query_text = 'who authored LightRAG' endpoint = f'{BASE_URL}/query' # Test 1: References enabled (default) print('\nπŸ§ͺ Test 1: References enabled (default)') print('-' * 40) try: response = requests.post( endpoint, json={'query': query_text, 'mode': 'mix', 'include_references': True}, headers=AUTH_HEADERS, timeout=30, ) if response.status_code == 200: data = response.json() # Check response structure if 'response' not in data: print("❌ Missing 'response' field") return False if 'references' not in data: print("❌ Missing 'references' field when include_references=True") return False references = data['references'] if references is None: print('❌ References should not be None when include_references=True') return False if not validate_references_format(references): return False print(f'βœ… References enabled: Found {len(references)} references') print(f' Response length: {len(data["response"])} characters') # Display reference list if references: print(' πŸ“š Reference List:') for i, ref in enumerate(references, 1): ref_id = ref.get('reference_id', 'Unknown') file_path = ref.get('file_path', 'Unknown') print(f' {i}. ID: {ref_id} | File: {file_path}') else: print(f'❌ Request failed: {response.status_code}') print(f' Error: {response.text}') return False except Exception as e: print(f'❌ Test 1 failed: {e!s}') return False # Test 2: References disabled print('\nπŸ§ͺ Test 2: References disabled') print('-' * 40) try: response = requests.post( endpoint, json={'query': query_text, 'mode': 'mix', 'include_references': False}, headers=AUTH_HEADERS, timeout=30, ) if response.status_code == 200: data = response.json() # Check response structure if 'response' not in data: print("❌ Missing 'response' field") return False references = data.get('references') if references is not None: print('❌ References should be None when include_references=False') return False print('βœ… References disabled: No references field present') print(f' Response length: {len(data["response"])} characters') else: print(f'❌ Request failed: {response.status_code}') print(f' Error: {response.text}') return False except Exception as e: print(f'❌ Test 2 failed: {e!s}') return False print('\nβœ… /query endpoint references tests passed!') return True @pytest.mark.integration @pytest.mark.requires_api def test_query_stream_endpoint_references(): """Test /query/stream endpoint references functionality""" print('\n' + '=' * 60) print('Testing /query/stream endpoint references functionality') print('=' * 60) query_text = 'who authored LightRAG' endpoint = f'{BASE_URL}/query/stream' # Test 1: Streaming with references enabled print('\nπŸ§ͺ Test 1: Streaming with references enabled') print('-' * 40) try: response = requests.post( endpoint, json={'query': query_text, 'mode': 'mix', 'include_references': True}, headers=AUTH_HEADERS, timeout=30, stream=True, ) if response.status_code == 200: # Collect streaming response full_response = '' for chunk in response.iter_content(chunk_size=1024, decode_unicode=True): if chunk: # Ensure chunk is string type if isinstance(chunk, bytes): chunk = chunk.decode('utf-8') full_response += chunk # Parse streaming response references, response_chunks, errors = parse_streaming_response(full_response) if errors: print(f'❌ Errors in streaming response: {errors}') return False if references is None: print('❌ No references found in streaming response') return False if not validate_references_format(references): return False if not response_chunks: print('❌ No response chunks found in streaming response') return False print(f'βœ… Streaming with references: Found {len(references)} references') print(f' Response chunks: {len(response_chunks)}') print(f' Total response length: {sum(len(chunk) for chunk in response_chunks)} characters') # Display reference list if references: print(' πŸ“š Reference List:') for i, ref in enumerate(references, 1): ref_id = ref.get('reference_id', 'Unknown') file_path = ref.get('file_path', 'Unknown') print(f' {i}. ID: {ref_id} | File: {file_path}') else: print(f'❌ Request failed: {response.status_code}') print(f' Error: {response.text}') return False except Exception as e: print(f'❌ Test 1 failed: {e!s}') return False # Test 2: Streaming with references disabled print('\nπŸ§ͺ Test 2: Streaming with references disabled') print('-' * 40) try: response = requests.post( endpoint, json={'query': query_text, 'mode': 'mix', 'include_references': False}, headers=AUTH_HEADERS, timeout=30, stream=True, ) if response.status_code == 200: # Collect streaming response full_response = '' for chunk in response.iter_content(chunk_size=1024, decode_unicode=True): if chunk: # Ensure chunk is string type if isinstance(chunk, bytes): chunk = chunk.decode('utf-8') full_response += chunk # Parse streaming response references, response_chunks, errors = parse_streaming_response(full_response) if errors: print(f'❌ Errors in streaming response: {errors}') return False if references is not None: print('❌ References should be None when include_references=False') return False if not response_chunks: print('❌ No response chunks found in streaming response') return False print('βœ… Streaming without references: No references present') print(f' Response chunks: {len(response_chunks)}') print(f' Total response length: {sum(len(chunk) for chunk in response_chunks)} characters') else: print(f'❌ Request failed: {response.status_code}') print(f' Error: {response.text}') return False except Exception as e: print(f'❌ Test 2 failed: {e!s}') return False print('\nβœ… /query/stream endpoint references tests passed!') return True @pytest.mark.integration @pytest.mark.requires_api def test_references_consistency(): """Test references consistency across all endpoints""" print('\n' + '=' * 60) print('Testing references consistency across endpoints') print('=' * 60) query_text = 'who authored LightRAG' query_params = { 'query': query_text, 'mode': 'mix', 'top_k': 10, 'chunk_top_k': 8, 'include_references': True, } references_data = {} # Test /query endpoint print('\nπŸ§ͺ Testing /query endpoint') print('-' * 40) try: response = requests.post(f'{BASE_URL}/query', json=query_params, headers=AUTH_HEADERS, timeout=30) if response.status_code == 200: data = response.json() references_data['query'] = data.get('references', []) print(f'βœ… /query: {len(references_data["query"])} references') else: print(f'❌ /query failed: {response.status_code}') return False except Exception as e: print(f'❌ /query test failed: {e!s}') return False # Test /query/stream endpoint print('\nπŸ§ͺ Testing /query/stream endpoint') print('-' * 40) try: response = requests.post( f'{BASE_URL}/query/stream', json=query_params, headers=AUTH_HEADERS, timeout=30, stream=True, ) if response.status_code == 200: full_response = '' for chunk in response.iter_content(chunk_size=1024, decode_unicode=True): if chunk: # Ensure chunk is string type if isinstance(chunk, bytes): chunk = chunk.decode('utf-8') full_response += chunk references, _, errors = parse_streaming_response(full_response) if errors: print(f'❌ Errors: {errors}') return False references_data['stream'] = references or [] print(f'βœ… /query/stream: {len(references_data["stream"])} references') else: print(f'❌ /query/stream failed: {response.status_code}') return False except Exception as e: print(f'❌ /query/stream test failed: {e!s}') return False # Test /query/data endpoint print('\nπŸ§ͺ Testing /query/data endpoint') print('-' * 40) try: response = requests.post( f'{BASE_URL}/query/data', json=query_params, headers=AUTH_HEADERS, timeout=30, ) if response.status_code == 200: data = response.json() query_data = data.get('data', {}) references_data['data'] = query_data.get('references', []) print(f'βœ… /query/data: {len(references_data["data"])} references') else: print(f'❌ /query/data failed: {response.status_code}') return False except Exception as e: print(f'❌ /query/data test failed: {e!s}') return False # Compare references consistency print('\nπŸ” Comparing references consistency') print('-' * 40) # Convert to sets of (reference_id, file_path) tuples for comparison def refs_to_set(refs): return {(ref.get('reference_id', ''), ref.get('file_path', '')) for ref in refs} query_refs = refs_to_set(references_data['query']) stream_refs = refs_to_set(references_data['stream']) data_refs = refs_to_set(references_data['data']) # Check consistency consistency_passed = True if query_refs != stream_refs: print('❌ References mismatch between /query and /query/stream') print(f' /query only: {query_refs - stream_refs}') print(f' /query/stream only: {stream_refs - query_refs}') consistency_passed = False if query_refs != data_refs: print('❌ References mismatch between /query and /query/data') print(f' /query only: {query_refs - data_refs}') print(f' /query/data only: {data_refs - query_refs}') consistency_passed = False if stream_refs != data_refs: print('❌ References mismatch between /query/stream and /query/data') print(f' /query/stream only: {stream_refs - data_refs}') print(f' /query/data only: {data_refs - stream_refs}') consistency_passed = False if consistency_passed: print('βœ… All endpoints return consistent references') print(f' Common references count: {len(query_refs)}') # Display common reference list if query_refs: print(' πŸ“š Common Reference List:') for i, (ref_id, file_path) in enumerate(sorted(query_refs), 1): print(f' {i}. ID: {ref_id} | File: {file_path}') return consistency_passed @pytest.mark.integration @pytest.mark.requires_api def test_aquery_data_endpoint(): """Test the /query/data endpoint""" # Use unified configuration endpoint = f'{BASE_URL}/query/data' # Query request query_request = { 'query': 'who authored LighRAG', 'mode': 'mix', # Use mixed mode to get the most comprehensive results 'top_k': 20, 'chunk_top_k': 15, 'max_entity_tokens': 4000, 'max_relation_tokens': 4000, 'max_total_tokens': 16000, 'enable_rerank': True, } print('=' * 60) print('LightRAG aquery_data endpoint test') print(' Returns structured data including entities, relationships and text chunks') print(' Can be used for custom processing and analysis') print('=' * 60) print(f'Query content: {query_request["query"]}') print(f'Query mode: {query_request["mode"]}') print(f'API endpoint: {endpoint}') print('-' * 60) try: # Send request print('Sending request...') start_time = time.time() response = requests.post(endpoint, json=query_request, headers=AUTH_HEADERS, timeout=30) end_time = time.time() response_time = end_time - start_time print(f'Response time: {response_time:.2f} seconds') print(f'HTTP status code: {response.status_code}') if response.status_code == 200: data = response.json() print_query_results(data) else: print(f'Request failed: {response.status_code}') print(f'Error message: {response.text}') except requests.exceptions.ConnectionError: print('❌ Connection failed: Please ensure LightRAG API service is running') print(' Start command: python -m lightrag.api.lightrag_server') except requests.exceptions.Timeout: print('❌ Request timeout: Query processing took too long') except Exception as e: print(f'❌ Error occurred: {e!s}') def print_query_results(data: dict[str, Any]): """Format and print query results""" # Check for new data format with status and message status = data.get('status', 'unknown') message = data.get('message', '') print(f'\nπŸ“‹ Query Status: {status}') if message: print(f'πŸ“‹ Message: {message}') # Handle new nested data format query_data = data.get('data', {}) # Fallback to old format if new format is not present if not query_data and any(key in data for key in ['entities', 'relationships', 'chunks']): print(' (Using legacy data format)') query_data = data entities = query_data.get('entities', []) relationships = query_data.get('relationships', []) chunks = query_data.get('chunks', []) references = query_data.get('references', []) print('\nπŸ“Š Query result statistics:') print(f' Entity count: {len(entities)}') print(f' Relationship count: {len(relationships)}') print(f' Text chunk count: {len(chunks)}') print(f' Reference count: {len(references)}') # Print metadata (now at top level in new format) metadata = data.get('metadata', {}) if metadata: print('\nπŸ” Query metadata:') print(f' Query mode: {metadata.get("query_mode", "unknown")}') keywords = metadata.get('keywords', {}) if keywords: high_level = keywords.get('high_level', []) low_level = keywords.get('low_level', []) if high_level: print(f' High-level keywords: {", ".join(high_level)}') if low_level: print(f' Low-level keywords: {", ".join(low_level)}') processing_info = metadata.get('processing_info', {}) if processing_info: print(' Processing info:') for key, value in processing_info.items(): print(f' {key}: {value}') # Print entity information if entities: print('\nπŸ‘₯ Retrieved entities (first 5):') for i, entity in enumerate(entities[:5]): entity_name = entity.get('entity_name', 'Unknown') entity_type = entity.get('entity_type', 'Unknown') description = entity.get('description', 'No description') file_path = entity.get('file_path', 'Unknown source') reference_id = entity.get('reference_id', 'No reference') print(f' {i + 1}. {entity_name} ({entity_type})') print(f' Description: {description[:100]}{"..." if len(description) > 100 else ""}') print(f' Source: {file_path}') print(f' Reference ID: {reference_id}') print() # Print relationship information if relationships: print('πŸ”— Retrieved relationships (first 5):') for i, rel in enumerate(relationships[:5]): src = rel.get('src_id', 'Unknown') tgt = rel.get('tgt_id', 'Unknown') description = rel.get('description', 'No description') keywords = rel.get('keywords', 'No keywords') file_path = rel.get('file_path', 'Unknown source') reference_id = rel.get('reference_id', 'No reference') print(f' {i + 1}. {src} β†’ {tgt}') print(f' Keywords: {keywords}') print(f' Description: {description[:100]}{"..." if len(description) > 100 else ""}') print(f' Source: {file_path}') print(f' Reference ID: {reference_id}') print() # Print text chunk information if chunks: print('πŸ“„ Retrieved text chunks (first 3):') for i, chunk in enumerate(chunks[:3]): content = chunk.get('content', 'No content') file_path = chunk.get('file_path', 'Unknown source') chunk_id = chunk.get('chunk_id', 'Unknown ID') reference_id = chunk.get('reference_id', 'No reference') print(f' {i + 1}. Text chunk ID: {chunk_id}') print(f' Source: {file_path}') print(f' Reference ID: {reference_id}') print(f' Content: {content[:200]}{"..." if len(content) > 200 else ""}') print() # Print references information (new in updated format) if references: print('πŸ“š References:') for i, ref in enumerate(references): reference_id = ref.get('reference_id', 'Unknown ID') file_path = ref.get('file_path', 'Unknown source') print(f' {i + 1}. Reference ID: {reference_id}') print(f' File Path: {file_path}') print() print('=' * 60) @pytest.mark.integration @pytest.mark.requires_api def compare_with_regular_query(): """Compare results between regular query and data query""" query_text = 'LightRAGηš„δ½œθ€…ζ˜―θ°' print('\nπŸ”„ Comparison test: Regular query vs Data query') print('-' * 60) # Regular query try: print('1. Regular query (/query):') regular_response = requests.post( f'{BASE_URL}/query', json={'query': query_text, 'mode': 'mix'}, headers=AUTH_HEADERS, timeout=30, ) if regular_response.status_code == 200: regular_data = regular_response.json() response_text = regular_data.get('response', 'No response') print(f' Generated answer: {response_text[:300]}{"..." if len(response_text) > 300 else ""}') else: print(f' Regular query failed: {regular_response.status_code}') if regular_response.status_code == 403: print(' Authentication failed - Please check API Key configuration') elif regular_response.status_code == 401: print(' Unauthorized - Please check authentication information') print(f' Error details: {regular_response.text}') except Exception as e: print(f' Regular query error: {e!s}') @pytest.mark.integration @pytest.mark.requires_api def run_all_reference_tests(): """Run all reference-related tests""" print('\n' + 'πŸš€' * 20) print('LightRAG References Test Suite') print('πŸš€' * 20) all_tests_passed = True # Test 1: /query endpoint references try: if not test_query_endpoint_references(): all_tests_passed = False except Exception as e: print(f'❌ /query endpoint test failed with exception: {e!s}') all_tests_passed = False # Test 2: /query/stream endpoint references try: if not test_query_stream_endpoint_references(): all_tests_passed = False except Exception as e: print(f'❌ /query/stream endpoint test failed with exception: {e!s}') all_tests_passed = False # Test 3: References consistency across endpoints try: if not test_references_consistency(): all_tests_passed = False except Exception as e: print(f'❌ References consistency test failed with exception: {e!s}') all_tests_passed = False # Final summary print('\n' + '=' * 60) print('TEST SUITE SUMMARY') print('=' * 60) if all_tests_passed: print('πŸŽ‰ ALL TESTS PASSED!') print('βœ… /query endpoint references functionality works correctly') print('βœ… /query/stream endpoint references functionality works correctly') print('βœ… References are consistent across all endpoints') print('\nπŸ”§ System is ready for production use with reference support!') else: print('❌ SOME TESTS FAILED!') print('Please check the error messages above and fix the issues.') print('\nπŸ”§ System needs attention before production deployment.') return all_tests_passed if __name__ == '__main__': import sys if len(sys.argv) > 1 and sys.argv[1] == '--references-only': # Run only the new reference tests success = run_all_reference_tests() sys.exit(0 if success else 1) else: # Run original tests plus new reference tests print('Running original aquery_data endpoint test...') test_aquery_data_endpoint() print('\nRunning comparison test...') compare_with_regular_query() print('\nRunning new reference tests...') run_all_reference_tests() print('\nπŸ’‘ Usage tips:') print('1. Ensure LightRAG API service is running') print('2. Adjust base_url and authentication information as needed') print('3. Modify query parameters to test different retrieval strategies') print('4. Data query results can be used for further analysis and processing') print('5. Run with --references-only flag to test only reference functionality')