From 9009abed3ecd61605f3ec43dcda8ada1787bd3a6 Mon Sep 17 00:00:00 2001 From: yangdx Date: Wed, 3 Dec 2025 13:08:26 +0800 Subject: [PATCH] Fix top_n behavior with chunking to limit documents not chunks - Disable API-level top_n when chunking - Apply top_n to aggregated documents - Add comprehensive test coverage --- lightrag/rerank.py | 17 ++++ tests/test_rerank_chunking.py | 174 ++++++++++++++++++++++++++++++++++ 2 files changed, 191 insertions(+) diff --git a/lightrag/rerank.py b/lightrag/rerank.py index 2e22f19a..12950fe6 100644 --- a/lightrag/rerank.py +++ b/lightrag/rerank.py @@ -223,6 +223,8 @@ async def generic_rerank_api( # Handle document chunking if enabled original_documents = documents doc_indices = None + original_top_n = top_n # Save original top_n for post-aggregation limiting + if enable_chunking: documents, doc_indices = chunk_documents_for_rerank( documents, max_tokens=max_tokens_per_doc @@ -230,6 +232,14 @@ async def generic_rerank_api( logger.debug( f"Chunked {len(original_documents)} documents into {len(documents)} chunks" ) + # When chunking is enabled, disable top_n at API level to get all chunk scores + # This ensures proper document-level coverage after aggregation + # We'll apply top_n to aggregated document results instead + if top_n is not None: + logger.debug( + f"Chunking enabled: disabled API-level top_n={top_n} to ensure complete document coverage" + ) + top_n = None # Build request payload based on request format if request_format == "aliyun": @@ -344,6 +354,13 @@ async def generic_rerank_api( len(original_documents), aggregation="max", ) + # Apply original top_n limit at document level (post-aggregation) + # This preserves document-level semantics: top_n limits documents, not chunks + if ( + original_top_n is not None + and len(standardized_results) > original_top_n + ): + standardized_results = standardized_results[:original_top_n] return standardized_results diff --git a/tests/test_rerank_chunking.py b/tests/test_rerank_chunking.py index 1700988a..09f1816b 100644 --- a/tests/test_rerank_chunking.py +++ b/tests/test_rerank_chunking.py @@ -234,6 +234,180 @@ class TestAggregateChunkScores: assert aggregated[0]["relevance_score"] == 0.8 +@pytest.mark.offline +class TestTopNWithChunking: + """Tests for top_n behavior when chunking is enabled (Bug fix verification)""" + + @pytest.mark.asyncio + async def test_top_n_limits_documents_not_chunks(self): + """ + Test that top_n correctly limits documents (not chunks) when chunking is enabled. + + Bug scenario: 10 docs expand to 50 chunks. With old behavior, top_n=5 would + return scores for only 5 chunks (possibly all from 1-2 docs). After aggregation, + fewer than 5 documents would be returned. + + Fixed behavior: top_n=5 should return exactly 5 documents after aggregation. + """ + # Setup: 5 documents, each producing multiple chunks when chunked + # Using small max_tokens to force chunking + long_docs = [" ".join([f"doc{i}_word{j}" for j in range(50)]) for i in range(5)] + query = "test query" + + # First, determine how many chunks will be created by actual chunking + _, doc_indices = chunk_documents_for_rerank( + long_docs, max_tokens=50, overlap_tokens=10 + ) + num_chunks = len(doc_indices) + + # Mock API returns scores for ALL chunks (simulating disabled API-level top_n) + # Give different scores to ensure doc 0 gets highest, doc 1 second, etc. + # Assign scores based on original document index (lower doc index = higher score) + mock_chunk_scores = [] + for i in range(num_chunks): + original_doc = doc_indices[i] + # Higher score for lower doc index, with small variation per chunk + base_score = 0.9 - (original_doc * 0.1) + mock_chunk_scores.append({"index": i, "relevance_score": base_score}) + + mock_response = Mock() + mock_response.status = 200 + mock_response.json = AsyncMock(return_value={"results": mock_chunk_scores}) + mock_response.request_info = None + mock_response.history = None + mock_response.headers = {} + mock_response.__aenter__ = AsyncMock(return_value=mock_response) + mock_response.__aexit__ = AsyncMock(return_value=None) + + mock_session = Mock() + mock_session.post = Mock(return_value=mock_response) + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=None) + + with patch("lightrag.rerank.aiohttp.ClientSession", return_value=mock_session): + result = await cohere_rerank( + query=query, + documents=long_docs, + api_key="test-key", + base_url="http://test.com/rerank", + enable_chunking=True, + max_tokens_per_doc=50, # Match chunking above + top_n=3, # Request top 3 documents + ) + + # Verify: should get exactly 3 documents (not unlimited chunks) + assert len(result) == 3 + # All results should have valid document indices (0-4) + assert all(0 <= r["index"] < 5 for r in result) + # Results should be sorted by score (descending) + assert all( + result[i]["relevance_score"] >= result[i + 1]["relevance_score"] + for i in range(len(result) - 1) + ) + # The top 3 docs should be 0, 1, 2 (highest scores) + result_indices = [r["index"] for r in result] + assert set(result_indices) == {0, 1, 2} + + @pytest.mark.asyncio + async def test_api_receives_no_top_n_when_chunking_enabled(self): + """ + Test that the API request does NOT include top_n when chunking is enabled. + + This ensures all chunk scores are retrieved for proper aggregation. + """ + documents = [" ".join([f"word{i}" for i in range(100)]), "short doc"] + query = "test query" + + captured_payload = {} + + mock_response = Mock() + mock_response.status = 200 + mock_response.json = AsyncMock( + return_value={ + "results": [ + {"index": 0, "relevance_score": 0.9}, + {"index": 1, "relevance_score": 0.8}, + {"index": 2, "relevance_score": 0.7}, + ] + } + ) + mock_response.request_info = None + mock_response.history = None + mock_response.headers = {} + mock_response.__aenter__ = AsyncMock(return_value=mock_response) + mock_response.__aexit__ = AsyncMock(return_value=None) + + def capture_post(*args, **kwargs): + captured_payload.update(kwargs.get("json", {})) + return mock_response + + mock_session = Mock() + mock_session.post = Mock(side_effect=capture_post) + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=None) + + with patch("lightrag.rerank.aiohttp.ClientSession", return_value=mock_session): + await cohere_rerank( + query=query, + documents=documents, + api_key="test-key", + base_url="http://test.com/rerank", + enable_chunking=True, + max_tokens_per_doc=30, + top_n=1, # User wants top 1 document + ) + + # Verify: API payload should NOT have top_n (disabled for chunking) + assert "top_n" not in captured_payload + + @pytest.mark.asyncio + async def test_top_n_not_modified_when_chunking_disabled(self): + """ + Test that top_n is passed through to API when chunking is disabled. + """ + documents = ["doc1", "doc2"] + query = "test query" + + captured_payload = {} + + mock_response = Mock() + mock_response.status = 200 + mock_response.json = AsyncMock( + return_value={ + "results": [ + {"index": 0, "relevance_score": 0.9}, + ] + } + ) + mock_response.request_info = None + mock_response.history = None + mock_response.headers = {} + mock_response.__aenter__ = AsyncMock(return_value=mock_response) + mock_response.__aexit__ = AsyncMock(return_value=None) + + def capture_post(*args, **kwargs): + captured_payload.update(kwargs.get("json", {})) + return mock_response + + mock_session = Mock() + mock_session.post = Mock(side_effect=capture_post) + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=None) + + with patch("lightrag.rerank.aiohttp.ClientSession", return_value=mock_session): + await cohere_rerank( + query=query, + documents=documents, + api_key="test-key", + base_url="http://test.com/rerank", + enable_chunking=False, # Chunking disabled + top_n=1, + ) + + # Verify: API payload should have top_n when chunking is disabled + assert captured_payload.get("top_n") == 1 + + @pytest.mark.offline class TestCohereRerankChunking: """Integration tests for cohere_rerank with chunking enabled"""