LightRAG/lightrag/evaluation/optimize_prompt_dspy.py
clssck 59e89772de refactor: consolidate to PostgreSQL-only backend and modernize stack
Remove legacy storage implementations and deprecated examples:
- Delete FAISS, JSON, Memgraph, Milvus, MongoDB, Nano Vector DB, Neo4j, NetworkX, Qdrant, Redis storage backends
- Remove Kubernetes deployment manifests and installation scripts
- Delete unofficial examples for deprecated backends and offline deployment docs
Streamline core infrastructure:
- Consolidate storage layer to PostgreSQL-only implementation
- Add full-text search caching with FTS cache module
- Implement metrics collection and monitoring pipeline
- Add explain and metrics API routes
Modernize frontend and tooling:
- Switch web UI to Bun with bun.lock, remove npm and pnpm lockfiles
- Update Dockerfile for PostgreSQL-only deployment
- Add Makefile for common development tasks
- Update environment and configuration examples
Enhance evaluation and testing capabilities:
- Add prompt optimization with DSPy and auto-tuning
- Implement ground truth regeneration and variant testing
- Add prompt debugging and response comparison utilities
- Expand test coverage with new integration scenarios
Simplify dependencies and configuration:
- Remove offline-specific requirement files
- Update pyproject.toml with streamlined dependencies
- Add Python version pinning with .python-version
- Create project guidelines in CLAUDE.md and AGENTS.md
2025-12-12 16:28:49 +01:00

567 lines
18 KiB
Python

#!/usr/bin/env python3
"""
DSPy-based Prompt Optimization for LightRAG.
This script automatically optimizes RAG response prompts using:
1. DSPy's MIPROv2 optimizer for instruction tuning
2. RAGAS metrics (Faithfulness + Answer Relevancy) as objectives
3. Your pharma test dataset for training/validation
The optimized prompt can then be copied to lightrag/prompt.py.
Usage:
# Quick optimization (3 queries, light mode)
python optimize_prompt_dspy.py --mode light --num-queries 3
# Full optimization (all queries, medium mode)
python optimize_prompt_dspy.py --mode medium
# Export optimized prompt
python optimize_prompt_dspy.py --export-prompt optimized_prompt.txt
"""
import argparse
import asyncio
import json
import os
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Literal
import dspy
from dotenv import load_dotenv
# Load environment
load_dotenv()
# Ensure we can import lightrag
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
# ============================================================================
# DSPy Signatures
# ============================================================================
class RAGResponse(dspy.Signature):
"""Generate a grounded answer based on retrieved context.
The answer must be faithful to the context (no hallucination) and
directly relevant to the question asked.
"""
context: str = dspy.InputField(desc="Retrieved knowledge graph context with entities, relations, and source excerpts")
question: str = dspy.InputField(desc="User's question to answer")
answer: str = dspy.OutputField(desc="Grounded answer based only on the context provided")
class RAGResponseWithReasoning(dspy.Signature):
"""Generate a grounded answer with explicit reasoning about context usage.
First identify relevant facts from context, then synthesize the answer.
"""
context: str = dspy.InputField(desc="Retrieved knowledge graph context")
question: str = dspy.InputField(desc="User's question")
relevant_facts: str = dspy.OutputField(desc="Key facts from context relevant to the question")
answer: str = dspy.OutputField(desc="Answer synthesized from the relevant facts only")
# ============================================================================
# DSPy Modules
# ============================================================================
class SimpleRAG(dspy.Module):
"""Simple RAG response module."""
def __init__(self):
super().__init__()
self.respond = dspy.Predict(RAGResponse)
def forward(self, context: str, question: str) -> dspy.Prediction:
return self.respond(context=context, question=question)
class ChainOfThoughtRAG(dspy.Module):
"""RAG with chain-of-thought reasoning."""
def __init__(self):
super().__init__()
self.respond = dspy.ChainOfThought(RAGResponse)
def forward(self, context: str, question: str) -> dspy.Prediction:
return self.respond(context=context, question=question)
class ReasoningRAG(dspy.Module):
"""RAG that first extracts relevant facts then answers."""
def __init__(self):
super().__init__()
self.respond = dspy.ChainOfThought(RAGResponseWithReasoning)
def forward(self, context: str, question: str) -> dspy.Prediction:
return self.respond(context=context, question=question)
# ============================================================================
# RAGAS-based Metric
# ============================================================================
def ragas_metric(example: dspy.Example, pred: dspy.Prediction, trace=None) -> float:
"""
Evaluate prediction using RAGAS metrics.
Returns combined score of faithfulness and answer relevancy.
"""
from datasets import Dataset
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from ragas import evaluate
from ragas.metrics import answer_relevancy, faithfulness
question = example.question
context = example.context
ground_truth = getattr(example, 'ground_truth', '')
answer = pred.answer
# Skip if answer is empty
if not answer or not answer.strip():
return 0.0
# Create RAGAS dataset
data = {
'question': [question],
'answer': [answer],
'contexts': [[context]],
'ground_truth': [ground_truth] if ground_truth else [answer],
}
dataset = Dataset.from_dict(data)
# Run evaluation
try:
llm = ChatOpenAI(model='gpt-4o-mini', temperature=0)
embeddings = OpenAIEmbeddings(model='text-embedding-3-large')
result = evaluate(
dataset,
metrics=[faithfulness, answer_relevancy],
llm=llm,
embeddings=embeddings,
)
# Extract scores (they come as lists for single samples)
faith = result['faithfulness']
relevance = result['answer_relevancy']
if isinstance(faith, list):
faith = faith[0]
if isinstance(relevance, list):
relevance = relevance[0]
# Combined score (weight faithfulness slightly higher)
score = 0.6 * float(faith) + 0.4 * float(relevance)
if trace is not None:
print(f" RAGAS: faith={faith:.3f} rel={relevance:.3f} -> {score:.3f}")
return score
except Exception as e:
print(f" RAGAS eval failed: {e}")
return 0.0
def fast_metric(example: dspy.Example, pred: dspy.Prediction, trace=None) -> float:
"""
Fast heuristic metric for initial bootstrapping.
Checks basic quality signals without calling RAGAS.
"""
answer = pred.answer if hasattr(pred, 'answer') else ''
context = example.context
question = example.question
if not answer or not answer.strip():
return 0.0
score = 0.0
# Length check (not too short, not too long)
words = len(answer.split())
if 20 <= words <= 500:
score += 0.2
elif 10 <= words <= 1000:
score += 0.1
# Contains key terms from question
q_words = set(question.lower().split())
a_words = set(answer.lower().split())
overlap = len(q_words & a_words) / max(len(q_words), 1)
score += 0.3 * min(overlap * 2, 1.0)
# References context (rough check)
context_words = set(context.lower().split()[:200]) # First 200 words
context_overlap = len(context_words & a_words) / max(len(a_words), 1)
score += 0.5 * min(context_overlap * 3, 1.0)
return min(score, 1.0)
# ============================================================================
# Data Loading
# ============================================================================
async def get_context_from_server(query: str, server_url: str = 'http://localhost:9621') -> str:
"""Get context from running LightRAG server."""
import httpx
async with httpx.AsyncClient(timeout=60) as client:
response = await client.post(
f'{server_url}/query',
json={
'query': query,
'mode': 'mix',
'only_need_context': True,
}
)
data = response.json()
return data.get('response', '')
def load_dataset(dataset_path: Path, num_queries: int | None = None) -> list[dict]:
"""Load test dataset."""
with open(dataset_path) as f:
data = json.load(f)
if isinstance(data, dict) and 'test_cases' in data:
dataset = data['test_cases']
else:
dataset = data
if num_queries:
dataset = dataset[:num_queries]
return dataset
async def prepare_dspy_examples(
dataset: list[dict],
server_url: str = 'http://localhost:9621',
) -> list[dspy.Example]:
"""Convert dataset to DSPy examples with context."""
examples = []
print(f"Fetching context for {len(dataset)} queries...")
for i, item in enumerate(dataset):
question = item['question']
ground_truth = item.get('ground_truth', '')
print(f" [{i+1}/{len(dataset)}] {question[:50]}...")
context = await get_context_from_server(question, server_url)
if not context or context == 'No relevant context found for the query.':
print(f" Skipping - no context")
continue
example = dspy.Example(
context=context,
question=question,
ground_truth=ground_truth,
).with_inputs('context', 'question')
examples.append(example)
print(f"Prepared {len(examples)} examples")
return examples
# ============================================================================
# Optimization
# ============================================================================
def optimize_with_mipro(
module: dspy.Module,
trainset: list[dspy.Example],
metric,
mode: Literal['light', 'medium', 'heavy'] = 'light',
) -> dspy.Module:
"""Optimize module using MIPROv2."""
from dspy.teleprompt import MIPROv2
print(f"\nRunning MIPROv2 optimization (mode={mode})...")
optimizer = MIPROv2(
metric=metric,
auto=mode,
num_threads=2, # Conservative for API rate limits
)
optimized = optimizer.compile(
module,
trainset=trainset,
requires_permission_to_run=False,
)
return optimized
def optimize_with_bootstrap(
module: dspy.Module,
trainset: list[dspy.Example],
metric,
max_demos: int = 4,
) -> dspy.Module:
"""Optimize module using BootstrapFewShot."""
from dspy.teleprompt import BootstrapFewShot
print(f"\nRunning BootstrapFewShot optimization...")
optimizer = BootstrapFewShot(
metric=metric,
max_bootstrapped_demos=max_demos,
max_labeled_demos=max_demos,
max_rounds=2,
)
optimized = optimizer.compile(module, trainset=trainset)
return optimized
# ============================================================================
# Evaluation
# ============================================================================
def evaluate_module(
module: dspy.Module,
testset: list[dspy.Example],
metric,
name: str = "Module",
) -> float:
"""Evaluate a module on test set."""
from dspy.evaluate import Evaluate
print(f"\nEvaluating {name}...")
evaluator = Evaluate(
devset=testset,
num_threads=2,
display_progress=True,
display_table=5,
)
result = evaluator(module, metric=metric)
# Extract score from result (could be float or EvaluationResult object)
if hasattr(result, 'score'):
score = float(result.score)
elif isinstance(result, (int, float)):
score = float(result)
else:
# Try to extract from string representation
score = float(str(result).split('%')[0].split()[-1]) / 100 if '%' in str(result) else 0.0
print(f"{name} score: {score:.3f}")
return score
# ============================================================================
# Prompt Export
# ============================================================================
def extract_optimized_prompt(module: dspy.Module) -> str:
"""Extract the optimized prompt template from a DSPy module."""
# Get the predictor
predictor = None
for name, child in module.named_predictors():
predictor = child
break
if predictor is None:
return "Could not extract prompt - no predictor found"
# Build prompt representation
prompt_parts = []
# Add signature docstring if available
if hasattr(predictor, 'signature'):
sig = predictor.signature
if sig.__doc__:
prompt_parts.append(f"# Task Description\n{sig.__doc__}\n")
# Add input/output field descriptions
prompt_parts.append("# Input Fields")
for field_name, field in sig.input_fields.items():
desc = getattr(field, 'desc', '') or ''
prompt_parts.append(f"- {field_name}: {desc}")
prompt_parts.append("\n# Output Fields")
for field_name, field in sig.output_fields.items():
desc = getattr(field, 'desc', '') or ''
prompt_parts.append(f"- {field_name}: {desc}")
# Add any demos/examples
if hasattr(predictor, 'demos') and predictor.demos:
prompt_parts.append("\n# Few-Shot Examples")
for i, demo in enumerate(predictor.demos):
prompt_parts.append(f"\n## Example {i+1}")
for key, value in demo.items():
if isinstance(value, str) and len(value) > 200:
value = value[:200] + "..."
prompt_parts.append(f"{key}: {value}")
# Add extended signature if available (from MIPRO)
if hasattr(predictor, 'extended_signature'):
prompt_parts.append("\n# Optimized Instructions")
prompt_parts.append(str(predictor.extended_signature))
return "\n".join(prompt_parts)
def format_as_lightrag_prompt(optimized_prompt: str) -> str:
"""Format the optimized prompt for use in lightrag/prompt.py."""
template = '''"""
RAG Response Prompt - Optimized by DSPy
========================================
Copy this to lightrag/prompt.py as PROMPTS['rag_response']
"""
PROMPTS['rag_response'] = """Answer based ONLY on this context:
{{context_data}}
---
{optimized_instructions}
Question: {{user_prompt}}
Answer (grounded only in context above):"""
'''
return template.format(optimized_instructions=optimized_prompt)
# ============================================================================
# Main
# ============================================================================
async def main():
parser = argparse.ArgumentParser(description='DSPy Prompt Optimization for LightRAG')
parser.add_argument('--num-queries', '-n', type=int, help='Number of queries to use')
parser.add_argument('--server', '-s', type=str, default='http://localhost:9621', help='LightRAG server URL')
parser.add_argument('--mode', '-m', choices=['light', 'medium', 'heavy'], default='light',
help='Optimization intensity')
parser.add_argument('--optimizer', '-o', choices=['mipro', 'bootstrap', 'both'], default='bootstrap',
help='Optimizer to use')
parser.add_argument('--module', choices=['simple', 'cot', 'reasoning'], default='cot',
help='DSPy module architecture')
parser.add_argument('--fast-metric', action='store_true', help='Use fast heuristic metric instead of RAGAS')
parser.add_argument('--export-prompt', type=str, help='Export optimized prompt to file')
parser.add_argument('--save-module', type=str, help='Save optimized module to JSON')
args = parser.parse_args()
# Configure DSPy
lm = dspy.LM(
model=os.getenv('LLM_MODEL', 'openai/gpt-4o-mini'),
api_key=os.getenv('LLM_BINDING_API_KEY') or os.getenv('OPENAI_API_KEY'),
api_base=os.getenv('LLM_BINDING_HOST', 'https://api.openai.com/v1'),
temperature=0.1,
max_tokens=32000, # OpenRouter models support high limits
)
dspy.configure(lm=lm)
print("=" * 70)
print("DSPy Prompt Optimization for LightRAG")
print("=" * 70)
print(f"LLM: {os.getenv('LLM_MODEL', 'gpt-4o-mini')}")
print(f"Optimizer: {args.optimizer}")
print(f"Module: {args.module}")
print(f"Mode: {args.mode}")
print("=" * 70)
# Load dataset
dataset_path = Path(__file__).parent / 'pharma_test_dataset.json'
dataset = load_dataset(dataset_path, args.num_queries)
# Prepare examples
examples = await prepare_dspy_examples(dataset, args.server)
if len(examples) < 2:
print("Error: Need at least 2 examples for optimization")
return
# Split into train/test
split_idx = max(1, len(examples) - 2)
trainset = examples[:split_idx]
testset = examples[split_idx:]
print(f"\nTrain set: {len(trainset)} examples")
print(f"Test set: {len(testset)} examples")
# Select module
if args.module == 'simple':
module = SimpleRAG()
elif args.module == 'cot':
module = ChainOfThoughtRAG()
else:
module = ReasoningRAG()
# Select metric
metric = fast_metric if args.fast_metric else ragas_metric
# Evaluate baseline
print("\n" + "=" * 70)
print("BASELINE EVALUATION")
print("=" * 70)
baseline_score = evaluate_module(module, testset, metric, "Baseline")
# Optimize
print("\n" + "=" * 70)
print("OPTIMIZATION")
print("=" * 70)
optimized_module = None
if args.optimizer in ['bootstrap', 'both']:
optimized_module = optimize_with_bootstrap(module, trainset, metric)
bootstrap_score = evaluate_module(optimized_module, testset, metric, "Bootstrap")
if args.optimizer in ['mipro', 'both']:
base_for_mipro = optimized_module if optimized_module else module
optimized_module = optimize_with_mipro(base_for_mipro, trainset, metric, args.mode)
mipro_score = evaluate_module(optimized_module, testset, metric, "MIPRO")
# Extract and display optimized prompt
print("\n" + "=" * 70)
print("OPTIMIZED PROMPT")
print("=" * 70)
optimized_prompt = extract_optimized_prompt(optimized_module)
print(optimized_prompt)
# Export if requested
if args.export_prompt:
formatted = format_as_lightrag_prompt(optimized_prompt)
with open(args.export_prompt, 'w') as f:
f.write(formatted)
print(f"\nExported to: {args.export_prompt}")
if args.save_module:
optimized_module.save(args.save_module)
print(f"Saved module to: {args.save_module}")
print("\n" + "=" * 70)
print("SUMMARY")
print("=" * 70)
print(f"Baseline score: {baseline_score:.3f}")
if args.optimizer in ['bootstrap', 'both']:
print(f"Bootstrap score: {bootstrap_score:.3f}")
if args.optimizer in ['mipro', 'both']:
print(f"MIPRO score: {mipro_score:.3f}")
if __name__ == '__main__':
asyncio.run(main())