""" Accuracy tests for optimized prompts. Validates that optimized prompts produce correct, parseable outputs. Run with: uv run --extra test python tests/test_prompt_accuracy.py """ from __future__ import annotations import asyncio import json import sys from dataclasses import dataclass from pathlib import Path sys.path.insert(0, str(Path(__file__).parent.parent)) from lightrag.prompt import PROMPTS # ============================================================================= # Test Data # ============================================================================= KEYWORD_TEST_QUERIES = [ { 'query': 'What are the main causes of climate change and how do they affect polar ice caps?', 'expected_high': ['climate change', 'causes', 'effects'], 'expected_low': ['polar ice caps', 'greenhouse'], }, { 'query': "How did Apple's iPhone sales compare to Samsung Galaxy in Q3 2024?", 'expected_high': ['sales comparison', 'smartphone'], 'expected_low': ['Apple', 'iPhone', 'Samsung', 'Galaxy', 'Q3 2024'], }, { 'query': 'hello', # Trivial - should return empty 'expected_high': [], 'expected_low': [], }, ] ORPHAN_TEST_CASES = [ { 'orphan': { 'name': 'Pfizer', 'type': 'organization', 'desc': 'Pharmaceutical company that developed COVID-19 vaccine', }, 'candidate': { 'name': 'Moderna', 'type': 'organization', 'desc': 'Biotechnology company that developed mRNA COVID-19 vaccine', }, 'should_connect': True, 'reason': 'Both are COVID-19 vaccine developers', }, { 'orphan': { 'name': 'Mount Everest', 'type': 'location', 'desc': 'Highest mountain in the world, located in the Himalayas', }, 'candidate': { 'name': 'Python Programming', 'type': 'concept', 'desc': 'Popular programming language used for data science', }, 'should_connect': False, 'reason': 'No logical connection between mountain and programming language', }, ] SUMMARIZATION_TEST_CASES = [ { 'name': 'Albert Einstein', 'type': 'Entity', 'descriptions': [ '{"description": "Albert Einstein was a German-born theoretical physicist."}', '{"description": "Einstein developed the theory of relativity and won the Nobel Prize in Physics in 1921."}', '{"description": "He is widely regarded as one of the most influential scientists of the 20th century."}', ], 'must_contain': ['physicist', 'relativity', 'Nobel Prize', 'influential'], }, ] RAG_TEST_CASES = [ { 'query': 'What is the capital of France?', 'context': 'Paris is the capital and largest city of France. It has a population of over 2 million people.', 'must_contain': ['Paris'], 'must_not_contain': ['[1]', '[2]', 'References'], }, ] # ============================================================================= # Helper Functions # ============================================================================= async def call_llm(prompt: str, model: str = 'gpt-4o-mini') -> str: """Call OpenAI API with a single prompt.""" import openai client = openai.AsyncOpenAI() response = await client.chat.completions.create( model=model, messages=[{'role': 'user', 'content': prompt}], temperature=0.0, ) return response.choices[0].message.content @dataclass class TestResult: name: str passed: bool details: str raw_output: str = '' # ============================================================================= # Test Functions # ============================================================================= async def test_keywords_extraction() -> list[TestResult]: """Test keywords extraction prompt.""" results = [] examples = '\n'.join(PROMPTS['keywords_extraction_examples']) for case in KEYWORD_TEST_QUERIES: prompt = PROMPTS['keywords_extraction'].format(examples=examples, query=case['query']) output = await call_llm(prompt) # Try to parse JSON try: # Clean potential markdown clean = output.strip() if clean.startswith('```'): clean = clean.split('```')[1] if clean.startswith('json'): clean = clean[4:] parsed = json.loads(clean) has_high = 'high_level_keywords' in parsed has_low = 'low_level_keywords' in parsed is_list_high = isinstance(parsed.get('high_level_keywords'), list) is_list_low = isinstance(parsed.get('low_level_keywords'), list) if has_high and has_low and is_list_high and is_list_low: # Check if trivial query returns empty if case['expected_high'] == [] and case['expected_low'] == []: passed = len(parsed['high_level_keywords']) == 0 and len(parsed['low_level_keywords']) == 0 details = 'Empty lists returned for trivial query' if passed else f'Non-empty for trivial: {parsed}' else: # Check that some expected keywords are present (case-insensitive) high_lower = [k.lower() for k in parsed['high_level_keywords']] low_lower = [k.lower() for k in parsed['low_level_keywords']] all_keywords = ' '.join(high_lower + low_lower) found_high = sum(1 for exp in case['expected_high'] if exp.lower() in all_keywords) found_low = sum(1 for exp in case['expected_low'] if exp.lower() in all_keywords) passed = found_high > 0 or found_low > 0 details = f'Found {found_high}/{len(case["expected_high"])} high, {found_low}/{len(case["expected_low"])} low' else: passed = False details = f'Missing keys or wrong types: has_high={has_high}, has_low={has_low}' except json.JSONDecodeError as e: passed = False details = f'JSON parse error: {e}' results.append( TestResult( name=f'Keywords: {case["query"][:40]}...', passed=passed, details=details, raw_output=output[:200] ) ) return results async def test_orphan_validation() -> list[TestResult]: """Test orphan connection validation prompt.""" results = [] for case in ORPHAN_TEST_CASES: prompt = PROMPTS['orphan_connection_validation'].format( orphan_name=case['orphan']['name'], orphan_type=case['orphan']['type'], orphan_description=case['orphan']['desc'], candidate_name=case['candidate']['name'], candidate_type=case['candidate']['type'], candidate_description=case['candidate']['desc'], similarity_score=0.85, ) output = await call_llm(prompt) try: # Clean potential markdown clean = output.strip() if clean.startswith('```'): clean = clean.split('```')[1] if clean.startswith('json'): clean = clean[4:] parsed = json.loads(clean) has_should_connect = 'should_connect' in parsed has_confidence = 'confidence' in parsed has_reasoning = 'reasoning' in parsed if has_should_connect and has_confidence and has_reasoning: correct_decision = parsed['should_connect'] == case['should_connect'] valid_confidence = 0.0 <= parsed['confidence'] <= 1.0 passed = correct_decision and valid_confidence details = f'Decision: {parsed["should_connect"]} (expected {case["should_connect"]}), confidence: {parsed["confidence"]:.2f}' else: passed = False details = f'Missing keys: should_connect={has_should_connect}, confidence={has_confidence}, reasoning={has_reasoning}' except json.JSONDecodeError as e: passed = False details = f'JSON parse error: {e}' results.append( TestResult( name=f'Orphan: {case["orphan"]["name"]} ↔ {case["candidate"]["name"]}', passed=passed, details=details, raw_output=output[:200], ) ) return results async def test_entity_summarization() -> list[TestResult]: """Test entity summarization prompt.""" results = [] for case in SUMMARIZATION_TEST_CASES: prompt = PROMPTS['summarize_entity_descriptions'].format( description_name=case['name'], description_type=case['type'], description_list='\n'.join(case['descriptions']), summary_length=200, language='English', ) output = await call_llm(prompt) # Check if required terms are present output_lower = output.lower() found = [term for term in case['must_contain'] if term.lower() in output_lower] missing = [term for term in case['must_contain'] if term.lower() not in output_lower] # Check it's not empty and mentions the entity has_content = len(output.strip()) > 50 mentions_entity = case['name'].lower() in output_lower passed = len(found) >= len(case['must_contain']) // 2 and has_content and mentions_entity details = f'Found {len(found)}/{len(case["must_contain"])} terms, mentions entity: {mentions_entity}' if missing: details += f', missing: {missing}' results.append( TestResult(name=f'Summarize: {case["name"]}', passed=passed, details=details, raw_output=output[:200]) ) return results async def test_naive_rag_response() -> list[TestResult]: """Test naive RAG response prompt.""" results = [] for case in RAG_TEST_CASES: prompt = PROMPTS['naive_rag_response'].format( response_type='concise paragraph', user_prompt=case['query'], content_data=case['context'], ) output = await call_llm(prompt) # Check must_contain output_lower = output.lower() found = [term for term in case['must_contain'] if term.lower() in output_lower] # Check must_not_contain (citation markers) violations = [term for term in case['must_not_contain'] if term in output] passed = len(found) == len(case['must_contain']) and len(violations) == 0 details = f'Found {len(found)}/{len(case["must_contain"])} required terms' if violations: details += f', VIOLATIONS: {violations}' results.append( TestResult(name=f'RAG: {case["query"][:40]}', passed=passed, details=details, raw_output=output[:200]) ) return results # ============================================================================= # Main # ============================================================================= async def main() -> None: """Run all accuracy tests.""" print('\n' + '=' * 70) print(' PROMPT ACCURACY TESTS') print('=' * 70) all_results = [] # Run tests in parallel print('\nRunning tests...') keywords_results, orphan_results, summarize_results, rag_results = await asyncio.gather( test_keywords_extraction(), test_orphan_validation(), test_entity_summarization(), test_naive_rag_response(), ) all_results.extend(keywords_results) all_results.extend(orphan_results) all_results.extend(summarize_results) all_results.extend(rag_results) # Print results print('\n' + '-' * 70) print(' RESULTS') print('-' * 70) passed = 0 failed = 0 for result in all_results: status = '✓ PASS' if result.passed else '✗ FAIL' print(f'\n{status}: {result.name}') print(f' {result.details}') if not result.passed: print(f' Output: {result.raw_output}...') if result.passed: passed += 1 else: failed += 1 # Summary print('\n' + '=' * 70) print(f' SUMMARY: {passed}/{passed + failed} tests passed') print('=' * 70) if failed > 0: print('\n⚠️ Some tests failed - review prompt changes') sys.exit(1) else: print('\n✓ All prompts producing correct outputs!') if __name__ == '__main__': asyncio.run(main())