Improve edge case handling for max_tokens=1

Co-authored-by: netbrah <162479981+netbrah@users.noreply.github.com>
(cherry picked from commit 8835fc244a)
This commit is contained in:
copilot-swe-agent[bot] 2025-11-24 03:43:05 +00:00 committed by Raphaël MANSUY
parent 26602f3e20
commit b28a701532
2 changed files with 326 additions and 7 deletions

View file

@ -2,7 +2,7 @@ from __future__ import annotations
import os
import aiohttp
from typing import Any, List, Dict, Optional
from typing import Any, List, Dict, Optional, Tuple
from tenacity import (
retry,
stop_after_attempt,
@ -19,6 +19,158 @@ from dotenv import load_dotenv
load_dotenv(dotenv_path=".env", override=False)
def chunk_documents_for_rerank(
documents: List[str],
max_tokens: int = 480,
overlap_tokens: int = 32,
tokenizer_model: str = "gpt-4o-mini",
) -> Tuple[List[str], List[int]]:
"""
Chunk documents that exceed token limit for reranking.
Args:
documents: List of document strings to chunk
max_tokens: Maximum tokens per chunk (default 480 to leave margin for 512 limit)
overlap_tokens: Number of tokens to overlap between chunks
tokenizer_model: Model name for tiktoken tokenizer
Returns:
Tuple of (chunked_documents, original_doc_indices)
- chunked_documents: List of document chunks (may be more than input)
- original_doc_indices: Maps each chunk back to its original document index
"""
# Clamp overlap_tokens to ensure the loop always advances
# If overlap_tokens >= max_tokens, the chunking loop would hang
if overlap_tokens >= max_tokens:
original_overlap = overlap_tokens
# Ensure overlap is at least 1 token less than max to guarantee progress
# For very small max_tokens (e.g., 1), set overlap to 0
overlap_tokens = max(0, max_tokens - 1)
logger.warning(
f"overlap_tokens ({original_overlap}) must be less than max_tokens ({max_tokens}). "
f"Clamping to {overlap_tokens} to prevent infinite loop."
)
try:
from .utils import TiktokenTokenizer
tokenizer = TiktokenTokenizer(model_name=tokenizer_model)
except Exception as e:
logger.warning(
f"Failed to initialize tokenizer: {e}. Using character-based approximation."
)
# Fallback: approximate 1 token ≈ 4 characters
max_chars = max_tokens * 4
overlap_chars = overlap_tokens * 4
chunked_docs = []
doc_indices = []
for idx, doc in enumerate(documents):
if len(doc) <= max_chars:
chunked_docs.append(doc)
doc_indices.append(idx)
else:
# Split into overlapping chunks
start = 0
while start < len(doc):
end = min(start + max_chars, len(doc))
chunk = doc[start:end]
chunked_docs.append(chunk)
doc_indices.append(idx)
if end >= len(doc):
break
start = end - overlap_chars
return chunked_docs, doc_indices
# Use tokenizer for accurate chunking
chunked_docs = []
doc_indices = []
for idx, doc in enumerate(documents):
tokens = tokenizer.encode(doc)
if len(tokens) <= max_tokens:
# Document fits in one chunk
chunked_docs.append(doc)
doc_indices.append(idx)
else:
# Split into overlapping chunks
start = 0
while start < len(tokens):
end = min(start + max_tokens, len(tokens))
chunk_tokens = tokens[start:end]
chunk_text = tokenizer.decode(chunk_tokens)
chunked_docs.append(chunk_text)
doc_indices.append(idx)
if end >= len(tokens):
break
start = end - overlap_tokens
return chunked_docs, doc_indices
def aggregate_chunk_scores(
chunk_results: List[Dict[str, Any]],
doc_indices: List[int],
num_original_docs: int,
aggregation: str = "max",
) -> List[Dict[str, Any]]:
"""
Aggregate rerank scores from document chunks back to original documents.
Args:
chunk_results: Rerank results for chunks [{"index": chunk_idx, "relevance_score": score}, ...]
doc_indices: Maps each chunk index to original document index
num_original_docs: Total number of original documents
aggregation: Strategy for aggregating scores ("max", "mean", "first")
Returns:
List of results for original documents [{"index": doc_idx, "relevance_score": score}, ...]
"""
# Group scores by original document index
doc_scores: Dict[int, List[float]] = {i: [] for i in range(num_original_docs)}
for result in chunk_results:
chunk_idx = result["index"]
score = result["relevance_score"]
if 0 <= chunk_idx < len(doc_indices):
original_doc_idx = doc_indices[chunk_idx]
doc_scores[original_doc_idx].append(score)
# Aggregate scores
aggregated_results = []
for doc_idx, scores in doc_scores.items():
if not scores:
continue
if aggregation == "max":
final_score = max(scores)
elif aggregation == "mean":
final_score = sum(scores) / len(scores)
elif aggregation == "first":
final_score = scores[0]
else:
logger.warning(f"Unknown aggregation strategy: {aggregation}, using max")
final_score = max(scores)
aggregated_results.append(
{
"index": doc_idx,
"relevance_score": final_score,
}
)
# Sort by relevance score (descending)
aggregated_results.sort(key=lambda x: x["relevance_score"], reverse=True)
return aggregated_results
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=60),
@ -38,6 +190,8 @@ async def generic_rerank_api(
extra_body: Optional[Dict[str, Any]] = None,
response_format: str = "standard", # "standard" (Jina/Cohere) or "aliyun"
request_format: str = "standard", # "standard" (Jina/Cohere) or "aliyun"
enable_chunking: bool = False,
max_tokens_per_doc: int = 480,
) -> List[Dict[str, Any]]:
"""
Generic rerank API call for Jina/Cohere/Aliyun models.
@ -52,6 +206,9 @@ async def generic_rerank_api(
return_documents: Whether to return document text (Jina only)
extra_body: Additional body parameters
response_format: Response format type ("standard" for Jina/Cohere, "aliyun" for Aliyun)
request_format: Request format type
enable_chunking: Whether to chunk documents exceeding token limit
max_tokens_per_doc: Maximum tokens per document for chunking
Returns:
List of dictionary of ["index": int, "relevance_score": float]
@ -63,6 +220,17 @@ async def generic_rerank_api(
if api_key is not None:
headers["Authorization"] = f"Bearer {api_key}"
# Handle document chunking if enabled
original_documents = documents
doc_indices = None
if enable_chunking:
documents, doc_indices = chunk_documents_for_rerank(
documents, max_tokens=max_tokens_per_doc
)
logger.debug(
f"Chunked {len(original_documents)} documents into {len(documents)} chunks"
)
# Build request payload based on request format
if request_format == "aliyun":
# Aliyun format: nested input/parameters structure
@ -86,7 +254,7 @@ async def generic_rerank_api(
if extra_body:
payload["parameters"].update(extra_body)
else:
# Standard format for Jina/Cohere
# Standard format for Jina/Cohere/OpenAI
payload = {
"model": model,
"query": query,
@ -98,7 +266,7 @@ async def generic_rerank_api(
payload["top_n"] = top_n
# Only Jina API supports return_documents parameter
if return_documents is not None:
if return_documents is not None and response_format in ("standard",):
payload["return_documents"] = return_documents
# Add extra parameters
@ -147,7 +315,6 @@ async def generic_rerank_api(
f"Expected 'output.results' to be list, got {type(results)}: {results}"
)
results = []
elif response_format == "standard":
# Standard format: {"results": [...]}
results = response_json.get("results", [])
@ -158,16 +325,28 @@ async def generic_rerank_api(
results = []
else:
raise ValueError(f"Unsupported response format: {response_format}")
if not results:
logger.warning("Rerank API returned empty results")
return []
# Standardize return format
return [
standardized_results = [
{"index": result["index"], "relevance_score": result["relevance_score"]}
for result in results
]
# Aggregate chunk scores back to original documents if chunking was enabled
if enable_chunking and doc_indices:
standardized_results = aggregate_chunk_scores(
standardized_results,
doc_indices,
len(original_documents),
aggregation="max",
)
return standardized_results
async def cohere_rerank(
query: str,
@ -177,21 +356,46 @@ async def cohere_rerank(
model: str = "rerank-v3.5",
base_url: str = "https://api.cohere.com/v2/rerank",
extra_body: Optional[Dict[str, Any]] = None,
enable_chunking: bool = False,
max_tokens_per_doc: int = 4096,
) -> List[Dict[str, Any]]:
"""
Rerank documents using Cohere API.
Supports both standard Cohere API and Cohere-compatible proxies
Args:
query: The search query
documents: List of strings to rerank
top_n: Number of top results to return
api_key: API key
model: rerank model name
api_key: API key for authentication
model: rerank model name (default: rerank-v3.5)
base_url: API endpoint
extra_body: Additional body for http request(reserved for extra params)
enable_chunking: Whether to chunk documents exceeding max_tokens_per_doc
max_tokens_per_doc: Maximum tokens per document (default: 4096 for Cohere v3.5)
Returns:
List of dictionary of ["index": int, "relevance_score": float]
Example:
>>> # Standard Cohere API
>>> results = await cohere_rerank(
... query="What is the meaning of life?",
... documents=["Doc1", "Doc2"],
... api_key="your-cohere-key"
... )
>>> # LiteLLM proxy with user authentication
>>> results = await cohere_rerank(
... query="What is vector search?",
... documents=["Doc1", "Doc2"],
... model="answerai-colbert-small-v1",
... base_url="https://llm-proxy.example.com/v2/rerank",
... api_key="your-proxy-key",
... enable_chunking=True,
... max_tokens_per_doc=480
... )
"""
if api_key is None:
api_key = os.getenv("COHERE_API_KEY") or os.getenv("RERANK_BINDING_API_KEY")
@ -206,6 +410,8 @@ async def cohere_rerank(
return_documents=None, # Cohere doesn't support this parameter
extra_body=extra_body,
response_format="standard",
enable_chunking=enable_chunking,
max_tokens_per_doc=max_tokens_per_doc,
)

View file

@ -0,0 +1,113 @@
"""
Test for overlap_tokens validation to prevent infinite loop.
This test validates the fix for the bug where overlap_tokens >= max_tokens
causes an infinite loop in the chunking function.
"""
from lightrag.rerank import chunk_documents_for_rerank
class TestOverlapValidation:
"""Test suite for overlap_tokens validation"""
def test_overlap_greater_than_max_tokens(self):
"""Test that overlap_tokens > max_tokens is clamped and doesn't hang"""
documents = [" ".join([f"word{i}" for i in range(100)])]
# This should clamp overlap_tokens to 29 (max_tokens - 1)
chunked_docs, doc_indices = chunk_documents_for_rerank(
documents, max_tokens=30, overlap_tokens=32
)
# Should complete without hanging
assert len(chunked_docs) > 0
assert all(idx == 0 for idx in doc_indices)
def test_overlap_equal_to_max_tokens(self):
"""Test that overlap_tokens == max_tokens is clamped and doesn't hang"""
documents = [" ".join([f"word{i}" for i in range(100)])]
# This should clamp overlap_tokens to 29 (max_tokens - 1)
chunked_docs, doc_indices = chunk_documents_for_rerank(
documents, max_tokens=30, overlap_tokens=30
)
# Should complete without hanging
assert len(chunked_docs) > 0
assert all(idx == 0 for idx in doc_indices)
def test_overlap_slightly_less_than_max_tokens(self):
"""Test that overlap_tokens < max_tokens works normally"""
documents = [" ".join([f"word{i}" for i in range(100)])]
# This should work without clamping
chunked_docs, doc_indices = chunk_documents_for_rerank(
documents, max_tokens=30, overlap_tokens=29
)
# Should complete successfully
assert len(chunked_docs) > 0
assert all(idx == 0 for idx in doc_indices)
def test_small_max_tokens_with_large_overlap(self):
"""Test edge case with very small max_tokens"""
documents = [" ".join([f"word{i}" for i in range(50)])]
# max_tokens=5, overlap_tokens=10 should clamp to 4
chunked_docs, doc_indices = chunk_documents_for_rerank(
documents, max_tokens=5, overlap_tokens=10
)
# Should complete without hanging
assert len(chunked_docs) > 0
assert all(idx == 0 for idx in doc_indices)
def test_multiple_documents_with_invalid_overlap(self):
"""Test multiple documents with overlap_tokens >= max_tokens"""
documents = [
" ".join([f"word{i}" for i in range(50)]),
"short document",
" ".join([f"word{i}" for i in range(75)]),
]
# overlap_tokens > max_tokens
chunked_docs, doc_indices = chunk_documents_for_rerank(
documents, max_tokens=25, overlap_tokens=30
)
# Should complete successfully and chunk the long documents
assert len(chunked_docs) >= len(documents)
# Short document should not be chunked
assert "short document" in chunked_docs
def test_normal_operation_unaffected(self):
"""Test that normal cases continue to work correctly"""
documents = [
" ".join([f"word{i}" for i in range(100)]),
"short doc",
]
# Normal case: overlap_tokens (10) < max_tokens (50)
chunked_docs, doc_indices = chunk_documents_for_rerank(
documents, max_tokens=50, overlap_tokens=10
)
# Long document should be chunked, short one should not
assert len(chunked_docs) > 2 # At least 3 chunks (2 from long doc + 1 short)
assert "short doc" in chunked_docs
# Verify doc_indices maps correctly
assert doc_indices[-1] == 1 # Last chunk is from second document
def test_edge_case_max_tokens_one(self):
"""Test edge case where max_tokens=1"""
documents = [" ".join([f"word{i}" for i in range(20)])]
# max_tokens=1, overlap_tokens=5 should clamp to 0
chunked_docs, doc_indices = chunk_documents_for_rerank(
documents, max_tokens=1, overlap_tokens=5
)
# Should complete without hanging
assert len(chunked_docs) > 0
assert all(idx == 0 for idx in doc_indices)