LightRAG/tests/test_extraction_prompt_ab.py
clssck 082a5a8fad test(lightrag,api): add comprehensive test coverage and S3 support
Add extensive test suites for API routes and utilities:
- Implement test_search_routes.py (406 lines) for search endpoint validation
- Implement test_upload_routes.py (724 lines) for document upload workflows
- Implement test_s3_client.py (618 lines) for S3 storage operations
- Implement test_citation_utils.py (352 lines) for citation extraction
- Implement test_chunking.py (216 lines) for text chunking validation
Add S3 storage client implementation:
- Create lightrag/storage/s3_client.py with S3 operations
- Add storage module initialization with exports
- Integrate S3 client with document upload handling
Enhance API routes and core functionality:
- Add search_routes.py with full-text and graph search endpoints
- Add upload_routes.py with multipart document upload support
- Update operate.py with bulk operations and health checks
- Enhance postgres_impl.py with bulk upsert and parameterized queries
- Update lightrag_server.py to register new API routes
- Improve utils.py with citation and formatting utilities
Update dependencies and configuration:
- Add S3 and test dependencies to pyproject.toml
- Update docker-compose.test.yml for testing environment
- Sync uv.lock with new dependencies
Apply code quality improvements across all modified files:
- Add type hints to function signatures
- Update imports and router initialization
- Fix logging and error handling
2025-12-05 23:13:39 +01:00

432 lines
15 KiB
Python

"""
A/B Test for Entity Extraction Prompts
Compares original vs optimized extraction prompts using real LLM calls.
Run with: pytest tests/test_extraction_prompt_ab.py -v --run-integration
Or directly: python tests/test_extraction_prompt_ab.py
"""
from __future__ import annotations
import asyncio
import sys
from dataclasses import dataclass
from pathlib import Path
import pytest
import tiktoken
# Add project root to path for imports
sys.path.insert(0, str(Path(__file__).parent.parent))
from lightrag.prompt_optimized import PROMPTS_OPTIMIZED
from lightrag.prompt import PROMPTS
# =============================================================================
# Sample Texts for Testing
# =============================================================================
SAMPLE_TEXTS = {
"covid_medical": {
"name": "COVID-19 Medical",
"text": """
COVID-19, caused by the SARS-CoV-2 virus, emerged in Wuhan, China in late 2019.
The disease spreads primarily through respiratory droplets and can cause symptoms
ranging from mild fever and cough to severe pneumonia and acute respiratory distress
syndrome (ARDS). The World Health Organization declared it a pandemic on March 11, 2020.
Risk factors for severe disease include advanced age, obesity, and pre-existing
conditions such as diabetes and cardiovascular disease. Vaccines developed by Pfizer,
Moderna, and AstraZeneca have shown high efficacy in preventing severe illness.
""",
},
"financial_market": {
"name": "Financial Markets",
"text": """
Stock markets faced a sharp downturn today as tech giants saw significant declines,
with the global tech index dropping by 3.4% in midday trading. Analysts attribute
the selloff to investor concerns over rising interest rates and regulatory uncertainty.
Among the hardest hit, Nexon Technologies saw its stock plummet by 7.8% after
reporting lower-than-expected quarterly earnings. In contrast, Omega Energy posted
a modest 2.1% gain, driven by rising oil prices.
Meanwhile, commodity markets reflected a mixed sentiment. Gold futures rose by 1.5%,
reaching $2,080 per ounce, as investors sought safe-haven assets. Crude oil prices
continued their rally, climbing to $87.60 per barrel, supported by supply constraints.
The Federal Reserve's upcoming policy announcement is expected to influence investor
confidence and overall market stability.
""",
},
"legal_regulatory": {
"name": "Legal/Regulatory",
"text": """
The merger between Acme Corp and Beta Industries requires approval from the Federal
Trade Commission. Legal counsel advised that the deal may face antitrust scrutiny
due to market concentration concerns in the semiconductor industry.
The European Commission has also opened an investigation into the proposed acquisition,
citing potential impacts on competition in the EU market. Both companies have agreed
to divest certain assets to address regulatory concerns.
Industry analysts expect the approval process to take 12-18 months, with final
clearance dependent on remedies proposed by the merging parties.
""",
},
"narrative_fiction": {
"name": "Narrative Fiction",
"text": """
While Alex clenched his jaw, the buzz of frustration dull against the backdrop of
Taylor's authoritarian certainty. It was this competitive undercurrent that kept him
alert, the sense that his and Jordan's shared commitment to discovery was an unspoken
rebellion against Cruz's narrowing vision of control and order.
Then Taylor did something unexpected. They paused beside Jordan and, for a moment,
observed the device with something akin to reverence. "If this tech can be understood..."
Taylor said, their voice quieter, "It could change the game for us. For all of us."
Jordan looked up, and for a fleeting heartbeat, their eyes locked with Taylor's,
a wordless clash of wills softening into an uneasy truce.
""",
},
}
# =============================================================================
# Data Classes
# =============================================================================
@dataclass
class ExtractionResult:
"""Result from a single extraction run."""
entities: list[dict]
relations: list[dict]
raw_output: str
input_tokens: int
format_errors: int
@property
def entity_count(self) -> int:
return len(self.entities)
@property
def relation_count(self) -> int:
return len(self.relations)
@property
def orphan_count(self) -> int:
"""Entities with no relationships."""
entity_names = {e["name"].lower() for e in self.entities}
connected = set()
for r in self.relations:
connected.add(r["source"].lower())
connected.add(r["target"].lower())
return len(entity_names - connected)
@property
def orphan_ratio(self) -> float:
if self.entity_count == 0:
return 0.0
return self.orphan_count / self.entity_count
@dataclass
class ComparisonResult:
"""Comparison between two extraction results."""
sample_name: str
original: ExtractionResult
optimized: ExtractionResult
def entity_diff_pct(self) -> float:
if self.original.entity_count == 0:
return 0.0
return (
(self.optimized.entity_count - self.original.entity_count)
/ self.original.entity_count
* 100
)
def relation_diff_pct(self) -> float:
if self.original.relation_count == 0:
return 0.0
return (
(self.optimized.relation_count - self.original.relation_count)
/ self.original.relation_count
* 100
)
def token_diff_pct(self) -> float:
if self.original.input_tokens == 0:
return 0.0
return (
(self.optimized.input_tokens - self.original.input_tokens)
/ self.original.input_tokens
* 100
)
# =============================================================================
# Helper Functions
# =============================================================================
def format_prompt(
prompts: dict, text: str, entity_types: str = "person, organization, location, concept, product, event, category, method"
) -> tuple[str, str]:
"""Format system and user prompts with the given text."""
tuple_delimiter = prompts["DEFAULT_TUPLE_DELIMITER"]
completion_delimiter = prompts["DEFAULT_COMPLETION_DELIMITER"]
# Format examples
examples = "\n".join(prompts["entity_extraction_examples"])
examples = examples.format(
tuple_delimiter=tuple_delimiter,
completion_delimiter=completion_delimiter,
)
context = {
"tuple_delimiter": tuple_delimiter,
"completion_delimiter": completion_delimiter,
"entity_types": entity_types,
"language": "English",
"examples": examples,
"input_text": text,
}
system_prompt = prompts["entity_extraction_system_prompt"].format(**context)
user_prompt = prompts["entity_extraction_user_prompt"].format(**context)
return system_prompt, user_prompt
def count_tokens(text: str) -> int:
"""Count tokens using tiktoken."""
enc = tiktoken.encoding_for_model("gpt-4")
return len(enc.encode(text))
async def call_llm(system_prompt: str, user_prompt: str, model: str = "gpt-4o-mini") -> str:
"""Call OpenAI API with the given prompts."""
import openai
client = openai.AsyncOpenAI()
response = await client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
temperature=0.0,
)
return response.choices[0].message.content
def parse_extraction(output: str, tuple_delimiter: str = "<|#|>") -> tuple[list[dict], list[dict], int]:
"""Parse extraction output into entities and relations."""
entities = []
relations = []
format_errors = 0
for line in output.strip().split("\n"):
line = line.strip()
if not line or line.startswith("<|COMPLETE|>"):
continue
parts = line.split(tuple_delimiter)
if len(parts) >= 4 and parts[0].lower() == "entity":
entities.append(
{
"name": parts[1].strip(),
"type": parts[2].strip(),
"description": parts[3].strip() if len(parts) > 3 else "",
}
)
elif len(parts) >= 5 and parts[0].lower() == "relation":
relations.append(
{
"source": parts[1].strip(),
"target": parts[2].strip(),
"keywords": parts[3].strip(),
"description": parts[4].strip() if len(parts) > 4 else "",
}
)
elif line and not line.startswith("**") and tuple_delimiter in line:
# Line looks like it should be parsed but failed
format_errors += 1
return entities, relations, format_errors
async def run_extraction(prompts: dict, text: str) -> ExtractionResult:
"""Run extraction with the given prompts on the text."""
system_prompt, user_prompt = format_prompt(prompts, text)
input_tokens = count_tokens(system_prompt) + count_tokens(user_prompt)
output = await call_llm(system_prompt, user_prompt)
entities, relations, format_errors = parse_extraction(output)
return ExtractionResult(
entities=entities,
relations=relations,
raw_output=output,
input_tokens=input_tokens,
format_errors=format_errors,
)
def print_comparison_table(results: list[ComparisonResult]) -> None:
"""Print a formatted comparison table."""
print("\n" + "=" * 80)
print("ENTITY EXTRACTION PROMPT A/B COMPARISON")
print("=" * 80)
total_orig_entities = 0
total_opt_entities = 0
total_orig_relations = 0
total_opt_relations = 0
total_orig_tokens = 0
total_opt_tokens = 0
for r in results:
print(f"\n--- {r.sample_name} ---")
print(f"{'Metric':<20} {'Original':>12} {'Optimized':>12} {'Diff':>12}")
print("-" * 56)
print(f"{'Entities':<20} {r.original.entity_count:>12} {r.optimized.entity_count:>12} {r.entity_diff_pct():>+11.0f}%")
print(f"{'Relations':<20} {r.original.relation_count:>12} {r.optimized.relation_count:>12} {r.relation_diff_pct():>+11.0f}%")
print(f"{'Orphan Ratio':<20} {r.original.orphan_ratio:>11.0%} {r.optimized.orphan_ratio:>11.0%} {'':>12}")
print(f"{'Format Errors':<20} {r.original.format_errors:>12} {r.optimized.format_errors:>12}")
print(f"{'Input Tokens':<20} {r.original.input_tokens:>12,} {r.optimized.input_tokens:>12,} {r.token_diff_pct():>+11.0f}%")
total_orig_entities += r.original.entity_count
total_opt_entities += r.optimized.entity_count
total_orig_relations += r.original.relation_count
total_opt_relations += r.optimized.relation_count
total_orig_tokens += r.original.input_tokens
total_opt_tokens += r.optimized.input_tokens
# Aggregate
print("\n" + "=" * 80)
print("AGGREGATE RESULTS")
print("=" * 80)
print(f"{'Metric':<20} {'Original':>12} {'Optimized':>12} {'Diff':>12}")
print("-" * 56)
ent_diff = (total_opt_entities - total_orig_entities) / total_orig_entities * 100 if total_orig_entities else 0
rel_diff = (total_opt_relations - total_orig_relations) / total_orig_relations * 100 if total_orig_relations else 0
tok_diff = (total_opt_tokens - total_orig_tokens) / total_orig_tokens * 100 if total_orig_tokens else 0
print(f"{'Total Entities':<20} {total_orig_entities:>12} {total_opt_entities:>12} {ent_diff:>+11.0f}%")
print(f"{'Total Relations':<20} {total_orig_relations:>12} {total_opt_relations:>12} {rel_diff:>+11.0f}%")
print(f"{'Total Input Tokens':<20} {total_orig_tokens:>12,} {total_opt_tokens:>12,} {tok_diff:>+11.0f}%")
# Recommendation
print("\n" + "-" * 56)
if tok_diff < -30 and ent_diff >= -10:
print("RECOMMENDATION: Use OPTIMIZED prompt (significant token savings, comparable extraction)")
elif ent_diff > 20 and tok_diff < 0:
print("RECOMMENDATION: Use OPTIMIZED prompt (better extraction AND token savings)")
elif ent_diff < -20:
print("RECOMMENDATION: Keep ORIGINAL prompt (optimized extracts significantly fewer entities)")
else:
print("RECOMMENDATION: Both prompts are comparable - consider token cost vs extraction breadth")
print("=" * 80 + "\n")
# =============================================================================
# Pytest Tests
# =============================================================================
class TestExtractionPromptAB:
"""A/B testing for entity extraction prompts."""
@pytest.mark.integration
@pytest.mark.asyncio
async def test_compare_all_samples(self) -> None:
"""Compare prompts across all sample texts."""
results = []
for _key, sample in SAMPLE_TEXTS.items():
print(f"\nProcessing: {sample['name']}...")
original = await run_extraction(PROMPTS, sample["text"])
optimized = await run_extraction(PROMPTS_OPTIMIZED, sample["text"])
results.append(
ComparisonResult(
sample_name=sample["name"],
original=original,
optimized=optimized,
)
)
print_comparison_table(results)
# Basic assertions
for r in results:
assert r.original.format_errors == 0, f"Original had format errors on {r.sample_name}"
assert r.optimized.format_errors == 0, f"Optimized had format errors on {r.sample_name}"
@pytest.mark.integration
@pytest.mark.asyncio
async def test_single_sample(self) -> None:
"""Quick test with just one sample."""
sample = SAMPLE_TEXTS["covid_medical"]
original = await run_extraction(PROMPTS, sample["text"])
optimized = await run_extraction(PROMPTS_OPTIMIZED, sample["text"])
result = ComparisonResult(
sample_name=sample["name"],
original=original,
optimized=optimized,
)
print_comparison_table([result])
assert original.entity_count > 0, "Original should extract entities"
assert optimized.entity_count > 0, "Optimized should extract entities"
assert optimized.format_errors == 0, "Optimized should have no format errors"
# =============================================================================
# CLI Runner
# =============================================================================
async def main() -> None:
"""Run A/B comparison from command line."""
print("Starting Entity Extraction Prompt A/B Test...")
print("This will make real API calls to OpenAI.\n")
results = []
for _key, sample in SAMPLE_TEXTS.items():
print(f"Processing: {sample['name']}...")
original = await run_extraction(PROMPTS, sample["text"])
optimized = await run_extraction(PROMPTS_OPTIMIZED, sample["text"])
results.append(
ComparisonResult(
sample_name=sample["name"],
original=original,
optimized=optimized,
)
)
print_comparison_table(results)
if __name__ == "__main__":
asyncio.run(main())