LightRAG/tests/test_prompt_accuracy.py
clssck 082a5a8fad test(lightrag,api): add comprehensive test coverage and S3 support
Add extensive test suites for API routes and utilities:
- Implement test_search_routes.py (406 lines) for search endpoint validation
- Implement test_upload_routes.py (724 lines) for document upload workflows
- Implement test_s3_client.py (618 lines) for S3 storage operations
- Implement test_citation_utils.py (352 lines) for citation extraction
- Implement test_chunking.py (216 lines) for text chunking validation
Add S3 storage client implementation:
- Create lightrag/storage/s3_client.py with S3 operations
- Add storage module initialization with exports
- Integrate S3 client with document upload handling
Enhance API routes and core functionality:
- Add search_routes.py with full-text and graph search endpoints
- Add upload_routes.py with multipart document upload support
- Update operate.py with bulk operations and health checks
- Enhance postgres_impl.py with bulk upsert and parameterized queries
- Update lightrag_server.py to register new API routes
- Improve utils.py with citation and formatting utilities
Update dependencies and configuration:
- Add S3 and test dependencies to pyproject.toml
- Update docker-compose.test.yml for testing environment
- Sync uv.lock with new dependencies
Apply code quality improvements across all modified files:
- Add type hints to function signatures
- Update imports and router initialization
- Fix logging and error handling
2025-12-05 23:13:39 +01:00

364 lines
12 KiB
Python

"""
Accuracy tests for optimized prompts.
Validates that optimized prompts produce correct, parseable outputs.
Run with: uv run --extra test python tests/test_prompt_accuracy.py
"""
from __future__ import annotations
import asyncio
import json
import sys
from dataclasses import dataclass
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from lightrag.prompt import PROMPTS
# =============================================================================
# Test Data
# =============================================================================
KEYWORD_TEST_QUERIES = [
{
"query": "What are the main causes of climate change and how do they affect polar ice caps?",
"expected_high": ["climate change", "causes", "effects"],
"expected_low": ["polar ice caps", "greenhouse"],
},
{
"query": "How did Apple's iPhone sales compare to Samsung Galaxy in Q3 2024?",
"expected_high": ["sales comparison", "smartphone"],
"expected_low": ["Apple", "iPhone", "Samsung", "Galaxy", "Q3 2024"],
},
{
"query": "hello", # Trivial - should return empty
"expected_high": [],
"expected_low": [],
},
]
ORPHAN_TEST_CASES = [
{
"orphan": {"name": "Pfizer", "type": "organization", "desc": "Pharmaceutical company that developed COVID-19 vaccine"},
"candidate": {"name": "Moderna", "type": "organization", "desc": "Biotechnology company that developed mRNA COVID-19 vaccine"},
"should_connect": True,
"reason": "Both are COVID-19 vaccine developers",
},
{
"orphan": {"name": "Mount Everest", "type": "location", "desc": "Highest mountain in the world, located in the Himalayas"},
"candidate": {"name": "Python Programming", "type": "concept", "desc": "Popular programming language used for data science"},
"should_connect": False,
"reason": "No logical connection between mountain and programming language",
},
]
SUMMARIZATION_TEST_CASES = [
{
"name": "Albert Einstein",
"type": "Entity",
"descriptions": [
'{"description": "Albert Einstein was a German-born theoretical physicist."}',
'{"description": "Einstein developed the theory of relativity and won the Nobel Prize in Physics in 1921."}',
'{"description": "He is widely regarded as one of the most influential scientists of the 20th century."}',
],
"must_contain": ["physicist", "relativity", "Nobel Prize", "influential"],
},
]
RAG_TEST_CASES = [
{
"query": "What is the capital of France?",
"context": "Paris is the capital and largest city of France. It has a population of over 2 million people.",
"must_contain": ["Paris"],
"must_not_contain": ["[1]", "[2]", "References"],
},
]
# =============================================================================
# Helper Functions
# =============================================================================
async def call_llm(prompt: str, model: str = "gpt-4o-mini") -> str:
"""Call OpenAI API with a single prompt."""
import openai
client = openai.AsyncOpenAI()
response = await client.chat.completions.create(
model=model,
messages=[{"role": "user", "content": prompt}],
temperature=0.0,
)
return response.choices[0].message.content
@dataclass
class TestResult:
name: str
passed: bool
details: str
raw_output: str = ""
# =============================================================================
# Test Functions
# =============================================================================
async def test_keywords_extraction() -> list[TestResult]:
"""Test keywords extraction prompt."""
results = []
examples = "\n".join(PROMPTS["keywords_extraction_examples"])
for case in KEYWORD_TEST_QUERIES:
prompt = PROMPTS["keywords_extraction"].format(
examples=examples,
query=case["query"]
)
output = await call_llm(prompt)
# Try to parse JSON
try:
# Clean potential markdown
clean = output.strip()
if clean.startswith("```"):
clean = clean.split("```")[1]
if clean.startswith("json"):
clean = clean[4:]
parsed = json.loads(clean)
has_high = "high_level_keywords" in parsed
has_low = "low_level_keywords" in parsed
is_list_high = isinstance(parsed.get("high_level_keywords"), list)
is_list_low = isinstance(parsed.get("low_level_keywords"), list)
if has_high and has_low and is_list_high and is_list_low:
# Check if trivial query returns empty
if case["expected_high"] == [] and case["expected_low"] == []:
passed = len(parsed["high_level_keywords"]) == 0 and len(parsed["low_level_keywords"]) == 0
details = "Empty lists returned for trivial query" if passed else f"Non-empty for trivial: {parsed}"
else:
# Check that some expected keywords are present (case-insensitive)
high_lower = [k.lower() for k in parsed["high_level_keywords"]]
low_lower = [k.lower() for k in parsed["low_level_keywords"]]
all_keywords = " ".join(high_lower + low_lower)
found_high = sum(1 for exp in case["expected_high"] if exp.lower() in all_keywords)
found_low = sum(1 for exp in case["expected_low"] if exp.lower() in all_keywords)
passed = found_high > 0 or found_low > 0
details = f"Found {found_high}/{len(case['expected_high'])} high, {found_low}/{len(case['expected_low'])} low"
else:
passed = False
details = f"Missing keys or wrong types: has_high={has_high}, has_low={has_low}"
except json.JSONDecodeError as e:
passed = False
details = f"JSON parse error: {e}"
results.append(TestResult(
name=f"Keywords: {case['query'][:40]}...",
passed=passed,
details=details,
raw_output=output[:200]
))
return results
async def test_orphan_validation() -> list[TestResult]:
"""Test orphan connection validation prompt."""
results = []
for case in ORPHAN_TEST_CASES:
prompt = PROMPTS["orphan_connection_validation"].format(
orphan_name=case["orphan"]["name"],
orphan_type=case["orphan"]["type"],
orphan_description=case["orphan"]["desc"],
candidate_name=case["candidate"]["name"],
candidate_type=case["candidate"]["type"],
candidate_description=case["candidate"]["desc"],
similarity_score=0.85,
)
output = await call_llm(prompt)
try:
# Clean potential markdown
clean = output.strip()
if clean.startswith("```"):
clean = clean.split("```")[1]
if clean.startswith("json"):
clean = clean[4:]
parsed = json.loads(clean)
has_should_connect = "should_connect" in parsed
has_confidence = "confidence" in parsed
has_reasoning = "reasoning" in parsed
if has_should_connect and has_confidence and has_reasoning:
correct_decision = parsed["should_connect"] == case["should_connect"]
valid_confidence = 0.0 <= parsed["confidence"] <= 1.0
passed = correct_decision and valid_confidence
details = f"Decision: {parsed['should_connect']} (expected {case['should_connect']}), confidence: {parsed['confidence']:.2f}"
else:
passed = False
details = f"Missing keys: should_connect={has_should_connect}, confidence={has_confidence}, reasoning={has_reasoning}"
except json.JSONDecodeError as e:
passed = False
details = f"JSON parse error: {e}"
results.append(TestResult(
name=f"Orphan: {case['orphan']['name']}{case['candidate']['name']}",
passed=passed,
details=details,
raw_output=output[:200]
))
return results
async def test_entity_summarization() -> list[TestResult]:
"""Test entity summarization prompt."""
results = []
for case in SUMMARIZATION_TEST_CASES:
prompt = PROMPTS["summarize_entity_descriptions"].format(
description_name=case["name"],
description_type=case["type"],
description_list="\n".join(case["descriptions"]),
summary_length=200,
language="English",
)
output = await call_llm(prompt)
# Check if required terms are present
output_lower = output.lower()
found = [term for term in case["must_contain"] if term.lower() in output_lower]
missing = [term for term in case["must_contain"] if term.lower() not in output_lower]
# Check it's not empty and mentions the entity
has_content = len(output.strip()) > 50
mentions_entity = case["name"].lower() in output_lower
passed = len(found) >= len(case["must_contain"]) // 2 and has_content and mentions_entity
details = f"Found {len(found)}/{len(case['must_contain'])} terms, mentions entity: {mentions_entity}"
if missing:
details += f", missing: {missing}"
results.append(TestResult(
name=f"Summarize: {case['name']}",
passed=passed,
details=details,
raw_output=output[:200]
))
return results
async def test_naive_rag_response() -> list[TestResult]:
"""Test naive RAG response prompt."""
results = []
for case in RAG_TEST_CASES:
prompt = PROMPTS["naive_rag_response"].format(
response_type="concise paragraph",
user_prompt=case["query"],
content_data=case["context"],
)
output = await call_llm(prompt)
# Check must_contain
output_lower = output.lower()
found = [term for term in case["must_contain"] if term.lower() in output_lower]
# Check must_not_contain (citation markers)
violations = [term for term in case["must_not_contain"] if term in output]
passed = len(found) == len(case["must_contain"]) and len(violations) == 0
details = f"Found {len(found)}/{len(case['must_contain'])} required terms"
if violations:
details += f", VIOLATIONS: {violations}"
results.append(TestResult(
name=f"RAG: {case['query'][:40]}",
passed=passed,
details=details,
raw_output=output[:200]
))
return results
# =============================================================================
# Main
# =============================================================================
async def main() -> None:
"""Run all accuracy tests."""
print("\n" + "=" * 70)
print(" PROMPT ACCURACY TESTS")
print("=" * 70)
all_results = []
# Run tests in parallel
print("\nRunning tests...")
keywords_results, orphan_results, summarize_results, rag_results = await asyncio.gather(
test_keywords_extraction(),
test_orphan_validation(),
test_entity_summarization(),
test_naive_rag_response(),
)
all_results.extend(keywords_results)
all_results.extend(orphan_results)
all_results.extend(summarize_results)
all_results.extend(rag_results)
# Print results
print("\n" + "-" * 70)
print(" RESULTS")
print("-" * 70)
passed = 0
failed = 0
for result in all_results:
status = "✓ PASS" if result.passed else "✗ FAIL"
print(f"\n{status}: {result.name}")
print(f" {result.details}")
if not result.passed:
print(f" Output: {result.raw_output}...")
if result.passed:
passed += 1
else:
failed += 1
# Summary
print("\n" + "=" * 70)
print(f" SUMMARY: {passed}/{passed + failed} tests passed")
print("=" * 70)
if failed > 0:
print("\n⚠️ Some tests failed - review prompt changes")
sys.exit(1)
else:
print("\n✓ All prompts producing correct outputs!")
if __name__ == "__main__":
asyncio.run(main())