diff --git a/lightrag/rerank.py b/lightrag/rerank.py index 35551f5a..81632b71 100644 --- a/lightrag/rerank.py +++ b/lightrag/rerank.py @@ -2,7 +2,7 @@ from __future__ import annotations import os import aiohttp -from typing import Any, List, Dict, Optional +from typing import Any, List, Dict, Optional, Tuple from tenacity import ( retry, stop_after_attempt, @@ -19,6 +19,158 @@ from dotenv import load_dotenv load_dotenv(dotenv_path=".env", override=False) +def chunk_documents_for_rerank( + documents: List[str], + max_tokens: int = 480, + overlap_tokens: int = 32, + tokenizer_model: str = "gpt-4o-mini", +) -> Tuple[List[str], List[int]]: + """ + Chunk documents that exceed token limit for reranking. + + Args: + documents: List of document strings to chunk + max_tokens: Maximum tokens per chunk (default 480 to leave margin for 512 limit) + overlap_tokens: Number of tokens to overlap between chunks + tokenizer_model: Model name for tiktoken tokenizer + + Returns: + Tuple of (chunked_documents, original_doc_indices) + - chunked_documents: List of document chunks (may be more than input) + - original_doc_indices: Maps each chunk back to its original document index + """ + # Clamp overlap_tokens to ensure the loop always advances + # If overlap_tokens >= max_tokens, the chunking loop would hang + if overlap_tokens >= max_tokens: + original_overlap = overlap_tokens + # Ensure overlap is at least 1 token less than max to guarantee progress + # For very small max_tokens (e.g., 1), set overlap to 0 + overlap_tokens = max(0, max_tokens - 1) + logger.warning( + f"overlap_tokens ({original_overlap}) must be less than max_tokens ({max_tokens}). " + f"Clamping to {overlap_tokens} to prevent infinite loop." + ) + + try: + from .utils import TiktokenTokenizer + + tokenizer = TiktokenTokenizer(model_name=tokenizer_model) + except Exception as e: + logger.warning( + f"Failed to initialize tokenizer: {e}. Using character-based approximation." + ) + # Fallback: approximate 1 token ≈ 4 characters + max_chars = max_tokens * 4 + overlap_chars = overlap_tokens * 4 + + chunked_docs = [] + doc_indices = [] + + for idx, doc in enumerate(documents): + if len(doc) <= max_chars: + chunked_docs.append(doc) + doc_indices.append(idx) + else: + # Split into overlapping chunks + start = 0 + while start < len(doc): + end = min(start + max_chars, len(doc)) + chunk = doc[start:end] + chunked_docs.append(chunk) + doc_indices.append(idx) + + if end >= len(doc): + break + start = end - overlap_chars + + return chunked_docs, doc_indices + + # Use tokenizer for accurate chunking + chunked_docs = [] + doc_indices = [] + + for idx, doc in enumerate(documents): + tokens = tokenizer.encode(doc) + + if len(tokens) <= max_tokens: + # Document fits in one chunk + chunked_docs.append(doc) + doc_indices.append(idx) + else: + # Split into overlapping chunks + start = 0 + while start < len(tokens): + end = min(start + max_tokens, len(tokens)) + chunk_tokens = tokens[start:end] + chunk_text = tokenizer.decode(chunk_tokens) + chunked_docs.append(chunk_text) + doc_indices.append(idx) + + if end >= len(tokens): + break + start = end - overlap_tokens + + return chunked_docs, doc_indices + + +def aggregate_chunk_scores( + chunk_results: List[Dict[str, Any]], + doc_indices: List[int], + num_original_docs: int, + aggregation: str = "max", +) -> List[Dict[str, Any]]: + """ + Aggregate rerank scores from document chunks back to original documents. + + Args: + chunk_results: Rerank results for chunks [{"index": chunk_idx, "relevance_score": score}, ...] + doc_indices: Maps each chunk index to original document index + num_original_docs: Total number of original documents + aggregation: Strategy for aggregating scores ("max", "mean", "first") + + Returns: + List of results for original documents [{"index": doc_idx, "relevance_score": score}, ...] + """ + # Group scores by original document index + doc_scores: Dict[int, List[float]] = {i: [] for i in range(num_original_docs)} + + for result in chunk_results: + chunk_idx = result["index"] + score = result["relevance_score"] + + if 0 <= chunk_idx < len(doc_indices): + original_doc_idx = doc_indices[chunk_idx] + doc_scores[original_doc_idx].append(score) + + # Aggregate scores + aggregated_results = [] + for doc_idx, scores in doc_scores.items(): + if not scores: + continue + + if aggregation == "max": + final_score = max(scores) + elif aggregation == "mean": + final_score = sum(scores) / len(scores) + elif aggregation == "first": + final_score = scores[0] + else: + logger.warning(f"Unknown aggregation strategy: {aggregation}, using max") + final_score = max(scores) + + aggregated_results.append( + { + "index": doc_idx, + "relevance_score": final_score, + } + ) + + # Sort by relevance score (descending) + aggregated_results.sort(key=lambda x: x["relevance_score"], reverse=True) + + return aggregated_results + + @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=60), @@ -38,6 +190,8 @@ async def generic_rerank_api( extra_body: Optional[Dict[str, Any]] = None, response_format: str = "standard", # "standard" (Jina/Cohere) or "aliyun" request_format: str = "standard", # "standard" (Jina/Cohere) or "aliyun" + enable_chunking: bool = False, + max_tokens_per_doc: int = 480, ) -> List[Dict[str, Any]]: """ Generic rerank API call for Jina/Cohere/Aliyun models. @@ -52,6 +206,9 @@ async def generic_rerank_api( return_documents: Whether to return document text (Jina only) extra_body: Additional body parameters response_format: Response format type ("standard" for Jina/Cohere, "aliyun" for Aliyun) + request_format: Request format type + enable_chunking: Whether to chunk documents exceeding token limit + max_tokens_per_doc: Maximum tokens per document for chunking Returns: List of dictionary of ["index": int, "relevance_score": float] @@ -63,6 +220,17 @@ async def generic_rerank_api( if api_key is not None: headers["Authorization"] = f"Bearer {api_key}" + # Handle document chunking if enabled + original_documents = documents + doc_indices = None + if enable_chunking: + documents, doc_indices = chunk_documents_for_rerank( + documents, max_tokens=max_tokens_per_doc + ) + logger.debug( + f"Chunked {len(original_documents)} documents into {len(documents)} chunks" + ) + # Build request payload based on request format if request_format == "aliyun": # Aliyun format: nested input/parameters structure @@ -86,7 +254,7 @@ async def generic_rerank_api( if extra_body: payload["parameters"].update(extra_body) else: - # Standard format for Jina/Cohere + # Standard format for Jina/Cohere/OpenAI payload = { "model": model, "query": query, @@ -98,7 +266,7 @@ async def generic_rerank_api( payload["top_n"] = top_n # Only Jina API supports return_documents parameter - if return_documents is not None: + if return_documents is not None and response_format in ("standard",): payload["return_documents"] = return_documents # Add extra parameters @@ -147,7 +315,6 @@ async def generic_rerank_api( f"Expected 'output.results' to be list, got {type(results)}: {results}" ) results = [] - elif response_format == "standard": # Standard format: {"results": [...]} results = response_json.get("results", []) @@ -158,16 +325,28 @@ async def generic_rerank_api( results = [] else: raise ValueError(f"Unsupported response format: {response_format}") + if not results: logger.warning("Rerank API returned empty results") return [] # Standardize return format - return [ + standardized_results = [ {"index": result["index"], "relevance_score": result["relevance_score"]} for result in results ] + # Aggregate chunk scores back to original documents if chunking was enabled + if enable_chunking and doc_indices: + standardized_results = aggregate_chunk_scores( + standardized_results, + doc_indices, + len(original_documents), + aggregation="max", + ) + + return standardized_results + async def cohere_rerank( query: str, @@ -177,21 +356,46 @@ async def cohere_rerank( model: str = "rerank-v3.5", base_url: str = "https://api.cohere.com/v2/rerank", extra_body: Optional[Dict[str, Any]] = None, + enable_chunking: bool = False, + max_tokens_per_doc: int = 4096, ) -> List[Dict[str, Any]]: """ Rerank documents using Cohere API. + Supports both standard Cohere API and Cohere-compatible proxies + Args: query: The search query documents: List of strings to rerank top_n: Number of top results to return - api_key: API key - model: rerank model name + api_key: API key for authentication + model: rerank model name (default: rerank-v3.5) base_url: API endpoint extra_body: Additional body for http request(reserved for extra params) + enable_chunking: Whether to chunk documents exceeding max_tokens_per_doc + max_tokens_per_doc: Maximum tokens per document (default: 4096 for Cohere v3.5) Returns: List of dictionary of ["index": int, "relevance_score": float] + + Example: + >>> # Standard Cohere API + >>> results = await cohere_rerank( + ... query="What is the meaning of life?", + ... documents=["Doc1", "Doc2"], + ... api_key="your-cohere-key" + ... ) + + >>> # LiteLLM proxy with user authentication + >>> results = await cohere_rerank( + ... query="What is vector search?", + ... documents=["Doc1", "Doc2"], + ... model="answerai-colbert-small-v1", + ... base_url="https://llm-proxy.example.com/v2/rerank", + ... api_key="your-proxy-key", + ... enable_chunking=True, + ... max_tokens_per_doc=480 + ... ) """ if api_key is None: api_key = os.getenv("COHERE_API_KEY") or os.getenv("RERANK_BINDING_API_KEY") @@ -206,6 +410,8 @@ async def cohere_rerank( return_documents=None, # Cohere doesn't support this parameter extra_body=extra_body, response_format="standard", + enable_chunking=enable_chunking, + max_tokens_per_doc=max_tokens_per_doc, ) diff --git a/tests/test_overlap_validation.py b/tests/test_overlap_validation.py new file mode 100644 index 00000000..7f84a3cf --- /dev/null +++ b/tests/test_overlap_validation.py @@ -0,0 +1,113 @@ +""" +Test for overlap_tokens validation to prevent infinite loop. + +This test validates the fix for the bug where overlap_tokens >= max_tokens +causes an infinite loop in the chunking function. +""" + +from lightrag.rerank import chunk_documents_for_rerank + + +class TestOverlapValidation: + """Test suite for overlap_tokens validation""" + + def test_overlap_greater_than_max_tokens(self): + """Test that overlap_tokens > max_tokens is clamped and doesn't hang""" + documents = [" ".join([f"word{i}" for i in range(100)])] + + # This should clamp overlap_tokens to 29 (max_tokens - 1) + chunked_docs, doc_indices = chunk_documents_for_rerank( + documents, max_tokens=30, overlap_tokens=32 + ) + + # Should complete without hanging + assert len(chunked_docs) > 0 + assert all(idx == 0 for idx in doc_indices) + + def test_overlap_equal_to_max_tokens(self): + """Test that overlap_tokens == max_tokens is clamped and doesn't hang""" + documents = [" ".join([f"word{i}" for i in range(100)])] + + # This should clamp overlap_tokens to 29 (max_tokens - 1) + chunked_docs, doc_indices = chunk_documents_for_rerank( + documents, max_tokens=30, overlap_tokens=30 + ) + + # Should complete without hanging + assert len(chunked_docs) > 0 + assert all(idx == 0 for idx in doc_indices) + + def test_overlap_slightly_less_than_max_tokens(self): + """Test that overlap_tokens < max_tokens works normally""" + documents = [" ".join([f"word{i}" for i in range(100)])] + + # This should work without clamping + chunked_docs, doc_indices = chunk_documents_for_rerank( + documents, max_tokens=30, overlap_tokens=29 + ) + + # Should complete successfully + assert len(chunked_docs) > 0 + assert all(idx == 0 for idx in doc_indices) + + def test_small_max_tokens_with_large_overlap(self): + """Test edge case with very small max_tokens""" + documents = [" ".join([f"word{i}" for i in range(50)])] + + # max_tokens=5, overlap_tokens=10 should clamp to 4 + chunked_docs, doc_indices = chunk_documents_for_rerank( + documents, max_tokens=5, overlap_tokens=10 + ) + + # Should complete without hanging + assert len(chunked_docs) > 0 + assert all(idx == 0 for idx in doc_indices) + + def test_multiple_documents_with_invalid_overlap(self): + """Test multiple documents with overlap_tokens >= max_tokens""" + documents = [ + " ".join([f"word{i}" for i in range(50)]), + "short document", + " ".join([f"word{i}" for i in range(75)]), + ] + + # overlap_tokens > max_tokens + chunked_docs, doc_indices = chunk_documents_for_rerank( + documents, max_tokens=25, overlap_tokens=30 + ) + + # Should complete successfully and chunk the long documents + assert len(chunked_docs) >= len(documents) + # Short document should not be chunked + assert "short document" in chunked_docs + + def test_normal_operation_unaffected(self): + """Test that normal cases continue to work correctly""" + documents = [ + " ".join([f"word{i}" for i in range(100)]), + "short doc", + ] + + # Normal case: overlap_tokens (10) < max_tokens (50) + chunked_docs, doc_indices = chunk_documents_for_rerank( + documents, max_tokens=50, overlap_tokens=10 + ) + + # Long document should be chunked, short one should not + assert len(chunked_docs) > 2 # At least 3 chunks (2 from long doc + 1 short) + assert "short doc" in chunked_docs + # Verify doc_indices maps correctly + assert doc_indices[-1] == 1 # Last chunk is from second document + + def test_edge_case_max_tokens_one(self): + """Test edge case where max_tokens=1""" + documents = [" ".join([f"word{i}" for i in range(20)])] + + # max_tokens=1, overlap_tokens=5 should clamp to 0 + chunked_docs, doc_indices = chunk_documents_for_rerank( + documents, max_tokens=1, overlap_tokens=5 + ) + + # Should complete without hanging + assert len(chunked_docs) > 0 + assert all(idx == 0 for idx in doc_indices)