Remove legacy storage implementations and deprecated examples: - Delete FAISS, JSON, Memgraph, Milvus, MongoDB, Nano Vector DB, Neo4j, NetworkX, Qdrant, Redis storage backends - Remove Kubernetes deployment manifests and installation scripts - Delete unofficial examples for deprecated backends and offline deployment docs Streamline core infrastructure: - Consolidate storage layer to PostgreSQL-only implementation - Add full-text search caching with FTS cache module - Implement metrics collection and monitoring pipeline - Add explain and metrics API routes Modernize frontend and tooling: - Switch web UI to Bun with bun.lock, remove npm and pnpm lockfiles - Update Dockerfile for PostgreSQL-only deployment - Add Makefile for common development tasks - Update environment and configuration examples Enhance evaluation and testing capabilities: - Add prompt optimization with DSPy and auto-tuning - Implement ground truth regeneration and variant testing - Add prompt debugging and response comparison utilities - Expand test coverage with new integration scenarios Simplify dependencies and configuration: - Remove offline-specific requirement files - Update pyproject.toml with streamlined dependencies - Add Python version pinning with .python-version - Create project guidelines in CLAUDE.md and AGENTS.md
567 lines
18 KiB
Python
567 lines
18 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
DSPy-based Prompt Optimization for LightRAG.
|
|
|
|
This script automatically optimizes RAG response prompts using:
|
|
1. DSPy's MIPROv2 optimizer for instruction tuning
|
|
2. RAGAS metrics (Faithfulness + Answer Relevancy) as objectives
|
|
3. Your pharma test dataset for training/validation
|
|
|
|
The optimized prompt can then be copied to lightrag/prompt.py.
|
|
|
|
Usage:
|
|
# Quick optimization (3 queries, light mode)
|
|
python optimize_prompt_dspy.py --mode light --num-queries 3
|
|
|
|
# Full optimization (all queries, medium mode)
|
|
python optimize_prompt_dspy.py --mode medium
|
|
|
|
# Export optimized prompt
|
|
python optimize_prompt_dspy.py --export-prompt optimized_prompt.txt
|
|
"""
|
|
|
|
import argparse
|
|
import asyncio
|
|
import json
|
|
import os
|
|
import sys
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import Literal
|
|
|
|
import dspy
|
|
from dotenv import load_dotenv
|
|
|
|
# Load environment
|
|
load_dotenv()
|
|
|
|
# Ensure we can import lightrag
|
|
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
|
|
|
|
|
# ============================================================================
|
|
# DSPy Signatures
|
|
# ============================================================================
|
|
|
|
class RAGResponse(dspy.Signature):
|
|
"""Generate a grounded answer based on retrieved context.
|
|
|
|
The answer must be faithful to the context (no hallucination) and
|
|
directly relevant to the question asked.
|
|
"""
|
|
context: str = dspy.InputField(desc="Retrieved knowledge graph context with entities, relations, and source excerpts")
|
|
question: str = dspy.InputField(desc="User's question to answer")
|
|
answer: str = dspy.OutputField(desc="Grounded answer based only on the context provided")
|
|
|
|
|
|
class RAGResponseWithReasoning(dspy.Signature):
|
|
"""Generate a grounded answer with explicit reasoning about context usage.
|
|
|
|
First identify relevant facts from context, then synthesize the answer.
|
|
"""
|
|
context: str = dspy.InputField(desc="Retrieved knowledge graph context")
|
|
question: str = dspy.InputField(desc="User's question")
|
|
relevant_facts: str = dspy.OutputField(desc="Key facts from context relevant to the question")
|
|
answer: str = dspy.OutputField(desc="Answer synthesized from the relevant facts only")
|
|
|
|
|
|
# ============================================================================
|
|
# DSPy Modules
|
|
# ============================================================================
|
|
|
|
class SimpleRAG(dspy.Module):
|
|
"""Simple RAG response module."""
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.respond = dspy.Predict(RAGResponse)
|
|
|
|
def forward(self, context: str, question: str) -> dspy.Prediction:
|
|
return self.respond(context=context, question=question)
|
|
|
|
|
|
class ChainOfThoughtRAG(dspy.Module):
|
|
"""RAG with chain-of-thought reasoning."""
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.respond = dspy.ChainOfThought(RAGResponse)
|
|
|
|
def forward(self, context: str, question: str) -> dspy.Prediction:
|
|
return self.respond(context=context, question=question)
|
|
|
|
|
|
class ReasoningRAG(dspy.Module):
|
|
"""RAG that first extracts relevant facts then answers."""
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.respond = dspy.ChainOfThought(RAGResponseWithReasoning)
|
|
|
|
def forward(self, context: str, question: str) -> dspy.Prediction:
|
|
return self.respond(context=context, question=question)
|
|
|
|
|
|
# ============================================================================
|
|
# RAGAS-based Metric
|
|
# ============================================================================
|
|
|
|
def ragas_metric(example: dspy.Example, pred: dspy.Prediction, trace=None) -> float:
|
|
"""
|
|
Evaluate prediction using RAGAS metrics.
|
|
|
|
Returns combined score of faithfulness and answer relevancy.
|
|
"""
|
|
from datasets import Dataset
|
|
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
|
|
from ragas import evaluate
|
|
from ragas.metrics import answer_relevancy, faithfulness
|
|
|
|
question = example.question
|
|
context = example.context
|
|
ground_truth = getattr(example, 'ground_truth', '')
|
|
answer = pred.answer
|
|
|
|
# Skip if answer is empty
|
|
if not answer or not answer.strip():
|
|
return 0.0
|
|
|
|
# Create RAGAS dataset
|
|
data = {
|
|
'question': [question],
|
|
'answer': [answer],
|
|
'contexts': [[context]],
|
|
'ground_truth': [ground_truth] if ground_truth else [answer],
|
|
}
|
|
dataset = Dataset.from_dict(data)
|
|
|
|
# Run evaluation
|
|
try:
|
|
llm = ChatOpenAI(model='gpt-4o-mini', temperature=0)
|
|
embeddings = OpenAIEmbeddings(model='text-embedding-3-large')
|
|
|
|
result = evaluate(
|
|
dataset,
|
|
metrics=[faithfulness, answer_relevancy],
|
|
llm=llm,
|
|
embeddings=embeddings,
|
|
)
|
|
|
|
# Extract scores (they come as lists for single samples)
|
|
faith = result['faithfulness']
|
|
relevance = result['answer_relevancy']
|
|
|
|
if isinstance(faith, list):
|
|
faith = faith[0]
|
|
if isinstance(relevance, list):
|
|
relevance = relevance[0]
|
|
|
|
# Combined score (weight faithfulness slightly higher)
|
|
score = 0.6 * float(faith) + 0.4 * float(relevance)
|
|
|
|
if trace is not None:
|
|
print(f" RAGAS: faith={faith:.3f} rel={relevance:.3f} -> {score:.3f}")
|
|
|
|
return score
|
|
|
|
except Exception as e:
|
|
print(f" RAGAS eval failed: {e}")
|
|
return 0.0
|
|
|
|
|
|
def fast_metric(example: dspy.Example, pred: dspy.Prediction, trace=None) -> float:
|
|
"""
|
|
Fast heuristic metric for initial bootstrapping.
|
|
|
|
Checks basic quality signals without calling RAGAS.
|
|
"""
|
|
answer = pred.answer if hasattr(pred, 'answer') else ''
|
|
context = example.context
|
|
question = example.question
|
|
|
|
if not answer or not answer.strip():
|
|
return 0.0
|
|
|
|
score = 0.0
|
|
|
|
# Length check (not too short, not too long)
|
|
words = len(answer.split())
|
|
if 20 <= words <= 500:
|
|
score += 0.2
|
|
elif 10 <= words <= 1000:
|
|
score += 0.1
|
|
|
|
# Contains key terms from question
|
|
q_words = set(question.lower().split())
|
|
a_words = set(answer.lower().split())
|
|
overlap = len(q_words & a_words) / max(len(q_words), 1)
|
|
score += 0.3 * min(overlap * 2, 1.0)
|
|
|
|
# References context (rough check)
|
|
context_words = set(context.lower().split()[:200]) # First 200 words
|
|
context_overlap = len(context_words & a_words) / max(len(a_words), 1)
|
|
score += 0.5 * min(context_overlap * 3, 1.0)
|
|
|
|
return min(score, 1.0)
|
|
|
|
|
|
# ============================================================================
|
|
# Data Loading
|
|
# ============================================================================
|
|
|
|
async def get_context_from_server(query: str, server_url: str = 'http://localhost:9621') -> str:
|
|
"""Get context from running LightRAG server."""
|
|
import httpx
|
|
|
|
async with httpx.AsyncClient(timeout=60) as client:
|
|
response = await client.post(
|
|
f'{server_url}/query',
|
|
json={
|
|
'query': query,
|
|
'mode': 'mix',
|
|
'only_need_context': True,
|
|
}
|
|
)
|
|
data = response.json()
|
|
return data.get('response', '')
|
|
|
|
|
|
def load_dataset(dataset_path: Path, num_queries: int | None = None) -> list[dict]:
|
|
"""Load test dataset."""
|
|
with open(dataset_path) as f:
|
|
data = json.load(f)
|
|
|
|
if isinstance(data, dict) and 'test_cases' in data:
|
|
dataset = data['test_cases']
|
|
else:
|
|
dataset = data
|
|
|
|
if num_queries:
|
|
dataset = dataset[:num_queries]
|
|
|
|
return dataset
|
|
|
|
|
|
async def prepare_dspy_examples(
|
|
dataset: list[dict],
|
|
server_url: str = 'http://localhost:9621',
|
|
) -> list[dspy.Example]:
|
|
"""Convert dataset to DSPy examples with context."""
|
|
examples = []
|
|
|
|
print(f"Fetching context for {len(dataset)} queries...")
|
|
|
|
for i, item in enumerate(dataset):
|
|
question = item['question']
|
|
ground_truth = item.get('ground_truth', '')
|
|
|
|
print(f" [{i+1}/{len(dataset)}] {question[:50]}...")
|
|
|
|
context = await get_context_from_server(question, server_url)
|
|
|
|
if not context or context == 'No relevant context found for the query.':
|
|
print(f" Skipping - no context")
|
|
continue
|
|
|
|
example = dspy.Example(
|
|
context=context,
|
|
question=question,
|
|
ground_truth=ground_truth,
|
|
).with_inputs('context', 'question')
|
|
|
|
examples.append(example)
|
|
|
|
print(f"Prepared {len(examples)} examples")
|
|
return examples
|
|
|
|
|
|
# ============================================================================
|
|
# Optimization
|
|
# ============================================================================
|
|
|
|
def optimize_with_mipro(
|
|
module: dspy.Module,
|
|
trainset: list[dspy.Example],
|
|
metric,
|
|
mode: Literal['light', 'medium', 'heavy'] = 'light',
|
|
) -> dspy.Module:
|
|
"""Optimize module using MIPROv2."""
|
|
from dspy.teleprompt import MIPROv2
|
|
|
|
print(f"\nRunning MIPROv2 optimization (mode={mode})...")
|
|
|
|
optimizer = MIPROv2(
|
|
metric=metric,
|
|
auto=mode,
|
|
num_threads=2, # Conservative for API rate limits
|
|
)
|
|
|
|
optimized = optimizer.compile(
|
|
module,
|
|
trainset=trainset,
|
|
requires_permission_to_run=False,
|
|
)
|
|
|
|
return optimized
|
|
|
|
|
|
def optimize_with_bootstrap(
|
|
module: dspy.Module,
|
|
trainset: list[dspy.Example],
|
|
metric,
|
|
max_demos: int = 4,
|
|
) -> dspy.Module:
|
|
"""Optimize module using BootstrapFewShot."""
|
|
from dspy.teleprompt import BootstrapFewShot
|
|
|
|
print(f"\nRunning BootstrapFewShot optimization...")
|
|
|
|
optimizer = BootstrapFewShot(
|
|
metric=metric,
|
|
max_bootstrapped_demos=max_demos,
|
|
max_labeled_demos=max_demos,
|
|
max_rounds=2,
|
|
)
|
|
|
|
optimized = optimizer.compile(module, trainset=trainset)
|
|
|
|
return optimized
|
|
|
|
|
|
# ============================================================================
|
|
# Evaluation
|
|
# ============================================================================
|
|
|
|
def evaluate_module(
|
|
module: dspy.Module,
|
|
testset: list[dspy.Example],
|
|
metric,
|
|
name: str = "Module",
|
|
) -> float:
|
|
"""Evaluate a module on test set."""
|
|
from dspy.evaluate import Evaluate
|
|
|
|
print(f"\nEvaluating {name}...")
|
|
|
|
evaluator = Evaluate(
|
|
devset=testset,
|
|
num_threads=2,
|
|
display_progress=True,
|
|
display_table=5,
|
|
)
|
|
|
|
result = evaluator(module, metric=metric)
|
|
|
|
# Extract score from result (could be float or EvaluationResult object)
|
|
if hasattr(result, 'score'):
|
|
score = float(result.score)
|
|
elif isinstance(result, (int, float)):
|
|
score = float(result)
|
|
else:
|
|
# Try to extract from string representation
|
|
score = float(str(result).split('%')[0].split()[-1]) / 100 if '%' in str(result) else 0.0
|
|
|
|
print(f"{name} score: {score:.3f}")
|
|
|
|
return score
|
|
|
|
|
|
# ============================================================================
|
|
# Prompt Export
|
|
# ============================================================================
|
|
|
|
def extract_optimized_prompt(module: dspy.Module) -> str:
|
|
"""Extract the optimized prompt template from a DSPy module."""
|
|
# Get the predictor
|
|
predictor = None
|
|
for name, child in module.named_predictors():
|
|
predictor = child
|
|
break
|
|
|
|
if predictor is None:
|
|
return "Could not extract prompt - no predictor found"
|
|
|
|
# Build prompt representation
|
|
prompt_parts = []
|
|
|
|
# Add signature docstring if available
|
|
if hasattr(predictor, 'signature'):
|
|
sig = predictor.signature
|
|
if sig.__doc__:
|
|
prompt_parts.append(f"# Task Description\n{sig.__doc__}\n")
|
|
|
|
# Add input/output field descriptions
|
|
prompt_parts.append("# Input Fields")
|
|
for field_name, field in sig.input_fields.items():
|
|
desc = getattr(field, 'desc', '') or ''
|
|
prompt_parts.append(f"- {field_name}: {desc}")
|
|
|
|
prompt_parts.append("\n# Output Fields")
|
|
for field_name, field in sig.output_fields.items():
|
|
desc = getattr(field, 'desc', '') or ''
|
|
prompt_parts.append(f"- {field_name}: {desc}")
|
|
|
|
# Add any demos/examples
|
|
if hasattr(predictor, 'demos') and predictor.demos:
|
|
prompt_parts.append("\n# Few-Shot Examples")
|
|
for i, demo in enumerate(predictor.demos):
|
|
prompt_parts.append(f"\n## Example {i+1}")
|
|
for key, value in demo.items():
|
|
if isinstance(value, str) and len(value) > 200:
|
|
value = value[:200] + "..."
|
|
prompt_parts.append(f"{key}: {value}")
|
|
|
|
# Add extended signature if available (from MIPRO)
|
|
if hasattr(predictor, 'extended_signature'):
|
|
prompt_parts.append("\n# Optimized Instructions")
|
|
prompt_parts.append(str(predictor.extended_signature))
|
|
|
|
return "\n".join(prompt_parts)
|
|
|
|
|
|
def format_as_lightrag_prompt(optimized_prompt: str) -> str:
|
|
"""Format the optimized prompt for use in lightrag/prompt.py."""
|
|
template = '''"""
|
|
RAG Response Prompt - Optimized by DSPy
|
|
========================================
|
|
|
|
Copy this to lightrag/prompt.py as PROMPTS['rag_response']
|
|
"""
|
|
|
|
PROMPTS['rag_response'] = """Answer based ONLY on this context:
|
|
|
|
{{context_data}}
|
|
|
|
---
|
|
|
|
{optimized_instructions}
|
|
|
|
Question: {{user_prompt}}
|
|
|
|
Answer (grounded only in context above):"""
|
|
'''
|
|
|
|
return template.format(optimized_instructions=optimized_prompt)
|
|
|
|
|
|
# ============================================================================
|
|
# Main
|
|
# ============================================================================
|
|
|
|
async def main():
|
|
parser = argparse.ArgumentParser(description='DSPy Prompt Optimization for LightRAG')
|
|
parser.add_argument('--num-queries', '-n', type=int, help='Number of queries to use')
|
|
parser.add_argument('--server', '-s', type=str, default='http://localhost:9621', help='LightRAG server URL')
|
|
parser.add_argument('--mode', '-m', choices=['light', 'medium', 'heavy'], default='light',
|
|
help='Optimization intensity')
|
|
parser.add_argument('--optimizer', '-o', choices=['mipro', 'bootstrap', 'both'], default='bootstrap',
|
|
help='Optimizer to use')
|
|
parser.add_argument('--module', choices=['simple', 'cot', 'reasoning'], default='cot',
|
|
help='DSPy module architecture')
|
|
parser.add_argument('--fast-metric', action='store_true', help='Use fast heuristic metric instead of RAGAS')
|
|
parser.add_argument('--export-prompt', type=str, help='Export optimized prompt to file')
|
|
parser.add_argument('--save-module', type=str, help='Save optimized module to JSON')
|
|
args = parser.parse_args()
|
|
|
|
# Configure DSPy
|
|
lm = dspy.LM(
|
|
model=os.getenv('LLM_MODEL', 'openai/gpt-4o-mini'),
|
|
api_key=os.getenv('LLM_BINDING_API_KEY') or os.getenv('OPENAI_API_KEY'),
|
|
api_base=os.getenv('LLM_BINDING_HOST', 'https://api.openai.com/v1'),
|
|
temperature=0.1,
|
|
max_tokens=32000, # OpenRouter models support high limits
|
|
)
|
|
dspy.configure(lm=lm)
|
|
|
|
print("=" * 70)
|
|
print("DSPy Prompt Optimization for LightRAG")
|
|
print("=" * 70)
|
|
print(f"LLM: {os.getenv('LLM_MODEL', 'gpt-4o-mini')}")
|
|
print(f"Optimizer: {args.optimizer}")
|
|
print(f"Module: {args.module}")
|
|
print(f"Mode: {args.mode}")
|
|
print("=" * 70)
|
|
|
|
# Load dataset
|
|
dataset_path = Path(__file__).parent / 'pharma_test_dataset.json'
|
|
dataset = load_dataset(dataset_path, args.num_queries)
|
|
|
|
# Prepare examples
|
|
examples = await prepare_dspy_examples(dataset, args.server)
|
|
|
|
if len(examples) < 2:
|
|
print("Error: Need at least 2 examples for optimization")
|
|
return
|
|
|
|
# Split into train/test
|
|
split_idx = max(1, len(examples) - 2)
|
|
trainset = examples[:split_idx]
|
|
testset = examples[split_idx:]
|
|
|
|
print(f"\nTrain set: {len(trainset)} examples")
|
|
print(f"Test set: {len(testset)} examples")
|
|
|
|
# Select module
|
|
if args.module == 'simple':
|
|
module = SimpleRAG()
|
|
elif args.module == 'cot':
|
|
module = ChainOfThoughtRAG()
|
|
else:
|
|
module = ReasoningRAG()
|
|
|
|
# Select metric
|
|
metric = fast_metric if args.fast_metric else ragas_metric
|
|
|
|
# Evaluate baseline
|
|
print("\n" + "=" * 70)
|
|
print("BASELINE EVALUATION")
|
|
print("=" * 70)
|
|
baseline_score = evaluate_module(module, testset, metric, "Baseline")
|
|
|
|
# Optimize
|
|
print("\n" + "=" * 70)
|
|
print("OPTIMIZATION")
|
|
print("=" * 70)
|
|
|
|
optimized_module = None
|
|
|
|
if args.optimizer in ['bootstrap', 'both']:
|
|
optimized_module = optimize_with_bootstrap(module, trainset, metric)
|
|
bootstrap_score = evaluate_module(optimized_module, testset, metric, "Bootstrap")
|
|
|
|
if args.optimizer in ['mipro', 'both']:
|
|
base_for_mipro = optimized_module if optimized_module else module
|
|
optimized_module = optimize_with_mipro(base_for_mipro, trainset, metric, args.mode)
|
|
mipro_score = evaluate_module(optimized_module, testset, metric, "MIPRO")
|
|
|
|
# Extract and display optimized prompt
|
|
print("\n" + "=" * 70)
|
|
print("OPTIMIZED PROMPT")
|
|
print("=" * 70)
|
|
|
|
optimized_prompt = extract_optimized_prompt(optimized_module)
|
|
print(optimized_prompt)
|
|
|
|
# Export if requested
|
|
if args.export_prompt:
|
|
formatted = format_as_lightrag_prompt(optimized_prompt)
|
|
with open(args.export_prompt, 'w') as f:
|
|
f.write(formatted)
|
|
print(f"\nExported to: {args.export_prompt}")
|
|
|
|
if args.save_module:
|
|
optimized_module.save(args.save_module)
|
|
print(f"Saved module to: {args.save_module}")
|
|
|
|
print("\n" + "=" * 70)
|
|
print("SUMMARY")
|
|
print("=" * 70)
|
|
print(f"Baseline score: {baseline_score:.3f}")
|
|
if args.optimizer in ['bootstrap', 'both']:
|
|
print(f"Bootstrap score: {bootstrap_score:.3f}")
|
|
if args.optimizer in ['mipro', 'both']:
|
|
print(f"MIPRO score: {mipro_score:.3f}")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
asyncio.run(main())
|