Add Cohere reranker config, chunking, and tests
This commit is contained in:
parent
16eb0d5bee
commit
a05bbf105e
5 changed files with 620 additions and 20 deletions
|
|
@ -102,6 +102,9 @@ RERANK_BINDING=null
|
||||||
# RERANK_MODEL=rerank-v3.5
|
# RERANK_MODEL=rerank-v3.5
|
||||||
# RERANK_BINDING_HOST=https://api.cohere.com/v2/rerank
|
# RERANK_BINDING_HOST=https://api.cohere.com/v2/rerank
|
||||||
# RERANK_BINDING_API_KEY=your_rerank_api_key_here
|
# RERANK_BINDING_API_KEY=your_rerank_api_key_here
|
||||||
|
### Cohere rerank chunking configuration (useful for models with token limits like ColBERT)
|
||||||
|
# RERANK_ENABLE_CHUNKING=true
|
||||||
|
# RERANK_MAX_TOKENS_PER_DOC=480
|
||||||
|
|
||||||
### Default value for Jina AI
|
### Default value for Jina AI
|
||||||
# RERANK_MODEL=jina-reranker-v2-base-multilingual
|
# RERANK_MODEL=jina-reranker-v2-base-multilingual
|
||||||
|
|
|
||||||
|
|
@ -15,9 +15,12 @@ Configuration Required:
|
||||||
EMBEDDING_BINDING_HOST
|
EMBEDDING_BINDING_HOST
|
||||||
EMBEDDING_BINDING_API_KEY
|
EMBEDDING_BINDING_API_KEY
|
||||||
3. Set your vLLM deployed AI rerank model setting with env vars:
|
3. Set your vLLM deployed AI rerank model setting with env vars:
|
||||||
RERANK_MODEL
|
RERANK_BINDING=cohere
|
||||||
RERANK_BINDING_HOST
|
RERANK_MODEL (e.g., answerai-colbert-small-v1 or rerank-v3.5)
|
||||||
|
RERANK_BINDING_HOST (e.g., https://api.cohere.com/v2/rerank or LiteLLM proxy)
|
||||||
RERANK_BINDING_API_KEY
|
RERANK_BINDING_API_KEY
|
||||||
|
RERANK_ENABLE_CHUNKING=true (optional, for models with token limits)
|
||||||
|
RERANK_MAX_TOKENS_PER_DOC=480 (optional, default 4096)
|
||||||
|
|
||||||
Note: Rerank is controlled per query via the 'enable_rerank' parameter (default: True)
|
Note: Rerank is controlled per query via the 'enable_rerank' parameter (default: True)
|
||||||
"""
|
"""
|
||||||
|
|
@ -66,9 +69,11 @@ async def embedding_func(texts: list[str]) -> np.ndarray:
|
||||||
|
|
||||||
rerank_model_func = partial(
|
rerank_model_func = partial(
|
||||||
cohere_rerank,
|
cohere_rerank,
|
||||||
model=os.getenv("RERANK_MODEL"),
|
model=os.getenv("RERANK_MODEL", "rerank-v3.5"),
|
||||||
api_key=os.getenv("RERANK_BINDING_API_KEY"),
|
api_key=os.getenv("RERANK_BINDING_API_KEY"),
|
||||||
base_url=os.getenv("RERANK_BINDING_HOST"),
|
base_url=os.getenv("RERANK_BINDING_HOST", "https://api.cohere.com/v2/rerank"),
|
||||||
|
enable_chunking=os.getenv("RERANK_ENABLE_CHUNKING", "false").lower() == "true",
|
||||||
|
max_tokens_per_doc=int(os.getenv("RERANK_MAX_TOKENS_PER_DOC", "4096")),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -967,15 +967,27 @@ def create_app(args):
|
||||||
query: str, documents: list, top_n: int = None, extra_body: dict = None
|
query: str, documents: list, top_n: int = None, extra_body: dict = None
|
||||||
):
|
):
|
||||||
"""Server rerank function with configuration from environment variables"""
|
"""Server rerank function with configuration from environment variables"""
|
||||||
return await selected_rerank_func(
|
# Prepare kwargs for rerank function
|
||||||
query=query,
|
kwargs = {
|
||||||
documents=documents,
|
"query": query,
|
||||||
top_n=top_n,
|
"documents": documents,
|
||||||
api_key=args.rerank_binding_api_key,
|
"top_n": top_n,
|
||||||
model=args.rerank_model,
|
"api_key": args.rerank_binding_api_key,
|
||||||
base_url=args.rerank_binding_host,
|
"model": args.rerank_model,
|
||||||
extra_body=extra_body,
|
"base_url": args.rerank_binding_host,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add Cohere-specific parameters if using cohere binding
|
||||||
|
if args.rerank_binding == "cohere":
|
||||||
|
# Enable chunking if configured (useful for models with token limits like ColBERT)
|
||||||
|
kwargs["enable_chunking"] = (
|
||||||
|
os.getenv("RERANK_ENABLE_CHUNKING", "false").lower() == "true"
|
||||||
)
|
)
|
||||||
|
kwargs["max_tokens_per_doc"] = int(
|
||||||
|
os.getenv("RERANK_MAX_TOKENS_PER_DOC", "4096")
|
||||||
|
)
|
||||||
|
|
||||||
|
return await selected_rerank_func(**kwargs, extra_body=extra_body)
|
||||||
|
|
||||||
rerank_model_func = server_rerank_func
|
rerank_model_func = server_rerank_func
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from typing import Any, List, Dict, Optional
|
from typing import Any, List, Dict, Optional, Tuple
|
||||||
from tenacity import (
|
from tenacity import (
|
||||||
retry,
|
retry,
|
||||||
stop_after_attempt,
|
stop_after_attempt,
|
||||||
|
|
@ -19,6 +19,146 @@ from dotenv import load_dotenv
|
||||||
load_dotenv(dotenv_path=".env", override=False)
|
load_dotenv(dotenv_path=".env", override=False)
|
||||||
|
|
||||||
|
|
||||||
|
def chunk_documents_for_rerank(
|
||||||
|
documents: List[str],
|
||||||
|
max_tokens: int = 480,
|
||||||
|
overlap_tokens: int = 32,
|
||||||
|
tokenizer_model: str = "gpt-4o-mini",
|
||||||
|
) -> Tuple[List[str], List[int]]:
|
||||||
|
"""
|
||||||
|
Chunk documents that exceed token limit for reranking.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
documents: List of document strings to chunk
|
||||||
|
max_tokens: Maximum tokens per chunk (default 480 to leave margin for 512 limit)
|
||||||
|
overlap_tokens: Number of tokens to overlap between chunks
|
||||||
|
tokenizer_model: Model name for tiktoken tokenizer
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (chunked_documents, original_doc_indices)
|
||||||
|
- chunked_documents: List of document chunks (may be more than input)
|
||||||
|
- original_doc_indices: Maps each chunk back to its original document index
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from .utils import TiktokenTokenizer
|
||||||
|
|
||||||
|
tokenizer = TiktokenTokenizer(model_name=tokenizer_model)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to initialize tokenizer: {e}. Using character-based approximation."
|
||||||
|
)
|
||||||
|
# Fallback: approximate 1 token ≈ 4 characters
|
||||||
|
max_chars = max_tokens * 4
|
||||||
|
overlap_chars = overlap_tokens * 4
|
||||||
|
|
||||||
|
chunked_docs = []
|
||||||
|
doc_indices = []
|
||||||
|
|
||||||
|
for idx, doc in enumerate(documents):
|
||||||
|
if len(doc) <= max_chars:
|
||||||
|
chunked_docs.append(doc)
|
||||||
|
doc_indices.append(idx)
|
||||||
|
else:
|
||||||
|
# Split into overlapping chunks
|
||||||
|
start = 0
|
||||||
|
while start < len(doc):
|
||||||
|
end = min(start + max_chars, len(doc))
|
||||||
|
chunk = doc[start:end]
|
||||||
|
chunked_docs.append(chunk)
|
||||||
|
doc_indices.append(idx)
|
||||||
|
|
||||||
|
if end >= len(doc):
|
||||||
|
break
|
||||||
|
start = end - overlap_chars
|
||||||
|
|
||||||
|
return chunked_docs, doc_indices
|
||||||
|
|
||||||
|
# Use tokenizer for accurate chunking
|
||||||
|
chunked_docs = []
|
||||||
|
doc_indices = []
|
||||||
|
|
||||||
|
for idx, doc in enumerate(documents):
|
||||||
|
tokens = tokenizer.encode(doc)
|
||||||
|
|
||||||
|
if len(tokens) <= max_tokens:
|
||||||
|
# Document fits in one chunk
|
||||||
|
chunked_docs.append(doc)
|
||||||
|
doc_indices.append(idx)
|
||||||
|
else:
|
||||||
|
# Split into overlapping chunks
|
||||||
|
start = 0
|
||||||
|
while start < len(tokens):
|
||||||
|
end = min(start + max_tokens, len(tokens))
|
||||||
|
chunk_tokens = tokens[start:end]
|
||||||
|
chunk_text = tokenizer.decode(chunk_tokens)
|
||||||
|
chunked_docs.append(chunk_text)
|
||||||
|
doc_indices.append(idx)
|
||||||
|
|
||||||
|
if end >= len(tokens):
|
||||||
|
break
|
||||||
|
start = end - overlap_tokens
|
||||||
|
|
||||||
|
return chunked_docs, doc_indices
|
||||||
|
|
||||||
|
|
||||||
|
def aggregate_chunk_scores(
|
||||||
|
chunk_results: List[Dict[str, Any]],
|
||||||
|
doc_indices: List[int],
|
||||||
|
num_original_docs: int,
|
||||||
|
aggregation: str = "max",
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Aggregate rerank scores from document chunks back to original documents.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chunk_results: Rerank results for chunks [{"index": chunk_idx, "relevance_score": score}, ...]
|
||||||
|
doc_indices: Maps each chunk index to original document index
|
||||||
|
num_original_docs: Total number of original documents
|
||||||
|
aggregation: Strategy for aggregating scores ("max", "mean", "first")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of results for original documents [{"index": doc_idx, "relevance_score": score}, ...]
|
||||||
|
"""
|
||||||
|
# Group scores by original document index
|
||||||
|
doc_scores: Dict[int, List[float]] = {i: [] for i in range(num_original_docs)}
|
||||||
|
|
||||||
|
for result in chunk_results:
|
||||||
|
chunk_idx = result["index"]
|
||||||
|
score = result["relevance_score"]
|
||||||
|
|
||||||
|
if 0 <= chunk_idx < len(doc_indices):
|
||||||
|
original_doc_idx = doc_indices[chunk_idx]
|
||||||
|
doc_scores[original_doc_idx].append(score)
|
||||||
|
|
||||||
|
# Aggregate scores
|
||||||
|
aggregated_results = []
|
||||||
|
for doc_idx, scores in doc_scores.items():
|
||||||
|
if not scores:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if aggregation == "max":
|
||||||
|
final_score = max(scores)
|
||||||
|
elif aggregation == "mean":
|
||||||
|
final_score = sum(scores) / len(scores)
|
||||||
|
elif aggregation == "first":
|
||||||
|
final_score = scores[0]
|
||||||
|
else:
|
||||||
|
logger.warning(f"Unknown aggregation strategy: {aggregation}, using max")
|
||||||
|
final_score = max(scores)
|
||||||
|
|
||||||
|
aggregated_results.append(
|
||||||
|
{
|
||||||
|
"index": doc_idx,
|
||||||
|
"relevance_score": final_score,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Sort by relevance score (descending)
|
||||||
|
aggregated_results.sort(key=lambda x: x["relevance_score"], reverse=True)
|
||||||
|
|
||||||
|
return aggregated_results
|
||||||
|
|
||||||
|
|
||||||
@retry(
|
@retry(
|
||||||
stop=stop_after_attempt(3),
|
stop=stop_after_attempt(3),
|
||||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||||
|
|
@ -38,6 +178,8 @@ async def generic_rerank_api(
|
||||||
extra_body: Optional[Dict[str, Any]] = None,
|
extra_body: Optional[Dict[str, Any]] = None,
|
||||||
response_format: str = "standard", # "standard" (Jina/Cohere) or "aliyun"
|
response_format: str = "standard", # "standard" (Jina/Cohere) or "aliyun"
|
||||||
request_format: str = "standard", # "standard" (Jina/Cohere) or "aliyun"
|
request_format: str = "standard", # "standard" (Jina/Cohere) or "aliyun"
|
||||||
|
enable_chunking: bool = False,
|
||||||
|
max_tokens_per_doc: int = 480,
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Generic rerank API call for Jina/Cohere/Aliyun models.
|
Generic rerank API call for Jina/Cohere/Aliyun models.
|
||||||
|
|
@ -52,6 +194,9 @@ async def generic_rerank_api(
|
||||||
return_documents: Whether to return document text (Jina only)
|
return_documents: Whether to return document text (Jina only)
|
||||||
extra_body: Additional body parameters
|
extra_body: Additional body parameters
|
||||||
response_format: Response format type ("standard" for Jina/Cohere, "aliyun" for Aliyun)
|
response_format: Response format type ("standard" for Jina/Cohere, "aliyun" for Aliyun)
|
||||||
|
request_format: Request format type
|
||||||
|
enable_chunking: Whether to chunk documents exceeding token limit
|
||||||
|
max_tokens_per_doc: Maximum tokens per document for chunking
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of dictionary of ["index": int, "relevance_score": float]
|
List of dictionary of ["index": int, "relevance_score": float]
|
||||||
|
|
@ -63,6 +208,17 @@ async def generic_rerank_api(
|
||||||
if api_key is not None:
|
if api_key is not None:
|
||||||
headers["Authorization"] = f"Bearer {api_key}"
|
headers["Authorization"] = f"Bearer {api_key}"
|
||||||
|
|
||||||
|
# Handle document chunking if enabled
|
||||||
|
original_documents = documents
|
||||||
|
doc_indices = None
|
||||||
|
if enable_chunking:
|
||||||
|
documents, doc_indices = chunk_documents_for_rerank(
|
||||||
|
documents, max_tokens=max_tokens_per_doc
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
f"Chunked {len(original_documents)} documents into {len(documents)} chunks"
|
||||||
|
)
|
||||||
|
|
||||||
# Build request payload based on request format
|
# Build request payload based on request format
|
||||||
if request_format == "aliyun":
|
if request_format == "aliyun":
|
||||||
# Aliyun format: nested input/parameters structure
|
# Aliyun format: nested input/parameters structure
|
||||||
|
|
@ -86,7 +242,7 @@ async def generic_rerank_api(
|
||||||
if extra_body:
|
if extra_body:
|
||||||
payload["parameters"].update(extra_body)
|
payload["parameters"].update(extra_body)
|
||||||
else:
|
else:
|
||||||
# Standard format for Jina/Cohere
|
# Standard format for Jina/Cohere/OpenAI
|
||||||
payload = {
|
payload = {
|
||||||
"model": model,
|
"model": model,
|
||||||
"query": query,
|
"query": query,
|
||||||
|
|
@ -98,7 +254,7 @@ async def generic_rerank_api(
|
||||||
payload["top_n"] = top_n
|
payload["top_n"] = top_n
|
||||||
|
|
||||||
# Only Jina API supports return_documents parameter
|
# Only Jina API supports return_documents parameter
|
||||||
if return_documents is not None:
|
if return_documents is not None and response_format in ("standard",):
|
||||||
payload["return_documents"] = return_documents
|
payload["return_documents"] = return_documents
|
||||||
|
|
||||||
# Add extra parameters
|
# Add extra parameters
|
||||||
|
|
@ -147,7 +303,6 @@ async def generic_rerank_api(
|
||||||
f"Expected 'output.results' to be list, got {type(results)}: {results}"
|
f"Expected 'output.results' to be list, got {type(results)}: {results}"
|
||||||
)
|
)
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
elif response_format == "standard":
|
elif response_format == "standard":
|
||||||
# Standard format: {"results": [...]}
|
# Standard format: {"results": [...]}
|
||||||
results = response_json.get("results", [])
|
results = response_json.get("results", [])
|
||||||
|
|
@ -158,16 +313,28 @@ async def generic_rerank_api(
|
||||||
results = []
|
results = []
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported response format: {response_format}")
|
raise ValueError(f"Unsupported response format: {response_format}")
|
||||||
|
|
||||||
if not results:
|
if not results:
|
||||||
logger.warning("Rerank API returned empty results")
|
logger.warning("Rerank API returned empty results")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# Standardize return format
|
# Standardize return format
|
||||||
return [
|
standardized_results = [
|
||||||
{"index": result["index"], "relevance_score": result["relevance_score"]}
|
{"index": result["index"], "relevance_score": result["relevance_score"]}
|
||||||
for result in results
|
for result in results
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Aggregate chunk scores back to original documents if chunking was enabled
|
||||||
|
if enable_chunking and doc_indices:
|
||||||
|
standardized_results = aggregate_chunk_scores(
|
||||||
|
standardized_results,
|
||||||
|
doc_indices,
|
||||||
|
len(original_documents),
|
||||||
|
aggregation="max",
|
||||||
|
)
|
||||||
|
|
||||||
|
return standardized_results
|
||||||
|
|
||||||
|
|
||||||
async def cohere_rerank(
|
async def cohere_rerank(
|
||||||
query: str,
|
query: str,
|
||||||
|
|
@ -177,21 +344,46 @@ async def cohere_rerank(
|
||||||
model: str = "rerank-v3.5",
|
model: str = "rerank-v3.5",
|
||||||
base_url: str = "https://api.cohere.com/v2/rerank",
|
base_url: str = "https://api.cohere.com/v2/rerank",
|
||||||
extra_body: Optional[Dict[str, Any]] = None,
|
extra_body: Optional[Dict[str, Any]] = None,
|
||||||
|
enable_chunking: bool = False,
|
||||||
|
max_tokens_per_doc: int = 4096,
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Rerank documents using Cohere API.
|
Rerank documents using Cohere API.
|
||||||
|
|
||||||
|
Supports both standard Cohere API and Cohere-compatible proxies
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query: The search query
|
query: The search query
|
||||||
documents: List of strings to rerank
|
documents: List of strings to rerank
|
||||||
top_n: Number of top results to return
|
top_n: Number of top results to return
|
||||||
api_key: API key
|
api_key: API key for authentication
|
||||||
model: rerank model name
|
model: rerank model name (default: rerank-v3.5)
|
||||||
base_url: API endpoint
|
base_url: API endpoint
|
||||||
extra_body: Additional body for http request(reserved for extra params)
|
extra_body: Additional body for http request(reserved for extra params)
|
||||||
|
enable_chunking: Whether to chunk documents exceeding max_tokens_per_doc
|
||||||
|
max_tokens_per_doc: Maximum tokens per document (default: 4096 for Cohere v3.5)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of dictionary of ["index": int, "relevance_score": float]
|
List of dictionary of ["index": int, "relevance_score": float]
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> # Standard Cohere API
|
||||||
|
>>> results = await cohere_rerank(
|
||||||
|
... query="What is the meaning of life?",
|
||||||
|
... documents=["Doc1", "Doc2"],
|
||||||
|
... api_key="your-cohere-key"
|
||||||
|
... )
|
||||||
|
|
||||||
|
>>> # LiteLLM proxy with user authentication
|
||||||
|
>>> results = await cohere_rerank(
|
||||||
|
... query="What is vector search?",
|
||||||
|
... documents=["Doc1", "Doc2"],
|
||||||
|
... model="answerai-colbert-small-v1",
|
||||||
|
... base_url="https://llm-proxy.example.com/v2/rerank",
|
||||||
|
... api_key="your-proxy-key",
|
||||||
|
... enable_chunking=True,
|
||||||
|
... max_tokens_per_doc=480
|
||||||
|
... )
|
||||||
"""
|
"""
|
||||||
if api_key is None:
|
if api_key is None:
|
||||||
api_key = os.getenv("COHERE_API_KEY") or os.getenv("RERANK_BINDING_API_KEY")
|
api_key = os.getenv("COHERE_API_KEY") or os.getenv("RERANK_BINDING_API_KEY")
|
||||||
|
|
@ -206,6 +398,8 @@ async def cohere_rerank(
|
||||||
return_documents=None, # Cohere doesn't support this parameter
|
return_documents=None, # Cohere doesn't support this parameter
|
||||||
extra_body=extra_body,
|
extra_body=extra_body,
|
||||||
response_format="standard",
|
response_format="standard",
|
||||||
|
enable_chunking=enable_chunking,
|
||||||
|
max_tokens_per_doc=max_tokens_per_doc,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
386
tests/test_rerank_chunking.py
Normal file
386
tests/test_rerank_chunking.py
Normal file
|
|
@ -0,0 +1,386 @@
|
||||||
|
"""
|
||||||
|
Unit tests for rerank document chunking functionality.
|
||||||
|
|
||||||
|
Tests the chunk_documents_for_rerank and aggregate_chunk_scores functions
|
||||||
|
in lightrag/rerank.py to ensure proper document splitting and score aggregation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import Mock, patch, AsyncMock
|
||||||
|
from lightrag.rerank import (
|
||||||
|
chunk_documents_for_rerank,
|
||||||
|
aggregate_chunk_scores,
|
||||||
|
cohere_rerank,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestChunkDocumentsForRerank:
|
||||||
|
"""Test suite for chunk_documents_for_rerank function"""
|
||||||
|
|
||||||
|
def test_no_chunking_needed_for_short_docs(self):
|
||||||
|
"""Documents shorter than max_tokens should not be chunked"""
|
||||||
|
documents = [
|
||||||
|
"Short doc 1",
|
||||||
|
"Short doc 2",
|
||||||
|
"Short doc 3",
|
||||||
|
]
|
||||||
|
|
||||||
|
chunked_docs, doc_indices = chunk_documents_for_rerank(
|
||||||
|
documents, max_tokens=100, overlap_tokens=10
|
||||||
|
)
|
||||||
|
|
||||||
|
# No chunking should occur
|
||||||
|
assert len(chunked_docs) == 3
|
||||||
|
assert chunked_docs == documents
|
||||||
|
assert doc_indices == [0, 1, 2]
|
||||||
|
|
||||||
|
def test_chunking_with_character_fallback(self):
|
||||||
|
"""Test chunking falls back to character-based when tokenizer unavailable"""
|
||||||
|
# Create a very long document that exceeds character limit
|
||||||
|
long_doc = "a" * 2000 # 2000 characters
|
||||||
|
documents = [long_doc, "short doc"]
|
||||||
|
|
||||||
|
with patch("lightrag.rerank.TiktokenTokenizer", side_effect=ImportError):
|
||||||
|
chunked_docs, doc_indices = chunk_documents_for_rerank(
|
||||||
|
documents,
|
||||||
|
max_tokens=100, # 100 tokens = ~400 chars
|
||||||
|
overlap_tokens=10, # 10 tokens = ~40 chars
|
||||||
|
)
|
||||||
|
|
||||||
|
# First doc should be split into chunks, second doc stays whole
|
||||||
|
assert len(chunked_docs) > 2 # At least one chunk from first doc + second doc
|
||||||
|
assert chunked_docs[-1] == "short doc" # Last chunk is the short doc
|
||||||
|
# Verify doc_indices maps chunks to correct original document
|
||||||
|
assert doc_indices[-1] == 1 # Last chunk maps to document 1
|
||||||
|
|
||||||
|
def test_chunking_with_tiktoken_tokenizer(self):
|
||||||
|
"""Test chunking with actual tokenizer"""
|
||||||
|
# Create document with known token count
|
||||||
|
# Approximate: "word " = ~1 token, so 200 words ~ 200 tokens
|
||||||
|
long_doc = " ".join([f"word{i}" for i in range(200)])
|
||||||
|
documents = [long_doc, "short"]
|
||||||
|
|
||||||
|
chunked_docs, doc_indices = chunk_documents_for_rerank(
|
||||||
|
documents, max_tokens=50, overlap_tokens=10
|
||||||
|
)
|
||||||
|
|
||||||
|
# Long doc should be split, short doc should remain
|
||||||
|
assert len(chunked_docs) > 2
|
||||||
|
assert doc_indices[-1] == 1 # Last chunk is from second document
|
||||||
|
|
||||||
|
# Verify overlapping chunks contain overlapping content
|
||||||
|
if len(chunked_docs) > 2:
|
||||||
|
# Check that consecutive chunks from same doc have some overlap
|
||||||
|
for i in range(len(doc_indices) - 1):
|
||||||
|
if doc_indices[i] == doc_indices[i + 1] == 0:
|
||||||
|
# Both chunks from first doc, should have overlap
|
||||||
|
chunk1_words = chunked_docs[i].split()
|
||||||
|
chunk2_words = chunked_docs[i + 1].split()
|
||||||
|
# At least one word should be common due to overlap
|
||||||
|
assert any(word in chunk2_words for word in chunk1_words[-5:])
|
||||||
|
|
||||||
|
def test_empty_documents(self):
|
||||||
|
"""Test handling of empty document list"""
|
||||||
|
documents = []
|
||||||
|
chunked_docs, doc_indices = chunk_documents_for_rerank(documents)
|
||||||
|
|
||||||
|
assert chunked_docs == []
|
||||||
|
assert doc_indices == []
|
||||||
|
|
||||||
|
def test_single_document_chunking(self):
|
||||||
|
"""Test chunking of a single long document"""
|
||||||
|
# Create document with ~100 tokens
|
||||||
|
long_doc = " ".join([f"token{i}" for i in range(100)])
|
||||||
|
documents = [long_doc]
|
||||||
|
|
||||||
|
chunked_docs, doc_indices = chunk_documents_for_rerank(
|
||||||
|
documents, max_tokens=30, overlap_tokens=5
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should create multiple chunks
|
||||||
|
assert len(chunked_docs) > 1
|
||||||
|
# All chunks should map to document 0
|
||||||
|
assert all(idx == 0 for idx in doc_indices)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAggregateChunkScores:
|
||||||
|
"""Test suite for aggregate_chunk_scores function"""
|
||||||
|
|
||||||
|
def test_no_chunking_simple_aggregation(self):
|
||||||
|
"""Test aggregation when no chunking occurred (1:1 mapping)"""
|
||||||
|
chunk_results = [
|
||||||
|
{"index": 0, "relevance_score": 0.9},
|
||||||
|
{"index": 1, "relevance_score": 0.7},
|
||||||
|
{"index": 2, "relevance_score": 0.5},
|
||||||
|
]
|
||||||
|
doc_indices = [0, 1, 2] # 1:1 mapping
|
||||||
|
num_original_docs = 3
|
||||||
|
|
||||||
|
aggregated = aggregate_chunk_scores(
|
||||||
|
chunk_results, doc_indices, num_original_docs, aggregation="max"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Results should be sorted by score
|
||||||
|
assert len(aggregated) == 3
|
||||||
|
assert aggregated[0]["index"] == 0
|
||||||
|
assert aggregated[0]["relevance_score"] == 0.9
|
||||||
|
assert aggregated[1]["index"] == 1
|
||||||
|
assert aggregated[1]["relevance_score"] == 0.7
|
||||||
|
assert aggregated[2]["index"] == 2
|
||||||
|
assert aggregated[2]["relevance_score"] == 0.5
|
||||||
|
|
||||||
|
def test_max_aggregation_with_chunks(self):
|
||||||
|
"""Test max aggregation strategy with multiple chunks per document"""
|
||||||
|
# 5 chunks: first 3 from doc 0, last 2 from doc 1
|
||||||
|
chunk_results = [
|
||||||
|
{"index": 0, "relevance_score": 0.5},
|
||||||
|
{"index": 1, "relevance_score": 0.8},
|
||||||
|
{"index": 2, "relevance_score": 0.6},
|
||||||
|
{"index": 3, "relevance_score": 0.7},
|
||||||
|
{"index": 4, "relevance_score": 0.4},
|
||||||
|
]
|
||||||
|
doc_indices = [0, 0, 0, 1, 1]
|
||||||
|
num_original_docs = 2
|
||||||
|
|
||||||
|
aggregated = aggregate_chunk_scores(
|
||||||
|
chunk_results, doc_indices, num_original_docs, aggregation="max"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should take max score for each document
|
||||||
|
assert len(aggregated) == 2
|
||||||
|
assert aggregated[0]["index"] == 0
|
||||||
|
assert aggregated[0]["relevance_score"] == 0.8 # max of 0.5, 0.8, 0.6
|
||||||
|
assert aggregated[1]["index"] == 1
|
||||||
|
assert aggregated[1]["relevance_score"] == 0.7 # max of 0.7, 0.4
|
||||||
|
|
||||||
|
def test_mean_aggregation_with_chunks(self):
|
||||||
|
"""Test mean aggregation strategy"""
|
||||||
|
chunk_results = [
|
||||||
|
{"index": 0, "relevance_score": 0.6},
|
||||||
|
{"index": 1, "relevance_score": 0.8},
|
||||||
|
{"index": 2, "relevance_score": 0.4},
|
||||||
|
]
|
||||||
|
doc_indices = [0, 0, 1] # First two chunks from doc 0, last from doc 1
|
||||||
|
num_original_docs = 2
|
||||||
|
|
||||||
|
aggregated = aggregate_chunk_scores(
|
||||||
|
chunk_results, doc_indices, num_original_docs, aggregation="mean"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(aggregated) == 2
|
||||||
|
assert aggregated[0]["index"] == 0
|
||||||
|
assert aggregated[0]["relevance_score"] == pytest.approx(0.7) # (0.6 + 0.8) / 2
|
||||||
|
assert aggregated[1]["index"] == 1
|
||||||
|
assert aggregated[1]["relevance_score"] == 0.4
|
||||||
|
|
||||||
|
def test_first_aggregation_with_chunks(self):
|
||||||
|
"""Test first aggregation strategy"""
|
||||||
|
chunk_results = [
|
||||||
|
{"index": 0, "relevance_score": 0.6},
|
||||||
|
{"index": 1, "relevance_score": 0.8},
|
||||||
|
{"index": 2, "relevance_score": 0.4},
|
||||||
|
]
|
||||||
|
doc_indices = [0, 0, 1]
|
||||||
|
num_original_docs = 2
|
||||||
|
|
||||||
|
aggregated = aggregate_chunk_scores(
|
||||||
|
chunk_results, doc_indices, num_original_docs, aggregation="first"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(aggregated) == 2
|
||||||
|
# First should use first score seen for each doc
|
||||||
|
assert aggregated[0]["index"] == 0
|
||||||
|
assert aggregated[0]["relevance_score"] == 0.6 # First score for doc 0
|
||||||
|
assert aggregated[1]["index"] == 1
|
||||||
|
assert aggregated[1]["relevance_score"] == 0.4
|
||||||
|
|
||||||
|
def test_empty_chunk_results(self):
|
||||||
|
"""Test handling of empty results"""
|
||||||
|
aggregated = aggregate_chunk_scores([], [], 3, aggregation="max")
|
||||||
|
assert aggregated == []
|
||||||
|
|
||||||
|
def test_documents_with_no_scores(self):
|
||||||
|
"""Test when some documents have no chunks/scores"""
|
||||||
|
chunk_results = [
|
||||||
|
{"index": 0, "relevance_score": 0.9},
|
||||||
|
{"index": 1, "relevance_score": 0.7},
|
||||||
|
]
|
||||||
|
doc_indices = [0, 0] # Both chunks from document 0
|
||||||
|
num_original_docs = 3 # But we have 3 documents total
|
||||||
|
|
||||||
|
aggregated = aggregate_chunk_scores(
|
||||||
|
chunk_results, doc_indices, num_original_docs, aggregation="max"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Only doc 0 should appear in results
|
||||||
|
assert len(aggregated) == 1
|
||||||
|
assert aggregated[0]["index"] == 0
|
||||||
|
|
||||||
|
def test_unknown_aggregation_strategy(self):
|
||||||
|
"""Test that unknown strategy falls back to max"""
|
||||||
|
chunk_results = [
|
||||||
|
{"index": 0, "relevance_score": 0.6},
|
||||||
|
{"index": 1, "relevance_score": 0.8},
|
||||||
|
]
|
||||||
|
doc_indices = [0, 0]
|
||||||
|
num_original_docs = 1
|
||||||
|
|
||||||
|
# Use invalid strategy
|
||||||
|
aggregated = aggregate_chunk_scores(
|
||||||
|
chunk_results, doc_indices, num_original_docs, aggregation="invalid"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should fall back to max
|
||||||
|
assert aggregated[0]["relevance_score"] == 0.8
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.offline
|
||||||
|
class TestCohereRerankChunking:
|
||||||
|
"""Integration tests for cohere_rerank with chunking enabled"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cohere_rerank_with_chunking_disabled(self):
|
||||||
|
"""Test that chunking can be disabled"""
|
||||||
|
documents = ["doc1", "doc2"]
|
||||||
|
query = "test query"
|
||||||
|
|
||||||
|
# Mock the generic_rerank_api
|
||||||
|
with patch(
|
||||||
|
"lightrag.rerank.generic_rerank_api", new_callable=AsyncMock
|
||||||
|
) as mock_api:
|
||||||
|
mock_api.return_value = [
|
||||||
|
{"index": 0, "relevance_score": 0.9},
|
||||||
|
{"index": 1, "relevance_score": 0.7},
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await cohere_rerank(
|
||||||
|
query=query,
|
||||||
|
documents=documents,
|
||||||
|
api_key="test-key",
|
||||||
|
enable_chunking=False,
|
||||||
|
max_tokens_per_doc=100,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify generic_rerank_api was called with correct parameters
|
||||||
|
mock_api.assert_called_once()
|
||||||
|
call_kwargs = mock_api.call_args[1]
|
||||||
|
assert call_kwargs["enable_chunking"] is False
|
||||||
|
assert call_kwargs["max_tokens_per_doc"] == 100
|
||||||
|
# Result should mirror mocked scores
|
||||||
|
assert len(result) == 2
|
||||||
|
assert result[0]["index"] == 0
|
||||||
|
assert result[0]["relevance_score"] == 0.9
|
||||||
|
assert result[1]["index"] == 1
|
||||||
|
assert result[1]["relevance_score"] == 0.7
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cohere_rerank_with_chunking_enabled(self):
|
||||||
|
"""Test that chunking parameters are passed through"""
|
||||||
|
documents = ["doc1", "doc2"]
|
||||||
|
query = "test query"
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"lightrag.rerank.generic_rerank_api", new_callable=AsyncMock
|
||||||
|
) as mock_api:
|
||||||
|
mock_api.return_value = [
|
||||||
|
{"index": 0, "relevance_score": 0.9},
|
||||||
|
{"index": 1, "relevance_score": 0.7},
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await cohere_rerank(
|
||||||
|
query=query,
|
||||||
|
documents=documents,
|
||||||
|
api_key="test-key",
|
||||||
|
enable_chunking=True,
|
||||||
|
max_tokens_per_doc=480,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify parameters were passed
|
||||||
|
call_kwargs = mock_api.call_args[1]
|
||||||
|
assert call_kwargs["enable_chunking"] is True
|
||||||
|
assert call_kwargs["max_tokens_per_doc"] == 480
|
||||||
|
# Result should mirror mocked scores
|
||||||
|
assert len(result) == 2
|
||||||
|
assert result[0]["index"] == 0
|
||||||
|
assert result[0]["relevance_score"] == 0.9
|
||||||
|
assert result[1]["index"] == 1
|
||||||
|
assert result[1]["relevance_score"] == 0.7
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cohere_rerank_default_parameters(self):
|
||||||
|
"""Test default parameter values for cohere_rerank"""
|
||||||
|
documents = ["doc1"]
|
||||||
|
query = "test"
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"lightrag.rerank.generic_rerank_api", new_callable=AsyncMock
|
||||||
|
) as mock_api:
|
||||||
|
mock_api.return_value = [{"index": 0, "relevance_score": 0.9}]
|
||||||
|
|
||||||
|
result = await cohere_rerank(
|
||||||
|
query=query, documents=documents, api_key="test-key"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify default values
|
||||||
|
call_kwargs = mock_api.call_args[1]
|
||||||
|
assert call_kwargs["enable_chunking"] is False
|
||||||
|
assert call_kwargs["max_tokens_per_doc"] == 4096
|
||||||
|
assert call_kwargs["model"] == "rerank-v3.5"
|
||||||
|
# Result should mirror mocked scores
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0]["index"] == 0
|
||||||
|
assert result[0]["relevance_score"] == 0.9
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.offline
|
||||||
|
class TestEndToEndChunking:
|
||||||
|
"""End-to-end tests for chunking workflow"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_end_to_end_chunking_workflow(self):
|
||||||
|
"""Test complete chunking workflow from documents to aggregated results"""
|
||||||
|
# Create documents where first one needs chunking
|
||||||
|
long_doc = " ".join([f"word{i}" for i in range(100)])
|
||||||
|
documents = [long_doc, "short doc"]
|
||||||
|
query = "test query"
|
||||||
|
|
||||||
|
# Mock the HTTP call inside generic_rerank_api
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.status = 200
|
||||||
|
mock_response.json = AsyncMock(
|
||||||
|
return_value={
|
||||||
|
"results": [
|
||||||
|
{"index": 0, "relevance_score": 0.5}, # chunk 0 from doc 0
|
||||||
|
{"index": 1, "relevance_score": 0.8}, # chunk 1 from doc 0
|
||||||
|
{"index": 2, "relevance_score": 0.6}, # chunk 2 from doc 0
|
||||||
|
{"index": 3, "relevance_score": 0.7}, # doc 1 (short)
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
mock_response.request_info = None
|
||||||
|
mock_response.history = None
|
||||||
|
mock_response.headers = {}
|
||||||
|
|
||||||
|
mock_session = Mock()
|
||||||
|
mock_session.post = AsyncMock(return_value=mock_response)
|
||||||
|
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
|
||||||
|
mock_session.__aexit__ = AsyncMock()
|
||||||
|
|
||||||
|
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||||
|
result = await cohere_rerank(
|
||||||
|
query=query,
|
||||||
|
documents=documents,
|
||||||
|
api_key="test-key",
|
||||||
|
base_url="http://test.com/rerank",
|
||||||
|
enable_chunking=True,
|
||||||
|
max_tokens_per_doc=30, # Force chunking of long doc
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should get 2 results (one per original document)
|
||||||
|
# The long doc's chunks should be aggregated
|
||||||
|
assert len(result) <= len(documents)
|
||||||
|
# Results should be sorted by score
|
||||||
|
assert all(
|
||||||
|
result[i]["relevance_score"] >= result[i + 1]["relevance_score"]
|
||||||
|
for i in range(len(result) - 1)
|
||||||
|
)
|
||||||
Loading…
Add table
Reference in a new issue