Remove legacy storage implementations and deprecated examples: - Delete FAISS, JSON, Memgraph, Milvus, MongoDB, Nano Vector DB, Neo4j, NetworkX, Qdrant, Redis storage backends - Remove Kubernetes deployment manifests and installation scripts - Delete unofficial examples for deprecated backends and offline deployment docs Streamline core infrastructure: - Consolidate storage layer to PostgreSQL-only implementation - Add full-text search caching with FTS cache module - Implement metrics collection and monitoring pipeline - Add explain and metrics API routes Modernize frontend and tooling: - Switch web UI to Bun with bun.lock, remove npm and pnpm lockfiles - Update Dockerfile for PostgreSQL-only deployment - Add Makefile for common development tasks - Update environment and configuration examples Enhance evaluation and testing capabilities: - Add prompt optimization with DSPy and auto-tuning - Implement ground truth regeneration and variant testing - Add prompt debugging and response comparison utilities - Expand test coverage with new integration scenarios Simplify dependencies and configuration: - Remove offline-specific requirement files - Update pyproject.toml with streamlined dependencies - Add Python version pinning with .python-version - Create project guidelines in CLAUDE.md and AGENTS.md
211 lines
6.9 KiB
Python
211 lines
6.9 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Test HyDE (Hypothetical Document Embeddings) concept in isolation.
|
|
|
|
Compares retrieval quality between:
|
|
1. Standard: embed query directly
|
|
2. HyDE: embed hypothetical answers, average them
|
|
|
|
Uses the same embedding model and vector DB as LightRAG.
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
import os
|
|
|
|
import numpy as np
|
|
from openai import AsyncOpenAI
|
|
|
|
# Config
|
|
EMBEDDING_MODEL = 'text-embedding-3-large'
|
|
LLM_MODEL = 'gpt-4o-mini'
|
|
PG_HOST = os.getenv('POSTGRES_HOST', 'localhost')
|
|
PG_PORT = os.getenv('POSTGRES_PORT', '5433')
|
|
PG_USER = os.getenv('POSTGRES_USER', 'lightrag')
|
|
PG_PASS = os.getenv('POSTGRES_PASSWORD', 'lightrag_pass')
|
|
PG_DB = os.getenv('POSTGRES_DATABASE', 'lightrag')
|
|
|
|
client = AsyncOpenAI()
|
|
|
|
|
|
async def get_embedding(text: str) -> list[float]:
|
|
"""Get embedding for a single text."""
|
|
response = await client.embeddings.create(
|
|
model=EMBEDDING_MODEL,
|
|
input=text,
|
|
dimensions=1536, # Match DB dimension (text-embedding-3-large at 1536)
|
|
)
|
|
return response.data[0].embedding
|
|
|
|
|
|
async def generate_hypothetical_answers(query: str, num_docs: int = 3) -> list[str]:
|
|
"""Generate hypothetical answers using LLM."""
|
|
prompt = f"""Generate {num_docs} brief hypothetical answers to this question.
|
|
Each answer should be 2-3 sentences of factual-sounding content that would answer the question.
|
|
Write as if you have the information - no hedging or "I don't know".
|
|
|
|
Question: {query}
|
|
|
|
Output valid JSON:
|
|
{{"hypothetical_answers": ["answer1", "answer2", "answer3"]}}"""
|
|
|
|
response = await client.chat.completions.create(
|
|
model=LLM_MODEL,
|
|
messages=[{'role': 'user', 'content': prompt}],
|
|
temperature=0.7,
|
|
)
|
|
|
|
content = response.choices[0].message.content
|
|
# Parse JSON from response
|
|
try:
|
|
# Handle markdown code blocks
|
|
if '```json' in content:
|
|
content = content.split('```json')[1].split('```')[0]
|
|
elif '```' in content:
|
|
content = content.split('```')[1].split('```')[0]
|
|
data = json.loads(content.strip())
|
|
return data.get('hypothetical_answers', [])
|
|
except json.JSONDecodeError as e:
|
|
print(f'Failed to parse JSON: {e}')
|
|
print(f'Raw content: {content}')
|
|
return []
|
|
|
|
|
|
def average_embeddings(embeddings: list[list[float]]) -> list[float]:
|
|
"""Average multiple embedding vectors."""
|
|
arr = np.array(embeddings)
|
|
return arr.mean(axis=0).tolist()
|
|
|
|
|
|
async def search_chunks(embedding: list[float], top_k: int = 5) -> list[dict]:
|
|
"""Search chunks in PostgreSQL using the embedding."""
|
|
import asyncpg
|
|
|
|
conn = await asyncpg.connect(
|
|
host=PG_HOST,
|
|
port=int(PG_PORT),
|
|
user=PG_USER,
|
|
password=PG_PASS,
|
|
database=PG_DB,
|
|
)
|
|
|
|
embedding_str = ','.join(map(str, embedding))
|
|
|
|
query = f"""
|
|
SELECT c.id,
|
|
LEFT(c.content, 200) as content_preview,
|
|
c.content_vector <=> '[{embedding_str}]'::vector as distance
|
|
FROM lightrag_vdb_chunks c
|
|
WHERE c.workspace = 'default'
|
|
ORDER BY c.content_vector <=> '[{embedding_str}]'::vector
|
|
LIMIT $1;
|
|
"""
|
|
|
|
rows = await conn.fetch(query, top_k)
|
|
await conn.close()
|
|
|
|
return [{'id': r['id'], 'preview': r['content_preview'], 'distance': float(r['distance'])} for r in rows]
|
|
|
|
|
|
async def compare_retrieval(query: str):
|
|
"""Compare standard vs HyDE retrieval for a query."""
|
|
print(f'\n{"=" * 80}')
|
|
print(f'QUERY: {query}')
|
|
print('=' * 80)
|
|
|
|
# Standard retrieval
|
|
print('\n📌 STANDARD (embed query directly):')
|
|
query_embedding = await get_embedding(query)
|
|
standard_results = await search_chunks(query_embedding, top_k=5)
|
|
|
|
for i, r in enumerate(standard_results, 1):
|
|
print(f' {i}. [dist={r["distance"]:.4f}] {r["preview"][:100]}...')
|
|
|
|
avg_standard_dist = np.mean([r['distance'] for r in standard_results])
|
|
print(f' → Avg distance: {avg_standard_dist:.4f}')
|
|
|
|
# HyDE retrieval
|
|
print('\n🔮 HYDE (embed hypothetical answers):')
|
|
hypotheticals = await generate_hypothetical_answers(query, num_docs=3)
|
|
|
|
print(' Generated hypotheticals:')
|
|
for i, h in enumerate(hypotheticals, 1):
|
|
print(f' {i}. {h[:100]}...')
|
|
|
|
# Embed hypotheticals and average
|
|
hyde_embeddings = []
|
|
for h in hypotheticals:
|
|
emb = await get_embedding(h)
|
|
hyde_embeddings.append(emb)
|
|
|
|
hyde_embedding = average_embeddings(hyde_embeddings)
|
|
hyde_results = await search_chunks(hyde_embedding, top_k=5)
|
|
|
|
print('\n Results:')
|
|
for i, r in enumerate(hyde_results, 1):
|
|
print(f' {i}. [dist={r["distance"]:.4f}] {r["preview"][:100]}...')
|
|
|
|
avg_hyde_dist = np.mean([r['distance'] for r in hyde_results])
|
|
print(f' → Avg distance: {avg_hyde_dist:.4f}')
|
|
|
|
# Compare
|
|
print('\n📊 COMPARISON:')
|
|
improvement = avg_standard_dist - avg_hyde_dist
|
|
pct = (improvement / avg_standard_dist) * 100 if avg_standard_dist > 0 else 0
|
|
|
|
if improvement > 0:
|
|
print(f' ✅ HyDE is BETTER by {improvement:.4f} ({pct:.1f}% closer)')
|
|
else:
|
|
print(f' ❌ Standard is BETTER by {-improvement:.4f} ({-pct:.1f}% closer)')
|
|
|
|
# Check overlap
|
|
standard_ids = {r['id'] for r in standard_results}
|
|
hyde_ids = {r['id'] for r in hyde_results}
|
|
overlap = len(standard_ids & hyde_ids)
|
|
print(f' 📎 Overlap: {overlap}/5 chunks in common')
|
|
|
|
return {
|
|
'query': query,
|
|
'standard_avg_dist': avg_standard_dist,
|
|
'hyde_avg_dist': avg_hyde_dist,
|
|
'improvement': improvement,
|
|
'overlap': overlap,
|
|
}
|
|
|
|
|
|
async def main():
|
|
# Test queries from pharma dataset
|
|
test_queries = [
|
|
'What were the key lessons learned from the Isatuximab monoclonal antibody drug development program in April 2020?',
|
|
'What CMC dossier lessons were learned from the PKU IND submission in 2023?',
|
|
'What risk management strategies were discussed in the 2017 Risk Review CIR for CMC development programs?',
|
|
'What biopharmacy considerations were discussed in the February 2022 CMC Cross Sharing session?',
|
|
'What were the main challenges and lessons learned from the COVID-19 mRNA vaccine development?',
|
|
]
|
|
|
|
results = []
|
|
for query in test_queries:
|
|
result = await compare_retrieval(query)
|
|
results.append(result)
|
|
|
|
# Summary
|
|
print('\n' + '=' * 80)
|
|
print('SUMMARY')
|
|
print('=' * 80)
|
|
|
|
hyde_wins = sum(1 for r in results if r['improvement'] > 0)
|
|
avg_improvement = np.mean([r['improvement'] for r in results])
|
|
avg_overlap = np.mean([r['overlap'] for r in results])
|
|
|
|
print(f'HyDE wins: {hyde_wins}/{len(results)} queries')
|
|
print(f'Avg distance improvement: {avg_improvement:.4f}')
|
|
print(f'Avg overlap with standard: {avg_overlap:.1f}/5 chunks')
|
|
|
|
if hyde_wins >= len(results) / 2:
|
|
print('\n✅ HyDE shows promise - worth implementing!')
|
|
else:
|
|
print("\n⚠️ HyDE doesn't help much for these queries")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
asyncio.run(main())
|