LightRAG/tests/run_quality_tests.py
clssck d2c9e6e2ec test(lightrag): add orphan connection feature with quality validation tests
Implement automatic orphan entity connection system that identifies entities with
no relationships and creates meaningful connections via vector similarity + LLM
validation. This improves knowledge graph connectivity and retrieval quality.
Changes:
- Add orphan connection configuration parameters (thresholds, cross-connect settings)
- Implement aconnect_orphan_entities() method with 4-step validation pipeline
- Add SQL templates for efficient orphan and candidate entity queries
- Create POST /graph/orphans/connect API endpoint with configurable parameters
- Add orphan connection validation prompt for LLM-based relationship verification
- Include relationship density requirement in extraction prompts to prevent orphans
- Update docker-compose.test.yml with optimized extraction parameters
- Add quality validation test suite (run_quality_tests.py) for retrieval evaluation
- Add unit test framework (test_orphan_connection_quality.py) with test cases
- Enable auto-run of orphan connection after document processing
2025-11-28 18:23:30 +01:00

294 lines
10 KiB
Python

#!/usr/bin/env python3
"""
Orphan Connection Quality Validation Script
Runs actual queries against LightRAG and analyzes whether orphan connections
improve or poison retrieval quality.
"""
import asyncio
import httpx
import json
from dataclasses import dataclass
API_BASE = "http://localhost:9622"
@dataclass
class TestResult:
query: str
expected: list[str]
unexpected: list[str]
retrieved_entities: list[str]
precision: float
recall: float
noise_count: int
passed: bool
details: str
TEST_CASES = [
# Test 1: Neural Network Types (PRECISION)
# Note: "Quantum" may appear legitimately due to "Quantum Machine Learning" being a real field
{
"query": "What types of neural networks are used in deep learning?",
"expected": ["Neural Networks", "Convolutional Neural Network",
"Recurrent Neural Network", "Transformer"],
"unexpected": ["FDA", "Atopic Dermatitis", "Vehicle Emissions Standards"], # Truly unrelated
"category": "precision",
"description": "Should retrieve NN types via orphan connections (CNN->NN, RNN->NN)"
},
# Test 2: Quantum Companies (RECALL)
{
"query": "What companies are working on quantum computing?",
"expected": ["IonQ", "Microsoft", "Google", "IBM"],
"unexpected": ["FDA", "Atopic Dermatitis"], # Medical domain unrelated
"category": "recall",
"description": "Should find IonQ (via Trapped Ions) and Microsoft (via Topological Qubits)"
},
# Test 3: Greenhouse Gases (RECALL)
# Note: "Quantum" may appear due to "climate simulation via quantum computing" being valid
{
"query": "What are greenhouse gases?",
"expected": ["Carbon Dioxide", "CO2", "Methane", "CH4", "Nitrous Oxide", "N2O", "Fluorinated"],
"unexpected": ["FDA", "Atopic Dermatitis", "IonQ"], # Medical/specific tech unrelated
"category": "recall",
"description": "Should retrieve all GHGs via orphan connections forming a cluster"
},
# Test 4: Reinforcement Learning (NOISE)
# Note: Cross-domain mentions like "climate modeling" may appear from original docs
{
"query": "What is reinforcement learning?",
"expected": ["Reinforcement Learning", "Machine Learning"],
"unexpected": ["FDA", "Atopic Dermatitis", "Dupixent"], # Medical domain truly unrelated
"category": "noise",
"description": "Should NOT pull in truly unrelated medical domain"
},
# Test 5: Computer Vision (NOISE)
# Note: Drug Discovery may appear due to "medical imaging" being a CV application
{
"query": "How does computer vision work?",
"expected": ["Computer Vision", "Image", "Object", "Feature", "Edge Detection"],
"unexpected": ["FDA", "Atopic Dermatitis", "Kyoto Protocol"], # Truly unrelated domains
"category": "noise",
"description": "Should retrieve CV techniques, not truly unrelated domains"
},
# Test 6: Amazon Cross-Domain Check (EDGE CASE)
{
"query": "What is Amazon?",
"expected": ["Amazon"],
"unexpected": ["FDA", "Atopic Dermatitis"], # Medical domain unrelated to tech company
"category": "edge_case",
"description": "Check if Amazon->Microsoft connection causes retrieval issues"
},
# Test 7: Medical Domain Isolation (STRICT NOISE TEST)
{
"query": "What is Dupixent used for?",
"expected": ["Dupixent", "Atopic Dermatitis", "FDA"],
"unexpected": ["Neural Networks", "Quantum Computing", "Climate Change", "IonQ"],
"category": "noise",
"description": "Medical query should NOT retrieve tech/climate domains"
},
]
async def run_query(query: str, mode: str = "local") -> dict:
"""Run a query against LightRAG API."""
async with httpx.AsyncClient(timeout=60.0) as client:
response = await client.post(
f"{API_BASE}/query",
json={
"query": query,
"mode": mode,
"only_need_context": True
}
)
return response.json()
def extract_entities_from_context(context: str) -> list[str]:
"""Extract entity names from the context string."""
entities = []
# Look for entity patterns in the context
lines = context.split('\n')
for line in lines:
# Entity lines often start with entity names in quotes or bold
if 'Entity:' in line or line.startswith('-'):
# Extract potential entity name
parts = line.split(':')
if len(parts) > 1:
entity = parts[1].strip().strip('"').strip("'")
if entity and len(entity) > 2:
entities.append(entity)
return entities
async def evaluate_test_case(test_case: dict) -> TestResult:
"""Evaluate a single test case."""
query = test_case["query"]
expected = test_case["expected"]
unexpected = test_case["unexpected"]
try:
result = await run_query(query)
response_text = result.get("response", "")
# Check which expected entities appear in the response
found_expected = []
missed_expected = []
for entity in expected:
# Case-insensitive partial match
if entity.lower() in response_text.lower():
found_expected.append(entity)
else:
missed_expected.append(entity)
# Check for unexpected (noise) entities
found_unexpected = []
for entity in unexpected:
if entity.lower() in response_text.lower():
found_unexpected.append(entity)
# Calculate metrics
precision = len(found_expected) / len(expected) if expected else 1.0
recall = len(found_expected) / len(expected) if expected else 1.0
noise_count = len(found_unexpected)
# Pass criteria: recall > 50% AND no noise detected
passed = recall >= 0.5 and noise_count == 0
details = f"Found: {found_expected} | Missed: {missed_expected} | Noise: {found_unexpected}"
return TestResult(
query=query,
expected=expected,
unexpected=unexpected,
retrieved_entities=found_expected,
precision=precision,
recall=recall,
noise_count=noise_count,
passed=passed,
details=details
)
except Exception as e:
return TestResult(
query=query,
expected=expected,
unexpected=unexpected,
retrieved_entities=[],
precision=0.0,
recall=0.0,
noise_count=0,
passed=False,
details=f"Error: {str(e)}"
)
async def get_graph_stats() -> dict:
"""Get current graph statistics."""
async with httpx.AsyncClient(timeout=30.0) as client:
health = await client.get(f"{API_BASE}/health")
graph = await client.get(f"{API_BASE}/graphs?label=*&max_depth=0&max_nodes=1000")
graph_data = graph.json()
nodes = graph_data.get("nodes", [])
edges = graph_data.get("edges", [])
# Count orphans (nodes with no edges)
node_ids = {n["id"] for n in nodes}
connected_ids = set()
for e in edges:
connected_ids.add(e.get("source"))
connected_ids.add(e.get("target"))
orphan_ids = node_ids - connected_ids
return {
"total_nodes": len(nodes),
"total_edges": len(edges),
"orphan_count": len(orphan_ids),
"orphan_rate": len(orphan_ids) / len(nodes) if nodes else 0
}
async def main():
print("=" * 60)
print("ORPHAN CONNECTION QUALITY VALIDATION")
print("=" * 60)
# Get graph stats first
try:
stats = await get_graph_stats()
print(f"\n📊 Current Graph Statistics:")
print(f" Nodes: {stats['total_nodes']}")
print(f" Edges: {stats['total_edges']}")
print(f" Orphans: {stats['orphan_count']} ({stats['orphan_rate']:.1%})")
except Exception as e:
print(f"⚠️ Could not get graph stats: {e}")
print("\n" + "-" * 60)
print("Running Quality Tests...")
print("-" * 60)
results = []
for i, test_case in enumerate(TEST_CASES, 1):
print(f"\n🧪 Test {i}: {test_case['category'].upper()} - {test_case['description']}")
print(f" Query: \"{test_case['query']}\"")
result = await evaluate_test_case(test_case)
results.append(result)
status = "✅ PASS" if result.passed else "❌ FAIL"
print(f" {status}")
print(f" Recall: {result.recall:.0%} | Noise: {result.noise_count}")
print(f" {result.details}")
# Summary
print("\n" + "=" * 60)
print("SUMMARY")
print("=" * 60)
passed = sum(1 for r in results if r.passed)
total = len(results)
avg_recall = sum(r.recall for r in results) / len(results)
total_noise = sum(r.noise_count for r in results)
print(f"\n📈 Results: {passed}/{total} tests passed ({passed/total:.0%})")
print(f"📈 Average Recall: {avg_recall:.0%}")
print(f"📈 Total Noise Instances: {total_noise}")
# Category breakdown
categories = {}
for r, tc in zip(results, TEST_CASES):
cat = tc["category"]
if cat not in categories:
categories[cat] = {"passed": 0, "total": 0}
categories[cat]["total"] += 1
if r.passed:
categories[cat]["passed"] += 1
print("\n📊 By Category:")
for cat, data in categories.items():
print(f" {cat.upper()}: {data['passed']}/{data['total']}")
# Verdict
print("\n" + "-" * 60)
if total_noise == 0 and avg_recall >= 0.6:
print("✅ VERDICT: Orphan connections are IMPROVING retrieval")
print(" - No cross-domain pollution detected")
print(" - Good recall on expected entities")
elif total_noise > 0:
print("⚠️ VERDICT: Orphan connections MAY BE POISONING retrieval")
print(f" - {total_noise} noise instances detected")
print(" - Review the connections causing cross-domain bleed")
else:
print("⚠️ VERDICT: Orphan connections have MIXED results")
print(" - Recall could be improved")
print(" - No significant noise detected")
print("-" * 60)
if __name__ == "__main__":
asyncio.run(main())