Fix trailing whitespace and update test mocking for rerank module
• Remove trailing whitespace
• Fix TiktokenTokenizer import patch
• Add async context manager mocks
• Update aiohttp.ClientSession patch
• Improve test reliability
(cherry picked from commit 561ba4e4b5)
This commit is contained in:
parent
f6c20faa16
commit
d56b4c856e
3 changed files with 405 additions and 15 deletions
|
|
@ -50,7 +50,7 @@ def chunk_documents_for_rerank(
|
|||
f"overlap_tokens ({original_overlap}) must be less than max_tokens ({max_tokens}). "
|
||||
f"Clamping to {overlap_tokens} to prevent infinite loop."
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
from .utils import TiktokenTokenizer
|
||||
|
||||
|
|
|
|||
|
|
@ -14,12 +14,12 @@ class TestOverlapValidation:
|
|||
def test_overlap_greater_than_max_tokens(self):
|
||||
"""Test that overlap_tokens > max_tokens is clamped and doesn't hang"""
|
||||
documents = [" ".join([f"word{i}" for i in range(100)])]
|
||||
|
||||
|
||||
# This should clamp overlap_tokens to 29 (max_tokens - 1)
|
||||
chunked_docs, doc_indices = chunk_documents_for_rerank(
|
||||
documents, max_tokens=30, overlap_tokens=32
|
||||
)
|
||||
|
||||
|
||||
# Should complete without hanging
|
||||
assert len(chunked_docs) > 0
|
||||
assert all(idx == 0 for idx in doc_indices)
|
||||
|
|
@ -27,12 +27,12 @@ class TestOverlapValidation:
|
|||
def test_overlap_equal_to_max_tokens(self):
|
||||
"""Test that overlap_tokens == max_tokens is clamped and doesn't hang"""
|
||||
documents = [" ".join([f"word{i}" for i in range(100)])]
|
||||
|
||||
|
||||
# This should clamp overlap_tokens to 29 (max_tokens - 1)
|
||||
chunked_docs, doc_indices = chunk_documents_for_rerank(
|
||||
documents, max_tokens=30, overlap_tokens=30
|
||||
)
|
||||
|
||||
|
||||
# Should complete without hanging
|
||||
assert len(chunked_docs) > 0
|
||||
assert all(idx == 0 for idx in doc_indices)
|
||||
|
|
@ -40,12 +40,12 @@ class TestOverlapValidation:
|
|||
def test_overlap_slightly_less_than_max_tokens(self):
|
||||
"""Test that overlap_tokens < max_tokens works normally"""
|
||||
documents = [" ".join([f"word{i}" for i in range(100)])]
|
||||
|
||||
|
||||
# This should work without clamping
|
||||
chunked_docs, doc_indices = chunk_documents_for_rerank(
|
||||
documents, max_tokens=30, overlap_tokens=29
|
||||
)
|
||||
|
||||
|
||||
# Should complete successfully
|
||||
assert len(chunked_docs) > 0
|
||||
assert all(idx == 0 for idx in doc_indices)
|
||||
|
|
@ -53,12 +53,12 @@ class TestOverlapValidation:
|
|||
def test_small_max_tokens_with_large_overlap(self):
|
||||
"""Test edge case with very small max_tokens"""
|
||||
documents = [" ".join([f"word{i}" for i in range(50)])]
|
||||
|
||||
|
||||
# max_tokens=5, overlap_tokens=10 should clamp to 4
|
||||
chunked_docs, doc_indices = chunk_documents_for_rerank(
|
||||
documents, max_tokens=5, overlap_tokens=10
|
||||
)
|
||||
|
||||
|
||||
# Should complete without hanging
|
||||
assert len(chunked_docs) > 0
|
||||
assert all(idx == 0 for idx in doc_indices)
|
||||
|
|
@ -70,12 +70,12 @@ class TestOverlapValidation:
|
|||
"short document",
|
||||
" ".join([f"word{i}" for i in range(75)]),
|
||||
]
|
||||
|
||||
|
||||
# overlap_tokens > max_tokens
|
||||
chunked_docs, doc_indices = chunk_documents_for_rerank(
|
||||
documents, max_tokens=25, overlap_tokens=30
|
||||
)
|
||||
|
||||
|
||||
# Should complete successfully and chunk the long documents
|
||||
assert len(chunked_docs) >= len(documents)
|
||||
# Short document should not be chunked
|
||||
|
|
@ -87,12 +87,12 @@ class TestOverlapValidation:
|
|||
" ".join([f"word{i}" for i in range(100)]),
|
||||
"short doc",
|
||||
]
|
||||
|
||||
|
||||
# Normal case: overlap_tokens (10) < max_tokens (50)
|
||||
chunked_docs, doc_indices = chunk_documents_for_rerank(
|
||||
documents, max_tokens=50, overlap_tokens=10
|
||||
)
|
||||
|
||||
|
||||
# Long document should be chunked, short one should not
|
||||
assert len(chunked_docs) > 2 # At least 3 chunks (2 from long doc + 1 short)
|
||||
assert "short doc" in chunked_docs
|
||||
|
|
@ -102,12 +102,12 @@ class TestOverlapValidation:
|
|||
def test_edge_case_max_tokens_one(self):
|
||||
"""Test edge case where max_tokens=1"""
|
||||
documents = [" ".join([f"word{i}" for i in range(20)])]
|
||||
|
||||
|
||||
# max_tokens=1, overlap_tokens=5 should clamp to 0
|
||||
chunked_docs, doc_indices = chunk_documents_for_rerank(
|
||||
documents, max_tokens=1, overlap_tokens=5
|
||||
)
|
||||
|
||||
|
||||
# Should complete without hanging
|
||||
assert len(chunked_docs) > 0
|
||||
assert all(idx == 0 for idx in doc_indices)
|
||||
|
|
|
|||
390
tests/test_rerank_chunking.py
Normal file
390
tests/test_rerank_chunking.py
Normal file
|
|
@ -0,0 +1,390 @@
|
|||
"""
|
||||
Unit tests for rerank document chunking functionality.
|
||||
|
||||
Tests the chunk_documents_for_rerank and aggregate_chunk_scores functions
|
||||
in lightrag/rerank.py to ensure proper document splitting and score aggregation.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch, AsyncMock
|
||||
from lightrag.rerank import (
|
||||
chunk_documents_for_rerank,
|
||||
aggregate_chunk_scores,
|
||||
cohere_rerank,
|
||||
)
|
||||
|
||||
|
||||
class TestChunkDocumentsForRerank:
|
||||
"""Test suite for chunk_documents_for_rerank function"""
|
||||
|
||||
def test_no_chunking_needed_for_short_docs(self):
|
||||
"""Documents shorter than max_tokens should not be chunked"""
|
||||
documents = [
|
||||
"Short doc 1",
|
||||
"Short doc 2",
|
||||
"Short doc 3",
|
||||
]
|
||||
|
||||
chunked_docs, doc_indices = chunk_documents_for_rerank(
|
||||
documents, max_tokens=100, overlap_tokens=10
|
||||
)
|
||||
|
||||
# No chunking should occur
|
||||
assert len(chunked_docs) == 3
|
||||
assert chunked_docs == documents
|
||||
assert doc_indices == [0, 1, 2]
|
||||
|
||||
def test_chunking_with_character_fallback(self):
|
||||
"""Test chunking falls back to character-based when tokenizer unavailable"""
|
||||
# Create a very long document that exceeds character limit
|
||||
long_doc = "a" * 2000 # 2000 characters
|
||||
documents = [long_doc, "short doc"]
|
||||
|
||||
with patch("lightrag.utils.TiktokenTokenizer", side_effect=ImportError):
|
||||
chunked_docs, doc_indices = chunk_documents_for_rerank(
|
||||
documents,
|
||||
max_tokens=100, # 100 tokens = ~400 chars
|
||||
overlap_tokens=10, # 10 tokens = ~40 chars
|
||||
)
|
||||
|
||||
# First doc should be split into chunks, second doc stays whole
|
||||
assert len(chunked_docs) > 2 # At least one chunk from first doc + second doc
|
||||
assert chunked_docs[-1] == "short doc" # Last chunk is the short doc
|
||||
# Verify doc_indices maps chunks to correct original document
|
||||
assert doc_indices[-1] == 1 # Last chunk maps to document 1
|
||||
|
||||
def test_chunking_with_tiktoken_tokenizer(self):
|
||||
"""Test chunking with actual tokenizer"""
|
||||
# Create document with known token count
|
||||
# Approximate: "word " = ~1 token, so 200 words ~ 200 tokens
|
||||
long_doc = " ".join([f"word{i}" for i in range(200)])
|
||||
documents = [long_doc, "short"]
|
||||
|
||||
chunked_docs, doc_indices = chunk_documents_for_rerank(
|
||||
documents, max_tokens=50, overlap_tokens=10
|
||||
)
|
||||
|
||||
# Long doc should be split, short doc should remain
|
||||
assert len(chunked_docs) > 2
|
||||
assert doc_indices[-1] == 1 # Last chunk is from second document
|
||||
|
||||
# Verify overlapping chunks contain overlapping content
|
||||
if len(chunked_docs) > 2:
|
||||
# Check that consecutive chunks from same doc have some overlap
|
||||
for i in range(len(doc_indices) - 1):
|
||||
if doc_indices[i] == doc_indices[i + 1] == 0:
|
||||
# Both chunks from first doc, should have overlap
|
||||
chunk1_words = chunked_docs[i].split()
|
||||
chunk2_words = chunked_docs[i + 1].split()
|
||||
# At least one word should be common due to overlap
|
||||
assert any(word in chunk2_words for word in chunk1_words[-5:])
|
||||
|
||||
def test_empty_documents(self):
|
||||
"""Test handling of empty document list"""
|
||||
documents = []
|
||||
chunked_docs, doc_indices = chunk_documents_for_rerank(documents)
|
||||
|
||||
assert chunked_docs == []
|
||||
assert doc_indices == []
|
||||
|
||||
def test_single_document_chunking(self):
|
||||
"""Test chunking of a single long document"""
|
||||
# Create document with ~100 tokens
|
||||
long_doc = " ".join([f"token{i}" for i in range(100)])
|
||||
documents = [long_doc]
|
||||
|
||||
chunked_docs, doc_indices = chunk_documents_for_rerank(
|
||||
documents, max_tokens=30, overlap_tokens=5
|
||||
)
|
||||
|
||||
# Should create multiple chunks
|
||||
assert len(chunked_docs) > 1
|
||||
# All chunks should map to document 0
|
||||
assert all(idx == 0 for idx in doc_indices)
|
||||
|
||||
|
||||
class TestAggregateChunkScores:
|
||||
"""Test suite for aggregate_chunk_scores function"""
|
||||
|
||||
def test_no_chunking_simple_aggregation(self):
|
||||
"""Test aggregation when no chunking occurred (1:1 mapping)"""
|
||||
chunk_results = [
|
||||
{"index": 0, "relevance_score": 0.9},
|
||||
{"index": 1, "relevance_score": 0.7},
|
||||
{"index": 2, "relevance_score": 0.5},
|
||||
]
|
||||
doc_indices = [0, 1, 2] # 1:1 mapping
|
||||
num_original_docs = 3
|
||||
|
||||
aggregated = aggregate_chunk_scores(
|
||||
chunk_results, doc_indices, num_original_docs, aggregation="max"
|
||||
)
|
||||
|
||||
# Results should be sorted by score
|
||||
assert len(aggregated) == 3
|
||||
assert aggregated[0]["index"] == 0
|
||||
assert aggregated[0]["relevance_score"] == 0.9
|
||||
assert aggregated[1]["index"] == 1
|
||||
assert aggregated[1]["relevance_score"] == 0.7
|
||||
assert aggregated[2]["index"] == 2
|
||||
assert aggregated[2]["relevance_score"] == 0.5
|
||||
|
||||
def test_max_aggregation_with_chunks(self):
|
||||
"""Test max aggregation strategy with multiple chunks per document"""
|
||||
# 5 chunks: first 3 from doc 0, last 2 from doc 1
|
||||
chunk_results = [
|
||||
{"index": 0, "relevance_score": 0.5},
|
||||
{"index": 1, "relevance_score": 0.8},
|
||||
{"index": 2, "relevance_score": 0.6},
|
||||
{"index": 3, "relevance_score": 0.7},
|
||||
{"index": 4, "relevance_score": 0.4},
|
||||
]
|
||||
doc_indices = [0, 0, 0, 1, 1]
|
||||
num_original_docs = 2
|
||||
|
||||
aggregated = aggregate_chunk_scores(
|
||||
chunk_results, doc_indices, num_original_docs, aggregation="max"
|
||||
)
|
||||
|
||||
# Should take max score for each document
|
||||
assert len(aggregated) == 2
|
||||
assert aggregated[0]["index"] == 0
|
||||
assert aggregated[0]["relevance_score"] == 0.8 # max of 0.5, 0.8, 0.6
|
||||
assert aggregated[1]["index"] == 1
|
||||
assert aggregated[1]["relevance_score"] == 0.7 # max of 0.7, 0.4
|
||||
|
||||
def test_mean_aggregation_with_chunks(self):
|
||||
"""Test mean aggregation strategy"""
|
||||
chunk_results = [
|
||||
{"index": 0, "relevance_score": 0.6},
|
||||
{"index": 1, "relevance_score": 0.8},
|
||||
{"index": 2, "relevance_score": 0.4},
|
||||
]
|
||||
doc_indices = [0, 0, 1] # First two chunks from doc 0, last from doc 1
|
||||
num_original_docs = 2
|
||||
|
||||
aggregated = aggregate_chunk_scores(
|
||||
chunk_results, doc_indices, num_original_docs, aggregation="mean"
|
||||
)
|
||||
|
||||
assert len(aggregated) == 2
|
||||
assert aggregated[0]["index"] == 0
|
||||
assert aggregated[0]["relevance_score"] == pytest.approx(0.7) # (0.6 + 0.8) / 2
|
||||
assert aggregated[1]["index"] == 1
|
||||
assert aggregated[1]["relevance_score"] == 0.4
|
||||
|
||||
def test_first_aggregation_with_chunks(self):
|
||||
"""Test first aggregation strategy"""
|
||||
chunk_results = [
|
||||
{"index": 0, "relevance_score": 0.6},
|
||||
{"index": 1, "relevance_score": 0.8},
|
||||
{"index": 2, "relevance_score": 0.4},
|
||||
]
|
||||
doc_indices = [0, 0, 1]
|
||||
num_original_docs = 2
|
||||
|
||||
aggregated = aggregate_chunk_scores(
|
||||
chunk_results, doc_indices, num_original_docs, aggregation="first"
|
||||
)
|
||||
|
||||
assert len(aggregated) == 2
|
||||
# First should use first score seen for each doc
|
||||
assert aggregated[0]["index"] == 0
|
||||
assert aggregated[0]["relevance_score"] == 0.6 # First score for doc 0
|
||||
assert aggregated[1]["index"] == 1
|
||||
assert aggregated[1]["relevance_score"] == 0.4
|
||||
|
||||
def test_empty_chunk_results(self):
|
||||
"""Test handling of empty results"""
|
||||
aggregated = aggregate_chunk_scores([], [], 3, aggregation="max")
|
||||
assert aggregated == []
|
||||
|
||||
def test_documents_with_no_scores(self):
|
||||
"""Test when some documents have no chunks/scores"""
|
||||
chunk_results = [
|
||||
{"index": 0, "relevance_score": 0.9},
|
||||
{"index": 1, "relevance_score": 0.7},
|
||||
]
|
||||
doc_indices = [0, 0] # Both chunks from document 0
|
||||
num_original_docs = 3 # But we have 3 documents total
|
||||
|
||||
aggregated = aggregate_chunk_scores(
|
||||
chunk_results, doc_indices, num_original_docs, aggregation="max"
|
||||
)
|
||||
|
||||
# Only doc 0 should appear in results
|
||||
assert len(aggregated) == 1
|
||||
assert aggregated[0]["index"] == 0
|
||||
|
||||
def test_unknown_aggregation_strategy(self):
|
||||
"""Test that unknown strategy falls back to max"""
|
||||
chunk_results = [
|
||||
{"index": 0, "relevance_score": 0.6},
|
||||
{"index": 1, "relevance_score": 0.8},
|
||||
]
|
||||
doc_indices = [0, 0]
|
||||
num_original_docs = 1
|
||||
|
||||
# Use invalid strategy
|
||||
aggregated = aggregate_chunk_scores(
|
||||
chunk_results, doc_indices, num_original_docs, aggregation="invalid"
|
||||
)
|
||||
|
||||
# Should fall back to max
|
||||
assert aggregated[0]["relevance_score"] == 0.8
|
||||
|
||||
|
||||
@pytest.mark.offline
|
||||
class TestCohereRerankChunking:
|
||||
"""Integration tests for cohere_rerank with chunking enabled"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cohere_rerank_with_chunking_disabled(self):
|
||||
"""Test that chunking can be disabled"""
|
||||
documents = ["doc1", "doc2"]
|
||||
query = "test query"
|
||||
|
||||
# Mock the generic_rerank_api
|
||||
with patch(
|
||||
"lightrag.rerank.generic_rerank_api", new_callable=AsyncMock
|
||||
) as mock_api:
|
||||
mock_api.return_value = [
|
||||
{"index": 0, "relevance_score": 0.9},
|
||||
{"index": 1, "relevance_score": 0.7},
|
||||
]
|
||||
|
||||
result = await cohere_rerank(
|
||||
query=query,
|
||||
documents=documents,
|
||||
api_key="test-key",
|
||||
enable_chunking=False,
|
||||
max_tokens_per_doc=100,
|
||||
)
|
||||
|
||||
# Verify generic_rerank_api was called with correct parameters
|
||||
mock_api.assert_called_once()
|
||||
call_kwargs = mock_api.call_args[1]
|
||||
assert call_kwargs["enable_chunking"] is False
|
||||
assert call_kwargs["max_tokens_per_doc"] == 100
|
||||
# Result should mirror mocked scores
|
||||
assert len(result) == 2
|
||||
assert result[0]["index"] == 0
|
||||
assert result[0]["relevance_score"] == 0.9
|
||||
assert result[1]["index"] == 1
|
||||
assert result[1]["relevance_score"] == 0.7
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cohere_rerank_with_chunking_enabled(self):
|
||||
"""Test that chunking parameters are passed through"""
|
||||
documents = ["doc1", "doc2"]
|
||||
query = "test query"
|
||||
|
||||
with patch(
|
||||
"lightrag.rerank.generic_rerank_api", new_callable=AsyncMock
|
||||
) as mock_api:
|
||||
mock_api.return_value = [
|
||||
{"index": 0, "relevance_score": 0.9},
|
||||
{"index": 1, "relevance_score": 0.7},
|
||||
]
|
||||
|
||||
result = await cohere_rerank(
|
||||
query=query,
|
||||
documents=documents,
|
||||
api_key="test-key",
|
||||
enable_chunking=True,
|
||||
max_tokens_per_doc=480,
|
||||
)
|
||||
|
||||
# Verify parameters were passed
|
||||
call_kwargs = mock_api.call_args[1]
|
||||
assert call_kwargs["enable_chunking"] is True
|
||||
assert call_kwargs["max_tokens_per_doc"] == 480
|
||||
# Result should mirror mocked scores
|
||||
assert len(result) == 2
|
||||
assert result[0]["index"] == 0
|
||||
assert result[0]["relevance_score"] == 0.9
|
||||
assert result[1]["index"] == 1
|
||||
assert result[1]["relevance_score"] == 0.7
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cohere_rerank_default_parameters(self):
|
||||
"""Test default parameter values for cohere_rerank"""
|
||||
documents = ["doc1"]
|
||||
query = "test"
|
||||
|
||||
with patch(
|
||||
"lightrag.rerank.generic_rerank_api", new_callable=AsyncMock
|
||||
) as mock_api:
|
||||
mock_api.return_value = [{"index": 0, "relevance_score": 0.9}]
|
||||
|
||||
result = await cohere_rerank(
|
||||
query=query, documents=documents, api_key="test-key"
|
||||
)
|
||||
|
||||
# Verify default values
|
||||
call_kwargs = mock_api.call_args[1]
|
||||
assert call_kwargs["enable_chunking"] is False
|
||||
assert call_kwargs["max_tokens_per_doc"] == 4096
|
||||
assert call_kwargs["model"] == "rerank-v3.5"
|
||||
# Result should mirror mocked scores
|
||||
assert len(result) == 1
|
||||
assert result[0]["index"] == 0
|
||||
assert result[0]["relevance_score"] == 0.9
|
||||
|
||||
|
||||
@pytest.mark.offline
|
||||
class TestEndToEndChunking:
|
||||
"""End-to-end tests for chunking workflow"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_end_to_end_chunking_workflow(self):
|
||||
"""Test complete chunking workflow from documents to aggregated results"""
|
||||
# Create documents where first one needs chunking
|
||||
long_doc = " ".join([f"word{i}" for i in range(100)])
|
||||
documents = [long_doc, "short doc"]
|
||||
query = "test query"
|
||||
|
||||
# Mock the HTTP call inside generic_rerank_api
|
||||
mock_response = Mock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(
|
||||
return_value={
|
||||
"results": [
|
||||
{"index": 0, "relevance_score": 0.5}, # chunk 0 from doc 0
|
||||
{"index": 1, "relevance_score": 0.8}, # chunk 1 from doc 0
|
||||
{"index": 2, "relevance_score": 0.6}, # chunk 2 from doc 0
|
||||
{"index": 3, "relevance_score": 0.7}, # doc 1 (short)
|
||||
]
|
||||
}
|
||||
)
|
||||
mock_response.request_info = None
|
||||
mock_response.history = None
|
||||
mock_response.headers = {}
|
||||
# Make mock_response an async context manager (for `async with session.post() as response`)
|
||||
mock_response.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_response.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
mock_session = Mock()
|
||||
# session.post() returns an async context manager, so return mock_response which is now one
|
||||
mock_session.post = Mock(return_value=mock_response)
|
||||
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_session.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
with patch("lightrag.rerank.aiohttp.ClientSession", return_value=mock_session):
|
||||
result = await cohere_rerank(
|
||||
query=query,
|
||||
documents=documents,
|
||||
api_key="test-key",
|
||||
base_url="http://test.com/rerank",
|
||||
enable_chunking=True,
|
||||
max_tokens_per_doc=30, # Force chunking of long doc
|
||||
)
|
||||
|
||||
# Should get 2 results (one per original document)
|
||||
# The long doc's chunks should be aggregated
|
||||
assert len(result) <= len(documents)
|
||||
# Results should be sorted by score
|
||||
assert all(
|
||||
result[i]["relevance_score"] >= result[i + 1]["relevance_score"]
|
||||
for i in range(len(result) - 1)
|
||||
)
|
||||
Loading…
Add table
Reference in a new issue