LightRAG/tests/test_metrics.py
clssck 59e89772de refactor: consolidate to PostgreSQL-only backend and modernize stack
Remove legacy storage implementations and deprecated examples:
- Delete FAISS, JSON, Memgraph, Milvus, MongoDB, Nano Vector DB, Neo4j, NetworkX, Qdrant, Redis storage backends
- Remove Kubernetes deployment manifests and installation scripts
- Delete unofficial examples for deprecated backends and offline deployment docs
Streamline core infrastructure:
- Consolidate storage layer to PostgreSQL-only implementation
- Add full-text search caching with FTS cache module
- Implement metrics collection and monitoring pipeline
- Add explain and metrics API routes
Modernize frontend and tooling:
- Switch web UI to Bun with bun.lock, remove npm and pnpm lockfiles
- Update Dockerfile for PostgreSQL-only deployment
- Add Makefile for common development tasks
- Update environment and configuration examples
Enhance evaluation and testing capabilities:
- Add prompt optimization with DSPy and auto-tuning
- Implement ground truth regeneration and variant testing
- Add prompt debugging and response comparison utilities
- Expand test coverage with new integration scenarios
Simplify dependencies and configuration:
- Remove offline-specific requirement files
- Update pyproject.toml with streamlined dependencies
- Add Python version pinning with .python-version
- Create project guidelines in CLAUDE.md and AGENTS.md
2025-12-12 16:28:49 +01:00

315 lines
9.5 KiB
Python

"""
Unit tests for metrics collection infrastructure.
"""
import time
import pytest
from lightrag.metrics import (
MetricsCollector,
QueryMetric,
get_metrics_collector,
record_query_metric,
)
class TestMetricsCollector:
"""Test MetricsCollector class."""
def test_initialization(self):
"""Collector initializes with empty history."""
collector = MetricsCollector()
assert collector.total_queries == 0
assert collector.total_llm_calls == 0
assert collector.total_llm_cache_hits == 0
assert collector.total_embed_calls == 0
assert len(collector.query_history) == 0
def test_record_query(self):
"""Query metrics are recorded correctly."""
collector = MetricsCollector()
metric = QueryMetric(
timestamp=time.time(),
duration_ms=100.0,
mode='mix',
cache_hit=False,
entities_count=10,
relations_count=20,
chunks_count=5,
tokens_used=1000,
)
collector.record_query(metric)
assert collector.total_queries == 1
assert collector.total_llm_calls == 1
assert len(collector.query_history) == 1
def test_record_query_cache_hit(self):
"""Cache hits are recorded correctly."""
collector = MetricsCollector()
metric = QueryMetric(
timestamp=time.time(),
duration_ms=50.0,
mode='mix',
cache_hit=True,
entities_count=10,
relations_count=20,
chunks_count=5,
tokens_used=1000,
)
collector.record_query(metric)
assert collector.total_queries == 1
assert collector.total_llm_cache_hits == 1
assert collector.total_llm_calls == 0
def test_record_llm_call(self):
"""LLM calls are tracked."""
collector = MetricsCollector()
collector.record_llm_call(cache_hit=False)
assert collector.total_llm_calls == 1
assert collector.total_llm_cache_hits == 0
collector.record_llm_call(cache_hit=True)
assert collector.total_llm_calls == 1
assert collector.total_llm_cache_hits == 1
def test_record_embed_call(self):
"""Embedding calls are tracked."""
collector = MetricsCollector()
collector.record_embed_call()
collector.record_embed_call()
assert collector.total_embed_calls == 2
class TestPercentileCalculation:
"""Test percentile calculation."""
def test_empty_values(self):
"""Empty values return 0.0."""
collector = MetricsCollector()
assert collector._get_percentile([], 50) == 0.0
def test_single_value(self):
"""Single value returns that value for any percentile."""
collector = MetricsCollector()
assert collector._get_percentile([100.0], 50) == 100.0
assert collector._get_percentile([100.0], 99) == 100.0
def test_p50_median(self):
"""P50 returns median."""
collector = MetricsCollector()
# Sorted: [10, 20, 30, 40, 50]
values = [30, 10, 50, 20, 40]
p50 = collector._get_percentile(values, 50)
assert p50 == 30.0
def test_p99_high(self):
"""P99 returns high value."""
collector = MetricsCollector()
values: list[float] = [float(x) for x in range(1, 101)] # 1.0-100.0
p99 = collector._get_percentile(values, 99)
assert p99 >= 99.0
class TestComputeStats:
"""Test statistics computation."""
def test_empty_stats(self):
"""Stats for empty collector."""
collector = MetricsCollector()
stats = collector.compute_stats()
assert stats['total_queries'] == 0
assert stats['queries_in_window'] == 0
assert stats['latency_percentiles'] is None
assert stats['cache_stats']['hit_rate'] == 0.0
def test_stats_with_queries(self):
"""Stats with recorded queries."""
collector = MetricsCollector()
# Record some queries
for i in range(5):
collector.record_query(
QueryMetric(
timestamp=time.time(),
duration_ms=100.0 + i * 10,
mode='mix',
cache_hit=False,
entities_count=10,
relations_count=20,
chunks_count=5,
tokens_used=1000,
)
)
stats = collector.compute_stats()
assert stats['total_queries'] == 5
assert stats['queries_in_window'] == 5
assert stats['latency_percentiles'] is not None
assert stats['latency_percentiles']['p50'] >= 100.0
def test_stats_window_filtering(self):
"""Stats filter by time window."""
collector = MetricsCollector()
# Record old query (outside window)
old_metric = QueryMetric(
timestamp=time.time() - 7200, # 2 hours ago
duration_ms=100.0,
mode='mix',
cache_hit=False,
entities_count=10,
relations_count=20,
chunks_count=5,
tokens_used=1000,
)
collector.record_query(old_metric)
# Record recent query
recent_metric = QueryMetric(
timestamp=time.time(),
duration_ms=200.0,
mode='mix',
cache_hit=False,
entities_count=10,
relations_count=20,
chunks_count=5,
tokens_used=1000,
)
collector.record_query(recent_metric)
stats = collector.compute_stats(window_seconds=3600) # 1 hour window
assert stats['total_queries'] == 2
assert stats['queries_in_window'] == 1 # Only recent one
def test_cache_hit_rate(self):
"""Cache hit rate calculation."""
collector = MetricsCollector()
# 2 cache hits, 3 cache misses
collector.total_llm_cache_hits = 2
collector.total_llm_calls = 3
stats = collector.compute_stats()
expected_rate = 2 / 5 # 2 hits out of 5 total
assert abs(stats['cache_stats']['hit_rate'] - expected_rate) < 0.01
def test_mode_distribution(self):
"""Mode distribution tracking."""
collector = MetricsCollector()
for mode, count in [('mix', 3), ('local', 2), ('global', 1)]:
for _ in range(count):
collector.record_query(
QueryMetric(
timestamp=time.time(),
duration_ms=100.0,
mode=mode,
cache_hit=False,
entities_count=10,
relations_count=20,
chunks_count=5,
tokens_used=1000,
)
)
stats = collector.compute_stats()
assert stats['mode_distribution']['mix'] == 3
assert stats['mode_distribution']['local'] == 2
assert stats['mode_distribution']['global'] == 1
class TestRecentQueries:
"""Test recent queries retrieval."""
def test_recent_queries_empty(self):
"""Empty collector returns empty list."""
collector = MetricsCollector()
recent = collector.get_recent_queries()
assert recent == []
def test_recent_queries_limited(self):
"""Limit parameter is respected."""
collector = MetricsCollector()
for i in range(10):
collector.record_query(
QueryMetric(
timestamp=time.time(),
duration_ms=100.0,
mode='mix',
cache_hit=False,
entities_count=i,
relations_count=20,
chunks_count=5,
tokens_used=1000,
)
)
recent = collector.get_recent_queries(limit=3)
assert len(recent) == 3
def test_recent_queries_order(self):
"""Most recent queries are first."""
collector = MetricsCollector()
for i in range(3):
collector.record_query(
QueryMetric(
timestamp=time.time() + i,
duration_ms=100.0 + i,
mode='mix',
cache_hit=False,
entities_count=i,
relations_count=20,
chunks_count=5,
tokens_used=1000,
)
)
recent = collector.get_recent_queries()
# Most recent (i=2) should be first
assert recent[0]['entities_count'] == 2
assert recent[1]['entities_count'] == 1
assert recent[2]['entities_count'] == 0
class TestGlobalMetricsCollector:
"""Test global metrics collector singleton."""
@pytest.mark.asyncio
async def test_get_collector_singleton(self):
"""Same collector returned on multiple calls."""
c1 = await get_metrics_collector()
c2 = await get_metrics_collector()
assert c1 is c2
@pytest.mark.asyncio
async def test_record_query_metric_helper(self):
"""Convenience function records metrics."""
collector = await get_metrics_collector()
initial_count = collector.total_queries
await record_query_metric(
duration_ms=100.0,
mode='mix',
cache_hit=False,
entities_count=10,
relations_count=20,
chunks_count=5,
tokens_used=1000,
)
assert collector.total_queries == initial_count + 1