Fix top_n behavior with chunking to limit documents not chunks
- Disable API-level top_n when chunking - Apply top_n to aggregated documents - Add comprehensive test coverage
This commit is contained in:
parent
561ba4e4b5
commit
9009abed3e
2 changed files with 191 additions and 0 deletions
|
|
@ -223,6 +223,8 @@ async def generic_rerank_api(
|
||||||
# Handle document chunking if enabled
|
# Handle document chunking if enabled
|
||||||
original_documents = documents
|
original_documents = documents
|
||||||
doc_indices = None
|
doc_indices = None
|
||||||
|
original_top_n = top_n # Save original top_n for post-aggregation limiting
|
||||||
|
|
||||||
if enable_chunking:
|
if enable_chunking:
|
||||||
documents, doc_indices = chunk_documents_for_rerank(
|
documents, doc_indices = chunk_documents_for_rerank(
|
||||||
documents, max_tokens=max_tokens_per_doc
|
documents, max_tokens=max_tokens_per_doc
|
||||||
|
|
@ -230,6 +232,14 @@ async def generic_rerank_api(
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Chunked {len(original_documents)} documents into {len(documents)} chunks"
|
f"Chunked {len(original_documents)} documents into {len(documents)} chunks"
|
||||||
)
|
)
|
||||||
|
# When chunking is enabled, disable top_n at API level to get all chunk scores
|
||||||
|
# This ensures proper document-level coverage after aggregation
|
||||||
|
# We'll apply top_n to aggregated document results instead
|
||||||
|
if top_n is not None:
|
||||||
|
logger.debug(
|
||||||
|
f"Chunking enabled: disabled API-level top_n={top_n} to ensure complete document coverage"
|
||||||
|
)
|
||||||
|
top_n = None
|
||||||
|
|
||||||
# Build request payload based on request format
|
# Build request payload based on request format
|
||||||
if request_format == "aliyun":
|
if request_format == "aliyun":
|
||||||
|
|
@ -344,6 +354,13 @@ async def generic_rerank_api(
|
||||||
len(original_documents),
|
len(original_documents),
|
||||||
aggregation="max",
|
aggregation="max",
|
||||||
)
|
)
|
||||||
|
# Apply original top_n limit at document level (post-aggregation)
|
||||||
|
# This preserves document-level semantics: top_n limits documents, not chunks
|
||||||
|
if (
|
||||||
|
original_top_n is not None
|
||||||
|
and len(standardized_results) > original_top_n
|
||||||
|
):
|
||||||
|
standardized_results = standardized_results[:original_top_n]
|
||||||
|
|
||||||
return standardized_results
|
return standardized_results
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -234,6 +234,180 @@ class TestAggregateChunkScores:
|
||||||
assert aggregated[0]["relevance_score"] == 0.8
|
assert aggregated[0]["relevance_score"] == 0.8
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.offline
|
||||||
|
class TestTopNWithChunking:
|
||||||
|
"""Tests for top_n behavior when chunking is enabled (Bug fix verification)"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_top_n_limits_documents_not_chunks(self):
|
||||||
|
"""
|
||||||
|
Test that top_n correctly limits documents (not chunks) when chunking is enabled.
|
||||||
|
|
||||||
|
Bug scenario: 10 docs expand to 50 chunks. With old behavior, top_n=5 would
|
||||||
|
return scores for only 5 chunks (possibly all from 1-2 docs). After aggregation,
|
||||||
|
fewer than 5 documents would be returned.
|
||||||
|
|
||||||
|
Fixed behavior: top_n=5 should return exactly 5 documents after aggregation.
|
||||||
|
"""
|
||||||
|
# Setup: 5 documents, each producing multiple chunks when chunked
|
||||||
|
# Using small max_tokens to force chunking
|
||||||
|
long_docs = [" ".join([f"doc{i}_word{j}" for j in range(50)]) for i in range(5)]
|
||||||
|
query = "test query"
|
||||||
|
|
||||||
|
# First, determine how many chunks will be created by actual chunking
|
||||||
|
_, doc_indices = chunk_documents_for_rerank(
|
||||||
|
long_docs, max_tokens=50, overlap_tokens=10
|
||||||
|
)
|
||||||
|
num_chunks = len(doc_indices)
|
||||||
|
|
||||||
|
# Mock API returns scores for ALL chunks (simulating disabled API-level top_n)
|
||||||
|
# Give different scores to ensure doc 0 gets highest, doc 1 second, etc.
|
||||||
|
# Assign scores based on original document index (lower doc index = higher score)
|
||||||
|
mock_chunk_scores = []
|
||||||
|
for i in range(num_chunks):
|
||||||
|
original_doc = doc_indices[i]
|
||||||
|
# Higher score for lower doc index, with small variation per chunk
|
||||||
|
base_score = 0.9 - (original_doc * 0.1)
|
||||||
|
mock_chunk_scores.append({"index": i, "relevance_score": base_score})
|
||||||
|
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.status = 200
|
||||||
|
mock_response.json = AsyncMock(return_value={"results": mock_chunk_scores})
|
||||||
|
mock_response.request_info = None
|
||||||
|
mock_response.history = None
|
||||||
|
mock_response.headers = {}
|
||||||
|
mock_response.__aenter__ = AsyncMock(return_value=mock_response)
|
||||||
|
mock_response.__aexit__ = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
mock_session = Mock()
|
||||||
|
mock_session.post = Mock(return_value=mock_response)
|
||||||
|
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
|
||||||
|
mock_session.__aexit__ = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
with patch("lightrag.rerank.aiohttp.ClientSession", return_value=mock_session):
|
||||||
|
result = await cohere_rerank(
|
||||||
|
query=query,
|
||||||
|
documents=long_docs,
|
||||||
|
api_key="test-key",
|
||||||
|
base_url="http://test.com/rerank",
|
||||||
|
enable_chunking=True,
|
||||||
|
max_tokens_per_doc=50, # Match chunking above
|
||||||
|
top_n=3, # Request top 3 documents
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify: should get exactly 3 documents (not unlimited chunks)
|
||||||
|
assert len(result) == 3
|
||||||
|
# All results should have valid document indices (0-4)
|
||||||
|
assert all(0 <= r["index"] < 5 for r in result)
|
||||||
|
# Results should be sorted by score (descending)
|
||||||
|
assert all(
|
||||||
|
result[i]["relevance_score"] >= result[i + 1]["relevance_score"]
|
||||||
|
for i in range(len(result) - 1)
|
||||||
|
)
|
||||||
|
# The top 3 docs should be 0, 1, 2 (highest scores)
|
||||||
|
result_indices = [r["index"] for r in result]
|
||||||
|
assert set(result_indices) == {0, 1, 2}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_api_receives_no_top_n_when_chunking_enabled(self):
|
||||||
|
"""
|
||||||
|
Test that the API request does NOT include top_n when chunking is enabled.
|
||||||
|
|
||||||
|
This ensures all chunk scores are retrieved for proper aggregation.
|
||||||
|
"""
|
||||||
|
documents = [" ".join([f"word{i}" for i in range(100)]), "short doc"]
|
||||||
|
query = "test query"
|
||||||
|
|
||||||
|
captured_payload = {}
|
||||||
|
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.status = 200
|
||||||
|
mock_response.json = AsyncMock(
|
||||||
|
return_value={
|
||||||
|
"results": [
|
||||||
|
{"index": 0, "relevance_score": 0.9},
|
||||||
|
{"index": 1, "relevance_score": 0.8},
|
||||||
|
{"index": 2, "relevance_score": 0.7},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
mock_response.request_info = None
|
||||||
|
mock_response.history = None
|
||||||
|
mock_response.headers = {}
|
||||||
|
mock_response.__aenter__ = AsyncMock(return_value=mock_response)
|
||||||
|
mock_response.__aexit__ = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
def capture_post(*args, **kwargs):
|
||||||
|
captured_payload.update(kwargs.get("json", {}))
|
||||||
|
return mock_response
|
||||||
|
|
||||||
|
mock_session = Mock()
|
||||||
|
mock_session.post = Mock(side_effect=capture_post)
|
||||||
|
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
|
||||||
|
mock_session.__aexit__ = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
with patch("lightrag.rerank.aiohttp.ClientSession", return_value=mock_session):
|
||||||
|
await cohere_rerank(
|
||||||
|
query=query,
|
||||||
|
documents=documents,
|
||||||
|
api_key="test-key",
|
||||||
|
base_url="http://test.com/rerank",
|
||||||
|
enable_chunking=True,
|
||||||
|
max_tokens_per_doc=30,
|
||||||
|
top_n=1, # User wants top 1 document
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify: API payload should NOT have top_n (disabled for chunking)
|
||||||
|
assert "top_n" not in captured_payload
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_top_n_not_modified_when_chunking_disabled(self):
|
||||||
|
"""
|
||||||
|
Test that top_n is passed through to API when chunking is disabled.
|
||||||
|
"""
|
||||||
|
documents = ["doc1", "doc2"]
|
||||||
|
query = "test query"
|
||||||
|
|
||||||
|
captured_payload = {}
|
||||||
|
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.status = 200
|
||||||
|
mock_response.json = AsyncMock(
|
||||||
|
return_value={
|
||||||
|
"results": [
|
||||||
|
{"index": 0, "relevance_score": 0.9},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
mock_response.request_info = None
|
||||||
|
mock_response.history = None
|
||||||
|
mock_response.headers = {}
|
||||||
|
mock_response.__aenter__ = AsyncMock(return_value=mock_response)
|
||||||
|
mock_response.__aexit__ = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
def capture_post(*args, **kwargs):
|
||||||
|
captured_payload.update(kwargs.get("json", {}))
|
||||||
|
return mock_response
|
||||||
|
|
||||||
|
mock_session = Mock()
|
||||||
|
mock_session.post = Mock(side_effect=capture_post)
|
||||||
|
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
|
||||||
|
mock_session.__aexit__ = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
with patch("lightrag.rerank.aiohttp.ClientSession", return_value=mock_session):
|
||||||
|
await cohere_rerank(
|
||||||
|
query=query,
|
||||||
|
documents=documents,
|
||||||
|
api_key="test-key",
|
||||||
|
base_url="http://test.com/rerank",
|
||||||
|
enable_chunking=False, # Chunking disabled
|
||||||
|
top_n=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify: API payload should have top_n when chunking is disabled
|
||||||
|
assert captured_payload.get("top_n") == 1
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.offline
|
@pytest.mark.offline
|
||||||
class TestCohereRerankChunking:
|
class TestCohereRerankChunking:
|
||||||
"""Integration tests for cohere_rerank with chunking enabled"""
|
"""Integration tests for cohere_rerank with chunking enabled"""
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue