LightRAG/tests/test_search_routes.py
clssck 082a5a8fad test(lightrag,api): add comprehensive test coverage and S3 support
Add extensive test suites for API routes and utilities:
- Implement test_search_routes.py (406 lines) for search endpoint validation
- Implement test_upload_routes.py (724 lines) for document upload workflows
- Implement test_s3_client.py (618 lines) for S3 storage operations
- Implement test_citation_utils.py (352 lines) for citation extraction
- Implement test_chunking.py (216 lines) for text chunking validation
Add S3 storage client implementation:
- Create lightrag/storage/s3_client.py with S3 operations
- Add storage module initialization with exports
- Integrate S3 client with document upload handling
Enhance API routes and core functionality:
- Add search_routes.py with full-text and graph search endpoints
- Add upload_routes.py with multipart document upload support
- Update operate.py with bulk operations and health checks
- Enhance postgres_impl.py with bulk upsert and parameterized queries
- Update lightrag_server.py to register new API routes
- Improve utils.py with citation and formatting utilities
Update dependencies and configuration:
- Add S3 and test dependencies to pyproject.toml
- Update docker-compose.test.yml for testing environment
- Sync uv.lock with new dependencies
Apply code quality improvements across all modified files:
- Add type hints to function signatures
- Update imports and router initialization
- Fix logging and error handling
2025-12-05 23:13:39 +01:00

406 lines
14 KiB
Python

"""Tests for search routes in lightrag/api/routers/search_routes.py.
This module tests the BM25 full-text search endpoint using httpx AsyncClient
and FastAPI's TestClient pattern with mocked PostgreSQLDB.
"""
import sys
from unittest.mock import AsyncMock, MagicMock
import pytest
# Mock the config module BEFORE importing search_routes to prevent
# argparse from trying to parse pytest arguments as server arguments
mock_global_args = MagicMock()
mock_global_args.token_secret = 'test-secret'
mock_global_args.jwt_secret_key = 'test-jwt-secret'
mock_global_args.jwt_algorithm = 'HS256'
mock_global_args.jwt_expire_hours = 24
mock_global_args.username = None
mock_global_args.password = None
mock_global_args.guest_token = None
# Pre-populate sys.modules with mocked config
mock_config_module = MagicMock()
mock_config_module.global_args = mock_global_args
sys.modules['lightrag.api.config'] = mock_config_module
# Also mock the auth module to prevent initialization issues
mock_auth_module = MagicMock()
mock_auth_module.auth_handler = MagicMock()
sys.modules['lightrag.api.auth'] = mock_auth_module
# Now import FastAPI components (after mocking)
from fastapi import FastAPI
from httpx import ASGITransport, AsyncClient
from pydantic import BaseModel, Field
from typing import Annotated, Any, ClassVar
# Import the components we need from search_routes without triggering full init
# We'll recreate the essential parts for testing
class SearchResult(BaseModel):
"""A single search result (chunk match)."""
id: str = Field(description='Chunk ID')
full_doc_id: str = Field(description='Parent document ID')
chunk_order_index: int = Field(description='Position in document')
tokens: int = Field(description='Token count')
content: str = Field(description='Chunk content')
file_path: str | None = Field(default=None, description='Source file path')
s3_key: str | None = Field(default=None, description='S3 key for source document')
char_start: int | None = Field(default=None, description='Character offset start')
char_end: int | None = Field(default=None, description='Character offset end')
score: float = Field(description='BM25 relevance score')
class SearchResponse(BaseModel):
"""Response model for search endpoint."""
query: str = Field(description='Original search query')
results: list[SearchResult] = Field(description='Matching chunks')
count: int = Field(description='Number of results returned')
workspace: str = Field(description='Workspace searched')
def create_test_search_routes(db: Any, api_key: str | None = None):
"""Create search routes for testing (simplified version without auth dep)."""
from fastapi import APIRouter, HTTPException, Query
router = APIRouter(prefix='/search', tags=['search'])
@router.get('', response_model=SearchResponse)
async def search(
q: Annotated[str, Query(description='Search query', min_length=1)],
limit: Annotated[int, Query(description='Max results', ge=1, le=100)] = 10,
workspace: Annotated[str, Query(description='Workspace')] = 'default',
) -> SearchResponse:
"""Perform BM25 full-text search on chunks."""
try:
results = await db.full_text_search(
query=q,
workspace=workspace,
limit=limit,
)
search_results = [
SearchResult(
id=r.get('id', ''),
full_doc_id=r.get('full_doc_id', ''),
chunk_order_index=r.get('chunk_order_index', 0),
tokens=r.get('tokens', 0),
content=r.get('content', ''),
file_path=r.get('file_path'),
s3_key=r.get('s3_key'),
char_start=r.get('char_start'),
char_end=r.get('char_end'),
score=float(r.get('score', 0)),
)
for r in results
]
return SearchResponse(
query=q,
results=search_results,
count=len(search_results),
workspace=workspace,
)
except Exception as e:
raise HTTPException(status_code=500, detail=f'Search failed: {e}') from e
return router
@pytest.fixture
def mock_db():
"""Create a mock PostgreSQLDB instance."""
db = MagicMock()
db.full_text_search = AsyncMock()
return db
@pytest.fixture
def app(mock_db):
"""Create FastAPI app with search routes."""
app = FastAPI()
router = create_test_search_routes(db=mock_db, api_key=None)
app.include_router(router)
return app
@pytest.fixture
async def client(app):
"""Create async HTTP client for testing."""
async with AsyncClient(
transport=ASGITransport(app=app),
base_url='http://test',
) as client:
yield client
# ============================================================================
# Search Endpoint Tests
# ============================================================================
@pytest.mark.offline
class TestSearchEndpoint:
"""Tests for GET /search endpoint."""
@pytest.mark.asyncio
async def test_search_returns_results(self, client, mock_db):
"""Test that search returns properly formatted results."""
mock_db.full_text_search.return_value = [
{
'id': 'chunk-1',
'full_doc_id': 'doc-1',
'chunk_order_index': 0,
'tokens': 100,
'content': 'This is test content about machine learning.',
'file_path': '/path/to/doc.pdf',
's3_key': 'archive/default/doc-1/doc.pdf',
'char_start': 0,
'char_end': 100,
'score': 0.85,
}
]
response = await client.get('/search', params={'q': 'machine learning'})
assert response.status_code == 200
data = response.json()
assert data['query'] == 'machine learning'
assert data['count'] == 1
assert data['workspace'] == 'default'
assert len(data['results']) == 1
result = data['results'][0]
assert result['id'] == 'chunk-1'
assert result['content'] == 'This is test content about machine learning.'
assert result['score'] == 0.85
@pytest.mark.asyncio
async def test_search_includes_char_positions(self, client, mock_db):
"""Test that char_start/char_end are included in results."""
mock_db.full_text_search.return_value = [
{
'id': 'chunk-1',
'full_doc_id': 'doc-1',
'chunk_order_index': 0,
'tokens': 50,
'content': 'Test content',
'char_start': 100,
'char_end': 200,
'score': 0.75,
}
]
response = await client.get('/search', params={'q': 'test'})
assert response.status_code == 200
result = response.json()['results'][0]
assert result['char_start'] == 100
assert result['char_end'] == 200
@pytest.mark.asyncio
async def test_search_includes_s3_key(self, client, mock_db):
"""Test that s3_key is included in results when present."""
mock_db.full_text_search.return_value = [
{
'id': 'chunk-1',
'full_doc_id': 'doc-1',
'chunk_order_index': 0,
'tokens': 50,
'content': 'Test content',
's3_key': 'archive/default/doc-1/report.pdf',
'file_path': 's3://bucket/archive/default/doc-1/report.pdf',
'score': 0.75,
}
]
response = await client.get('/search', params={'q': 'test'})
assert response.status_code == 200
result = response.json()['results'][0]
assert result['s3_key'] == 'archive/default/doc-1/report.pdf'
assert result['file_path'] == 's3://bucket/archive/default/doc-1/report.pdf'
@pytest.mark.asyncio
async def test_search_null_s3_key(self, client, mock_db):
"""Test that null s3_key is handled correctly."""
mock_db.full_text_search.return_value = [
{
'id': 'chunk-1',
'full_doc_id': 'doc-1',
'chunk_order_index': 0,
'tokens': 50,
'content': 'Test content',
's3_key': None,
'file_path': '/local/path/doc.pdf',
'score': 0.75,
}
]
response = await client.get('/search', params={'q': 'test'})
assert response.status_code == 200
result = response.json()['results'][0]
assert result['s3_key'] is None
assert result['file_path'] == '/local/path/doc.pdf'
@pytest.mark.asyncio
async def test_search_empty_results(self, client, mock_db):
"""Test that empty results are returned correctly."""
mock_db.full_text_search.return_value = []
response = await client.get('/search', params={'q': 'nonexistent'})
assert response.status_code == 200
data = response.json()
assert data['count'] == 0
assert data['results'] == []
@pytest.mark.asyncio
async def test_search_limit_parameter(self, client, mock_db):
"""Test that limit parameter is passed to database."""
mock_db.full_text_search.return_value = []
await client.get('/search', params={'q': 'test', 'limit': 25})
mock_db.full_text_search.assert_called_once_with(
query='test',
workspace='default',
limit=25,
)
@pytest.mark.asyncio
async def test_search_workspace_parameter(self, client, mock_db):
"""Test that workspace parameter is passed to database."""
mock_db.full_text_search.return_value = []
await client.get('/search', params={'q': 'test', 'workspace': 'custom'})
mock_db.full_text_search.assert_called_once_with(
query='test',
workspace='custom',
limit=10, # default
)
@pytest.mark.asyncio
async def test_search_multiple_results(self, client, mock_db):
"""Test that multiple results are returned correctly."""
mock_db.full_text_search.return_value = [
{
'id': f'chunk-{i}',
'full_doc_id': f'doc-{i}',
'chunk_order_index': i,
'tokens': 100,
'content': f'Content {i}',
'score': 0.9 - i * 0.1,
}
for i in range(5)
]
response = await client.get('/search', params={'q': 'test', 'limit': 5})
assert response.status_code == 200
data = response.json()
assert data['count'] == 5
assert len(data['results']) == 5
# Results should be in order by score (as returned by db)
assert data['results'][0]['id'] == 'chunk-0'
assert data['results'][4]['id'] == 'chunk-4'
# ============================================================================
# Validation Tests
# ============================================================================
@pytest.mark.offline
class TestSearchValidation:
"""Tests for search endpoint validation."""
@pytest.mark.asyncio
async def test_search_empty_query_rejected(self, client, mock_db):
"""Test that empty query string is rejected with 422."""
response = await client.get('/search', params={'q': ''})
assert response.status_code == 422
@pytest.mark.asyncio
async def test_search_missing_query_rejected(self, client, mock_db):
"""Test that missing query parameter is rejected with 422."""
response = await client.get('/search')
assert response.status_code == 422
@pytest.mark.asyncio
async def test_search_limit_too_low_rejected(self, client, mock_db):
"""Test that limit < 1 is rejected."""
response = await client.get('/search', params={'q': 'test', 'limit': 0})
assert response.status_code == 422
@pytest.mark.asyncio
async def test_search_limit_too_high_rejected(self, client, mock_db):
"""Test that limit > 100 is rejected."""
response = await client.get('/search', params={'q': 'test', 'limit': 101})
assert response.status_code == 422
@pytest.mark.asyncio
async def test_search_limit_boundary_valid(self, client, mock_db):
"""Test that limit at boundaries (1 and 100) are accepted."""
mock_db.full_text_search.return_value = []
# Test lower boundary
response = await client.get('/search', params={'q': 'test', 'limit': 1})
assert response.status_code == 200
# Test upper boundary
response = await client.get('/search', params={'q': 'test', 'limit': 100})
assert response.status_code == 200
# ============================================================================
# Error Handling Tests
# ============================================================================
@pytest.mark.offline
class TestSearchErrors:
"""Tests for search endpoint error handling."""
@pytest.mark.asyncio
async def test_search_database_error(self, client, mock_db):
"""Test that database errors return 500."""
mock_db.full_text_search.side_effect = Exception('Database connection failed')
response = await client.get('/search', params={'q': 'test'})
assert response.status_code == 500
assert 'Database connection failed' in response.json()['detail']
@pytest.mark.asyncio
async def test_search_handles_missing_fields(self, client, mock_db):
"""Test that missing fields in db results are handled with defaults."""
mock_db.full_text_search.return_value = [
{
'id': 'chunk-1',
'full_doc_id': 'doc-1',
# Missing many fields
}
]
response = await client.get('/search', params={'q': 'test'})
assert response.status_code == 200
result = response.json()['results'][0]
# Should have defaults
assert result['chunk_order_index'] == 0
assert result['tokens'] == 0
assert result['content'] == ''
assert result['score'] == 0.0