LightRAG/tests/test_search_routes.py
clssck 59e89772de refactor: consolidate to PostgreSQL-only backend and modernize stack
Remove legacy storage implementations and deprecated examples:
- Delete FAISS, JSON, Memgraph, Milvus, MongoDB, Nano Vector DB, Neo4j, NetworkX, Qdrant, Redis storage backends
- Remove Kubernetes deployment manifests and installation scripts
- Delete unofficial examples for deprecated backends and offline deployment docs
Streamline core infrastructure:
- Consolidate storage layer to PostgreSQL-only implementation
- Add full-text search caching with FTS cache module
- Implement metrics collection and monitoring pipeline
- Add explain and metrics API routes
Modernize frontend and tooling:
- Switch web UI to Bun with bun.lock, remove npm and pnpm lockfiles
- Update Dockerfile for PostgreSQL-only deployment
- Add Makefile for common development tasks
- Update environment and configuration examples
Enhance evaluation and testing capabilities:
- Add prompt optimization with DSPy and auto-tuning
- Implement ground truth regeneration and variant testing
- Add prompt debugging and response comparison utilities
- Expand test coverage with new integration scenarios
Simplify dependencies and configuration:
- Remove offline-specific requirement files
- Update pyproject.toml with streamlined dependencies
- Add Python version pinning with .python-version
- Create project guidelines in CLAUDE.md and AGENTS.md
2025-12-12 16:28:49 +01:00

407 lines
14 KiB
Python

"""Tests for search routes in lightrag/api/routers/search_routes.py.
This module tests the BM25 full-text search endpoint using httpx AsyncClient
and FastAPI's TestClient pattern with mocked PostgreSQLDB.
"""
import sys
from unittest.mock import AsyncMock, MagicMock
import pytest
# Mock the config module BEFORE importing search_routes to prevent
# argparse from trying to parse pytest arguments as server arguments
mock_global_args = MagicMock()
mock_global_args.token_secret = 'test-secret'
mock_global_args.jwt_secret_key = 'test-jwt-secret'
mock_global_args.jwt_algorithm = 'HS256'
mock_global_args.jwt_expire_hours = 24
mock_global_args.username = None
mock_global_args.password = None
mock_global_args.guest_token = None
# Pre-populate sys.modules with mocked config
mock_config_module = MagicMock()
mock_config_module.global_args = mock_global_args
sys.modules['lightrag.api.config'] = mock_config_module
# Also mock the auth module to prevent initialization issues
mock_auth_module = MagicMock()
mock_auth_module.auth_handler = MagicMock()
sys.modules['lightrag.api.auth'] = mock_auth_module
# Now import FastAPI components (after mocking)
from typing import Annotated, Any
from fastapi import FastAPI
from httpx import ASGITransport, AsyncClient
from pydantic import BaseModel, Field
# Import the components we need from search_routes without triggering full init
# We'll recreate the essential parts for testing
class SearchResult(BaseModel):
"""A single search result (chunk match)."""
id: str = Field(description='Chunk ID')
full_doc_id: str = Field(description='Parent document ID')
chunk_order_index: int = Field(description='Position in document')
tokens: int = Field(description='Token count')
content: str = Field(description='Chunk content')
file_path: str | None = Field(default=None, description='Source file path')
s3_key: str | None = Field(default=None, description='S3 key for source document')
char_start: int | None = Field(default=None, description='Character offset start')
char_end: int | None = Field(default=None, description='Character offset end')
score: float = Field(description='BM25 relevance score')
class SearchResponse(BaseModel):
"""Response model for search endpoint."""
query: str = Field(description='Original search query')
results: list[SearchResult] = Field(description='Matching chunks')
count: int = Field(description='Number of results returned')
workspace: str = Field(description='Workspace searched')
def create_test_search_routes(db: Any, api_key: str | None = None):
"""Create search routes for testing (simplified version without auth dep)."""
from fastapi import APIRouter, HTTPException, Query
router = APIRouter(prefix='/search', tags=['search'])
@router.get('', response_model=SearchResponse)
async def search(
q: Annotated[str, Query(description='Search query', min_length=1)],
limit: Annotated[int, Query(description='Max results', ge=1, le=100)] = 10,
workspace: Annotated[str, Query(description='Workspace')] = 'default',
) -> SearchResponse:
"""Perform BM25 full-text search on chunks."""
try:
results = await db.full_text_search(
query=q,
workspace=workspace,
limit=limit,
)
search_results = [
SearchResult(
id=r.get('id', ''),
full_doc_id=r.get('full_doc_id', ''),
chunk_order_index=r.get('chunk_order_index', 0),
tokens=r.get('tokens', 0),
content=r.get('content', ''),
file_path=r.get('file_path'),
s3_key=r.get('s3_key'),
char_start=r.get('char_start'),
char_end=r.get('char_end'),
score=float(r.get('score', 0)),
)
for r in results
]
return SearchResponse(
query=q,
results=search_results,
count=len(search_results),
workspace=workspace,
)
except Exception as e:
raise HTTPException(status_code=500, detail=f'Search failed: {e}') from e
return router
@pytest.fixture
def mock_db():
"""Create a mock PostgreSQLDB instance."""
db = MagicMock()
db.full_text_search = AsyncMock()
return db
@pytest.fixture
def app(mock_db):
"""Create FastAPI app with search routes."""
app = FastAPI()
router = create_test_search_routes(db=mock_db, api_key=None)
app.include_router(router)
return app
@pytest.fixture
async def client(app):
"""Create async HTTP client for testing."""
async with AsyncClient(
transport=ASGITransport(app=app),
base_url='http://test',
) as client:
yield client
# ============================================================================
# Search Endpoint Tests
# ============================================================================
@pytest.mark.offline
class TestSearchEndpoint:
"""Tests for GET /search endpoint."""
@pytest.mark.asyncio
async def test_search_returns_results(self, client, mock_db):
"""Test that search returns properly formatted results."""
mock_db.full_text_search.return_value = [
{
'id': 'chunk-1',
'full_doc_id': 'doc-1',
'chunk_order_index': 0,
'tokens': 100,
'content': 'This is test content about machine learning.',
'file_path': '/path/to/doc.pdf',
's3_key': 'archive/default/doc-1/doc.pdf',
'char_start': 0,
'char_end': 100,
'score': 0.85,
}
]
response = await client.get('/search', params={'q': 'machine learning'})
assert response.status_code == 200
data = response.json()
assert data['query'] == 'machine learning'
assert data['count'] == 1
assert data['workspace'] == 'default'
assert len(data['results']) == 1
result = data['results'][0]
assert result['id'] == 'chunk-1'
assert result['content'] == 'This is test content about machine learning.'
assert result['score'] == 0.85
@pytest.mark.asyncio
async def test_search_includes_char_positions(self, client, mock_db):
"""Test that char_start/char_end are included in results."""
mock_db.full_text_search.return_value = [
{
'id': 'chunk-1',
'full_doc_id': 'doc-1',
'chunk_order_index': 0,
'tokens': 50,
'content': 'Test content',
'char_start': 100,
'char_end': 200,
'score': 0.75,
}
]
response = await client.get('/search', params={'q': 'test'})
assert response.status_code == 200
result = response.json()['results'][0]
assert result['char_start'] == 100
assert result['char_end'] == 200
@pytest.mark.asyncio
async def test_search_includes_s3_key(self, client, mock_db):
"""Test that s3_key is included in results when present."""
mock_db.full_text_search.return_value = [
{
'id': 'chunk-1',
'full_doc_id': 'doc-1',
'chunk_order_index': 0,
'tokens': 50,
'content': 'Test content',
's3_key': 'archive/default/doc-1/report.pdf',
'file_path': 's3://bucket/archive/default/doc-1/report.pdf',
'score': 0.75,
}
]
response = await client.get('/search', params={'q': 'test'})
assert response.status_code == 200
result = response.json()['results'][0]
assert result['s3_key'] == 'archive/default/doc-1/report.pdf'
assert result['file_path'] == 's3://bucket/archive/default/doc-1/report.pdf'
@pytest.mark.asyncio
async def test_search_null_s3_key(self, client, mock_db):
"""Test that null s3_key is handled correctly."""
mock_db.full_text_search.return_value = [
{
'id': 'chunk-1',
'full_doc_id': 'doc-1',
'chunk_order_index': 0,
'tokens': 50,
'content': 'Test content',
's3_key': None,
'file_path': '/local/path/doc.pdf',
'score': 0.75,
}
]
response = await client.get('/search', params={'q': 'test'})
assert response.status_code == 200
result = response.json()['results'][0]
assert result['s3_key'] is None
assert result['file_path'] == '/local/path/doc.pdf'
@pytest.mark.asyncio
async def test_search_empty_results(self, client, mock_db):
"""Test that empty results are returned correctly."""
mock_db.full_text_search.return_value = []
response = await client.get('/search', params={'q': 'nonexistent'})
assert response.status_code == 200
data = response.json()
assert data['count'] == 0
assert data['results'] == []
@pytest.mark.asyncio
async def test_search_limit_parameter(self, client, mock_db):
"""Test that limit parameter is passed to database."""
mock_db.full_text_search.return_value = []
await client.get('/search', params={'q': 'test', 'limit': 25})
mock_db.full_text_search.assert_called_once_with(
query='test',
workspace='default',
limit=25,
)
@pytest.mark.asyncio
async def test_search_workspace_parameter(self, client, mock_db):
"""Test that workspace parameter is passed to database."""
mock_db.full_text_search.return_value = []
await client.get('/search', params={'q': 'test', 'workspace': 'custom'})
mock_db.full_text_search.assert_called_once_with(
query='test',
workspace='custom',
limit=10, # default
)
@pytest.mark.asyncio
async def test_search_multiple_results(self, client, mock_db):
"""Test that multiple results are returned correctly."""
mock_db.full_text_search.return_value = [
{
'id': f'chunk-{i}',
'full_doc_id': f'doc-{i}',
'chunk_order_index': i,
'tokens': 100,
'content': f'Content {i}',
'score': 0.9 - i * 0.1,
}
for i in range(5)
]
response = await client.get('/search', params={'q': 'test', 'limit': 5})
assert response.status_code == 200
data = response.json()
assert data['count'] == 5
assert len(data['results']) == 5
# Results should be in order by score (as returned by db)
assert data['results'][0]['id'] == 'chunk-0'
assert data['results'][4]['id'] == 'chunk-4'
# ============================================================================
# Validation Tests
# ============================================================================
@pytest.mark.offline
class TestSearchValidation:
"""Tests for search endpoint validation."""
@pytest.mark.asyncio
async def test_search_empty_query_rejected(self, client, mock_db):
"""Test that empty query string is rejected with 422."""
response = await client.get('/search', params={'q': ''})
assert response.status_code == 422
@pytest.mark.asyncio
async def test_search_missing_query_rejected(self, client, mock_db):
"""Test that missing query parameter is rejected with 422."""
response = await client.get('/search')
assert response.status_code == 422
@pytest.mark.asyncio
async def test_search_limit_too_low_rejected(self, client, mock_db):
"""Test that limit < 1 is rejected."""
response = await client.get('/search', params={'q': 'test', 'limit': 0})
assert response.status_code == 422
@pytest.mark.asyncio
async def test_search_limit_too_high_rejected(self, client, mock_db):
"""Test that limit > 100 is rejected."""
response = await client.get('/search', params={'q': 'test', 'limit': 101})
assert response.status_code == 422
@pytest.mark.asyncio
async def test_search_limit_boundary_valid(self, client, mock_db):
"""Test that limit at boundaries (1 and 100) are accepted."""
mock_db.full_text_search.return_value = []
# Test lower boundary
response = await client.get('/search', params={'q': 'test', 'limit': 1})
assert response.status_code == 200
# Test upper boundary
response = await client.get('/search', params={'q': 'test', 'limit': 100})
assert response.status_code == 200
# ============================================================================
# Error Handling Tests
# ============================================================================
@pytest.mark.offline
class TestSearchErrors:
"""Tests for search endpoint error handling."""
@pytest.mark.asyncio
async def test_search_database_error(self, client, mock_db):
"""Test that database errors return 500."""
mock_db.full_text_search.side_effect = Exception('Database connection failed')
response = await client.get('/search', params={'q': 'test'})
assert response.status_code == 500
assert 'Database connection failed' in response.json()['detail']
@pytest.mark.asyncio
async def test_search_handles_missing_fields(self, client, mock_db):
"""Test that missing fields in db results are handled with defaults."""
mock_db.full_text_search.return_value = [
{
'id': 'chunk-1',
'full_doc_id': 'doc-1',
# Missing many fields
}
]
response = await client.get('/search', params={'q': 'test'})
assert response.status_code == 200
result = response.json()['results'][0]
# Should have defaults
assert result['chunk_order_index'] == 0
assert result['tokens'] == 0
assert result['content'] == ''
assert result['score'] == 0.0