Remove legacy storage implementations and deprecated examples: - Delete FAISS, JSON, Memgraph, Milvus, MongoDB, Nano Vector DB, Neo4j, NetworkX, Qdrant, Redis storage backends - Remove Kubernetes deployment manifests and installation scripts - Delete unofficial examples for deprecated backends and offline deployment docs Streamline core infrastructure: - Consolidate storage layer to PostgreSQL-only implementation - Add full-text search caching with FTS cache module - Implement metrics collection and monitoring pipeline - Add explain and metrics API routes Modernize frontend and tooling: - Switch web UI to Bun with bun.lock, remove npm and pnpm lockfiles - Update Dockerfile for PostgreSQL-only deployment - Add Makefile for common development tasks - Update environment and configuration examples Enhance evaluation and testing capabilities: - Add prompt optimization with DSPy and auto-tuning - Implement ground truth regeneration and variant testing - Add prompt debugging and response comparison utilities - Expand test coverage with new integration scenarios Simplify dependencies and configuration: - Remove offline-specific requirement files - Update pyproject.toml with streamlined dependencies - Add Python version pinning with .python-version - Create project guidelines in CLAUDE.md and AGENTS.md
427 lines
15 KiB
Python
427 lines
15 KiB
Python
"""
|
|
A/B Test for Entity Extraction Prompts
|
|
|
|
Compares original vs optimized extraction prompts using real LLM calls.
|
|
Run with: pytest tests/test_extraction_prompt_ab.py -v --run-integration
|
|
Or directly: python tests/test_extraction_prompt_ab.py
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from dataclasses import dataclass
|
|
|
|
import pytest
|
|
import tiktoken
|
|
from lightrag.prompt_optimized import PROMPTS_OPTIMIZED
|
|
|
|
from lightrag.prompt import PROMPTS
|
|
|
|
# =============================================================================
|
|
# Sample Texts for Testing
|
|
# =============================================================================
|
|
|
|
SAMPLE_TEXTS = {
|
|
'covid_medical': {
|
|
'name': 'COVID-19 Medical',
|
|
'text': """
|
|
COVID-19, caused by the SARS-CoV-2 virus, emerged in Wuhan, China in late 2019.
|
|
The disease spreads primarily through respiratory droplets and can cause symptoms
|
|
ranging from mild fever and cough to severe pneumonia and acute respiratory distress
|
|
syndrome (ARDS). The World Health Organization declared it a pandemic on March 11, 2020.
|
|
|
|
Risk factors for severe disease include advanced age, obesity, and pre-existing
|
|
conditions such as diabetes and cardiovascular disease. Vaccines developed by Pfizer,
|
|
Moderna, and AstraZeneca have shown high efficacy in preventing severe illness.
|
|
""",
|
|
},
|
|
'financial_market': {
|
|
'name': 'Financial Markets',
|
|
'text': """
|
|
Stock markets faced a sharp downturn today as tech giants saw significant declines,
|
|
with the global tech index dropping by 3.4% in midday trading. Analysts attribute
|
|
the selloff to investor concerns over rising interest rates and regulatory uncertainty.
|
|
|
|
Among the hardest hit, Nexon Technologies saw its stock plummet by 7.8% after
|
|
reporting lower-than-expected quarterly earnings. In contrast, Omega Energy posted
|
|
a modest 2.1% gain, driven by rising oil prices.
|
|
|
|
Meanwhile, commodity markets reflected a mixed sentiment. Gold futures rose by 1.5%,
|
|
reaching $2,080 per ounce, as investors sought safe-haven assets. Crude oil prices
|
|
continued their rally, climbing to $87.60 per barrel, supported by supply constraints.
|
|
|
|
The Federal Reserve's upcoming policy announcement is expected to influence investor
|
|
confidence and overall market stability.
|
|
""",
|
|
},
|
|
'legal_regulatory': {
|
|
'name': 'Legal/Regulatory',
|
|
'text': """
|
|
The merger between Acme Corp and Beta Industries requires approval from the Federal
|
|
Trade Commission. Legal counsel advised that the deal may face antitrust scrutiny
|
|
due to market concentration concerns in the semiconductor industry.
|
|
|
|
The European Commission has also opened an investigation into the proposed acquisition,
|
|
citing potential impacts on competition in the EU market. Both companies have agreed
|
|
to divest certain assets to address regulatory concerns.
|
|
|
|
Industry analysts expect the approval process to take 12-18 months, with final
|
|
clearance dependent on remedies proposed by the merging parties.
|
|
""",
|
|
},
|
|
'narrative_fiction': {
|
|
'name': 'Narrative Fiction',
|
|
'text': """
|
|
While Alex clenched his jaw, the buzz of frustration dull against the backdrop of
|
|
Taylor's authoritarian certainty. It was this competitive undercurrent that kept him
|
|
alert, the sense that his and Jordan's shared commitment to discovery was an unspoken
|
|
rebellion against Cruz's narrowing vision of control and order.
|
|
|
|
Then Taylor did something unexpected. They paused beside Jordan and, for a moment,
|
|
observed the device with something akin to reverence. "If this tech can be understood..."
|
|
Taylor said, their voice quieter, "It could change the game for us. For all of us."
|
|
|
|
Jordan looked up, and for a fleeting heartbeat, their eyes locked with Taylor's,
|
|
a wordless clash of wills softening into an uneasy truce.
|
|
""",
|
|
},
|
|
}
|
|
|
|
|
|
# =============================================================================
|
|
# Data Classes
|
|
# =============================================================================
|
|
|
|
|
|
@dataclass
|
|
class ExtractionResult:
|
|
"""Result from a single extraction run."""
|
|
|
|
entities: list[dict]
|
|
relations: list[dict]
|
|
raw_output: str
|
|
input_tokens: int
|
|
format_errors: int
|
|
|
|
@property
|
|
def entity_count(self) -> int:
|
|
return len(self.entities)
|
|
|
|
@property
|
|
def relation_count(self) -> int:
|
|
return len(self.relations)
|
|
|
|
@property
|
|
def orphan_count(self) -> int:
|
|
"""Entities with no relationships."""
|
|
entity_names = {e['name'].lower() for e in self.entities}
|
|
connected = set()
|
|
for r in self.relations:
|
|
connected.add(r['source'].lower())
|
|
connected.add(r['target'].lower())
|
|
return len(entity_names - connected)
|
|
|
|
@property
|
|
def orphan_ratio(self) -> float:
|
|
if self.entity_count == 0:
|
|
return 0.0
|
|
return self.orphan_count / self.entity_count
|
|
|
|
|
|
@dataclass
|
|
class ComparisonResult:
|
|
"""Comparison between two extraction results."""
|
|
|
|
sample_name: str
|
|
original: ExtractionResult
|
|
optimized: ExtractionResult
|
|
|
|
def entity_diff_pct(self) -> float:
|
|
if self.original.entity_count == 0:
|
|
return 0.0
|
|
return (self.optimized.entity_count - self.original.entity_count) / self.original.entity_count * 100
|
|
|
|
def relation_diff_pct(self) -> float:
|
|
if self.original.relation_count == 0:
|
|
return 0.0
|
|
return (self.optimized.relation_count - self.original.relation_count) / self.original.relation_count * 100
|
|
|
|
def token_diff_pct(self) -> float:
|
|
if self.original.input_tokens == 0:
|
|
return 0.0
|
|
return (self.optimized.input_tokens - self.original.input_tokens) / self.original.input_tokens * 100
|
|
|
|
|
|
# =============================================================================
|
|
# Helper Functions
|
|
# =============================================================================
|
|
|
|
|
|
def format_prompt(
|
|
prompts: dict,
|
|
text: str,
|
|
entity_types: str = 'person, organization, location, concept, product, event, category, method',
|
|
) -> tuple[str, str]:
|
|
"""Format system and user prompts with the given text."""
|
|
tuple_delimiter = prompts['DEFAULT_TUPLE_DELIMITER']
|
|
completion_delimiter = prompts['DEFAULT_COMPLETION_DELIMITER']
|
|
|
|
# Format examples
|
|
examples = '\n'.join(prompts['entity_extraction_examples'])
|
|
examples = examples.format(
|
|
tuple_delimiter=tuple_delimiter,
|
|
completion_delimiter=completion_delimiter,
|
|
)
|
|
|
|
context = {
|
|
'tuple_delimiter': tuple_delimiter,
|
|
'completion_delimiter': completion_delimiter,
|
|
'entity_types': entity_types,
|
|
'language': 'English',
|
|
'examples': examples,
|
|
'input_text': text,
|
|
}
|
|
|
|
system_prompt = prompts['entity_extraction_system_prompt'].format(**context)
|
|
user_prompt = prompts['entity_extraction_user_prompt'].format(**context)
|
|
|
|
return system_prompt, user_prompt
|
|
|
|
|
|
def count_tokens(text: str) -> int:
|
|
"""Count tokens using tiktoken."""
|
|
enc = tiktoken.encoding_for_model('gpt-4')
|
|
return len(enc.encode(text))
|
|
|
|
|
|
async def call_llm(system_prompt: str, user_prompt: str, model: str = 'gpt-4o-mini') -> str:
|
|
"""Call OpenAI API with the given prompts."""
|
|
import openai
|
|
|
|
client = openai.AsyncOpenAI()
|
|
|
|
response = await client.chat.completions.create(
|
|
model=model,
|
|
messages=[
|
|
{'role': 'system', 'content': system_prompt},
|
|
{'role': 'user', 'content': user_prompt},
|
|
],
|
|
temperature=0.0,
|
|
)
|
|
|
|
if not response.choices or response.choices[0].message is None:
|
|
raise ValueError(f'Invalid response structure: {response}')
|
|
content = response.choices[0].message.content
|
|
if content is None:
|
|
raise ValueError(f'Empty content in response: {response}')
|
|
return content
|
|
|
|
|
|
def parse_extraction(output: str, tuple_delimiter: str = '<|#|>') -> tuple[list[dict], list[dict], int]:
|
|
"""Parse extraction output into entities and relations."""
|
|
entities = []
|
|
relations = []
|
|
format_errors = 0
|
|
|
|
for line in output.strip().split('\n'):
|
|
line = line.strip()
|
|
if not line or line.startswith('<|COMPLETE|>'):
|
|
continue
|
|
|
|
parts = line.split(tuple_delimiter)
|
|
|
|
if len(parts) >= 4 and parts[0].lower() == 'entity':
|
|
entities.append(
|
|
{
|
|
'name': parts[1].strip(),
|
|
'type': parts[2].strip(),
|
|
'description': parts[3].strip() if len(parts) > 3 else '',
|
|
}
|
|
)
|
|
elif len(parts) >= 5 and parts[0].lower() == 'relation':
|
|
relations.append(
|
|
{
|
|
'source': parts[1].strip(),
|
|
'target': parts[2].strip(),
|
|
'keywords': parts[3].strip(),
|
|
'description': parts[4].strip() if len(parts) > 4 else '',
|
|
}
|
|
)
|
|
elif line and not line.startswith('**') and tuple_delimiter in line:
|
|
# Line looks like it should be parsed but failed
|
|
format_errors += 1
|
|
|
|
return entities, relations, format_errors
|
|
|
|
|
|
async def run_extraction(prompts: dict, text: str) -> ExtractionResult:
|
|
"""Run extraction with the given prompts on the text."""
|
|
system_prompt, user_prompt = format_prompt(prompts, text)
|
|
input_tokens = count_tokens(system_prompt) + count_tokens(user_prompt)
|
|
|
|
output = await call_llm(system_prompt, user_prompt)
|
|
entities, relations, format_errors = parse_extraction(output)
|
|
|
|
return ExtractionResult(
|
|
entities=entities,
|
|
relations=relations,
|
|
raw_output=output,
|
|
input_tokens=input_tokens,
|
|
format_errors=format_errors,
|
|
)
|
|
|
|
|
|
def print_comparison_table(results: list[ComparisonResult]) -> None:
|
|
"""Print a formatted comparison table."""
|
|
print('\n' + '=' * 80)
|
|
print('ENTITY EXTRACTION PROMPT A/B COMPARISON')
|
|
print('=' * 80)
|
|
|
|
total_orig_entities = 0
|
|
total_opt_entities = 0
|
|
total_orig_relations = 0
|
|
total_opt_relations = 0
|
|
total_orig_tokens = 0
|
|
total_opt_tokens = 0
|
|
|
|
for r in results:
|
|
print(f'\n--- {r.sample_name} ---')
|
|
print(f'{"Metric":<20} {"Original":>12} {"Optimized":>12} {"Diff":>12}')
|
|
print('-' * 56)
|
|
|
|
print(
|
|
f'{"Entities":<20} {r.original.entity_count:>12} {r.optimized.entity_count:>12} {r.entity_diff_pct():>+11.0f}%'
|
|
)
|
|
print(
|
|
f'{"Relations":<20} {r.original.relation_count:>12} {r.optimized.relation_count:>12} {r.relation_diff_pct():>+11.0f}%'
|
|
)
|
|
print(f'{"Orphan Ratio":<20} {r.original.orphan_ratio:>11.0%} {r.optimized.orphan_ratio:>11.0%} {"":>12}')
|
|
print(f'{"Format Errors":<20} {r.original.format_errors:>12} {r.optimized.format_errors:>12}')
|
|
print(
|
|
f'{"Input Tokens":<20} {r.original.input_tokens:>12,} {r.optimized.input_tokens:>12,} {r.token_diff_pct():>+11.0f}%'
|
|
)
|
|
|
|
total_orig_entities += r.original.entity_count
|
|
total_opt_entities += r.optimized.entity_count
|
|
total_orig_relations += r.original.relation_count
|
|
total_opt_relations += r.optimized.relation_count
|
|
total_orig_tokens += r.original.input_tokens
|
|
total_opt_tokens += r.optimized.input_tokens
|
|
|
|
# Aggregate
|
|
print('\n' + '=' * 80)
|
|
print('AGGREGATE RESULTS')
|
|
print('=' * 80)
|
|
print(f'{"Metric":<20} {"Original":>12} {"Optimized":>12} {"Diff":>12}')
|
|
print('-' * 56)
|
|
|
|
ent_diff = (total_opt_entities - total_orig_entities) / total_orig_entities * 100 if total_orig_entities else 0
|
|
rel_diff = (total_opt_relations - total_orig_relations) / total_orig_relations * 100 if total_orig_relations else 0
|
|
tok_diff = (total_opt_tokens - total_orig_tokens) / total_orig_tokens * 100 if total_orig_tokens else 0
|
|
|
|
print(f'{"Total Entities":<20} {total_orig_entities:>12} {total_opt_entities:>12} {ent_diff:>+11.0f}%')
|
|
print(f'{"Total Relations":<20} {total_orig_relations:>12} {total_opt_relations:>12} {rel_diff:>+11.0f}%')
|
|
print(f'{"Total Input Tokens":<20} {total_orig_tokens:>12,} {total_opt_tokens:>12,} {tok_diff:>+11.0f}%')
|
|
|
|
# Recommendation
|
|
print('\n' + '-' * 56)
|
|
if tok_diff < -30 and ent_diff >= -10:
|
|
print('RECOMMENDATION: Use OPTIMIZED prompt (significant token savings, comparable extraction)')
|
|
elif ent_diff > 20 and tok_diff < 0:
|
|
print('RECOMMENDATION: Use OPTIMIZED prompt (better extraction AND token savings)')
|
|
elif ent_diff < -20:
|
|
print('RECOMMENDATION: Keep ORIGINAL prompt (optimized extracts significantly fewer entities)')
|
|
else:
|
|
print('RECOMMENDATION: Both prompts are comparable - consider token cost vs extraction breadth')
|
|
|
|
print('=' * 80 + '\n')
|
|
|
|
|
|
# =============================================================================
|
|
# Pytest Tests
|
|
# =============================================================================
|
|
|
|
|
|
class TestExtractionPromptAB:
|
|
"""A/B testing for entity extraction prompts."""
|
|
|
|
@pytest.mark.integration
|
|
@pytest.mark.asyncio
|
|
async def test_compare_all_samples(self) -> None:
|
|
"""Compare prompts across all sample texts."""
|
|
results = []
|
|
|
|
for _key, sample in SAMPLE_TEXTS.items():
|
|
print(f'\nProcessing: {sample["name"]}...')
|
|
|
|
original = await run_extraction(PROMPTS, sample['text'])
|
|
optimized = await run_extraction(PROMPTS_OPTIMIZED, sample['text'])
|
|
|
|
results.append(
|
|
ComparisonResult(
|
|
sample_name=sample['name'],
|
|
original=original,
|
|
optimized=optimized,
|
|
)
|
|
)
|
|
|
|
print_comparison_table(results)
|
|
|
|
# Basic assertions
|
|
for r in results:
|
|
assert r.original.format_errors == 0, f'Original had format errors on {r.sample_name}'
|
|
assert r.optimized.format_errors == 0, f'Optimized had format errors on {r.sample_name}'
|
|
|
|
@pytest.mark.integration
|
|
@pytest.mark.asyncio
|
|
async def test_single_sample(self) -> None:
|
|
"""Quick test with just one sample."""
|
|
sample = SAMPLE_TEXTS['covid_medical']
|
|
|
|
original = await run_extraction(PROMPTS, sample['text'])
|
|
optimized = await run_extraction(PROMPTS_OPTIMIZED, sample['text'])
|
|
|
|
result = ComparisonResult(
|
|
sample_name=sample['name'],
|
|
original=original,
|
|
optimized=optimized,
|
|
)
|
|
|
|
print_comparison_table([result])
|
|
|
|
assert original.entity_count > 0, 'Original should extract entities'
|
|
assert optimized.entity_count > 0, 'Optimized should extract entities'
|
|
assert optimized.format_errors == 0, 'Optimized should have no format errors'
|
|
|
|
|
|
# =============================================================================
|
|
# CLI Runner
|
|
# =============================================================================
|
|
|
|
|
|
async def main() -> None:
|
|
"""Run A/B comparison from command line."""
|
|
print('Starting Entity Extraction Prompt A/B Test...')
|
|
print('This will make real API calls to OpenAI.\n')
|
|
|
|
results = []
|
|
|
|
for _key, sample in SAMPLE_TEXTS.items():
|
|
print(f'Processing: {sample["name"]}...')
|
|
|
|
original = await run_extraction(PROMPTS, sample['text'])
|
|
optimized = await run_extraction(PROMPTS_OPTIMIZED, sample['text'])
|
|
|
|
results.append(
|
|
ComparisonResult(
|
|
sample_name=sample['name'],
|
|
original=original,
|
|
optimized=optimized,
|
|
)
|
|
)
|
|
|
|
print_comparison_table(results)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
asyncio.run(main())
|