LightRAG/tests/test_prompt_accuracy.py
clssck 59e89772de refactor: consolidate to PostgreSQL-only backend and modernize stack
Remove legacy storage implementations and deprecated examples:
- Delete FAISS, JSON, Memgraph, Milvus, MongoDB, Nano Vector DB, Neo4j, NetworkX, Qdrant, Redis storage backends
- Remove Kubernetes deployment manifests and installation scripts
- Delete unofficial examples for deprecated backends and offline deployment docs
Streamline core infrastructure:
- Consolidate storage layer to PostgreSQL-only implementation
- Add full-text search caching with FTS cache module
- Implement metrics collection and monitoring pipeline
- Add explain and metrics API routes
Modernize frontend and tooling:
- Switch web UI to Bun with bun.lock, remove npm and pnpm lockfiles
- Update Dockerfile for PostgreSQL-only deployment
- Add Makefile for common development tasks
- Update environment and configuration examples
Enhance evaluation and testing capabilities:
- Add prompt optimization with DSPy and auto-tuning
- Implement ground truth regeneration and variant testing
- Add prompt debugging and response comparison utilities
- Expand test coverage with new integration scenarios
Simplify dependencies and configuration:
- Remove offline-specific requirement files
- Update pyproject.toml with streamlined dependencies
- Add Python version pinning with .python-version
- Create project guidelines in CLAUDE.md and AGENTS.md
2025-12-12 16:28:49 +01:00

373 lines
12 KiB
Python

"""
Accuracy tests for optimized prompts.
Validates that optimized prompts produce correct, parseable outputs.
Run with: uv run --extra test python tests/test_prompt_accuracy.py
"""
from __future__ import annotations
import asyncio
import json
import sys
from dataclasses import dataclass
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from lightrag.prompt import PROMPTS
# =============================================================================
# Test Data
# =============================================================================
KEYWORD_TEST_QUERIES = [
{
'query': 'What are the main causes of climate change and how do they affect polar ice caps?',
'expected_high': ['climate change', 'causes', 'effects'],
'expected_low': ['polar ice caps', 'greenhouse'],
},
{
'query': "How did Apple's iPhone sales compare to Samsung Galaxy in Q3 2024?",
'expected_high': ['sales comparison', 'smartphone'],
'expected_low': ['Apple', 'iPhone', 'Samsung', 'Galaxy', 'Q3 2024'],
},
{
'query': 'hello', # Trivial - should return empty
'expected_high': [],
'expected_low': [],
},
]
ORPHAN_TEST_CASES = [
{
'orphan': {
'name': 'Pfizer',
'type': 'organization',
'desc': 'Pharmaceutical company that developed COVID-19 vaccine',
},
'candidate': {
'name': 'Moderna',
'type': 'organization',
'desc': 'Biotechnology company that developed mRNA COVID-19 vaccine',
},
'should_connect': True,
'reason': 'Both are COVID-19 vaccine developers',
},
{
'orphan': {
'name': 'Mount Everest',
'type': 'location',
'desc': 'Highest mountain in the world, located in the Himalayas',
},
'candidate': {
'name': 'Python Programming',
'type': 'concept',
'desc': 'Popular programming language used for data science',
},
'should_connect': False,
'reason': 'No logical connection between mountain and programming language',
},
]
SUMMARIZATION_TEST_CASES = [
{
'name': 'Albert Einstein',
'type': 'Entity',
'descriptions': [
'{"description": "Albert Einstein was a German-born theoretical physicist."}',
'{"description": "Einstein developed the theory of relativity and won the Nobel Prize in Physics in 1921."}',
'{"description": "He is widely regarded as one of the most influential scientists of the 20th century."}',
],
'must_contain': ['physicist', 'relativity', 'Nobel Prize', 'influential'],
},
]
RAG_TEST_CASES = [
{
'query': 'What is the capital of France?',
'context': 'Paris is the capital and largest city of France. It has a population of over 2 million people.',
'must_contain': ['Paris'],
'must_not_contain': ['[1]', '[2]', 'References'],
},
]
# =============================================================================
# Helper Functions
# =============================================================================
async def call_llm(prompt: str, model: str = 'gpt-4o-mini') -> str:
"""Call OpenAI API with a single prompt."""
import openai
client = openai.AsyncOpenAI()
response = await client.chat.completions.create(
model=model,
messages=[{'role': 'user', 'content': prompt}],
temperature=0.0,
)
return response.choices[0].message.content
@dataclass
class TestResult:
name: str
passed: bool
details: str
raw_output: str = ''
# =============================================================================
# Test Functions
# =============================================================================
async def test_keywords_extraction() -> list[TestResult]:
"""Test keywords extraction prompt."""
results = []
examples = '\n'.join(PROMPTS['keywords_extraction_examples'])
for case in KEYWORD_TEST_QUERIES:
prompt = PROMPTS['keywords_extraction'].format(examples=examples, query=case['query'])
output = await call_llm(prompt)
# Try to parse JSON
try:
# Clean potential markdown
clean = output.strip()
if clean.startswith('```'):
clean = clean.split('```')[1]
if clean.startswith('json'):
clean = clean[4:]
parsed = json.loads(clean)
has_high = 'high_level_keywords' in parsed
has_low = 'low_level_keywords' in parsed
is_list_high = isinstance(parsed.get('high_level_keywords'), list)
is_list_low = isinstance(parsed.get('low_level_keywords'), list)
if has_high and has_low and is_list_high and is_list_low:
# Check if trivial query returns empty
if case['expected_high'] == [] and case['expected_low'] == []:
passed = len(parsed['high_level_keywords']) == 0 and len(parsed['low_level_keywords']) == 0
details = 'Empty lists returned for trivial query' if passed else f'Non-empty for trivial: {parsed}'
else:
# Check that some expected keywords are present (case-insensitive)
high_lower = [k.lower() for k in parsed['high_level_keywords']]
low_lower = [k.lower() for k in parsed['low_level_keywords']]
all_keywords = ' '.join(high_lower + low_lower)
found_high = sum(1 for exp in case['expected_high'] if exp.lower() in all_keywords)
found_low = sum(1 for exp in case['expected_low'] if exp.lower() in all_keywords)
passed = found_high > 0 or found_low > 0
details = f'Found {found_high}/{len(case["expected_high"])} high, {found_low}/{len(case["expected_low"])} low'
else:
passed = False
details = f'Missing keys or wrong types: has_high={has_high}, has_low={has_low}'
except json.JSONDecodeError as e:
passed = False
details = f'JSON parse error: {e}'
results.append(
TestResult(
name=f'Keywords: {case["query"][:40]}...', passed=passed, details=details, raw_output=output[:200]
)
)
return results
async def test_orphan_validation() -> list[TestResult]:
"""Test orphan connection validation prompt."""
results = []
for case in ORPHAN_TEST_CASES:
prompt = PROMPTS['orphan_connection_validation'].format(
orphan_name=case['orphan']['name'],
orphan_type=case['orphan']['type'],
orphan_description=case['orphan']['desc'],
candidate_name=case['candidate']['name'],
candidate_type=case['candidate']['type'],
candidate_description=case['candidate']['desc'],
similarity_score=0.85,
)
output = await call_llm(prompt)
try:
# Clean potential markdown
clean = output.strip()
if clean.startswith('```'):
clean = clean.split('```')[1]
if clean.startswith('json'):
clean = clean[4:]
parsed = json.loads(clean)
has_should_connect = 'should_connect' in parsed
has_confidence = 'confidence' in parsed
has_reasoning = 'reasoning' in parsed
if has_should_connect and has_confidence and has_reasoning:
correct_decision = parsed['should_connect'] == case['should_connect']
valid_confidence = 0.0 <= parsed['confidence'] <= 1.0
passed = correct_decision and valid_confidence
details = f'Decision: {parsed["should_connect"]} (expected {case["should_connect"]}), confidence: {parsed["confidence"]:.2f}'
else:
passed = False
details = f'Missing keys: should_connect={has_should_connect}, confidence={has_confidence}, reasoning={has_reasoning}'
except json.JSONDecodeError as e:
passed = False
details = f'JSON parse error: {e}'
results.append(
TestResult(
name=f'Orphan: {case["orphan"]["name"]}{case["candidate"]["name"]}',
passed=passed,
details=details,
raw_output=output[:200],
)
)
return results
async def test_entity_summarization() -> list[TestResult]:
"""Test entity summarization prompt."""
results = []
for case in SUMMARIZATION_TEST_CASES:
prompt = PROMPTS['summarize_entity_descriptions'].format(
description_name=case['name'],
description_type=case['type'],
description_list='\n'.join(case['descriptions']),
summary_length=200,
language='English',
)
output = await call_llm(prompt)
# Check if required terms are present
output_lower = output.lower()
found = [term for term in case['must_contain'] if term.lower() in output_lower]
missing = [term for term in case['must_contain'] if term.lower() not in output_lower]
# Check it's not empty and mentions the entity
has_content = len(output.strip()) > 50
mentions_entity = case['name'].lower() in output_lower
passed = len(found) >= len(case['must_contain']) // 2 and has_content and mentions_entity
details = f'Found {len(found)}/{len(case["must_contain"])} terms, mentions entity: {mentions_entity}'
if missing:
details += f', missing: {missing}'
results.append(
TestResult(name=f'Summarize: {case["name"]}', passed=passed, details=details, raw_output=output[:200])
)
return results
async def test_naive_rag_response() -> list[TestResult]:
"""Test naive RAG response prompt."""
results = []
for case in RAG_TEST_CASES:
prompt = PROMPTS['naive_rag_response'].format(
response_type='concise paragraph',
user_prompt=case['query'],
content_data=case['context'],
)
output = await call_llm(prompt)
# Check must_contain
output_lower = output.lower()
found = [term for term in case['must_contain'] if term.lower() in output_lower]
# Check must_not_contain (citation markers)
violations = [term for term in case['must_not_contain'] if term in output]
passed = len(found) == len(case['must_contain']) and len(violations) == 0
details = f'Found {len(found)}/{len(case["must_contain"])} required terms'
if violations:
details += f', VIOLATIONS: {violations}'
results.append(
TestResult(name=f'RAG: {case["query"][:40]}', passed=passed, details=details, raw_output=output[:200])
)
return results
# =============================================================================
# Main
# =============================================================================
async def main() -> None:
"""Run all accuracy tests."""
print('\n' + '=' * 70)
print(' PROMPT ACCURACY TESTS')
print('=' * 70)
all_results = []
# Run tests in parallel
print('\nRunning tests...')
keywords_results, orphan_results, summarize_results, rag_results = await asyncio.gather(
test_keywords_extraction(),
test_orphan_validation(),
test_entity_summarization(),
test_naive_rag_response(),
)
all_results.extend(keywords_results)
all_results.extend(orphan_results)
all_results.extend(summarize_results)
all_results.extend(rag_results)
# Print results
print('\n' + '-' * 70)
print(' RESULTS')
print('-' * 70)
passed = 0
failed = 0
for result in all_results:
status = '✓ PASS' if result.passed else '✗ FAIL'
print(f'\n{status}: {result.name}')
print(f' {result.details}')
if not result.passed:
print(f' Output: {result.raw_output}...')
if result.passed:
passed += 1
else:
failed += 1
# Summary
print('\n' + '=' * 70)
print(f' SUMMARY: {passed}/{passed + failed} tests passed')
print('=' * 70)
if failed > 0:
print('\n⚠️ Some tests failed - review prompt changes')
sys.exit(1)
else:
print('\n✓ All prompts producing correct outputs!')
if __name__ == '__main__':
asyncio.run(main())