Improve edge case handling for max_tokens=1
Co-authored-by: netbrah <162479981+netbrah@users.noreply.github.com>
(cherry picked from commit 8835fc244a)
This commit is contained in:
parent
26602f3e20
commit
b28a701532
2 changed files with 326 additions and 7 deletions
|
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
|||
|
||||
import os
|
||||
import aiohttp
|
||||
from typing import Any, List, Dict, Optional
|
||||
from typing import Any, List, Dict, Optional, Tuple
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
|
|
@ -19,6 +19,158 @@ from dotenv import load_dotenv
|
|||
load_dotenv(dotenv_path=".env", override=False)
|
||||
|
||||
|
||||
def chunk_documents_for_rerank(
|
||||
documents: List[str],
|
||||
max_tokens: int = 480,
|
||||
overlap_tokens: int = 32,
|
||||
tokenizer_model: str = "gpt-4o-mini",
|
||||
) -> Tuple[List[str], List[int]]:
|
||||
"""
|
||||
Chunk documents that exceed token limit for reranking.
|
||||
|
||||
Args:
|
||||
documents: List of document strings to chunk
|
||||
max_tokens: Maximum tokens per chunk (default 480 to leave margin for 512 limit)
|
||||
overlap_tokens: Number of tokens to overlap between chunks
|
||||
tokenizer_model: Model name for tiktoken tokenizer
|
||||
|
||||
Returns:
|
||||
Tuple of (chunked_documents, original_doc_indices)
|
||||
- chunked_documents: List of document chunks (may be more than input)
|
||||
- original_doc_indices: Maps each chunk back to its original document index
|
||||
"""
|
||||
# Clamp overlap_tokens to ensure the loop always advances
|
||||
# If overlap_tokens >= max_tokens, the chunking loop would hang
|
||||
if overlap_tokens >= max_tokens:
|
||||
original_overlap = overlap_tokens
|
||||
# Ensure overlap is at least 1 token less than max to guarantee progress
|
||||
# For very small max_tokens (e.g., 1), set overlap to 0
|
||||
overlap_tokens = max(0, max_tokens - 1)
|
||||
logger.warning(
|
||||
f"overlap_tokens ({original_overlap}) must be less than max_tokens ({max_tokens}). "
|
||||
f"Clamping to {overlap_tokens} to prevent infinite loop."
|
||||
)
|
||||
|
||||
try:
|
||||
from .utils import TiktokenTokenizer
|
||||
|
||||
tokenizer = TiktokenTokenizer(model_name=tokenizer_model)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to initialize tokenizer: {e}. Using character-based approximation."
|
||||
)
|
||||
# Fallback: approximate 1 token ≈ 4 characters
|
||||
max_chars = max_tokens * 4
|
||||
overlap_chars = overlap_tokens * 4
|
||||
|
||||
chunked_docs = []
|
||||
doc_indices = []
|
||||
|
||||
for idx, doc in enumerate(documents):
|
||||
if len(doc) <= max_chars:
|
||||
chunked_docs.append(doc)
|
||||
doc_indices.append(idx)
|
||||
else:
|
||||
# Split into overlapping chunks
|
||||
start = 0
|
||||
while start < len(doc):
|
||||
end = min(start + max_chars, len(doc))
|
||||
chunk = doc[start:end]
|
||||
chunked_docs.append(chunk)
|
||||
doc_indices.append(idx)
|
||||
|
||||
if end >= len(doc):
|
||||
break
|
||||
start = end - overlap_chars
|
||||
|
||||
return chunked_docs, doc_indices
|
||||
|
||||
# Use tokenizer for accurate chunking
|
||||
chunked_docs = []
|
||||
doc_indices = []
|
||||
|
||||
for idx, doc in enumerate(documents):
|
||||
tokens = tokenizer.encode(doc)
|
||||
|
||||
if len(tokens) <= max_tokens:
|
||||
# Document fits in one chunk
|
||||
chunked_docs.append(doc)
|
||||
doc_indices.append(idx)
|
||||
else:
|
||||
# Split into overlapping chunks
|
||||
start = 0
|
||||
while start < len(tokens):
|
||||
end = min(start + max_tokens, len(tokens))
|
||||
chunk_tokens = tokens[start:end]
|
||||
chunk_text = tokenizer.decode(chunk_tokens)
|
||||
chunked_docs.append(chunk_text)
|
||||
doc_indices.append(idx)
|
||||
|
||||
if end >= len(tokens):
|
||||
break
|
||||
start = end - overlap_tokens
|
||||
|
||||
return chunked_docs, doc_indices
|
||||
|
||||
|
||||
def aggregate_chunk_scores(
|
||||
chunk_results: List[Dict[str, Any]],
|
||||
doc_indices: List[int],
|
||||
num_original_docs: int,
|
||||
aggregation: str = "max",
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Aggregate rerank scores from document chunks back to original documents.
|
||||
|
||||
Args:
|
||||
chunk_results: Rerank results for chunks [{"index": chunk_idx, "relevance_score": score}, ...]
|
||||
doc_indices: Maps each chunk index to original document index
|
||||
num_original_docs: Total number of original documents
|
||||
aggregation: Strategy for aggregating scores ("max", "mean", "first")
|
||||
|
||||
Returns:
|
||||
List of results for original documents [{"index": doc_idx, "relevance_score": score}, ...]
|
||||
"""
|
||||
# Group scores by original document index
|
||||
doc_scores: Dict[int, List[float]] = {i: [] for i in range(num_original_docs)}
|
||||
|
||||
for result in chunk_results:
|
||||
chunk_idx = result["index"]
|
||||
score = result["relevance_score"]
|
||||
|
||||
if 0 <= chunk_idx < len(doc_indices):
|
||||
original_doc_idx = doc_indices[chunk_idx]
|
||||
doc_scores[original_doc_idx].append(score)
|
||||
|
||||
# Aggregate scores
|
||||
aggregated_results = []
|
||||
for doc_idx, scores in doc_scores.items():
|
||||
if not scores:
|
||||
continue
|
||||
|
||||
if aggregation == "max":
|
||||
final_score = max(scores)
|
||||
elif aggregation == "mean":
|
||||
final_score = sum(scores) / len(scores)
|
||||
elif aggregation == "first":
|
||||
final_score = scores[0]
|
||||
else:
|
||||
logger.warning(f"Unknown aggregation strategy: {aggregation}, using max")
|
||||
final_score = max(scores)
|
||||
|
||||
aggregated_results.append(
|
||||
{
|
||||
"index": doc_idx,
|
||||
"relevance_score": final_score,
|
||||
}
|
||||
)
|
||||
|
||||
# Sort by relevance score (descending)
|
||||
aggregated_results.sort(key=lambda x: x["relevance_score"], reverse=True)
|
||||
|
||||
return aggregated_results
|
||||
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||
|
|
@ -38,6 +190,8 @@ async def generic_rerank_api(
|
|||
extra_body: Optional[Dict[str, Any]] = None,
|
||||
response_format: str = "standard", # "standard" (Jina/Cohere) or "aliyun"
|
||||
request_format: str = "standard", # "standard" (Jina/Cohere) or "aliyun"
|
||||
enable_chunking: bool = False,
|
||||
max_tokens_per_doc: int = 480,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Generic rerank API call for Jina/Cohere/Aliyun models.
|
||||
|
|
@ -52,6 +206,9 @@ async def generic_rerank_api(
|
|||
return_documents: Whether to return document text (Jina only)
|
||||
extra_body: Additional body parameters
|
||||
response_format: Response format type ("standard" for Jina/Cohere, "aliyun" for Aliyun)
|
||||
request_format: Request format type
|
||||
enable_chunking: Whether to chunk documents exceeding token limit
|
||||
max_tokens_per_doc: Maximum tokens per document for chunking
|
||||
|
||||
Returns:
|
||||
List of dictionary of ["index": int, "relevance_score": float]
|
||||
|
|
@ -63,6 +220,17 @@ async def generic_rerank_api(
|
|||
if api_key is not None:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
# Handle document chunking if enabled
|
||||
original_documents = documents
|
||||
doc_indices = None
|
||||
if enable_chunking:
|
||||
documents, doc_indices = chunk_documents_for_rerank(
|
||||
documents, max_tokens=max_tokens_per_doc
|
||||
)
|
||||
logger.debug(
|
||||
f"Chunked {len(original_documents)} documents into {len(documents)} chunks"
|
||||
)
|
||||
|
||||
# Build request payload based on request format
|
||||
if request_format == "aliyun":
|
||||
# Aliyun format: nested input/parameters structure
|
||||
|
|
@ -86,7 +254,7 @@ async def generic_rerank_api(
|
|||
if extra_body:
|
||||
payload["parameters"].update(extra_body)
|
||||
else:
|
||||
# Standard format for Jina/Cohere
|
||||
# Standard format for Jina/Cohere/OpenAI
|
||||
payload = {
|
||||
"model": model,
|
||||
"query": query,
|
||||
|
|
@ -98,7 +266,7 @@ async def generic_rerank_api(
|
|||
payload["top_n"] = top_n
|
||||
|
||||
# Only Jina API supports return_documents parameter
|
||||
if return_documents is not None:
|
||||
if return_documents is not None and response_format in ("standard",):
|
||||
payload["return_documents"] = return_documents
|
||||
|
||||
# Add extra parameters
|
||||
|
|
@ -147,7 +315,6 @@ async def generic_rerank_api(
|
|||
f"Expected 'output.results' to be list, got {type(results)}: {results}"
|
||||
)
|
||||
results = []
|
||||
|
||||
elif response_format == "standard":
|
||||
# Standard format: {"results": [...]}
|
||||
results = response_json.get("results", [])
|
||||
|
|
@ -158,16 +325,28 @@ async def generic_rerank_api(
|
|||
results = []
|
||||
else:
|
||||
raise ValueError(f"Unsupported response format: {response_format}")
|
||||
|
||||
if not results:
|
||||
logger.warning("Rerank API returned empty results")
|
||||
return []
|
||||
|
||||
# Standardize return format
|
||||
return [
|
||||
standardized_results = [
|
||||
{"index": result["index"], "relevance_score": result["relevance_score"]}
|
||||
for result in results
|
||||
]
|
||||
|
||||
# Aggregate chunk scores back to original documents if chunking was enabled
|
||||
if enable_chunking and doc_indices:
|
||||
standardized_results = aggregate_chunk_scores(
|
||||
standardized_results,
|
||||
doc_indices,
|
||||
len(original_documents),
|
||||
aggregation="max",
|
||||
)
|
||||
|
||||
return standardized_results
|
||||
|
||||
|
||||
async def cohere_rerank(
|
||||
query: str,
|
||||
|
|
@ -177,21 +356,46 @@ async def cohere_rerank(
|
|||
model: str = "rerank-v3.5",
|
||||
base_url: str = "https://api.cohere.com/v2/rerank",
|
||||
extra_body: Optional[Dict[str, Any]] = None,
|
||||
enable_chunking: bool = False,
|
||||
max_tokens_per_doc: int = 4096,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Rerank documents using Cohere API.
|
||||
|
||||
Supports both standard Cohere API and Cohere-compatible proxies
|
||||
|
||||
Args:
|
||||
query: The search query
|
||||
documents: List of strings to rerank
|
||||
top_n: Number of top results to return
|
||||
api_key: API key
|
||||
model: rerank model name
|
||||
api_key: API key for authentication
|
||||
model: rerank model name (default: rerank-v3.5)
|
||||
base_url: API endpoint
|
||||
extra_body: Additional body for http request(reserved for extra params)
|
||||
enable_chunking: Whether to chunk documents exceeding max_tokens_per_doc
|
||||
max_tokens_per_doc: Maximum tokens per document (default: 4096 for Cohere v3.5)
|
||||
|
||||
Returns:
|
||||
List of dictionary of ["index": int, "relevance_score": float]
|
||||
|
||||
Example:
|
||||
>>> # Standard Cohere API
|
||||
>>> results = await cohere_rerank(
|
||||
... query="What is the meaning of life?",
|
||||
... documents=["Doc1", "Doc2"],
|
||||
... api_key="your-cohere-key"
|
||||
... )
|
||||
|
||||
>>> # LiteLLM proxy with user authentication
|
||||
>>> results = await cohere_rerank(
|
||||
... query="What is vector search?",
|
||||
... documents=["Doc1", "Doc2"],
|
||||
... model="answerai-colbert-small-v1",
|
||||
... base_url="https://llm-proxy.example.com/v2/rerank",
|
||||
... api_key="your-proxy-key",
|
||||
... enable_chunking=True,
|
||||
... max_tokens_per_doc=480
|
||||
... )
|
||||
"""
|
||||
if api_key is None:
|
||||
api_key = os.getenv("COHERE_API_KEY") or os.getenv("RERANK_BINDING_API_KEY")
|
||||
|
|
@ -206,6 +410,8 @@ async def cohere_rerank(
|
|||
return_documents=None, # Cohere doesn't support this parameter
|
||||
extra_body=extra_body,
|
||||
response_format="standard",
|
||||
enable_chunking=enable_chunking,
|
||||
max_tokens_per_doc=max_tokens_per_doc,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
113
tests/test_overlap_validation.py
Normal file
113
tests/test_overlap_validation.py
Normal file
|
|
@ -0,0 +1,113 @@
|
|||
"""
|
||||
Test for overlap_tokens validation to prevent infinite loop.
|
||||
|
||||
This test validates the fix for the bug where overlap_tokens >= max_tokens
|
||||
causes an infinite loop in the chunking function.
|
||||
"""
|
||||
|
||||
from lightrag.rerank import chunk_documents_for_rerank
|
||||
|
||||
|
||||
class TestOverlapValidation:
|
||||
"""Test suite for overlap_tokens validation"""
|
||||
|
||||
def test_overlap_greater_than_max_tokens(self):
|
||||
"""Test that overlap_tokens > max_tokens is clamped and doesn't hang"""
|
||||
documents = [" ".join([f"word{i}" for i in range(100)])]
|
||||
|
||||
# This should clamp overlap_tokens to 29 (max_tokens - 1)
|
||||
chunked_docs, doc_indices = chunk_documents_for_rerank(
|
||||
documents, max_tokens=30, overlap_tokens=32
|
||||
)
|
||||
|
||||
# Should complete without hanging
|
||||
assert len(chunked_docs) > 0
|
||||
assert all(idx == 0 for idx in doc_indices)
|
||||
|
||||
def test_overlap_equal_to_max_tokens(self):
|
||||
"""Test that overlap_tokens == max_tokens is clamped and doesn't hang"""
|
||||
documents = [" ".join([f"word{i}" for i in range(100)])]
|
||||
|
||||
# This should clamp overlap_tokens to 29 (max_tokens - 1)
|
||||
chunked_docs, doc_indices = chunk_documents_for_rerank(
|
||||
documents, max_tokens=30, overlap_tokens=30
|
||||
)
|
||||
|
||||
# Should complete without hanging
|
||||
assert len(chunked_docs) > 0
|
||||
assert all(idx == 0 for idx in doc_indices)
|
||||
|
||||
def test_overlap_slightly_less_than_max_tokens(self):
|
||||
"""Test that overlap_tokens < max_tokens works normally"""
|
||||
documents = [" ".join([f"word{i}" for i in range(100)])]
|
||||
|
||||
# This should work without clamping
|
||||
chunked_docs, doc_indices = chunk_documents_for_rerank(
|
||||
documents, max_tokens=30, overlap_tokens=29
|
||||
)
|
||||
|
||||
# Should complete successfully
|
||||
assert len(chunked_docs) > 0
|
||||
assert all(idx == 0 for idx in doc_indices)
|
||||
|
||||
def test_small_max_tokens_with_large_overlap(self):
|
||||
"""Test edge case with very small max_tokens"""
|
||||
documents = [" ".join([f"word{i}" for i in range(50)])]
|
||||
|
||||
# max_tokens=5, overlap_tokens=10 should clamp to 4
|
||||
chunked_docs, doc_indices = chunk_documents_for_rerank(
|
||||
documents, max_tokens=5, overlap_tokens=10
|
||||
)
|
||||
|
||||
# Should complete without hanging
|
||||
assert len(chunked_docs) > 0
|
||||
assert all(idx == 0 for idx in doc_indices)
|
||||
|
||||
def test_multiple_documents_with_invalid_overlap(self):
|
||||
"""Test multiple documents with overlap_tokens >= max_tokens"""
|
||||
documents = [
|
||||
" ".join([f"word{i}" for i in range(50)]),
|
||||
"short document",
|
||||
" ".join([f"word{i}" for i in range(75)]),
|
||||
]
|
||||
|
||||
# overlap_tokens > max_tokens
|
||||
chunked_docs, doc_indices = chunk_documents_for_rerank(
|
||||
documents, max_tokens=25, overlap_tokens=30
|
||||
)
|
||||
|
||||
# Should complete successfully and chunk the long documents
|
||||
assert len(chunked_docs) >= len(documents)
|
||||
# Short document should not be chunked
|
||||
assert "short document" in chunked_docs
|
||||
|
||||
def test_normal_operation_unaffected(self):
|
||||
"""Test that normal cases continue to work correctly"""
|
||||
documents = [
|
||||
" ".join([f"word{i}" for i in range(100)]),
|
||||
"short doc",
|
||||
]
|
||||
|
||||
# Normal case: overlap_tokens (10) < max_tokens (50)
|
||||
chunked_docs, doc_indices = chunk_documents_for_rerank(
|
||||
documents, max_tokens=50, overlap_tokens=10
|
||||
)
|
||||
|
||||
# Long document should be chunked, short one should not
|
||||
assert len(chunked_docs) > 2 # At least 3 chunks (2 from long doc + 1 short)
|
||||
assert "short doc" in chunked_docs
|
||||
# Verify doc_indices maps correctly
|
||||
assert doc_indices[-1] == 1 # Last chunk is from second document
|
||||
|
||||
def test_edge_case_max_tokens_one(self):
|
||||
"""Test edge case where max_tokens=1"""
|
||||
documents = [" ".join([f"word{i}" for i in range(20)])]
|
||||
|
||||
# max_tokens=1, overlap_tokens=5 should clamp to 0
|
||||
chunked_docs, doc_indices = chunk_documents_for_rerank(
|
||||
documents, max_tokens=1, overlap_tokens=5
|
||||
)
|
||||
|
||||
# Should complete without hanging
|
||||
assert len(chunked_docs) > 0
|
||||
assert all(idx == 0 for idx in doc_indices)
|
||||
Loading…
Add table
Reference in a new issue