Format entire codebase with ruff and add type hints across all modules: - Apply ruff formatting to all Python files (121 files, 17K insertions) - Add type hints to function signatures throughout lightrag core and API - Update test suite with improved type annotations and docstrings - Add pyrightconfig.json for static type checking configuration - Create prompt_optimized.py and test_extraction_prompt_ab.py test files - Update ruff.toml and .gitignore for improved linting configuration - Standardize code style across examples, reproduce scripts, and utilities
532 lines
21 KiB
Python
532 lines
21 KiB
Python
"""
|
|
Unit tests for rerank document chunking functionality.
|
|
|
|
Tests the chunk_documents_for_rerank and aggregate_chunk_scores functions
|
|
in lightrag/rerank.py to ensure proper document splitting and score aggregation.
|
|
"""
|
|
|
|
from unittest.mock import AsyncMock, Mock, patch
|
|
|
|
import pytest
|
|
|
|
from lightrag.rerank import (
|
|
aggregate_chunk_scores,
|
|
chunk_documents_for_rerank,
|
|
cohere_rerank,
|
|
)
|
|
|
|
|
|
class TestChunkDocumentsForRerank:
|
|
"""Test suite for chunk_documents_for_rerank function"""
|
|
|
|
def test_no_chunking_needed_for_short_docs(self):
|
|
"""Documents shorter than max_tokens should not be chunked"""
|
|
documents = [
|
|
'Short doc 1',
|
|
'Short doc 2',
|
|
'Short doc 3',
|
|
]
|
|
|
|
chunked_docs, doc_indices = chunk_documents_for_rerank(documents, max_tokens=100, overlap_tokens=10)
|
|
|
|
# No chunking should occur
|
|
assert len(chunked_docs) == 3
|
|
assert chunked_docs == documents
|
|
assert doc_indices == [0, 1, 2]
|
|
|
|
def test_chunking_with_character_fallback(self):
|
|
"""Test chunking falls back to character-based when tokenizer unavailable"""
|
|
# Create a very long document that exceeds character limit
|
|
long_doc = 'a' * 2000 # 2000 characters
|
|
documents = [long_doc, 'short doc']
|
|
|
|
with patch('lightrag.rerank.TiktokenTokenizer', side_effect=ImportError):
|
|
chunked_docs, doc_indices = chunk_documents_for_rerank(
|
|
documents,
|
|
max_tokens=100, # 100 tokens = ~400 chars
|
|
overlap_tokens=10, # 10 tokens = ~40 chars
|
|
)
|
|
|
|
# First doc should be split into chunks, second doc stays whole
|
|
assert len(chunked_docs) > 2 # At least one chunk from first doc + second doc
|
|
assert chunked_docs[-1] == 'short doc' # Last chunk is the short doc
|
|
# Verify doc_indices maps chunks to correct original document
|
|
assert doc_indices[-1] == 1 # Last chunk maps to document 1
|
|
|
|
def test_chunking_with_tiktoken_tokenizer(self):
|
|
"""Test chunking with actual tokenizer"""
|
|
# Create document with known token count
|
|
# Approximate: "word " = ~1 token, so 200 words ~ 200 tokens
|
|
long_doc = ' '.join([f'word{i}' for i in range(200)])
|
|
documents = [long_doc, 'short']
|
|
|
|
chunked_docs, doc_indices = chunk_documents_for_rerank(documents, max_tokens=50, overlap_tokens=10)
|
|
|
|
# Long doc should be split, short doc should remain
|
|
assert len(chunked_docs) > 2
|
|
assert doc_indices[-1] == 1 # Last chunk is from second document
|
|
|
|
# Verify overlapping chunks contain overlapping content
|
|
if len(chunked_docs) > 2:
|
|
# Check that consecutive chunks from same doc have some overlap
|
|
for i in range(len(doc_indices) - 1):
|
|
if doc_indices[i] == doc_indices[i + 1] == 0:
|
|
# Both chunks from first doc, should have overlap
|
|
chunk1_words = chunked_docs[i].split()
|
|
chunk2_words = chunked_docs[i + 1].split()
|
|
# At least one word should be common due to overlap
|
|
assert any(word in chunk2_words for word in chunk1_words[-5:])
|
|
|
|
def test_empty_documents(self):
|
|
"""Test handling of empty document list"""
|
|
documents = []
|
|
chunked_docs, doc_indices = chunk_documents_for_rerank(documents)
|
|
|
|
assert chunked_docs == []
|
|
assert doc_indices == []
|
|
|
|
def test_single_document_chunking(self):
|
|
"""Test chunking of a single long document"""
|
|
# Create document with ~100 tokens
|
|
long_doc = ' '.join([f'token{i}' for i in range(100)])
|
|
documents = [long_doc]
|
|
|
|
chunked_docs, doc_indices = chunk_documents_for_rerank(documents, max_tokens=30, overlap_tokens=5)
|
|
|
|
# Should create multiple chunks
|
|
assert len(chunked_docs) > 1
|
|
# All chunks should map to document 0
|
|
assert all(idx == 0 for idx in doc_indices)
|
|
|
|
|
|
class TestAggregateChunkScores:
|
|
"""Test suite for aggregate_chunk_scores function"""
|
|
|
|
def test_no_chunking_simple_aggregation(self):
|
|
"""Test aggregation when no chunking occurred (1:1 mapping)"""
|
|
chunk_results = [
|
|
{'index': 0, 'relevance_score': 0.9},
|
|
{'index': 1, 'relevance_score': 0.7},
|
|
{'index': 2, 'relevance_score': 0.5},
|
|
]
|
|
doc_indices = [0, 1, 2] # 1:1 mapping
|
|
num_original_docs = 3
|
|
|
|
aggregated = aggregate_chunk_scores(chunk_results, doc_indices, num_original_docs, aggregation='max')
|
|
|
|
# Results should be sorted by score
|
|
assert len(aggregated) == 3
|
|
assert aggregated[0]['index'] == 0
|
|
assert aggregated[0]['relevance_score'] == 0.9
|
|
assert aggregated[1]['index'] == 1
|
|
assert aggregated[1]['relevance_score'] == 0.7
|
|
assert aggregated[2]['index'] == 2
|
|
assert aggregated[2]['relevance_score'] == 0.5
|
|
|
|
def test_max_aggregation_with_chunks(self):
|
|
"""Test max aggregation strategy with multiple chunks per document"""
|
|
# 5 chunks: first 3 from doc 0, last 2 from doc 1
|
|
chunk_results = [
|
|
{'index': 0, 'relevance_score': 0.5},
|
|
{'index': 1, 'relevance_score': 0.8},
|
|
{'index': 2, 'relevance_score': 0.6},
|
|
{'index': 3, 'relevance_score': 0.7},
|
|
{'index': 4, 'relevance_score': 0.4},
|
|
]
|
|
doc_indices = [0, 0, 0, 1, 1]
|
|
num_original_docs = 2
|
|
|
|
aggregated = aggregate_chunk_scores(chunk_results, doc_indices, num_original_docs, aggregation='max')
|
|
|
|
# Should take max score for each document
|
|
assert len(aggregated) == 2
|
|
assert aggregated[0]['index'] == 0
|
|
assert aggregated[0]['relevance_score'] == 0.8 # max of 0.5, 0.8, 0.6
|
|
assert aggregated[1]['index'] == 1
|
|
assert aggregated[1]['relevance_score'] == 0.7 # max of 0.7, 0.4
|
|
|
|
def test_mean_aggregation_with_chunks(self):
|
|
"""Test mean aggregation strategy"""
|
|
chunk_results = [
|
|
{'index': 0, 'relevance_score': 0.6},
|
|
{'index': 1, 'relevance_score': 0.8},
|
|
{'index': 2, 'relevance_score': 0.4},
|
|
]
|
|
doc_indices = [0, 0, 1] # First two chunks from doc 0, last from doc 1
|
|
num_original_docs = 2
|
|
|
|
aggregated = aggregate_chunk_scores(chunk_results, doc_indices, num_original_docs, aggregation='mean')
|
|
|
|
assert len(aggregated) == 2
|
|
assert aggregated[0]['index'] == 0
|
|
assert aggregated[0]['relevance_score'] == pytest.approx(0.7) # (0.6 + 0.8) / 2
|
|
assert aggregated[1]['index'] == 1
|
|
assert aggregated[1]['relevance_score'] == 0.4
|
|
|
|
def test_first_aggregation_with_chunks(self):
|
|
"""Test first aggregation strategy"""
|
|
chunk_results = [
|
|
{'index': 0, 'relevance_score': 0.6},
|
|
{'index': 1, 'relevance_score': 0.8},
|
|
{'index': 2, 'relevance_score': 0.4},
|
|
]
|
|
doc_indices = [0, 0, 1]
|
|
num_original_docs = 2
|
|
|
|
aggregated = aggregate_chunk_scores(chunk_results, doc_indices, num_original_docs, aggregation='first')
|
|
|
|
assert len(aggregated) == 2
|
|
# First should use first score seen for each doc
|
|
assert aggregated[0]['index'] == 0
|
|
assert aggregated[0]['relevance_score'] == 0.6 # First score for doc 0
|
|
assert aggregated[1]['index'] == 1
|
|
assert aggregated[1]['relevance_score'] == 0.4
|
|
|
|
def test_empty_chunk_results(self):
|
|
"""Test handling of empty results"""
|
|
aggregated = aggregate_chunk_scores([], [], 3, aggregation='max')
|
|
assert aggregated == []
|
|
|
|
def test_documents_with_no_scores(self):
|
|
"""Test when some documents have no chunks/scores"""
|
|
chunk_results = [
|
|
{'index': 0, 'relevance_score': 0.9},
|
|
{'index': 1, 'relevance_score': 0.7},
|
|
]
|
|
doc_indices = [0, 0] # Both chunks from document 0
|
|
num_original_docs = 3 # But we have 3 documents total
|
|
|
|
aggregated = aggregate_chunk_scores(chunk_results, doc_indices, num_original_docs, aggregation='max')
|
|
|
|
# Only doc 0 should appear in results
|
|
assert len(aggregated) == 1
|
|
assert aggregated[0]['index'] == 0
|
|
|
|
def test_unknown_aggregation_strategy(self):
|
|
"""Test that unknown strategy falls back to max"""
|
|
chunk_results = [
|
|
{'index': 0, 'relevance_score': 0.6},
|
|
{'index': 1, 'relevance_score': 0.8},
|
|
]
|
|
doc_indices = [0, 0]
|
|
num_original_docs = 1
|
|
|
|
# Use invalid strategy
|
|
aggregated = aggregate_chunk_scores(chunk_results, doc_indices, num_original_docs, aggregation='invalid')
|
|
|
|
# Should fall back to max
|
|
assert aggregated[0]['relevance_score'] == 0.8
|
|
|
|
|
|
@pytest.mark.offline
|
|
class TestTopNWithChunking:
|
|
"""Tests for top_n behavior when chunking is enabled (Bug fix verification)"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_top_n_limits_documents_not_chunks(self):
|
|
"""
|
|
Test that top_n correctly limits documents (not chunks) when chunking is enabled.
|
|
|
|
Bug scenario: 10 docs expand to 50 chunks. With old behavior, top_n=5 would
|
|
return scores for only 5 chunks (possibly all from 1-2 docs). After aggregation,
|
|
fewer than 5 documents would be returned.
|
|
|
|
Fixed behavior: top_n=5 should return exactly 5 documents after aggregation.
|
|
"""
|
|
# Setup: 5 documents, each producing multiple chunks when chunked
|
|
# Using small max_tokens to force chunking
|
|
long_docs = [' '.join([f'doc{i}_word{j}' for j in range(50)]) for i in range(5)]
|
|
query = 'test query'
|
|
|
|
# First, determine how many chunks will be created by actual chunking
|
|
_, doc_indices = chunk_documents_for_rerank(long_docs, max_tokens=50, overlap_tokens=10)
|
|
num_chunks = len(doc_indices)
|
|
|
|
# Mock API returns scores for ALL chunks (simulating disabled API-level top_n)
|
|
# Give different scores to ensure doc 0 gets highest, doc 1 second, etc.
|
|
# Assign scores based on original document index (lower doc index = higher score)
|
|
mock_chunk_scores = []
|
|
for i in range(num_chunks):
|
|
original_doc = doc_indices[i]
|
|
# Higher score for lower doc index, with small variation per chunk
|
|
base_score = 0.9 - (original_doc * 0.1)
|
|
mock_chunk_scores.append({'index': i, 'relevance_score': base_score})
|
|
|
|
mock_response = Mock()
|
|
mock_response.status = 200
|
|
mock_response.json = AsyncMock(return_value={'results': mock_chunk_scores})
|
|
mock_response.request_info = None
|
|
mock_response.history = None
|
|
mock_response.headers = {}
|
|
mock_response.__aenter__ = AsyncMock(return_value=mock_response)
|
|
mock_response.__aexit__ = AsyncMock(return_value=None)
|
|
|
|
mock_session = Mock()
|
|
mock_session.post = Mock(return_value=mock_response)
|
|
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
|
|
mock_session.__aexit__ = AsyncMock(return_value=None)
|
|
|
|
with patch('lightrag.rerank.aiohttp.ClientSession', return_value=mock_session):
|
|
result = await cohere_rerank(
|
|
query=query,
|
|
documents=long_docs,
|
|
api_key='test-key',
|
|
base_url='http://test.com/rerank',
|
|
enable_chunking=True,
|
|
max_tokens_per_doc=50, # Match chunking above
|
|
top_n=3, # Request top 3 documents
|
|
)
|
|
|
|
# Verify: should get exactly 3 documents (not unlimited chunks)
|
|
assert len(result) == 3
|
|
# All results should have valid document indices (0-4)
|
|
assert all(0 <= r['index'] < 5 for r in result)
|
|
# Results should be sorted by score (descending)
|
|
assert all(result[i]['relevance_score'] >= result[i + 1]['relevance_score'] for i in range(len(result) - 1))
|
|
# The top 3 docs should be 0, 1, 2 (highest scores)
|
|
result_indices = [r['index'] for r in result]
|
|
assert set(result_indices) == {0, 1, 2}
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_api_receives_no_top_n_when_chunking_enabled(self):
|
|
"""
|
|
Test that the API request does NOT include top_n when chunking is enabled.
|
|
|
|
This ensures all chunk scores are retrieved for proper aggregation.
|
|
"""
|
|
documents = [' '.join([f'word{i}' for i in range(100)]), 'short doc']
|
|
query = 'test query'
|
|
|
|
captured_payload = {}
|
|
|
|
mock_response = Mock()
|
|
mock_response.status = 200
|
|
mock_response.json = AsyncMock(
|
|
return_value={
|
|
'results': [
|
|
{'index': 0, 'relevance_score': 0.9},
|
|
{'index': 1, 'relevance_score': 0.8},
|
|
{'index': 2, 'relevance_score': 0.7},
|
|
]
|
|
}
|
|
)
|
|
mock_response.request_info = None
|
|
mock_response.history = None
|
|
mock_response.headers = {}
|
|
mock_response.__aenter__ = AsyncMock(return_value=mock_response)
|
|
mock_response.__aexit__ = AsyncMock(return_value=None)
|
|
|
|
def capture_post(*args, **kwargs):
|
|
captured_payload.update(kwargs.get('json', {}))
|
|
return mock_response
|
|
|
|
mock_session = Mock()
|
|
mock_session.post = Mock(side_effect=capture_post)
|
|
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
|
|
mock_session.__aexit__ = AsyncMock(return_value=None)
|
|
|
|
with patch('lightrag.rerank.aiohttp.ClientSession', return_value=mock_session):
|
|
await cohere_rerank(
|
|
query=query,
|
|
documents=documents,
|
|
api_key='test-key',
|
|
base_url='http://test.com/rerank',
|
|
enable_chunking=True,
|
|
max_tokens_per_doc=30,
|
|
top_n=1, # User wants top 1 document
|
|
)
|
|
|
|
# Verify: API payload should NOT have top_n (disabled for chunking)
|
|
assert 'top_n' not in captured_payload
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_top_n_not_modified_when_chunking_disabled(self):
|
|
"""
|
|
Test that top_n is passed through to API when chunking is disabled.
|
|
"""
|
|
documents = ['doc1', 'doc2']
|
|
query = 'test query'
|
|
|
|
captured_payload = {}
|
|
|
|
mock_response = Mock()
|
|
mock_response.status = 200
|
|
mock_response.json = AsyncMock(
|
|
return_value={
|
|
'results': [
|
|
{'index': 0, 'relevance_score': 0.9},
|
|
]
|
|
}
|
|
)
|
|
mock_response.request_info = None
|
|
mock_response.history = None
|
|
mock_response.headers = {}
|
|
mock_response.__aenter__ = AsyncMock(return_value=mock_response)
|
|
mock_response.__aexit__ = AsyncMock(return_value=None)
|
|
|
|
def capture_post(*args, **kwargs):
|
|
captured_payload.update(kwargs.get('json', {}))
|
|
return mock_response
|
|
|
|
mock_session = Mock()
|
|
mock_session.post = Mock(side_effect=capture_post)
|
|
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
|
|
mock_session.__aexit__ = AsyncMock(return_value=None)
|
|
|
|
with patch('lightrag.rerank.aiohttp.ClientSession', return_value=mock_session):
|
|
await cohere_rerank(
|
|
query=query,
|
|
documents=documents,
|
|
api_key='test-key',
|
|
base_url='http://test.com/rerank',
|
|
enable_chunking=False, # Chunking disabled
|
|
top_n=1,
|
|
)
|
|
|
|
# Verify: API payload should have top_n when chunking is disabled
|
|
assert captured_payload.get('top_n') == 1
|
|
|
|
|
|
@pytest.mark.offline
|
|
class TestCohereRerankChunking:
|
|
"""Integration tests for cohere_rerank with chunking enabled"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cohere_rerank_with_chunking_disabled(self):
|
|
"""Test that chunking can be disabled"""
|
|
documents = ['doc1', 'doc2']
|
|
query = 'test query'
|
|
|
|
# Mock the generic_rerank_api
|
|
with patch('lightrag.rerank.generic_rerank_api', new_callable=AsyncMock) as mock_api:
|
|
mock_api.return_value = [
|
|
{'index': 0, 'relevance_score': 0.9},
|
|
{'index': 1, 'relevance_score': 0.7},
|
|
]
|
|
|
|
result = await cohere_rerank(
|
|
query=query,
|
|
documents=documents,
|
|
api_key='test-key',
|
|
enable_chunking=False,
|
|
max_tokens_per_doc=100,
|
|
)
|
|
|
|
# Verify generic_rerank_api was called with correct parameters
|
|
mock_api.assert_called_once()
|
|
call_kwargs = mock_api.call_args[1]
|
|
assert call_kwargs['enable_chunking'] is False
|
|
assert call_kwargs['max_tokens_per_doc'] == 100
|
|
# Result should mirror mocked scores
|
|
assert len(result) == 2
|
|
assert result[0]['index'] == 0
|
|
assert result[0]['relevance_score'] == 0.9
|
|
assert result[1]['index'] == 1
|
|
assert result[1]['relevance_score'] == 0.7
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cohere_rerank_with_chunking_enabled(self):
|
|
"""Test that chunking parameters are passed through"""
|
|
documents = ['doc1', 'doc2']
|
|
query = 'test query'
|
|
|
|
with patch('lightrag.rerank.generic_rerank_api', new_callable=AsyncMock) as mock_api:
|
|
mock_api.return_value = [
|
|
{'index': 0, 'relevance_score': 0.9},
|
|
{'index': 1, 'relevance_score': 0.7},
|
|
]
|
|
|
|
result = await cohere_rerank(
|
|
query=query,
|
|
documents=documents,
|
|
api_key='test-key',
|
|
enable_chunking=True,
|
|
max_tokens_per_doc=480,
|
|
)
|
|
|
|
# Verify parameters were passed
|
|
call_kwargs = mock_api.call_args[1]
|
|
assert call_kwargs['enable_chunking'] is True
|
|
assert call_kwargs['max_tokens_per_doc'] == 480
|
|
# Result should mirror mocked scores
|
|
assert len(result) == 2
|
|
assert result[0]['index'] == 0
|
|
assert result[0]['relevance_score'] == 0.9
|
|
assert result[1]['index'] == 1
|
|
assert result[1]['relevance_score'] == 0.7
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cohere_rerank_default_parameters(self):
|
|
"""Test default parameter values for cohere_rerank"""
|
|
documents = ['doc1']
|
|
query = 'test'
|
|
|
|
with patch('lightrag.rerank.generic_rerank_api', new_callable=AsyncMock) as mock_api:
|
|
mock_api.return_value = [{'index': 0, 'relevance_score': 0.9}]
|
|
|
|
result = await cohere_rerank(query=query, documents=documents, api_key='test-key')
|
|
|
|
# Verify default values
|
|
call_kwargs = mock_api.call_args[1]
|
|
assert call_kwargs['enable_chunking'] is False
|
|
assert call_kwargs['max_tokens_per_doc'] == 4096
|
|
assert call_kwargs['model'] == 'rerank-v3.5'
|
|
# Result should mirror mocked scores
|
|
assert len(result) == 1
|
|
assert result[0]['index'] == 0
|
|
assert result[0]['relevance_score'] == 0.9
|
|
|
|
|
|
@pytest.mark.offline
|
|
class TestEndToEndChunking:
|
|
"""End-to-end tests for chunking workflow"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_end_to_end_chunking_workflow(self):
|
|
"""Test complete chunking workflow from documents to aggregated results"""
|
|
# Create documents where first one needs chunking
|
|
long_doc = ' '.join([f'word{i}' for i in range(100)])
|
|
documents = [long_doc, 'short doc']
|
|
query = 'test query'
|
|
|
|
# Mock the HTTP call inside generic_rerank_api
|
|
mock_response = Mock()
|
|
mock_response.status = 200
|
|
mock_response.json = AsyncMock(
|
|
return_value={
|
|
'results': [
|
|
{'index': 0, 'relevance_score': 0.5}, # chunk 0 from doc 0
|
|
{'index': 1, 'relevance_score': 0.8}, # chunk 1 from doc 0
|
|
{'index': 2, 'relevance_score': 0.6}, # chunk 2 from doc 0
|
|
{'index': 3, 'relevance_score': 0.7}, # doc 1 (short)
|
|
]
|
|
}
|
|
)
|
|
mock_response.request_info = None
|
|
mock_response.history = None
|
|
mock_response.headers = {}
|
|
# Make mock_response an async context manager (for `async with session.post() as response`)
|
|
mock_response.__aenter__ = AsyncMock(return_value=mock_response)
|
|
mock_response.__aexit__ = AsyncMock(return_value=None)
|
|
|
|
mock_session = Mock()
|
|
# session.post() returns an async context manager, so return mock_response which is now one
|
|
mock_session.post = Mock(return_value=mock_response)
|
|
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
|
|
mock_session.__aexit__ = AsyncMock(return_value=None)
|
|
|
|
with patch('lightrag.rerank.aiohttp.ClientSession', return_value=mock_session):
|
|
result = await cohere_rerank(
|
|
query=query,
|
|
documents=documents,
|
|
api_key='test-key',
|
|
base_url='http://test.com/rerank',
|
|
enable_chunking=True,
|
|
max_tokens_per_doc=30, # Force chunking of long doc
|
|
)
|
|
|
|
# Should get 2 results (one per original document)
|
|
# The long doc's chunks should be aggregated
|
|
assert len(result) <= len(documents)
|
|
# Results should be sorted by score
|
|
assert all(result[i]['relevance_score'] >= result[i + 1]['relevance_score'] for i in range(len(result) - 1))
|