Add extensive test suites for API routes and utilities: - Implement test_search_routes.py (406 lines) for search endpoint validation - Implement test_upload_routes.py (724 lines) for document upload workflows - Implement test_s3_client.py (618 lines) for S3 storage operations - Implement test_citation_utils.py (352 lines) for citation extraction - Implement test_chunking.py (216 lines) for text chunking validation Add S3 storage client implementation: - Create lightrag/storage/s3_client.py with S3 operations - Add storage module initialization with exports - Integrate S3 client with document upload handling Enhance API routes and core functionality: - Add search_routes.py with full-text and graph search endpoints - Add upload_routes.py with multipart document upload support - Update operate.py with bulk operations and health checks - Enhance postgres_impl.py with bulk upsert and parameterized queries - Update lightrag_server.py to register new API routes - Improve utils.py with citation and formatting utilities Update dependencies and configuration: - Add S3 and test dependencies to pyproject.toml - Update docker-compose.test.yml for testing environment - Sync uv.lock with new dependencies Apply code quality improvements across all modified files: - Add type hints to function signatures - Update imports and router initialization - Fix logging and error handling
432 lines
15 KiB
Python
432 lines
15 KiB
Python
"""
|
|
A/B Test for Entity Extraction Prompts
|
|
|
|
Compares original vs optimized extraction prompts using real LLM calls.
|
|
Run with: pytest tests/test_extraction_prompt_ab.py -v --run-integration
|
|
Or directly: python tests/test_extraction_prompt_ab.py
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import sys
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
import tiktoken
|
|
|
|
# Add project root to path for imports
|
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
|
|
from lightrag.prompt_optimized import PROMPTS_OPTIMIZED
|
|
|
|
from lightrag.prompt import PROMPTS
|
|
|
|
# =============================================================================
|
|
# Sample Texts for Testing
|
|
# =============================================================================
|
|
|
|
SAMPLE_TEXTS = {
|
|
"covid_medical": {
|
|
"name": "COVID-19 Medical",
|
|
"text": """
|
|
COVID-19, caused by the SARS-CoV-2 virus, emerged in Wuhan, China in late 2019.
|
|
The disease spreads primarily through respiratory droplets and can cause symptoms
|
|
ranging from mild fever and cough to severe pneumonia and acute respiratory distress
|
|
syndrome (ARDS). The World Health Organization declared it a pandemic on March 11, 2020.
|
|
|
|
Risk factors for severe disease include advanced age, obesity, and pre-existing
|
|
conditions such as diabetes and cardiovascular disease. Vaccines developed by Pfizer,
|
|
Moderna, and AstraZeneca have shown high efficacy in preventing severe illness.
|
|
""",
|
|
},
|
|
"financial_market": {
|
|
"name": "Financial Markets",
|
|
"text": """
|
|
Stock markets faced a sharp downturn today as tech giants saw significant declines,
|
|
with the global tech index dropping by 3.4% in midday trading. Analysts attribute
|
|
the selloff to investor concerns over rising interest rates and regulatory uncertainty.
|
|
|
|
Among the hardest hit, Nexon Technologies saw its stock plummet by 7.8% after
|
|
reporting lower-than-expected quarterly earnings. In contrast, Omega Energy posted
|
|
a modest 2.1% gain, driven by rising oil prices.
|
|
|
|
Meanwhile, commodity markets reflected a mixed sentiment. Gold futures rose by 1.5%,
|
|
reaching $2,080 per ounce, as investors sought safe-haven assets. Crude oil prices
|
|
continued their rally, climbing to $87.60 per barrel, supported by supply constraints.
|
|
|
|
The Federal Reserve's upcoming policy announcement is expected to influence investor
|
|
confidence and overall market stability.
|
|
""",
|
|
},
|
|
"legal_regulatory": {
|
|
"name": "Legal/Regulatory",
|
|
"text": """
|
|
The merger between Acme Corp and Beta Industries requires approval from the Federal
|
|
Trade Commission. Legal counsel advised that the deal may face antitrust scrutiny
|
|
due to market concentration concerns in the semiconductor industry.
|
|
|
|
The European Commission has also opened an investigation into the proposed acquisition,
|
|
citing potential impacts on competition in the EU market. Both companies have agreed
|
|
to divest certain assets to address regulatory concerns.
|
|
|
|
Industry analysts expect the approval process to take 12-18 months, with final
|
|
clearance dependent on remedies proposed by the merging parties.
|
|
""",
|
|
},
|
|
"narrative_fiction": {
|
|
"name": "Narrative Fiction",
|
|
"text": """
|
|
While Alex clenched his jaw, the buzz of frustration dull against the backdrop of
|
|
Taylor's authoritarian certainty. It was this competitive undercurrent that kept him
|
|
alert, the sense that his and Jordan's shared commitment to discovery was an unspoken
|
|
rebellion against Cruz's narrowing vision of control and order.
|
|
|
|
Then Taylor did something unexpected. They paused beside Jordan and, for a moment,
|
|
observed the device with something akin to reverence. "If this tech can be understood..."
|
|
Taylor said, their voice quieter, "It could change the game for us. For all of us."
|
|
|
|
Jordan looked up, and for a fleeting heartbeat, their eyes locked with Taylor's,
|
|
a wordless clash of wills softening into an uneasy truce.
|
|
""",
|
|
},
|
|
}
|
|
|
|
|
|
# =============================================================================
|
|
# Data Classes
|
|
# =============================================================================
|
|
|
|
|
|
@dataclass
|
|
class ExtractionResult:
|
|
"""Result from a single extraction run."""
|
|
|
|
entities: list[dict]
|
|
relations: list[dict]
|
|
raw_output: str
|
|
input_tokens: int
|
|
format_errors: int
|
|
|
|
@property
|
|
def entity_count(self) -> int:
|
|
return len(self.entities)
|
|
|
|
@property
|
|
def relation_count(self) -> int:
|
|
return len(self.relations)
|
|
|
|
@property
|
|
def orphan_count(self) -> int:
|
|
"""Entities with no relationships."""
|
|
entity_names = {e["name"].lower() for e in self.entities}
|
|
connected = set()
|
|
for r in self.relations:
|
|
connected.add(r["source"].lower())
|
|
connected.add(r["target"].lower())
|
|
return len(entity_names - connected)
|
|
|
|
@property
|
|
def orphan_ratio(self) -> float:
|
|
if self.entity_count == 0:
|
|
return 0.0
|
|
return self.orphan_count / self.entity_count
|
|
|
|
|
|
@dataclass
|
|
class ComparisonResult:
|
|
"""Comparison between two extraction results."""
|
|
|
|
sample_name: str
|
|
original: ExtractionResult
|
|
optimized: ExtractionResult
|
|
|
|
def entity_diff_pct(self) -> float:
|
|
if self.original.entity_count == 0:
|
|
return 0.0
|
|
return (
|
|
(self.optimized.entity_count - self.original.entity_count)
|
|
/ self.original.entity_count
|
|
* 100
|
|
)
|
|
|
|
def relation_diff_pct(self) -> float:
|
|
if self.original.relation_count == 0:
|
|
return 0.0
|
|
return (
|
|
(self.optimized.relation_count - self.original.relation_count)
|
|
/ self.original.relation_count
|
|
* 100
|
|
)
|
|
|
|
def token_diff_pct(self) -> float:
|
|
if self.original.input_tokens == 0:
|
|
return 0.0
|
|
return (
|
|
(self.optimized.input_tokens - self.original.input_tokens)
|
|
/ self.original.input_tokens
|
|
* 100
|
|
)
|
|
|
|
|
|
# =============================================================================
|
|
# Helper Functions
|
|
# =============================================================================
|
|
|
|
|
|
def format_prompt(
|
|
prompts: dict, text: str, entity_types: str = "person, organization, location, concept, product, event, category, method"
|
|
) -> tuple[str, str]:
|
|
"""Format system and user prompts with the given text."""
|
|
tuple_delimiter = prompts["DEFAULT_TUPLE_DELIMITER"]
|
|
completion_delimiter = prompts["DEFAULT_COMPLETION_DELIMITER"]
|
|
|
|
# Format examples
|
|
examples = "\n".join(prompts["entity_extraction_examples"])
|
|
examples = examples.format(
|
|
tuple_delimiter=tuple_delimiter,
|
|
completion_delimiter=completion_delimiter,
|
|
)
|
|
|
|
context = {
|
|
"tuple_delimiter": tuple_delimiter,
|
|
"completion_delimiter": completion_delimiter,
|
|
"entity_types": entity_types,
|
|
"language": "English",
|
|
"examples": examples,
|
|
"input_text": text,
|
|
}
|
|
|
|
system_prompt = prompts["entity_extraction_system_prompt"].format(**context)
|
|
user_prompt = prompts["entity_extraction_user_prompt"].format(**context)
|
|
|
|
return system_prompt, user_prompt
|
|
|
|
|
|
def count_tokens(text: str) -> int:
|
|
"""Count tokens using tiktoken."""
|
|
enc = tiktoken.encoding_for_model("gpt-4")
|
|
return len(enc.encode(text))
|
|
|
|
|
|
async def call_llm(system_prompt: str, user_prompt: str, model: str = "gpt-4o-mini") -> str:
|
|
"""Call OpenAI API with the given prompts."""
|
|
import openai
|
|
|
|
client = openai.AsyncOpenAI()
|
|
|
|
response = await client.chat.completions.create(
|
|
model=model,
|
|
messages=[
|
|
{"role": "system", "content": system_prompt},
|
|
{"role": "user", "content": user_prompt},
|
|
],
|
|
temperature=0.0,
|
|
)
|
|
|
|
return response.choices[0].message.content
|
|
|
|
|
|
def parse_extraction(output: str, tuple_delimiter: str = "<|#|>") -> tuple[list[dict], list[dict], int]:
|
|
"""Parse extraction output into entities and relations."""
|
|
entities = []
|
|
relations = []
|
|
format_errors = 0
|
|
|
|
for line in output.strip().split("\n"):
|
|
line = line.strip()
|
|
if not line or line.startswith("<|COMPLETE|>"):
|
|
continue
|
|
|
|
parts = line.split(tuple_delimiter)
|
|
|
|
if len(parts) >= 4 and parts[0].lower() == "entity":
|
|
entities.append(
|
|
{
|
|
"name": parts[1].strip(),
|
|
"type": parts[2].strip(),
|
|
"description": parts[3].strip() if len(parts) > 3 else "",
|
|
}
|
|
)
|
|
elif len(parts) >= 5 and parts[0].lower() == "relation":
|
|
relations.append(
|
|
{
|
|
"source": parts[1].strip(),
|
|
"target": parts[2].strip(),
|
|
"keywords": parts[3].strip(),
|
|
"description": parts[4].strip() if len(parts) > 4 else "",
|
|
}
|
|
)
|
|
elif line and not line.startswith("**") and tuple_delimiter in line:
|
|
# Line looks like it should be parsed but failed
|
|
format_errors += 1
|
|
|
|
return entities, relations, format_errors
|
|
|
|
|
|
async def run_extraction(prompts: dict, text: str) -> ExtractionResult:
|
|
"""Run extraction with the given prompts on the text."""
|
|
system_prompt, user_prompt = format_prompt(prompts, text)
|
|
input_tokens = count_tokens(system_prompt) + count_tokens(user_prompt)
|
|
|
|
output = await call_llm(system_prompt, user_prompt)
|
|
entities, relations, format_errors = parse_extraction(output)
|
|
|
|
return ExtractionResult(
|
|
entities=entities,
|
|
relations=relations,
|
|
raw_output=output,
|
|
input_tokens=input_tokens,
|
|
format_errors=format_errors,
|
|
)
|
|
|
|
|
|
def print_comparison_table(results: list[ComparisonResult]) -> None:
|
|
"""Print a formatted comparison table."""
|
|
print("\n" + "=" * 80)
|
|
print("ENTITY EXTRACTION PROMPT A/B COMPARISON")
|
|
print("=" * 80)
|
|
|
|
total_orig_entities = 0
|
|
total_opt_entities = 0
|
|
total_orig_relations = 0
|
|
total_opt_relations = 0
|
|
total_orig_tokens = 0
|
|
total_opt_tokens = 0
|
|
|
|
for r in results:
|
|
print(f"\n--- {r.sample_name} ---")
|
|
print(f"{'Metric':<20} {'Original':>12} {'Optimized':>12} {'Diff':>12}")
|
|
print("-" * 56)
|
|
|
|
print(f"{'Entities':<20} {r.original.entity_count:>12} {r.optimized.entity_count:>12} {r.entity_diff_pct():>+11.0f}%")
|
|
print(f"{'Relations':<20} {r.original.relation_count:>12} {r.optimized.relation_count:>12} {r.relation_diff_pct():>+11.0f}%")
|
|
print(f"{'Orphan Ratio':<20} {r.original.orphan_ratio:>11.0%} {r.optimized.orphan_ratio:>11.0%} {'':>12}")
|
|
print(f"{'Format Errors':<20} {r.original.format_errors:>12} {r.optimized.format_errors:>12}")
|
|
print(f"{'Input Tokens':<20} {r.original.input_tokens:>12,} {r.optimized.input_tokens:>12,} {r.token_diff_pct():>+11.0f}%")
|
|
|
|
total_orig_entities += r.original.entity_count
|
|
total_opt_entities += r.optimized.entity_count
|
|
total_orig_relations += r.original.relation_count
|
|
total_opt_relations += r.optimized.relation_count
|
|
total_orig_tokens += r.original.input_tokens
|
|
total_opt_tokens += r.optimized.input_tokens
|
|
|
|
# Aggregate
|
|
print("\n" + "=" * 80)
|
|
print("AGGREGATE RESULTS")
|
|
print("=" * 80)
|
|
print(f"{'Metric':<20} {'Original':>12} {'Optimized':>12} {'Diff':>12}")
|
|
print("-" * 56)
|
|
|
|
ent_diff = (total_opt_entities - total_orig_entities) / total_orig_entities * 100 if total_orig_entities else 0
|
|
rel_diff = (total_opt_relations - total_orig_relations) / total_orig_relations * 100 if total_orig_relations else 0
|
|
tok_diff = (total_opt_tokens - total_orig_tokens) / total_orig_tokens * 100 if total_orig_tokens else 0
|
|
|
|
print(f"{'Total Entities':<20} {total_orig_entities:>12} {total_opt_entities:>12} {ent_diff:>+11.0f}%")
|
|
print(f"{'Total Relations':<20} {total_orig_relations:>12} {total_opt_relations:>12} {rel_diff:>+11.0f}%")
|
|
print(f"{'Total Input Tokens':<20} {total_orig_tokens:>12,} {total_opt_tokens:>12,} {tok_diff:>+11.0f}%")
|
|
|
|
# Recommendation
|
|
print("\n" + "-" * 56)
|
|
if tok_diff < -30 and ent_diff >= -10:
|
|
print("RECOMMENDATION: Use OPTIMIZED prompt (significant token savings, comparable extraction)")
|
|
elif ent_diff > 20 and tok_diff < 0:
|
|
print("RECOMMENDATION: Use OPTIMIZED prompt (better extraction AND token savings)")
|
|
elif ent_diff < -20:
|
|
print("RECOMMENDATION: Keep ORIGINAL prompt (optimized extracts significantly fewer entities)")
|
|
else:
|
|
print("RECOMMENDATION: Both prompts are comparable - consider token cost vs extraction breadth")
|
|
|
|
print("=" * 80 + "\n")
|
|
|
|
|
|
# =============================================================================
|
|
# Pytest Tests
|
|
# =============================================================================
|
|
|
|
|
|
class TestExtractionPromptAB:
|
|
"""A/B testing for entity extraction prompts."""
|
|
|
|
@pytest.mark.integration
|
|
@pytest.mark.asyncio
|
|
async def test_compare_all_samples(self) -> None:
|
|
"""Compare prompts across all sample texts."""
|
|
results = []
|
|
|
|
for _key, sample in SAMPLE_TEXTS.items():
|
|
print(f"\nProcessing: {sample['name']}...")
|
|
|
|
original = await run_extraction(PROMPTS, sample["text"])
|
|
optimized = await run_extraction(PROMPTS_OPTIMIZED, sample["text"])
|
|
|
|
results.append(
|
|
ComparisonResult(
|
|
sample_name=sample["name"],
|
|
original=original,
|
|
optimized=optimized,
|
|
)
|
|
)
|
|
|
|
print_comparison_table(results)
|
|
|
|
# Basic assertions
|
|
for r in results:
|
|
assert r.original.format_errors == 0, f"Original had format errors on {r.sample_name}"
|
|
assert r.optimized.format_errors == 0, f"Optimized had format errors on {r.sample_name}"
|
|
|
|
@pytest.mark.integration
|
|
@pytest.mark.asyncio
|
|
async def test_single_sample(self) -> None:
|
|
"""Quick test with just one sample."""
|
|
sample = SAMPLE_TEXTS["covid_medical"]
|
|
|
|
original = await run_extraction(PROMPTS, sample["text"])
|
|
optimized = await run_extraction(PROMPTS_OPTIMIZED, sample["text"])
|
|
|
|
result = ComparisonResult(
|
|
sample_name=sample["name"],
|
|
original=original,
|
|
optimized=optimized,
|
|
)
|
|
|
|
print_comparison_table([result])
|
|
|
|
assert original.entity_count > 0, "Original should extract entities"
|
|
assert optimized.entity_count > 0, "Optimized should extract entities"
|
|
assert optimized.format_errors == 0, "Optimized should have no format errors"
|
|
|
|
|
|
# =============================================================================
|
|
# CLI Runner
|
|
# =============================================================================
|
|
|
|
|
|
async def main() -> None:
|
|
"""Run A/B comparison from command line."""
|
|
print("Starting Entity Extraction Prompt A/B Test...")
|
|
print("This will make real API calls to OpenAI.\n")
|
|
|
|
results = []
|
|
|
|
for _key, sample in SAMPLE_TEXTS.items():
|
|
print(f"Processing: {sample['name']}...")
|
|
|
|
original = await run_extraction(PROMPTS, sample["text"])
|
|
optimized = await run_extraction(PROMPTS_OPTIMIZED, sample["text"])
|
|
|
|
results.append(
|
|
ComparisonResult(
|
|
sample_name=sample["name"],
|
|
original=original,
|
|
optimized=optimized,
|
|
)
|
|
)
|
|
|
|
print_comparison_table(results)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main())
|