LightRAG/tests/test_extraction_prompt_ab.py
clssck 59e89772de refactor: consolidate to PostgreSQL-only backend and modernize stack
Remove legacy storage implementations and deprecated examples:
- Delete FAISS, JSON, Memgraph, Milvus, MongoDB, Nano Vector DB, Neo4j, NetworkX, Qdrant, Redis storage backends
- Remove Kubernetes deployment manifests and installation scripts
- Delete unofficial examples for deprecated backends and offline deployment docs
Streamline core infrastructure:
- Consolidate storage layer to PostgreSQL-only implementation
- Add full-text search caching with FTS cache module
- Implement metrics collection and monitoring pipeline
- Add explain and metrics API routes
Modernize frontend and tooling:
- Switch web UI to Bun with bun.lock, remove npm and pnpm lockfiles
- Update Dockerfile for PostgreSQL-only deployment
- Add Makefile for common development tasks
- Update environment and configuration examples
Enhance evaluation and testing capabilities:
- Add prompt optimization with DSPy and auto-tuning
- Implement ground truth regeneration and variant testing
- Add prompt debugging and response comparison utilities
- Expand test coverage with new integration scenarios
Simplify dependencies and configuration:
- Remove offline-specific requirement files
- Update pyproject.toml with streamlined dependencies
- Add Python version pinning with .python-version
- Create project guidelines in CLAUDE.md and AGENTS.md
2025-12-12 16:28:49 +01:00

427 lines
15 KiB
Python

"""
A/B Test for Entity Extraction Prompts
Compares original vs optimized extraction prompts using real LLM calls.
Run with: pytest tests/test_extraction_prompt_ab.py -v --run-integration
Or directly: python tests/test_extraction_prompt_ab.py
"""
from __future__ import annotations
import asyncio
from dataclasses import dataclass
import pytest
import tiktoken
from lightrag.prompt_optimized import PROMPTS_OPTIMIZED
from lightrag.prompt import PROMPTS
# =============================================================================
# Sample Texts for Testing
# =============================================================================
SAMPLE_TEXTS = {
'covid_medical': {
'name': 'COVID-19 Medical',
'text': """
COVID-19, caused by the SARS-CoV-2 virus, emerged in Wuhan, China in late 2019.
The disease spreads primarily through respiratory droplets and can cause symptoms
ranging from mild fever and cough to severe pneumonia and acute respiratory distress
syndrome (ARDS). The World Health Organization declared it a pandemic on March 11, 2020.
Risk factors for severe disease include advanced age, obesity, and pre-existing
conditions such as diabetes and cardiovascular disease. Vaccines developed by Pfizer,
Moderna, and AstraZeneca have shown high efficacy in preventing severe illness.
""",
},
'financial_market': {
'name': 'Financial Markets',
'text': """
Stock markets faced a sharp downturn today as tech giants saw significant declines,
with the global tech index dropping by 3.4% in midday trading. Analysts attribute
the selloff to investor concerns over rising interest rates and regulatory uncertainty.
Among the hardest hit, Nexon Technologies saw its stock plummet by 7.8% after
reporting lower-than-expected quarterly earnings. In contrast, Omega Energy posted
a modest 2.1% gain, driven by rising oil prices.
Meanwhile, commodity markets reflected a mixed sentiment. Gold futures rose by 1.5%,
reaching $2,080 per ounce, as investors sought safe-haven assets. Crude oil prices
continued their rally, climbing to $87.60 per barrel, supported by supply constraints.
The Federal Reserve's upcoming policy announcement is expected to influence investor
confidence and overall market stability.
""",
},
'legal_regulatory': {
'name': 'Legal/Regulatory',
'text': """
The merger between Acme Corp and Beta Industries requires approval from the Federal
Trade Commission. Legal counsel advised that the deal may face antitrust scrutiny
due to market concentration concerns in the semiconductor industry.
The European Commission has also opened an investigation into the proposed acquisition,
citing potential impacts on competition in the EU market. Both companies have agreed
to divest certain assets to address regulatory concerns.
Industry analysts expect the approval process to take 12-18 months, with final
clearance dependent on remedies proposed by the merging parties.
""",
},
'narrative_fiction': {
'name': 'Narrative Fiction',
'text': """
While Alex clenched his jaw, the buzz of frustration dull against the backdrop of
Taylor's authoritarian certainty. It was this competitive undercurrent that kept him
alert, the sense that his and Jordan's shared commitment to discovery was an unspoken
rebellion against Cruz's narrowing vision of control and order.
Then Taylor did something unexpected. They paused beside Jordan and, for a moment,
observed the device with something akin to reverence. "If this tech can be understood..."
Taylor said, their voice quieter, "It could change the game for us. For all of us."
Jordan looked up, and for a fleeting heartbeat, their eyes locked with Taylor's,
a wordless clash of wills softening into an uneasy truce.
""",
},
}
# =============================================================================
# Data Classes
# =============================================================================
@dataclass
class ExtractionResult:
"""Result from a single extraction run."""
entities: list[dict]
relations: list[dict]
raw_output: str
input_tokens: int
format_errors: int
@property
def entity_count(self) -> int:
return len(self.entities)
@property
def relation_count(self) -> int:
return len(self.relations)
@property
def orphan_count(self) -> int:
"""Entities with no relationships."""
entity_names = {e['name'].lower() for e in self.entities}
connected = set()
for r in self.relations:
connected.add(r['source'].lower())
connected.add(r['target'].lower())
return len(entity_names - connected)
@property
def orphan_ratio(self) -> float:
if self.entity_count == 0:
return 0.0
return self.orphan_count / self.entity_count
@dataclass
class ComparisonResult:
"""Comparison between two extraction results."""
sample_name: str
original: ExtractionResult
optimized: ExtractionResult
def entity_diff_pct(self) -> float:
if self.original.entity_count == 0:
return 0.0
return (self.optimized.entity_count - self.original.entity_count) / self.original.entity_count * 100
def relation_diff_pct(self) -> float:
if self.original.relation_count == 0:
return 0.0
return (self.optimized.relation_count - self.original.relation_count) / self.original.relation_count * 100
def token_diff_pct(self) -> float:
if self.original.input_tokens == 0:
return 0.0
return (self.optimized.input_tokens - self.original.input_tokens) / self.original.input_tokens * 100
# =============================================================================
# Helper Functions
# =============================================================================
def format_prompt(
prompts: dict,
text: str,
entity_types: str = 'person, organization, location, concept, product, event, category, method',
) -> tuple[str, str]:
"""Format system and user prompts with the given text."""
tuple_delimiter = prompts['DEFAULT_TUPLE_DELIMITER']
completion_delimiter = prompts['DEFAULT_COMPLETION_DELIMITER']
# Format examples
examples = '\n'.join(prompts['entity_extraction_examples'])
examples = examples.format(
tuple_delimiter=tuple_delimiter,
completion_delimiter=completion_delimiter,
)
context = {
'tuple_delimiter': tuple_delimiter,
'completion_delimiter': completion_delimiter,
'entity_types': entity_types,
'language': 'English',
'examples': examples,
'input_text': text,
}
system_prompt = prompts['entity_extraction_system_prompt'].format(**context)
user_prompt = prompts['entity_extraction_user_prompt'].format(**context)
return system_prompt, user_prompt
def count_tokens(text: str) -> int:
"""Count tokens using tiktoken."""
enc = tiktoken.encoding_for_model('gpt-4')
return len(enc.encode(text))
async def call_llm(system_prompt: str, user_prompt: str, model: str = 'gpt-4o-mini') -> str:
"""Call OpenAI API with the given prompts."""
import openai
client = openai.AsyncOpenAI()
response = await client.chat.completions.create(
model=model,
messages=[
{'role': 'system', 'content': system_prompt},
{'role': 'user', 'content': user_prompt},
],
temperature=0.0,
)
if not response.choices or response.choices[0].message is None:
raise ValueError(f'Invalid response structure: {response}')
content = response.choices[0].message.content
if content is None:
raise ValueError(f'Empty content in response: {response}')
return content
def parse_extraction(output: str, tuple_delimiter: str = '<|#|>') -> tuple[list[dict], list[dict], int]:
"""Parse extraction output into entities and relations."""
entities = []
relations = []
format_errors = 0
for line in output.strip().split('\n'):
line = line.strip()
if not line or line.startswith('<|COMPLETE|>'):
continue
parts = line.split(tuple_delimiter)
if len(parts) >= 4 and parts[0].lower() == 'entity':
entities.append(
{
'name': parts[1].strip(),
'type': parts[2].strip(),
'description': parts[3].strip() if len(parts) > 3 else '',
}
)
elif len(parts) >= 5 and parts[0].lower() == 'relation':
relations.append(
{
'source': parts[1].strip(),
'target': parts[2].strip(),
'keywords': parts[3].strip(),
'description': parts[4].strip() if len(parts) > 4 else '',
}
)
elif line and not line.startswith('**') and tuple_delimiter in line:
# Line looks like it should be parsed but failed
format_errors += 1
return entities, relations, format_errors
async def run_extraction(prompts: dict, text: str) -> ExtractionResult:
"""Run extraction with the given prompts on the text."""
system_prompt, user_prompt = format_prompt(prompts, text)
input_tokens = count_tokens(system_prompt) + count_tokens(user_prompt)
output = await call_llm(system_prompt, user_prompt)
entities, relations, format_errors = parse_extraction(output)
return ExtractionResult(
entities=entities,
relations=relations,
raw_output=output,
input_tokens=input_tokens,
format_errors=format_errors,
)
def print_comparison_table(results: list[ComparisonResult]) -> None:
"""Print a formatted comparison table."""
print('\n' + '=' * 80)
print('ENTITY EXTRACTION PROMPT A/B COMPARISON')
print('=' * 80)
total_orig_entities = 0
total_opt_entities = 0
total_orig_relations = 0
total_opt_relations = 0
total_orig_tokens = 0
total_opt_tokens = 0
for r in results:
print(f'\n--- {r.sample_name} ---')
print(f'{"Metric":<20} {"Original":>12} {"Optimized":>12} {"Diff":>12}')
print('-' * 56)
print(
f'{"Entities":<20} {r.original.entity_count:>12} {r.optimized.entity_count:>12} {r.entity_diff_pct():>+11.0f}%'
)
print(
f'{"Relations":<20} {r.original.relation_count:>12} {r.optimized.relation_count:>12} {r.relation_diff_pct():>+11.0f}%'
)
print(f'{"Orphan Ratio":<20} {r.original.orphan_ratio:>11.0%} {r.optimized.orphan_ratio:>11.0%} {"":>12}')
print(f'{"Format Errors":<20} {r.original.format_errors:>12} {r.optimized.format_errors:>12}')
print(
f'{"Input Tokens":<20} {r.original.input_tokens:>12,} {r.optimized.input_tokens:>12,} {r.token_diff_pct():>+11.0f}%'
)
total_orig_entities += r.original.entity_count
total_opt_entities += r.optimized.entity_count
total_orig_relations += r.original.relation_count
total_opt_relations += r.optimized.relation_count
total_orig_tokens += r.original.input_tokens
total_opt_tokens += r.optimized.input_tokens
# Aggregate
print('\n' + '=' * 80)
print('AGGREGATE RESULTS')
print('=' * 80)
print(f'{"Metric":<20} {"Original":>12} {"Optimized":>12} {"Diff":>12}')
print('-' * 56)
ent_diff = (total_opt_entities - total_orig_entities) / total_orig_entities * 100 if total_orig_entities else 0
rel_diff = (total_opt_relations - total_orig_relations) / total_orig_relations * 100 if total_orig_relations else 0
tok_diff = (total_opt_tokens - total_orig_tokens) / total_orig_tokens * 100 if total_orig_tokens else 0
print(f'{"Total Entities":<20} {total_orig_entities:>12} {total_opt_entities:>12} {ent_diff:>+11.0f}%')
print(f'{"Total Relations":<20} {total_orig_relations:>12} {total_opt_relations:>12} {rel_diff:>+11.0f}%')
print(f'{"Total Input Tokens":<20} {total_orig_tokens:>12,} {total_opt_tokens:>12,} {tok_diff:>+11.0f}%')
# Recommendation
print('\n' + '-' * 56)
if tok_diff < -30 and ent_diff >= -10:
print('RECOMMENDATION: Use OPTIMIZED prompt (significant token savings, comparable extraction)')
elif ent_diff > 20 and tok_diff < 0:
print('RECOMMENDATION: Use OPTIMIZED prompt (better extraction AND token savings)')
elif ent_diff < -20:
print('RECOMMENDATION: Keep ORIGINAL prompt (optimized extracts significantly fewer entities)')
else:
print('RECOMMENDATION: Both prompts are comparable - consider token cost vs extraction breadth')
print('=' * 80 + '\n')
# =============================================================================
# Pytest Tests
# =============================================================================
class TestExtractionPromptAB:
"""A/B testing for entity extraction prompts."""
@pytest.mark.integration
@pytest.mark.asyncio
async def test_compare_all_samples(self) -> None:
"""Compare prompts across all sample texts."""
results = []
for _key, sample in SAMPLE_TEXTS.items():
print(f'\nProcessing: {sample["name"]}...')
original = await run_extraction(PROMPTS, sample['text'])
optimized = await run_extraction(PROMPTS_OPTIMIZED, sample['text'])
results.append(
ComparisonResult(
sample_name=sample['name'],
original=original,
optimized=optimized,
)
)
print_comparison_table(results)
# Basic assertions
for r in results:
assert r.original.format_errors == 0, f'Original had format errors on {r.sample_name}'
assert r.optimized.format_errors == 0, f'Optimized had format errors on {r.sample_name}'
@pytest.mark.integration
@pytest.mark.asyncio
async def test_single_sample(self) -> None:
"""Quick test with just one sample."""
sample = SAMPLE_TEXTS['covid_medical']
original = await run_extraction(PROMPTS, sample['text'])
optimized = await run_extraction(PROMPTS_OPTIMIZED, sample['text'])
result = ComparisonResult(
sample_name=sample['name'],
original=original,
optimized=optimized,
)
print_comparison_table([result])
assert original.entity_count > 0, 'Original should extract entities'
assert optimized.entity_count > 0, 'Optimized should extract entities'
assert optimized.format_errors == 0, 'Optimized should have no format errors'
# =============================================================================
# CLI Runner
# =============================================================================
async def main() -> None:
"""Run A/B comparison from command line."""
print('Starting Entity Extraction Prompt A/B Test...')
print('This will make real API calls to OpenAI.\n')
results = []
for _key, sample in SAMPLE_TEXTS.items():
print(f'Processing: {sample["name"]}...')
original = await run_extraction(PROMPTS, sample['text'])
optimized = await run_extraction(PROMPTS_OPTIMIZED, sample['text'])
results.append(
ComparisonResult(
sample_name=sample['name'],
original=original,
optimized=optimized,
)
)
print_comparison_table(results)
if __name__ == '__main__':
asyncio.run(main())