diff --git a/lightrag/rerank.py b/lightrag/rerank.py index 81632b71..2e22f19a 100644 --- a/lightrag/rerank.py +++ b/lightrag/rerank.py @@ -50,7 +50,7 @@ def chunk_documents_for_rerank( f"overlap_tokens ({original_overlap}) must be less than max_tokens ({max_tokens}). " f"Clamping to {overlap_tokens} to prevent infinite loop." ) - + try: from .utils import TiktokenTokenizer diff --git a/tests/test_overlap_validation.py b/tests/test_overlap_validation.py index 7f84a3cf..4e7c9cbd 100644 --- a/tests/test_overlap_validation.py +++ b/tests/test_overlap_validation.py @@ -14,12 +14,12 @@ class TestOverlapValidation: def test_overlap_greater_than_max_tokens(self): """Test that overlap_tokens > max_tokens is clamped and doesn't hang""" documents = [" ".join([f"word{i}" for i in range(100)])] - + # This should clamp overlap_tokens to 29 (max_tokens - 1) chunked_docs, doc_indices = chunk_documents_for_rerank( documents, max_tokens=30, overlap_tokens=32 ) - + # Should complete without hanging assert len(chunked_docs) > 0 assert all(idx == 0 for idx in doc_indices) @@ -27,12 +27,12 @@ class TestOverlapValidation: def test_overlap_equal_to_max_tokens(self): """Test that overlap_tokens == max_tokens is clamped and doesn't hang""" documents = [" ".join([f"word{i}" for i in range(100)])] - + # This should clamp overlap_tokens to 29 (max_tokens - 1) chunked_docs, doc_indices = chunk_documents_for_rerank( documents, max_tokens=30, overlap_tokens=30 ) - + # Should complete without hanging assert len(chunked_docs) > 0 assert all(idx == 0 for idx in doc_indices) @@ -40,12 +40,12 @@ class TestOverlapValidation: def test_overlap_slightly_less_than_max_tokens(self): """Test that overlap_tokens < max_tokens works normally""" documents = [" ".join([f"word{i}" for i in range(100)])] - + # This should work without clamping chunked_docs, doc_indices = chunk_documents_for_rerank( documents, max_tokens=30, overlap_tokens=29 ) - + # Should complete successfully assert len(chunked_docs) > 0 assert all(idx == 0 for idx in doc_indices) @@ -53,12 +53,12 @@ class TestOverlapValidation: def test_small_max_tokens_with_large_overlap(self): """Test edge case with very small max_tokens""" documents = [" ".join([f"word{i}" for i in range(50)])] - + # max_tokens=5, overlap_tokens=10 should clamp to 4 chunked_docs, doc_indices = chunk_documents_for_rerank( documents, max_tokens=5, overlap_tokens=10 ) - + # Should complete without hanging assert len(chunked_docs) > 0 assert all(idx == 0 for idx in doc_indices) @@ -70,12 +70,12 @@ class TestOverlapValidation: "short document", " ".join([f"word{i}" for i in range(75)]), ] - + # overlap_tokens > max_tokens chunked_docs, doc_indices = chunk_documents_for_rerank( documents, max_tokens=25, overlap_tokens=30 ) - + # Should complete successfully and chunk the long documents assert len(chunked_docs) >= len(documents) # Short document should not be chunked @@ -87,12 +87,12 @@ class TestOverlapValidation: " ".join([f"word{i}" for i in range(100)]), "short doc", ] - + # Normal case: overlap_tokens (10) < max_tokens (50) chunked_docs, doc_indices = chunk_documents_for_rerank( documents, max_tokens=50, overlap_tokens=10 ) - + # Long document should be chunked, short one should not assert len(chunked_docs) > 2 # At least 3 chunks (2 from long doc + 1 short) assert "short doc" in chunked_docs @@ -102,12 +102,12 @@ class TestOverlapValidation: def test_edge_case_max_tokens_one(self): """Test edge case where max_tokens=1""" documents = [" ".join([f"word{i}" for i in range(20)])] - + # max_tokens=1, overlap_tokens=5 should clamp to 0 chunked_docs, doc_indices = chunk_documents_for_rerank( documents, max_tokens=1, overlap_tokens=5 ) - + # Should complete without hanging assert len(chunked_docs) > 0 assert all(idx == 0 for idx in doc_indices) diff --git a/tests/test_rerank_chunking.py b/tests/test_rerank_chunking.py new file mode 100644 index 00000000..1700988a --- /dev/null +++ b/tests/test_rerank_chunking.py @@ -0,0 +1,390 @@ +""" +Unit tests for rerank document chunking functionality. + +Tests the chunk_documents_for_rerank and aggregate_chunk_scores functions +in lightrag/rerank.py to ensure proper document splitting and score aggregation. +""" + +import pytest +from unittest.mock import Mock, patch, AsyncMock +from lightrag.rerank import ( + chunk_documents_for_rerank, + aggregate_chunk_scores, + cohere_rerank, +) + + +class TestChunkDocumentsForRerank: + """Test suite for chunk_documents_for_rerank function""" + + def test_no_chunking_needed_for_short_docs(self): + """Documents shorter than max_tokens should not be chunked""" + documents = [ + "Short doc 1", + "Short doc 2", + "Short doc 3", + ] + + chunked_docs, doc_indices = chunk_documents_for_rerank( + documents, max_tokens=100, overlap_tokens=10 + ) + + # No chunking should occur + assert len(chunked_docs) == 3 + assert chunked_docs == documents + assert doc_indices == [0, 1, 2] + + def test_chunking_with_character_fallback(self): + """Test chunking falls back to character-based when tokenizer unavailable""" + # Create a very long document that exceeds character limit + long_doc = "a" * 2000 # 2000 characters + documents = [long_doc, "short doc"] + + with patch("lightrag.utils.TiktokenTokenizer", side_effect=ImportError): + chunked_docs, doc_indices = chunk_documents_for_rerank( + documents, + max_tokens=100, # 100 tokens = ~400 chars + overlap_tokens=10, # 10 tokens = ~40 chars + ) + + # First doc should be split into chunks, second doc stays whole + assert len(chunked_docs) > 2 # At least one chunk from first doc + second doc + assert chunked_docs[-1] == "short doc" # Last chunk is the short doc + # Verify doc_indices maps chunks to correct original document + assert doc_indices[-1] == 1 # Last chunk maps to document 1 + + def test_chunking_with_tiktoken_tokenizer(self): + """Test chunking with actual tokenizer""" + # Create document with known token count + # Approximate: "word " = ~1 token, so 200 words ~ 200 tokens + long_doc = " ".join([f"word{i}" for i in range(200)]) + documents = [long_doc, "short"] + + chunked_docs, doc_indices = chunk_documents_for_rerank( + documents, max_tokens=50, overlap_tokens=10 + ) + + # Long doc should be split, short doc should remain + assert len(chunked_docs) > 2 + assert doc_indices[-1] == 1 # Last chunk is from second document + + # Verify overlapping chunks contain overlapping content + if len(chunked_docs) > 2: + # Check that consecutive chunks from same doc have some overlap + for i in range(len(doc_indices) - 1): + if doc_indices[i] == doc_indices[i + 1] == 0: + # Both chunks from first doc, should have overlap + chunk1_words = chunked_docs[i].split() + chunk2_words = chunked_docs[i + 1].split() + # At least one word should be common due to overlap + assert any(word in chunk2_words for word in chunk1_words[-5:]) + + def test_empty_documents(self): + """Test handling of empty document list""" + documents = [] + chunked_docs, doc_indices = chunk_documents_for_rerank(documents) + + assert chunked_docs == [] + assert doc_indices == [] + + def test_single_document_chunking(self): + """Test chunking of a single long document""" + # Create document with ~100 tokens + long_doc = " ".join([f"token{i}" for i in range(100)]) + documents = [long_doc] + + chunked_docs, doc_indices = chunk_documents_for_rerank( + documents, max_tokens=30, overlap_tokens=5 + ) + + # Should create multiple chunks + assert len(chunked_docs) > 1 + # All chunks should map to document 0 + assert all(idx == 0 for idx in doc_indices) + + +class TestAggregateChunkScores: + """Test suite for aggregate_chunk_scores function""" + + def test_no_chunking_simple_aggregation(self): + """Test aggregation when no chunking occurred (1:1 mapping)""" + chunk_results = [ + {"index": 0, "relevance_score": 0.9}, + {"index": 1, "relevance_score": 0.7}, + {"index": 2, "relevance_score": 0.5}, + ] + doc_indices = [0, 1, 2] # 1:1 mapping + num_original_docs = 3 + + aggregated = aggregate_chunk_scores( + chunk_results, doc_indices, num_original_docs, aggregation="max" + ) + + # Results should be sorted by score + assert len(aggregated) == 3 + assert aggregated[0]["index"] == 0 + assert aggregated[0]["relevance_score"] == 0.9 + assert aggregated[1]["index"] == 1 + assert aggregated[1]["relevance_score"] == 0.7 + assert aggregated[2]["index"] == 2 + assert aggregated[2]["relevance_score"] == 0.5 + + def test_max_aggregation_with_chunks(self): + """Test max aggregation strategy with multiple chunks per document""" + # 5 chunks: first 3 from doc 0, last 2 from doc 1 + chunk_results = [ + {"index": 0, "relevance_score": 0.5}, + {"index": 1, "relevance_score": 0.8}, + {"index": 2, "relevance_score": 0.6}, + {"index": 3, "relevance_score": 0.7}, + {"index": 4, "relevance_score": 0.4}, + ] + doc_indices = [0, 0, 0, 1, 1] + num_original_docs = 2 + + aggregated = aggregate_chunk_scores( + chunk_results, doc_indices, num_original_docs, aggregation="max" + ) + + # Should take max score for each document + assert len(aggregated) == 2 + assert aggregated[0]["index"] == 0 + assert aggregated[0]["relevance_score"] == 0.8 # max of 0.5, 0.8, 0.6 + assert aggregated[1]["index"] == 1 + assert aggregated[1]["relevance_score"] == 0.7 # max of 0.7, 0.4 + + def test_mean_aggregation_with_chunks(self): + """Test mean aggregation strategy""" + chunk_results = [ + {"index": 0, "relevance_score": 0.6}, + {"index": 1, "relevance_score": 0.8}, + {"index": 2, "relevance_score": 0.4}, + ] + doc_indices = [0, 0, 1] # First two chunks from doc 0, last from doc 1 + num_original_docs = 2 + + aggregated = aggregate_chunk_scores( + chunk_results, doc_indices, num_original_docs, aggregation="mean" + ) + + assert len(aggregated) == 2 + assert aggregated[0]["index"] == 0 + assert aggregated[0]["relevance_score"] == pytest.approx(0.7) # (0.6 + 0.8) / 2 + assert aggregated[1]["index"] == 1 + assert aggregated[1]["relevance_score"] == 0.4 + + def test_first_aggregation_with_chunks(self): + """Test first aggregation strategy""" + chunk_results = [ + {"index": 0, "relevance_score": 0.6}, + {"index": 1, "relevance_score": 0.8}, + {"index": 2, "relevance_score": 0.4}, + ] + doc_indices = [0, 0, 1] + num_original_docs = 2 + + aggregated = aggregate_chunk_scores( + chunk_results, doc_indices, num_original_docs, aggregation="first" + ) + + assert len(aggregated) == 2 + # First should use first score seen for each doc + assert aggregated[0]["index"] == 0 + assert aggregated[0]["relevance_score"] == 0.6 # First score for doc 0 + assert aggregated[1]["index"] == 1 + assert aggregated[1]["relevance_score"] == 0.4 + + def test_empty_chunk_results(self): + """Test handling of empty results""" + aggregated = aggregate_chunk_scores([], [], 3, aggregation="max") + assert aggregated == [] + + def test_documents_with_no_scores(self): + """Test when some documents have no chunks/scores""" + chunk_results = [ + {"index": 0, "relevance_score": 0.9}, + {"index": 1, "relevance_score": 0.7}, + ] + doc_indices = [0, 0] # Both chunks from document 0 + num_original_docs = 3 # But we have 3 documents total + + aggregated = aggregate_chunk_scores( + chunk_results, doc_indices, num_original_docs, aggregation="max" + ) + + # Only doc 0 should appear in results + assert len(aggregated) == 1 + assert aggregated[0]["index"] == 0 + + def test_unknown_aggregation_strategy(self): + """Test that unknown strategy falls back to max""" + chunk_results = [ + {"index": 0, "relevance_score": 0.6}, + {"index": 1, "relevance_score": 0.8}, + ] + doc_indices = [0, 0] + num_original_docs = 1 + + # Use invalid strategy + aggregated = aggregate_chunk_scores( + chunk_results, doc_indices, num_original_docs, aggregation="invalid" + ) + + # Should fall back to max + assert aggregated[0]["relevance_score"] == 0.8 + + +@pytest.mark.offline +class TestCohereRerankChunking: + """Integration tests for cohere_rerank with chunking enabled""" + + @pytest.mark.asyncio + async def test_cohere_rerank_with_chunking_disabled(self): + """Test that chunking can be disabled""" + documents = ["doc1", "doc2"] + query = "test query" + + # Mock the generic_rerank_api + with patch( + "lightrag.rerank.generic_rerank_api", new_callable=AsyncMock + ) as mock_api: + mock_api.return_value = [ + {"index": 0, "relevance_score": 0.9}, + {"index": 1, "relevance_score": 0.7}, + ] + + result = await cohere_rerank( + query=query, + documents=documents, + api_key="test-key", + enable_chunking=False, + max_tokens_per_doc=100, + ) + + # Verify generic_rerank_api was called with correct parameters + mock_api.assert_called_once() + call_kwargs = mock_api.call_args[1] + assert call_kwargs["enable_chunking"] is False + assert call_kwargs["max_tokens_per_doc"] == 100 + # Result should mirror mocked scores + assert len(result) == 2 + assert result[0]["index"] == 0 + assert result[0]["relevance_score"] == 0.9 + assert result[1]["index"] == 1 + assert result[1]["relevance_score"] == 0.7 + + @pytest.mark.asyncio + async def test_cohere_rerank_with_chunking_enabled(self): + """Test that chunking parameters are passed through""" + documents = ["doc1", "doc2"] + query = "test query" + + with patch( + "lightrag.rerank.generic_rerank_api", new_callable=AsyncMock + ) as mock_api: + mock_api.return_value = [ + {"index": 0, "relevance_score": 0.9}, + {"index": 1, "relevance_score": 0.7}, + ] + + result = await cohere_rerank( + query=query, + documents=documents, + api_key="test-key", + enable_chunking=True, + max_tokens_per_doc=480, + ) + + # Verify parameters were passed + call_kwargs = mock_api.call_args[1] + assert call_kwargs["enable_chunking"] is True + assert call_kwargs["max_tokens_per_doc"] == 480 + # Result should mirror mocked scores + assert len(result) == 2 + assert result[0]["index"] == 0 + assert result[0]["relevance_score"] == 0.9 + assert result[1]["index"] == 1 + assert result[1]["relevance_score"] == 0.7 + + @pytest.mark.asyncio + async def test_cohere_rerank_default_parameters(self): + """Test default parameter values for cohere_rerank""" + documents = ["doc1"] + query = "test" + + with patch( + "lightrag.rerank.generic_rerank_api", new_callable=AsyncMock + ) as mock_api: + mock_api.return_value = [{"index": 0, "relevance_score": 0.9}] + + result = await cohere_rerank( + query=query, documents=documents, api_key="test-key" + ) + + # Verify default values + call_kwargs = mock_api.call_args[1] + assert call_kwargs["enable_chunking"] is False + assert call_kwargs["max_tokens_per_doc"] == 4096 + assert call_kwargs["model"] == "rerank-v3.5" + # Result should mirror mocked scores + assert len(result) == 1 + assert result[0]["index"] == 0 + assert result[0]["relevance_score"] == 0.9 + + +@pytest.mark.offline +class TestEndToEndChunking: + """End-to-end tests for chunking workflow""" + + @pytest.mark.asyncio + async def test_end_to_end_chunking_workflow(self): + """Test complete chunking workflow from documents to aggregated results""" + # Create documents where first one needs chunking + long_doc = " ".join([f"word{i}" for i in range(100)]) + documents = [long_doc, "short doc"] + query = "test query" + + # Mock the HTTP call inside generic_rerank_api + mock_response = Mock() + mock_response.status = 200 + mock_response.json = AsyncMock( + return_value={ + "results": [ + {"index": 0, "relevance_score": 0.5}, # chunk 0 from doc 0 + {"index": 1, "relevance_score": 0.8}, # chunk 1 from doc 0 + {"index": 2, "relevance_score": 0.6}, # chunk 2 from doc 0 + {"index": 3, "relevance_score": 0.7}, # doc 1 (short) + ] + } + ) + mock_response.request_info = None + mock_response.history = None + mock_response.headers = {} + # Make mock_response an async context manager (for `async with session.post() as response`) + mock_response.__aenter__ = AsyncMock(return_value=mock_response) + mock_response.__aexit__ = AsyncMock(return_value=None) + + mock_session = Mock() + # session.post() returns an async context manager, so return mock_response which is now one + mock_session.post = Mock(return_value=mock_response) + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=None) + + with patch("lightrag.rerank.aiohttp.ClientSession", return_value=mock_session): + result = await cohere_rerank( + query=query, + documents=documents, + api_key="test-key", + base_url="http://test.com/rerank", + enable_chunking=True, + max_tokens_per_doc=30, # Force chunking of long doc + ) + + # Should get 2 results (one per original document) + # The long doc's chunks should be aggregated + assert len(result) <= len(documents) + # Results should be sorted by score + assert all( + result[i]["relevance_score"] >= result[i + 1]["relevance_score"] + for i in range(len(result) - 1) + )