This commit is contained in:
Raphaël MANSUY 2025-12-04 19:19:00 +08:00
parent c53b7cba76
commit af49572886

View file

@ -1,6 +1,6 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
""" """
RAGAS Evaluation Script for LightRAG System RAGAS Evaluation Script for Portfolio RAG System
Evaluates RAG response quality using RAGAS metrics: Evaluates RAG response quality using RAGAS metrics:
- Faithfulness: Is the answer factually accurate based on context? - Faithfulness: Is the answer factually accurate based on context?
@ -9,35 +9,15 @@ Evaluates RAG response quality using RAGAS metrics:
- Context Precision: Is retrieved context clean without noise? - Context Precision: Is retrieved context clean without noise?
Usage: Usage:
# Use defaults (sample_dataset.json, http://localhost:9621)
python lightrag/evaluation/eval_rag_quality.py python lightrag/evaluation/eval_rag_quality.py
python lightrag/evaluation/eval_rag_quality.py http://localhost:9621
# Specify custom dataset python lightrag/evaluation/eval_rag_quality.py http://your-rag-server.com:9621
python lightrag/evaluation/eval_rag_quality.py --dataset my_test.json
python lightrag/evaluation/eval_rag_quality.py -d my_test.json
# Specify custom RAG endpoint
python lightrag/evaluation/eval_rag_quality.py --ragendpoint http://my-server.com:9621
python lightrag/evaluation/eval_rag_quality.py -r http://my-server.com:9621
# Specify both
python lightrag/evaluation/eval_rag_quality.py -d my_test.json -r http://localhost:9621
# Get help
python lightrag/evaluation/eval_rag_quality.py --help
Results are saved to: lightrag/evaluation/results/ Results are saved to: lightrag/evaluation/results/
- results_YYYYMMDD_HHMMSS.csv (CSV export for analysis) - results_YYYYMMDD_HHMMSS.csv (CSV export for analysis)
- results_YYYYMMDD_HHMMSS.json (Full results with details) - results_YYYYMMDD_HHMMSS.json (Full results with details)
Technical Notes:
- Uses stable RAGAS API (LangchainLLMWrapper) for maximum compatibility
- Supports custom OpenAI-compatible endpoints via EVAL_LLM_BINDING_HOST
- Enables bypass_n mode for endpoints that don't support 'n' parameter
- Deprecation warnings are suppressed for cleaner output
""" """
import argparse
import asyncio import asyncio
import csv import csv
import json import json
@ -45,7 +25,6 @@ import math
import os import os
import sys import sys
import time import time
import warnings
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List from typing import Any, Dict, List
@ -54,52 +33,49 @@ import httpx
from dotenv import load_dotenv from dotenv import load_dotenv
from lightrag.utils import logger from lightrag.utils import logger
# Suppress LangchainLLMWrapper deprecation warning
# We use LangchainLLMWrapper for stability and compatibility with all RAGAS versions
warnings.filterwarnings(
"ignore",
message=".*LangchainLLMWrapper is deprecated.*",
category=DeprecationWarning,
)
# Suppress token usage warning for custom OpenAI-compatible endpoints
# Custom endpoints (vLLM, SGLang, etc.) often don't return usage information
# This is non-critical as token tracking is not required for RAGAS evaluation
warnings.filterwarnings(
"ignore",
message=".*Unexpected type for token usage.*",
category=UserWarning,
)
# Add parent directory to path # Add parent directory to path
sys.path.insert(0, str(Path(__file__).parent.parent.parent)) sys.path.insert(0, str(Path(__file__).parent.parent.parent))
# use the .env that is inside the current folder # Load .env from project root
# allows to use different .env file for each lightrag instance project_root = Path(__file__).parent.parent.parent
# the OS environment variables take precedence over the .env file load_dotenv(project_root / ".env")
load_dotenv(dotenv_path=".env", override=False)
# Setup OpenAI API key (required for RAGAS evaluation)
# Use LLM_BINDING_API_KEY when running with the OpenAI binding
llm_binding = os.getenv("LLM_BINDING", "").lower()
llm_binding_key = os.getenv("LLM_BINDING_API_KEY")
# Validate LLM_BINDING is set to openai
if llm_binding != "openai":
logger.error(
"❌ LLM_BINDING must be set to 'openai'. Current value: '%s'",
llm_binding or "(not set)",
)
sys.exit(1)
# Validate LLM_BINDING_API_KEY exists
if not llm_binding_key:
logger.error("❌ LLM_BINDING_API_KEY is not set. Cannot run RAGAS evaluation.")
sys.exit(1)
# Set OPENAI_API_KEY from LLM_BINDING_API_KEY
os.environ["OPENAI_API_KEY"] = llm_binding_key
logger.info("✅ LLM_BINDING: openai")
# Conditional imports - will raise ImportError if dependencies not installed
try: try:
from datasets import Dataset from datasets import Dataset
from ragas import evaluate from ragas import evaluate
from ragas.metrics import ( from ragas.metrics import (
AnswerRelevancy, answer_relevancy,
ContextPrecision, context_precision,
ContextRecall, context_recall,
Faithfulness, faithfulness,
) )
from ragas.llms import LangchainLLMWrapper except ImportError as e:
from langchain_openai import ChatOpenAI, OpenAIEmbeddings logger.error("❌ RAGAS import error: %s", e)
from tqdm.auto import tqdm logger.error(" Install with: pip install ragas datasets")
sys.exit(1)
RAGAS_AVAILABLE = True
except ImportError:
RAGAS_AVAILABLE = False
Dataset = None
evaluate = None
LangchainLLMWrapper = None
CONNECT_TIMEOUT_SECONDS = 180.0 CONNECT_TIMEOUT_SECONDS = 180.0
@ -123,94 +99,7 @@ class RAGEvaluator:
test_dataset_path: Path to test dataset JSON file test_dataset_path: Path to test dataset JSON file
rag_api_url: Base URL of LightRAG API (e.g., http://localhost:9621) rag_api_url: Base URL of LightRAG API (e.g., http://localhost:9621)
If None, will try to read from environment or use default If None, will try to read from environment or use default
Environment Variables:
EVAL_LLM_MODEL: LLM model for evaluation (default: gpt-4o-mini)
EVAL_EMBEDDING_MODEL: Embedding model for evaluation (default: text-embedding-3-small)
EVAL_LLM_BINDING_API_KEY: API key for LLM (fallback to OPENAI_API_KEY)
EVAL_LLM_BINDING_HOST: Custom endpoint URL for LLM (optional)
EVAL_EMBEDDING_BINDING_API_KEY: API key for embeddings (fallback: EVAL_LLM_BINDING_API_KEY -> OPENAI_API_KEY)
EVAL_EMBEDDING_BINDING_HOST: Custom endpoint URL for embeddings (fallback: EVAL_LLM_BINDING_HOST)
Raises:
ImportError: If ragas or datasets packages are not installed
EnvironmentError: If EVAL_LLM_BINDING_API_KEY and OPENAI_API_KEY are both not set
""" """
# Validate RAGAS dependencies are installed
if not RAGAS_AVAILABLE:
raise ImportError(
"RAGAS dependencies not installed. "
"Install with: pip install ragas datasets"
)
# Configure evaluation LLM (for RAGAS scoring)
eval_llm_api_key = os.getenv("EVAL_LLM_BINDING_API_KEY") or os.getenv(
"OPENAI_API_KEY"
)
if not eval_llm_api_key:
raise EnvironmentError(
"EVAL_LLM_BINDING_API_KEY or OPENAI_API_KEY is required for evaluation. "
"Set EVAL_LLM_BINDING_API_KEY to use a custom API key, "
"or ensure OPENAI_API_KEY is set."
)
eval_model = os.getenv("EVAL_LLM_MODEL", "gpt-4o-mini")
eval_llm_base_url = os.getenv("EVAL_LLM_BINDING_HOST")
# Configure evaluation embeddings (for RAGAS scoring)
# Fallback chain: EVAL_EMBEDDING_BINDING_API_KEY -> EVAL_LLM_BINDING_API_KEY -> OPENAI_API_KEY
eval_embedding_api_key = (
os.getenv("EVAL_EMBEDDING_BINDING_API_KEY")
or os.getenv("EVAL_LLM_BINDING_API_KEY")
or os.getenv("OPENAI_API_KEY")
)
eval_embedding_model = os.getenv(
"EVAL_EMBEDDING_MODEL", "text-embedding-3-large"
)
# Fallback chain: EVAL_EMBEDDING_BINDING_HOST -> EVAL_LLM_BINDING_HOST -> None
eval_embedding_base_url = os.getenv("EVAL_EMBEDDING_BINDING_HOST") or os.getenv(
"EVAL_LLM_BINDING_HOST"
)
# Create LLM and Embeddings instances for RAGAS
llm_kwargs = {
"model": eval_model,
"api_key": eval_llm_api_key,
"max_retries": int(os.getenv("EVAL_LLM_MAX_RETRIES", "5")),
"request_timeout": int(os.getenv("EVAL_LLM_TIMEOUT", "180")),
}
embedding_kwargs = {
"model": eval_embedding_model,
"api_key": eval_embedding_api_key,
}
if eval_llm_base_url:
llm_kwargs["base_url"] = eval_llm_base_url
if eval_embedding_base_url:
embedding_kwargs["base_url"] = eval_embedding_base_url
# Create base LangChain LLM
base_llm = ChatOpenAI(**llm_kwargs)
self.eval_embeddings = OpenAIEmbeddings(**embedding_kwargs)
# Wrap LLM with LangchainLLMWrapper and enable bypass_n mode for custom endpoints
# This ensures compatibility with endpoints that don't support the 'n' parameter
# by generating multiple outputs through repeated prompts instead of using 'n' parameter
try:
self.eval_llm = LangchainLLMWrapper(
langchain_llm=base_llm,
bypass_n=True, # Enable bypass_n to avoid passing 'n' to OpenAI API
)
logger.debug("Successfully configured bypass_n mode for LLM wrapper")
except Exception as e:
logger.warning(
"Could not configure LangchainLLMWrapper with bypass_n: %s. "
"Using base LLM directly, which may cause warnings with custom endpoints.",
e,
)
self.eval_llm = base_llm
if test_dataset_path is None: if test_dataset_path is None:
test_dataset_path = Path(__file__).parent / "sample_dataset.json" test_dataset_path = Path(__file__).parent / "sample_dataset.json"
@ -225,58 +114,6 @@ class RAGEvaluator:
# Load test dataset # Load test dataset
self.test_cases = self._load_test_dataset() self.test_cases = self._load_test_dataset()
# Store configuration values for display
self.eval_model = eval_model
self.eval_embedding_model = eval_embedding_model
self.eval_llm_base_url = eval_llm_base_url
self.eval_embedding_base_url = eval_embedding_base_url
self.eval_max_retries = llm_kwargs["max_retries"]
self.eval_timeout = llm_kwargs["request_timeout"]
# Display configuration
self._display_configuration()
def _display_configuration(self):
"""Display all evaluation configuration settings"""
logger.info("Evaluation Models:")
logger.info(" • LLM Model: %s", self.eval_model)
logger.info(" • Embedding Model: %s", self.eval_embedding_model)
# Display LLM endpoint
if self.eval_llm_base_url:
logger.info(" • LLM Endpoint: %s", self.eval_llm_base_url)
logger.info(
" • Bypass N-Parameter: Enabled (use LangchainLLMWrapper for compatibility)"
)
else:
logger.info(" • LLM Endpoint: OpenAI Official API")
# Display Embedding endpoint (only if different from LLM)
if self.eval_embedding_base_url:
if self.eval_embedding_base_url != self.eval_llm_base_url:
logger.info(
" • Embedding Endpoint: %s", self.eval_embedding_base_url
)
# If same as LLM endpoint, no need to display separately
elif not self.eval_llm_base_url:
# Both using OpenAI - already displayed above
pass
else:
# LLM uses custom endpoint, but embeddings use OpenAI
logger.info(" • Embedding Endpoint: OpenAI Official API")
logger.info("Concurrency & Rate Limiting:")
query_top_k = int(os.getenv("EVAL_QUERY_TOP_K", "10"))
logger.info(" • Query Top-K: %s Entities/Relations", query_top_k)
logger.info(" • LLM Max Retries: %s", self.eval_max_retries)
logger.info(" • LLM Timeout: %s seconds", self.eval_timeout)
logger.info("Test Configuration:")
logger.info(" • Total Test Cases: %s", len(self.test_cases))
logger.info(" • Test Dataset: %s", self.test_dataset_path.name)
logger.info(" • LightRAG API: %s", self.rag_api_url)
logger.info(" • Results Directory: %s", self.results_dir.name)
def _load_test_dataset(self) -> List[Dict[str, str]]: def _load_test_dataset(self) -> List[Dict[str, str]]:
"""Load test cases from JSON file""" """Load test cases from JSON file"""
if not self.test_dataset_path.exists(): if not self.test_dataset_path.exists():
@ -313,22 +150,13 @@ class RAGEvaluator:
"include_references": True, "include_references": True,
"include_chunk_content": True, # NEW: Request chunk content in references "include_chunk_content": True, # NEW: Request chunk content in references
"response_type": "Multiple Paragraphs", "response_type": "Multiple Paragraphs",
"top_k": int(os.getenv("EVAL_QUERY_TOP_K", "10")), "top_k": 10,
} }
# Get API key from environment for authentication
api_key = os.getenv("LIGHTRAG_API_KEY")
# Prepare headers with optional authentication
headers = {}
if api_key:
headers["X-API-Key"] = api_key
# Single optimized API call - gets both answer AND chunk content # Single optimized API call - gets both answer AND chunk content
response = await client.post( response = await client.post(
f"{self.rag_api_url}/query", f"{self.rag_api_url}/query",
json=payload, json=payload,
headers=headers if headers else None,
) )
response.raise_for_status() response.raise_for_status()
result = response.json() result = response.json()
@ -342,26 +170,14 @@ class RAGEvaluator:
first_ref = references[0] first_ref = references[0]
logger.debug("🔍 First Reference Keys: %s", list(first_ref.keys())) logger.debug("🔍 First Reference Keys: %s", list(first_ref.keys()))
if "content" in first_ref: if "content" in first_ref:
content_preview = first_ref["content"] logger.debug(
if isinstance(content_preview, list) and content_preview: "🔍 Content Preview: %s...", first_ref["content"][:100]
logger.debug( )
"🔍 Content Preview (first chunk): %s...",
content_preview[0][:100],
)
elif isinstance(content_preview, str):
logger.debug("🔍 Content Preview: %s...", content_preview[:100])
# Extract chunk content from enriched references # Extract chunk content from enriched references
# Note: content is now a list of chunks per reference (one file may have multiple chunks) contexts = [
contexts = [] ref.get("content", "") for ref in references if ref.get("content")
for ref in references: ]
content = ref.get("content", [])
if isinstance(content, list):
# Flatten the list: each chunk becomes a separate context
contexts.extend(content)
elif isinstance(content, str):
# Backward compatibility: if content is still a string (shouldn't happen)
contexts.append(content)
return { return {
"answer": answer, "answer": answer,
@ -392,55 +208,46 @@ class RAGEvaluator:
self, self,
idx: int, idx: int,
test_case: Dict[str, str], test_case: Dict[str, str],
rag_semaphore: asyncio.Semaphore, semaphore: asyncio.Semaphore,
eval_semaphore: asyncio.Semaphore,
client: httpx.AsyncClient, client: httpx.AsyncClient,
progress_counter: Dict[str, int],
position_pool: asyncio.Queue,
pbar_creation_lock: asyncio.Lock,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
Evaluate a single test case with two-stage pipeline concurrency control Evaluate a single test case with concurrency control
Args: Args:
idx: Test case index (1-based) idx: Test case index (1-based)
test_case: Test case dictionary with question and ground_truth test_case: Test case dictionary with question and ground_truth
rag_semaphore: Semaphore to control overall concurrency (covers entire function) semaphore: Semaphore to control concurrency
eval_semaphore: Semaphore to control RAGAS evaluation concurrency (Stage 2)
client: Shared httpx AsyncClient for connection pooling client: Shared httpx AsyncClient for connection pooling
progress_counter: Shared dictionary for progress tracking
position_pool: Queue of available tqdm position indices
pbar_creation_lock: Lock to serialize tqdm creation and prevent race conditions
Returns: Returns:
Evaluation result dictionary Evaluation result dictionary
""" """
# rag_semaphore controls the entire evaluation process to prevent total_cases = len(self.test_cases)
# all RAG responses from being generated at once when eval is slow
async with rag_semaphore: async with semaphore:
question = test_case["question"] question = test_case["question"]
ground_truth = test_case["ground_truth"] ground_truth = test_case["ground_truth"]
# Stage 1: Generate RAG response logger.info("[%s/%s] Evaluating: %s...", idx, total_cases, question[:60])
try:
rag_response = await self.generate_rag_response( # Generate RAG response by calling actual LightRAG API
question=question, client=client rag_response = await self.generate_rag_response(
) question=question, client=client
except Exception as e: )
logger.error("Error generating response for test %s: %s", idx, str(e))
progress_counter["completed"] += 1
return {
"test_number": idx,
"question": question,
"error": str(e),
"metrics": {},
"ragas_score": 0,
"timestamp": datetime.now().isoformat(),
}
# *** CRITICAL FIX: Use actual retrieved contexts, NOT ground_truth *** # *** CRITICAL FIX: Use actual retrieved contexts, NOT ground_truth ***
retrieved_contexts = rag_response["contexts"] retrieved_contexts = rag_response["contexts"]
# DEBUG: Print what was actually retrieved
logger.debug("📝 Retrieved %s contexts", len(retrieved_contexts))
if retrieved_contexts:
logger.debug(
"📄 First context preview: %s...", retrieved_contexts[0][:100]
)
else:
logger.warning("⚠️ No contexts retrieved!")
# Prepare dataset for RAGAS evaluation with CORRECT contexts # Prepare dataset for RAGAS evaluation with CORRECT contexts
eval_dataset = Dataset.from_dict( eval_dataset = Dataset.from_dict(
{ {
@ -451,141 +258,88 @@ class RAGEvaluator:
} }
) )
# Stage 2: Run RAGAS evaluation (controlled by eval_semaphore) # Run RAGAS evaluation
# IMPORTANT: Create fresh metric instances for each evaluation to avoid try:
# concurrent state conflicts when multiple tasks run in parallel eval_results = evaluate(
async with eval_semaphore: dataset=eval_dataset,
pbar = None metrics=[
position = None faithfulness,
try: answer_relevancy,
# Acquire a position from the pool for this tqdm progress bar context_recall,
position = await position_pool.get() context_precision,
],
)
# Serialize tqdm creation to prevent race conditions # Convert to DataFrame (RAGAS v0.3+ API)
# Multiple tasks creating tqdm simultaneously can cause display conflicts df = eval_results.to_pandas()
async with pbar_creation_lock:
# Create tqdm progress bar with assigned position to avoid overlapping
# leave=False ensures the progress bar is cleared after completion,
# preventing accumulation of completed bars and allowing position reuse
pbar = tqdm(
total=4,
desc=f"Eval-{idx:02d}",
position=position,
leave=False,
)
# Give tqdm time to initialize and claim its screen position
await asyncio.sleep(0.05)
eval_results = evaluate( # Extract scores from first row
dataset=eval_dataset, scores_row = df.iloc[0]
metrics=[
Faithfulness(),
AnswerRelevancy(),
ContextRecall(),
ContextPrecision(),
],
llm=self.eval_llm,
embeddings=self.eval_embeddings,
_pbar=pbar,
)
# Convert to DataFrame (RAGAS v0.3+ API) # Extract scores (RAGAS v0.3+ uses .to_pandas())
df = eval_results.to_pandas() result = {
"question": question,
"answer": rag_response["answer"][:200] + "..."
if len(rag_response["answer"]) > 200
else rag_response["answer"],
"ground_truth": ground_truth[:200] + "..."
if len(ground_truth) > 200
else ground_truth,
"project": test_case.get("project_context", "unknown"),
"metrics": {
"faithfulness": float(scores_row.get("faithfulness", 0)),
"answer_relevance": float(
scores_row.get("answer_relevancy", 0)
),
"context_recall": float(scores_row.get("context_recall", 0)),
"context_precision": float(
scores_row.get("context_precision", 0)
),
},
"timestamp": datetime.now().isoformat(),
}
# Extract scores from first row # Calculate RAGAS score (average of all metrics)
scores_row = df.iloc[0] metrics = result["metrics"]
ragas_score = sum(metrics.values()) / len(metrics) if metrics else 0
result["ragas_score"] = round(ragas_score, 4)
# Extract scores (RAGAS v0.3+ uses .to_pandas()) logger.info("✅ Faithfulness: %.4f", metrics["faithfulness"])
result = { logger.info("✅ Answer Relevance: %.4f", metrics["answer_relevance"])
"test_number": idx, logger.info("✅ Context Recall: %.4f", metrics["context_recall"])
"question": question, logger.info("✅ Context Precision: %.4f", metrics["context_precision"])
"answer": rag_response["answer"][:200] + "..." logger.info("📊 RAGAS Score: %.4f", result["ragas_score"])
if len(rag_response["answer"]) > 200
else rag_response["answer"],
"ground_truth": ground_truth[:200] + "..."
if len(ground_truth) > 200
else ground_truth,
"project": test_case.get("project", "unknown"),
"metrics": {
"faithfulness": float(scores_row.get("faithfulness", 0)),
"answer_relevance": float(
scores_row.get("answer_relevancy", 0)
),
"context_recall": float(
scores_row.get("context_recall", 0)
),
"context_precision": float(
scores_row.get("context_precision", 0)
),
},
"timestamp": datetime.now().isoformat(),
}
# Calculate RAGAS score (average of all metrics, excluding NaN values) return result
metrics = result["metrics"]
valid_metrics = [v for v in metrics.values() if not _is_nan(v)]
ragas_score = (
sum(valid_metrics) / len(valid_metrics) if valid_metrics else 0
)
result["ragas_score"] = round(ragas_score, 4)
# Update progress counter except Exception as e:
progress_counter["completed"] += 1 logger.exception("❌ Error evaluating: %s", e)
return {
return result "question": question,
"error": str(e),
except Exception as e: "metrics": {},
logger.error("Error evaluating test %s: %s", idx, str(e)) "ragas_score": 0,
progress_counter["completed"] += 1 "timestamp": datetime.now().isoformat(),
return { }
"test_number": idx,
"question": question,
"error": str(e),
"metrics": {},
"ragas_score": 0,
"timestamp": datetime.now().isoformat(),
}
finally:
# Force close progress bar to ensure completion
if pbar is not None:
pbar.close()
# Release the position back to the pool for reuse
if position is not None:
await position_pool.put(position)
async def evaluate_responses(self) -> List[Dict[str, Any]]: async def evaluate_responses(self) -> List[Dict[str, Any]]:
""" """
Evaluate all test cases in parallel with two-stage pipeline and return metrics Evaluate all test cases in parallel and return metrics
Returns: Returns:
List of evaluation results with metrics List of evaluation results with metrics
""" """
# Get evaluation concurrency from environment (default to 2 for parallel evaluation) # Get MAX_ASYNC from environment (default to 4 if not set)
max_async = int(os.getenv("EVAL_MAX_CONCURRENT", "2")) max_async = int(os.getenv("MAX_ASYNC", "4"))
logger.info("")
logger.info("%s", "=" * 70) logger.info("%s", "=" * 70)
logger.info("🚀 Starting RAGAS Evaluation of LightRAG System") logger.info("🚀 Starting RAGAS Evaluation of Portfolio RAG System")
logger.info("🔧 RAGAS Evaluation (Stage 2): %s concurrent", max_async) logger.info("🔧 Parallel evaluations: %s", max_async)
logger.info("%s", "=" * 70) logger.info("%s", "=" * 70)
# Create two-stage pipeline semaphores # Create semaphore to limit concurrent evaluations
# Stage 1: RAG generation - allow x2 concurrency to keep evaluation fed semaphore = asyncio.Semaphore(max_async)
rag_semaphore = asyncio.Semaphore(max_async * 2)
# Stage 2: RAGAS evaluation - primary bottleneck
eval_semaphore = asyncio.Semaphore(max_async)
# Create progress counter (shared across all tasks)
progress_counter = {"completed": 0}
# Create position pool for tqdm progress bars
# Positions range from 0 to max_async-1, ensuring no overlapping displays
position_pool = asyncio.Queue()
for i in range(max_async):
await position_pool.put(i)
# Create lock to serialize tqdm creation and prevent race conditions
# This ensures progress bars are created one at a time, avoiding display conflicts
pbar_creation_lock = asyncio.Lock()
# Create shared HTTP client with connection pooling and proper timeouts # Create shared HTTP client with connection pooling and proper timeouts
# Timeout: 3 minutes for connect, 5 minutes for read (LLM can be slow) # Timeout: 3 minutes for connect, 5 minutes for read (LLM can be slow)
@ -595,27 +349,18 @@ class RAGEvaluator:
read=READ_TIMEOUT_SECONDS, read=READ_TIMEOUT_SECONDS,
) )
limits = httpx.Limits( limits = httpx.Limits(
max_connections=(max_async + 1) * 2, # Allow buffer for RAG stage max_connections=max_async * 2, # Allow some buffer
max_keepalive_connections=max_async + 1, max_keepalive_connections=max_async,
) )
async with httpx.AsyncClient(timeout=timeout, limits=limits) as client: async with httpx.AsyncClient(timeout=timeout, limits=limits) as client:
# Create tasks for all test cases # Create tasks for all test cases
tasks = [ tasks = [
self.evaluate_single_case( self.evaluate_single_case(idx, test_case, semaphore, client)
idx,
test_case,
rag_semaphore,
eval_semaphore,
client,
progress_counter,
position_pool,
pbar_creation_lock,
)
for idx, test_case in enumerate(self.test_cases, 1) for idx, test_case in enumerate(self.test_cases, 1)
] ]
# Run all evaluations in parallel (limited by two-stage semaphores) # Run all evaluations in parallel (limited by semaphore)
results = await asyncio.gather(*tasks) results = await asyncio.gather(*tasks)
return list(results) return list(results)
@ -680,95 +425,6 @@ class RAGEvaluator:
return csv_path return csv_path
def _format_metric(self, value: float, width: int = 6) -> str:
"""
Format a metric value for display, handling NaN gracefully
Args:
value: The metric value to format
width: The width of the formatted string
Returns:
Formatted string (e.g., "0.8523" or " N/A ")
"""
if _is_nan(value):
return "N/A".center(width)
return f"{value:.4f}".rjust(width)
def _display_results_table(self, results: List[Dict[str, Any]]):
"""
Display evaluation results in a formatted table
Args:
results: List of evaluation results
"""
logger.info("")
logger.info("%s", "=" * 115)
logger.info("📊 EVALUATION RESULTS SUMMARY")
logger.info("%s", "=" * 115)
# Table header
logger.info(
"%-4s | %-50s | %6s | %7s | %6s | %7s | %6s | %6s",
"#",
"Question",
"Faith",
"AnswRel",
"CtxRec",
"CtxPrec",
"RAGAS",
"Status",
)
logger.info("%s", "-" * 115)
# Table rows
for result in results:
test_num = result.get("test_number", 0)
question = result.get("question", "")
# Truncate question to 50 chars
question_display = (
(question[:47] + "...") if len(question) > 50 else question
)
metrics = result.get("metrics", {})
if metrics:
# Success case - format each metric, handling NaN values
faith = metrics.get("faithfulness", 0)
ans_rel = metrics.get("answer_relevance", 0)
ctx_rec = metrics.get("context_recall", 0)
ctx_prec = metrics.get("context_precision", 0)
ragas = result.get("ragas_score", 0)
status = ""
logger.info(
"%-4d | %-50s | %s | %s | %s | %s | %s | %6s",
test_num,
question_display,
self._format_metric(faith, 6),
self._format_metric(ans_rel, 7),
self._format_metric(ctx_rec, 6),
self._format_metric(ctx_prec, 7),
self._format_metric(ragas, 6),
status,
)
else:
# Error case
error = result.get("error", "Unknown error")
error_display = (error[:20] + "...") if len(error) > 23 else error
logger.info(
"%-4d | %-50s | %6s | %7s | %6s | %7s | %6s | ✗ %s",
test_num,
question_display,
"N/A",
"N/A",
"N/A",
"N/A",
"N/A",
error_display,
)
logger.info("%s", "=" * 115)
def _calculate_benchmark_stats( def _calculate_benchmark_stats(
self, results: List[Dict[str, Any]] self, results: List[Dict[str, Any]]
) -> Dict[str, Any]: ) -> Dict[str, Any]:
@ -795,55 +451,45 @@ class RAGEvaluator:
"success_rate": 0.0, "success_rate": 0.0,
} }
# Calculate averages for each metric (handling NaN values correctly) # Calculate averages for each metric (handling NaN values)
# Track both sum and count for each metric to handle NaN values properly metrics_sum = {
metrics_data = { "faithfulness": 0.0,
"faithfulness": {"sum": 0.0, "count": 0}, "answer_relevance": 0.0,
"answer_relevance": {"sum": 0.0, "count": 0}, "context_recall": 0.0,
"context_recall": {"sum": 0.0, "count": 0}, "context_precision": 0.0,
"context_precision": {"sum": 0.0, "count": 0}, "ragas_score": 0.0,
"ragas_score": {"sum": 0.0, "count": 0},
} }
for result in valid_results: for result in valid_results:
metrics = result.get("metrics", {}) metrics = result.get("metrics", {})
# Skip NaN values when summing
# For each metric, sum non-NaN values and count them
faithfulness = metrics.get("faithfulness", 0) faithfulness = metrics.get("faithfulness", 0)
if not _is_nan(faithfulness): if not _is_nan(faithfulness):
metrics_data["faithfulness"]["sum"] += faithfulness metrics_sum["faithfulness"] += faithfulness
metrics_data["faithfulness"]["count"] += 1
answer_relevance = metrics.get("answer_relevance", 0) answer_relevance = metrics.get("answer_relevance", 0)
if not _is_nan(answer_relevance): if not _is_nan(answer_relevance):
metrics_data["answer_relevance"]["sum"] += answer_relevance metrics_sum["answer_relevance"] += answer_relevance
metrics_data["answer_relevance"]["count"] += 1
context_recall = metrics.get("context_recall", 0) context_recall = metrics.get("context_recall", 0)
if not _is_nan(context_recall): if not _is_nan(context_recall):
metrics_data["context_recall"]["sum"] += context_recall metrics_sum["context_recall"] += context_recall
metrics_data["context_recall"]["count"] += 1
context_precision = metrics.get("context_precision", 0) context_precision = metrics.get("context_precision", 0)
if not _is_nan(context_precision): if not _is_nan(context_precision):
metrics_data["context_precision"]["sum"] += context_precision metrics_sum["context_precision"] += context_precision
metrics_data["context_precision"]["count"] += 1
ragas_score = result.get("ragas_score", 0) ragas_score = result.get("ragas_score", 0)
if not _is_nan(ragas_score): if not _is_nan(ragas_score):
metrics_data["ragas_score"]["sum"] += ragas_score metrics_sum["ragas_score"] += ragas_score
metrics_data["ragas_score"]["count"] += 1
# Calculate averages using actual counts for each metric # Calculate averages
n = len(valid_results)
avg_metrics = {} avg_metrics = {}
for metric_name, data in metrics_data.items(): for k, v in metrics_sum.items():
if data["count"] > 0: avg_val = v / n if n > 0 else 0
avg_val = data["sum"] / data["count"] # Handle NaN in average
avg_metrics[metric_name] = ( avg_metrics[k] = round(avg_val, 4) if not _is_nan(avg_val) else 0.0
round(avg_val, 4) if not _is_nan(avg_val) else 0.0
)
else:
avg_metrics[metric_name] = 0.0
# Find min and max RAGAS scores (filter out NaN) # Find min and max RAGAS scores (filter out NaN)
ragas_scores = [] ragas_scores = []
@ -888,9 +534,6 @@ class RAGEvaluator:
"results": results, "results": results,
} }
# Display results table
self._display_results_table(results)
# Save JSON results # Save JSON results
json_path = ( json_path = (
self.results_dir self.results_dir
@ -898,9 +541,11 @@ class RAGEvaluator:
) )
with open(json_path, "w") as f: with open(json_path, "w") as f:
json.dump(summary, f, indent=2) json.dump(summary, f, indent=2)
logger.info("✅ JSON results saved to: %s", json_path)
# Export to CSV # Export to CSV
csv_path = self._export_to_csv(results) csv_path = self._export_to_csv(results)
logger.info("✅ CSV results saved to: %s", csv_path)
# Print summary # Print summary
logger.info("") logger.info("")
@ -925,7 +570,7 @@ class RAGEvaluator:
logger.info("Average Context Recall: %.4f", avg["context_recall"]) logger.info("Average Context Recall: %.4f", avg["context_recall"])
logger.info("Average Context Precision: %.4f", avg["context_precision"]) logger.info("Average Context Precision: %.4f", avg["context_precision"])
logger.info("Average RAGAS Score: %.4f", avg["ragas_score"]) logger.info("Average RAGAS Score: %.4f", avg["ragas_score"])
logger.info("%s", "-" * 70) logger.info("")
logger.info( logger.info(
"Min RAGAS Score: %.4f", "Min RAGAS Score: %.4f",
benchmark_stats["min_ragas_score"], benchmark_stats["min_ragas_score"],
@ -951,61 +596,28 @@ async def main():
""" """
Main entry point for RAGAS evaluation Main entry point for RAGAS evaluation
Command-line arguments:
--dataset, -d: Path to test dataset JSON file (default: sample_dataset.json)
--ragendpoint, -r: LightRAG API endpoint URL (default: http://localhost:9621 or $LIGHTRAG_API_URL)
Usage: Usage:
python lightrag/evaluation/eval_rag_quality.py python lightrag/evaluation/eval_rag_quality.py
python lightrag/evaluation/eval_rag_quality.py --dataset my_test.json python lightrag/evaluation/eval_rag_quality.py http://localhost:9621
python lightrag/evaluation/eval_rag_quality.py -d my_test.json -r http://localhost:9621 python lightrag/evaluation/eval_rag_quality.py http://your-server.com:9621
""" """
try: try:
# Parse command-line arguments # Get RAG API URL from command line or environment
parser = argparse.ArgumentParser( rag_api_url = None
description="RAGAS Evaluation Script for LightRAG System", if len(sys.argv) > 1:
formatter_class=argparse.RawDescriptionHelpFormatter, rag_api_url = sys.argv[1]
epilog="""
Examples:
# Use defaults
python lightrag/evaluation/eval_rag_quality.py
# Specify custom dataset
python lightrag/evaluation/eval_rag_quality.py --dataset my_test.json
# Specify custom RAG endpoint
python lightrag/evaluation/eval_rag_quality.py --ragendpoint http://my-server.com:9621
# Specify both
python lightrag/evaluation/eval_rag_quality.py -d my_test.json -r http://localhost:9621
""",
)
parser.add_argument(
"--dataset",
"-d",
type=str,
default=None,
help="Path to test dataset JSON file (default: sample_dataset.json in evaluation directory)",
)
parser.add_argument(
"--ragendpoint",
"-r",
type=str,
default=None,
help="LightRAG API endpoint URL (default: http://localhost:9621 or $LIGHTRAG_API_URL environment variable)",
)
args = parser.parse_args()
logger.info("")
logger.info("%s", "=" * 70) logger.info("%s", "=" * 70)
logger.info("🔍 RAGAS Evaluation - Using Real LightRAG API") logger.info("🔍 RAGAS Evaluation - Using Real LightRAG API")
logger.info("%s", "=" * 70) logger.info("%s", "=" * 70)
if rag_api_url:
logger.info("📡 RAG API URL: %s", rag_api_url)
else:
logger.info("📡 RAG API URL: http://localhost:9621 (default)")
logger.info("%s", "=" * 70)
evaluator = RAGEvaluator( evaluator = RAGEvaluator(rag_api_url=rag_api_url)
test_dataset_path=args.dataset, rag_api_url=args.ragendpoint
)
await evaluator.run() await evaluator.run()
except Exception as e: except Exception as e:
logger.exception("❌ Error: %s", e) logger.exception("❌ Error: %s", e)