Remove legacy storage implementations and deprecated examples: - Delete FAISS, JSON, Memgraph, Milvus, MongoDB, Nano Vector DB, Neo4j, NetworkX, Qdrant, Redis storage backends - Remove Kubernetes deployment manifests and installation scripts - Delete unofficial examples for deprecated backends and offline deployment docs Streamline core infrastructure: - Consolidate storage layer to PostgreSQL-only implementation - Add full-text search caching with FTS cache module - Implement metrics collection and monitoring pipeline - Add explain and metrics API routes Modernize frontend and tooling: - Switch web UI to Bun with bun.lock, remove npm and pnpm lockfiles - Update Dockerfile for PostgreSQL-only deployment - Add Makefile for common development tasks - Update environment and configuration examples Enhance evaluation and testing capabilities: - Add prompt optimization with DSPy and auto-tuning - Implement ground truth regeneration and variant testing - Add prompt debugging and response comparison utilities - Expand test coverage with new integration scenarios Simplify dependencies and configuration: - Remove offline-specific requirement files - Update pyproject.toml with streamlined dependencies - Add Python version pinning with .python-version - Create project guidelines in CLAUDE.md and AGENTS.md
373 lines
12 KiB
Python
373 lines
12 KiB
Python
"""
|
|
Accuracy tests for optimized prompts.
|
|
Validates that optimized prompts produce correct, parseable outputs.
|
|
|
|
Run with: uv run --extra test python tests/test_prompt_accuracy.py
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import sys
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
|
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
|
|
from lightrag.prompt import PROMPTS
|
|
|
|
# =============================================================================
|
|
# Test Data
|
|
# =============================================================================
|
|
|
|
KEYWORD_TEST_QUERIES = [
|
|
{
|
|
'query': 'What are the main causes of climate change and how do they affect polar ice caps?',
|
|
'expected_high': ['climate change', 'causes', 'effects'],
|
|
'expected_low': ['polar ice caps', 'greenhouse'],
|
|
},
|
|
{
|
|
'query': "How did Apple's iPhone sales compare to Samsung Galaxy in Q3 2024?",
|
|
'expected_high': ['sales comparison', 'smartphone'],
|
|
'expected_low': ['Apple', 'iPhone', 'Samsung', 'Galaxy', 'Q3 2024'],
|
|
},
|
|
{
|
|
'query': 'hello', # Trivial - should return empty
|
|
'expected_high': [],
|
|
'expected_low': [],
|
|
},
|
|
]
|
|
|
|
ORPHAN_TEST_CASES = [
|
|
{
|
|
'orphan': {
|
|
'name': 'Pfizer',
|
|
'type': 'organization',
|
|
'desc': 'Pharmaceutical company that developed COVID-19 vaccine',
|
|
},
|
|
'candidate': {
|
|
'name': 'Moderna',
|
|
'type': 'organization',
|
|
'desc': 'Biotechnology company that developed mRNA COVID-19 vaccine',
|
|
},
|
|
'should_connect': True,
|
|
'reason': 'Both are COVID-19 vaccine developers',
|
|
},
|
|
{
|
|
'orphan': {
|
|
'name': 'Mount Everest',
|
|
'type': 'location',
|
|
'desc': 'Highest mountain in the world, located in the Himalayas',
|
|
},
|
|
'candidate': {
|
|
'name': 'Python Programming',
|
|
'type': 'concept',
|
|
'desc': 'Popular programming language used for data science',
|
|
},
|
|
'should_connect': False,
|
|
'reason': 'No logical connection between mountain and programming language',
|
|
},
|
|
]
|
|
|
|
SUMMARIZATION_TEST_CASES = [
|
|
{
|
|
'name': 'Albert Einstein',
|
|
'type': 'Entity',
|
|
'descriptions': [
|
|
'{"description": "Albert Einstein was a German-born theoretical physicist."}',
|
|
'{"description": "Einstein developed the theory of relativity and won the Nobel Prize in Physics in 1921."}',
|
|
'{"description": "He is widely regarded as one of the most influential scientists of the 20th century."}',
|
|
],
|
|
'must_contain': ['physicist', 'relativity', 'Nobel Prize', 'influential'],
|
|
},
|
|
]
|
|
|
|
RAG_TEST_CASES = [
|
|
{
|
|
'query': 'What is the capital of France?',
|
|
'context': 'Paris is the capital and largest city of France. It has a population of over 2 million people.',
|
|
'must_contain': ['Paris'],
|
|
'must_not_contain': ['[1]', '[2]', 'References'],
|
|
},
|
|
]
|
|
|
|
|
|
# =============================================================================
|
|
# Helper Functions
|
|
# =============================================================================
|
|
|
|
|
|
async def call_llm(prompt: str, model: str = 'gpt-4o-mini') -> str:
|
|
"""Call OpenAI API with a single prompt."""
|
|
import openai
|
|
|
|
client = openai.AsyncOpenAI()
|
|
response = await client.chat.completions.create(
|
|
model=model,
|
|
messages=[{'role': 'user', 'content': prompt}],
|
|
temperature=0.0,
|
|
)
|
|
return response.choices[0].message.content
|
|
|
|
|
|
@dataclass
|
|
class TestResult:
|
|
name: str
|
|
passed: bool
|
|
details: str
|
|
raw_output: str = ''
|
|
|
|
|
|
# =============================================================================
|
|
# Test Functions
|
|
# =============================================================================
|
|
|
|
|
|
async def test_keywords_extraction() -> list[TestResult]:
|
|
"""Test keywords extraction prompt."""
|
|
results = []
|
|
|
|
examples = '\n'.join(PROMPTS['keywords_extraction_examples'])
|
|
|
|
for case in KEYWORD_TEST_QUERIES:
|
|
prompt = PROMPTS['keywords_extraction'].format(examples=examples, query=case['query'])
|
|
|
|
output = await call_llm(prompt)
|
|
|
|
# Try to parse JSON
|
|
try:
|
|
# Clean potential markdown
|
|
clean = output.strip()
|
|
if clean.startswith('```'):
|
|
clean = clean.split('```')[1]
|
|
if clean.startswith('json'):
|
|
clean = clean[4:]
|
|
|
|
parsed = json.loads(clean)
|
|
|
|
has_high = 'high_level_keywords' in parsed
|
|
has_low = 'low_level_keywords' in parsed
|
|
is_list_high = isinstance(parsed.get('high_level_keywords'), list)
|
|
is_list_low = isinstance(parsed.get('low_level_keywords'), list)
|
|
|
|
if has_high and has_low and is_list_high and is_list_low:
|
|
# Check if trivial query returns empty
|
|
if case['expected_high'] == [] and case['expected_low'] == []:
|
|
passed = len(parsed['high_level_keywords']) == 0 and len(parsed['low_level_keywords']) == 0
|
|
details = 'Empty lists returned for trivial query' if passed else f'Non-empty for trivial: {parsed}'
|
|
else:
|
|
# Check that some expected keywords are present (case-insensitive)
|
|
high_lower = [k.lower() for k in parsed['high_level_keywords']]
|
|
low_lower = [k.lower() for k in parsed['low_level_keywords']]
|
|
all_keywords = ' '.join(high_lower + low_lower)
|
|
|
|
found_high = sum(1 for exp in case['expected_high'] if exp.lower() in all_keywords)
|
|
found_low = sum(1 for exp in case['expected_low'] if exp.lower() in all_keywords)
|
|
|
|
passed = found_high > 0 or found_low > 0
|
|
details = f'Found {found_high}/{len(case["expected_high"])} high, {found_low}/{len(case["expected_low"])} low'
|
|
else:
|
|
passed = False
|
|
details = f'Missing keys or wrong types: has_high={has_high}, has_low={has_low}'
|
|
|
|
except json.JSONDecodeError as e:
|
|
passed = False
|
|
details = f'JSON parse error: {e}'
|
|
|
|
results.append(
|
|
TestResult(
|
|
name=f'Keywords: {case["query"][:40]}...', passed=passed, details=details, raw_output=output[:200]
|
|
)
|
|
)
|
|
|
|
return results
|
|
|
|
|
|
async def test_orphan_validation() -> list[TestResult]:
|
|
"""Test orphan connection validation prompt."""
|
|
results = []
|
|
|
|
for case in ORPHAN_TEST_CASES:
|
|
prompt = PROMPTS['orphan_connection_validation'].format(
|
|
orphan_name=case['orphan']['name'],
|
|
orphan_type=case['orphan']['type'],
|
|
orphan_description=case['orphan']['desc'],
|
|
candidate_name=case['candidate']['name'],
|
|
candidate_type=case['candidate']['type'],
|
|
candidate_description=case['candidate']['desc'],
|
|
similarity_score=0.85,
|
|
)
|
|
|
|
output = await call_llm(prompt)
|
|
|
|
try:
|
|
# Clean potential markdown
|
|
clean = output.strip()
|
|
if clean.startswith('```'):
|
|
clean = clean.split('```')[1]
|
|
if clean.startswith('json'):
|
|
clean = clean[4:]
|
|
|
|
parsed = json.loads(clean)
|
|
|
|
has_should_connect = 'should_connect' in parsed
|
|
has_confidence = 'confidence' in parsed
|
|
has_reasoning = 'reasoning' in parsed
|
|
|
|
if has_should_connect and has_confidence and has_reasoning:
|
|
correct_decision = parsed['should_connect'] == case['should_connect']
|
|
valid_confidence = 0.0 <= parsed['confidence'] <= 1.0
|
|
|
|
passed = correct_decision and valid_confidence
|
|
details = f'Decision: {parsed["should_connect"]} (expected {case["should_connect"]}), confidence: {parsed["confidence"]:.2f}'
|
|
else:
|
|
passed = False
|
|
details = f'Missing keys: should_connect={has_should_connect}, confidence={has_confidence}, reasoning={has_reasoning}'
|
|
|
|
except json.JSONDecodeError as e:
|
|
passed = False
|
|
details = f'JSON parse error: {e}'
|
|
|
|
results.append(
|
|
TestResult(
|
|
name=f'Orphan: {case["orphan"]["name"]} ↔ {case["candidate"]["name"]}',
|
|
passed=passed,
|
|
details=details,
|
|
raw_output=output[:200],
|
|
)
|
|
)
|
|
|
|
return results
|
|
|
|
|
|
async def test_entity_summarization() -> list[TestResult]:
|
|
"""Test entity summarization prompt."""
|
|
results = []
|
|
|
|
for case in SUMMARIZATION_TEST_CASES:
|
|
prompt = PROMPTS['summarize_entity_descriptions'].format(
|
|
description_name=case['name'],
|
|
description_type=case['type'],
|
|
description_list='\n'.join(case['descriptions']),
|
|
summary_length=200,
|
|
language='English',
|
|
)
|
|
|
|
output = await call_llm(prompt)
|
|
|
|
# Check if required terms are present
|
|
output_lower = output.lower()
|
|
found = [term for term in case['must_contain'] if term.lower() in output_lower]
|
|
missing = [term for term in case['must_contain'] if term.lower() not in output_lower]
|
|
|
|
# Check it's not empty and mentions the entity
|
|
has_content = len(output.strip()) > 50
|
|
mentions_entity = case['name'].lower() in output_lower
|
|
|
|
passed = len(found) >= len(case['must_contain']) // 2 and has_content and mentions_entity
|
|
details = f'Found {len(found)}/{len(case["must_contain"])} terms, mentions entity: {mentions_entity}'
|
|
if missing:
|
|
details += f', missing: {missing}'
|
|
|
|
results.append(
|
|
TestResult(name=f'Summarize: {case["name"]}', passed=passed, details=details, raw_output=output[:200])
|
|
)
|
|
|
|
return results
|
|
|
|
|
|
async def test_naive_rag_response() -> list[TestResult]:
|
|
"""Test naive RAG response prompt."""
|
|
results = []
|
|
|
|
for case in RAG_TEST_CASES:
|
|
prompt = PROMPTS['naive_rag_response'].format(
|
|
response_type='concise paragraph',
|
|
user_prompt=case['query'],
|
|
content_data=case['context'],
|
|
)
|
|
|
|
output = await call_llm(prompt)
|
|
|
|
# Check must_contain
|
|
output_lower = output.lower()
|
|
found = [term for term in case['must_contain'] if term.lower() in output_lower]
|
|
|
|
# Check must_not_contain (citation markers)
|
|
violations = [term for term in case['must_not_contain'] if term in output]
|
|
|
|
passed = len(found) == len(case['must_contain']) and len(violations) == 0
|
|
details = f'Found {len(found)}/{len(case["must_contain"])} required terms'
|
|
if violations:
|
|
details += f', VIOLATIONS: {violations}'
|
|
|
|
results.append(
|
|
TestResult(name=f'RAG: {case["query"][:40]}', passed=passed, details=details, raw_output=output[:200])
|
|
)
|
|
|
|
return results
|
|
|
|
|
|
# =============================================================================
|
|
# Main
|
|
# =============================================================================
|
|
|
|
|
|
async def main() -> None:
|
|
"""Run all accuracy tests."""
|
|
print('\n' + '=' * 70)
|
|
print(' PROMPT ACCURACY TESTS')
|
|
print('=' * 70)
|
|
|
|
all_results = []
|
|
|
|
# Run tests in parallel
|
|
print('\nRunning tests...')
|
|
|
|
keywords_results, orphan_results, summarize_results, rag_results = await asyncio.gather(
|
|
test_keywords_extraction(),
|
|
test_orphan_validation(),
|
|
test_entity_summarization(),
|
|
test_naive_rag_response(),
|
|
)
|
|
|
|
all_results.extend(keywords_results)
|
|
all_results.extend(orphan_results)
|
|
all_results.extend(summarize_results)
|
|
all_results.extend(rag_results)
|
|
|
|
# Print results
|
|
print('\n' + '-' * 70)
|
|
print(' RESULTS')
|
|
print('-' * 70)
|
|
|
|
passed = 0
|
|
failed = 0
|
|
|
|
for result in all_results:
|
|
status = '✓ PASS' if result.passed else '✗ FAIL'
|
|
print(f'\n{status}: {result.name}')
|
|
print(f' {result.details}')
|
|
if not result.passed:
|
|
print(f' Output: {result.raw_output}...')
|
|
|
|
if result.passed:
|
|
passed += 1
|
|
else:
|
|
failed += 1
|
|
|
|
# Summary
|
|
print('\n' + '=' * 70)
|
|
print(f' SUMMARY: {passed}/{passed + failed} tests passed')
|
|
print('=' * 70)
|
|
|
|
if failed > 0:
|
|
print('\n⚠️ Some tests failed - review prompt changes')
|
|
sys.exit(1)
|
|
else:
|
|
print('\n✓ All prompts producing correct outputs!')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
asyncio.run(main())
|