chore: sync with upstream (#4)

* chore: sync with upstream

- Cohere rerank improvements
- Content deduplication
- Dependency updates

* fix: address CodeRabbit review feedback

- Harden env parsing for RERANK_MAX_TOKENS_PER_DOC with try/except
- Add @pytest.mark.offline to test_overlap_validation
- Remove unused doc_indices variable
This commit is contained in:
clssck 2025-12-03 13:16:28 +01:00 committed by GitHub
parent 99f950671e
commit 9bae6267f6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 2135 additions and 940 deletions

206
.github/dependabot.yml vendored Normal file
View file

@ -0,0 +1,206 @@
# Keep GitHub Actions up to date with GitHub's Dependabot...
# https://docs.github.com/en/code-security/dependabot/working-with-dependabot/keeping-your-actions-up-to-date-with-dependabot
# https://docs.github.com/en/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file#package-ecosystem
version: 2
updates:
# ============================================================
# GitHub Actions
# PR Strategy:
# - All updates (major/minor/patch): Grouped into a single PR
# ============================================================
- package-ecosystem: github-actions
directory: /
groups:
github-actions:
patterns:
- "*" # Group all Actions updates into a single larger pull request
schedule:
interval: weekly
day: monday
time: "02:00"
timezone: "Asia/Shanghai"
labels:
- "dependencies"
- "github-actions"
open-pull-requests-limit: 2
# ============================================================
# Python (pip) Dependencies
# PR Strategy:
# - Major updates: Individual PR per package (except numpy which is ignored)
# - Minor updates: Grouped by category (llm-providers, storage, etc.)
# - Patch updates: Grouped by category
# ============================================================
- package-ecosystem: "pip"
directory: "/"
schedule:
interval: "weekly"
day: "wednesday"
time: "02:00"
timezone: "Asia/Shanghai"
cooldown:
default-days: 5
semver-major-days: 30
semver-minor-days: 7
semver-patch-days: 3
groups:
# Core dependencies - LLM providers and embeddings
llm-providers:
patterns:
- "openai"
- "anthropic"
- "google-*"
- "boto3"
- "botocore"
- "ollama"
update-types:
- "minor"
- "patch"
# Storage backends
storage:
patterns:
- "neo4j"
- "pymongo"
- "redis"
- "psycopg*"
- "asyncpg"
- "milvus*"
- "qdrant*"
update-types:
- "minor"
- "patch"
# Data processing and ML
data-processing:
patterns:
- "numpy"
- "scipy"
- "pandas"
- "tiktoken"
- "transformers"
- "torch*"
update-types:
- "minor"
- "patch"
# Web framework and API
web-framework:
patterns:
- "fastapi"
- "uvicorn"
- "gunicorn"
- "starlette"
- "pydantic*"
update-types:
- "minor"
- "patch"
# Development and testing tools
dev-tools:
patterns:
- "pytest*"
- "ruff"
- "pre-commit"
- "black"
- "mypy"
update-types:
- "minor"
- "patch"
# Minor and patch updates for everything else
python-minor-patch:
patterns:
- "*"
update-types:
- "minor"
- "patch"
ignore:
- dependency-name: "numpy"
update-types:
- "version-update:semver-major"
labels:
- "dependencies"
- "python"
open-pull-requests-limit: 5
# ============================================================
# Frontend (bun) Dependencies
# PR Strategy:
# - Major updates: Individual PR per package
# - Minor updates: Grouped by category (react, ui-components, etc.)
# - Patch updates: Grouped by category
# ============================================================
- package-ecosystem: "bun"
directory: "/lightrag_webui"
schedule:
interval: "weekly"
day: "friday"
time: "02:00"
timezone: "Asia/Shanghai"
cooldown:
default-days: 5
semver-major-days: 30
semver-minor-days: 7
semver-patch-days: 3
groups:
# React ecosystem
react:
patterns:
- "react"
- "react-dom"
- "react-router*"
- "@types/react*"
update-types:
- "minor"
- "patch"
# UI components and styling
ui-components:
patterns:
- "@radix-ui/*"
- "tailwind*"
- "@tailwindcss/*"
- "lucide-react"
- "class-variance-authority"
- "clsx"
update-types:
- "minor"
- "patch"
# Graph visualization
graph-viz:
patterns:
- "sigma"
- "@sigma/*"
- "graphology*"
update-types:
- "minor"
- "patch"
# Build tools and dev dependencies
build-tools:
patterns:
- "vite"
- "@vitejs/*"
- "typescript"
- "eslint*"
- "@eslint/*"
- "typescript-eslint"
- "prettier"
- "prettier-*"
- "@types/bun"
update-types:
- "minor"
- "patch"
# Content rendering libraries (math, diagrams, etc.)
content-rendering:
patterns:
- "katex"
- "mermaid"
update-types:
- "minor"
- "patch"
# All other minor and patch updates
frontend-minor-patch:
patterns:
- "*"
update-types:
- "minor"
- "patch"
labels:
- "dependencies"
- "frontend"
open-pull-requests-limit: 5

View file

@ -12,7 +12,9 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@v6
with:
fetch-depth: 0
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
@ -25,7 +27,7 @@ jobs:
password: ${{ secrets.GITHUB_TOKEN }}
- name: Build and push Docker image
uses: docker/build-push-action@v5
uses: docker/build-push-action@v6
with:
context: .
file: ./Dockerfile

View file

@ -102,6 +102,9 @@ RERANK_BINDING=null
# RERANK_MODEL=rerank-v3.5
# RERANK_BINDING_HOST=https://api.cohere.com/v2/rerank
# RERANK_BINDING_API_KEY=your_rerank_api_key_here
### Cohere rerank chunking configuration (useful for models with token limits like ColBERT)
# RERANK_ENABLE_CHUNKING=true
# RERANK_MAX_TOKENS_PER_DOC=480
### Default value for Jina AI
# RERANK_MODEL=jina-reranker-v2-base-multilingual

View file

@ -15,9 +15,12 @@ Configuration Required:
EMBEDDING_BINDING_HOST
EMBEDDING_BINDING_API_KEY
3. Set your vLLM deployed AI rerank model setting with env vars:
RERANK_MODEL
RERANK_BINDING_HOST
RERANK_BINDING=cohere
RERANK_MODEL (e.g., answerai-colbert-small-v1 or rerank-v3.5)
RERANK_BINDING_HOST (e.g., https://api.cohere.com/v2/rerank or LiteLLM proxy)
RERANK_BINDING_API_KEY
RERANK_ENABLE_CHUNKING=true (optional, for models with token limits)
RERANK_MAX_TOKENS_PER_DOC=480 (optional, default 4096)
Note: Rerank is controlled per query via the 'enable_rerank' parameter (default: True)
"""
@ -66,9 +69,11 @@ async def embedding_func(texts: list[str]) -> np.ndarray:
rerank_model_func = partial(
cohere_rerank,
model=os.getenv("RERANK_MODEL"),
model=os.getenv("RERANK_MODEL", "rerank-v3.5"),
api_key=os.getenv("RERANK_BINDING_API_KEY"),
base_url=os.getenv("RERANK_BINDING_HOST"),
base_url=os.getenv("RERANK_BINDING_HOST", "https://api.cohere.com/v2/rerank"),
enable_chunking=os.getenv("RERANK_ENABLE_CHUNKING", "false").lower() == "true",
max_tokens_per_doc=int(os.getenv("RERANK_MAX_TOKENS_PER_DOC", "4096")),
)

View file

@ -1 +1 @@
__api_version__ = "0258"
__api_version__ = "0259"

View file

@ -1023,15 +1023,30 @@ def create_app(args):
query: str, documents: list, top_n: int = None, extra_body: dict = None
):
"""Server rerank function with configuration from environment variables"""
return await selected_rerank_func(
query=query,
documents=documents,
top_n=top_n,
api_key=args.rerank_binding_api_key,
model=args.rerank_model,
base_url=args.rerank_binding_host,
extra_body=extra_body,
)
# Prepare kwargs for rerank function
kwargs = {
"query": query,
"documents": documents,
"top_n": top_n,
"api_key": args.rerank_binding_api_key,
"model": args.rerank_model,
"base_url": args.rerank_binding_host,
}
# Add Cohere-specific parameters if using cohere binding
if args.rerank_binding == "cohere":
# Enable chunking if configured (useful for models with token limits like ColBERT)
kwargs["enable_chunking"] = (
os.getenv("RERANK_ENABLE_CHUNKING", "false").lower() == "true"
)
try:
kwargs["max_tokens_per_doc"] = int(
os.getenv("RERANK_MAX_TOKENS_PER_DOC", "4096")
)
except ValueError:
kwargs["max_tokens_per_doc"] = 4096
return await selected_rerank_func(**kwargs, extra_body=extra_body)
rerank_model_func = server_rerank_func
logger.info(

View file

@ -24,7 +24,11 @@ from pydantic import BaseModel, Field, field_validator
from lightrag import LightRAG
from lightrag.base import DeletionResult, DocProcessingStatus, DocStatus
from lightrag.utils import generate_track_id
from lightrag.utils import (
generate_track_id,
compute_mdhash_id,
sanitize_text_for_encoding,
)
from lightrag.api.utils_api import get_combined_auth_dependency
from ..config import global_args
@ -159,7 +163,7 @@ class ReprocessResponse(BaseModel):
Attributes:
status: Status of the reprocessing operation
message: Message describing the operation result
track_id: Tracking ID for monitoring reprocessing progress
track_id: Always empty string. Reprocessed documents retain their original track_id.
"""
status: Literal["reprocessing_started"] = Field(
@ -167,7 +171,8 @@ class ReprocessResponse(BaseModel):
)
message: str = Field(description="Human-readable message describing the operation")
track_id: str = Field(
description="Tracking ID for monitoring reprocessing progress"
default="",
description="Always empty string. Reprocessed documents retain their original track_id from initial upload.",
)
class Config:
@ -175,7 +180,7 @@ class ReprocessResponse(BaseModel):
"example": {
"status": "reprocessing_started",
"message": "Reprocessing of failed documents has been initiated in background",
"track_id": "retry_20250729_170612_def456",
"track_id": "",
}
}
@ -2097,12 +2102,14 @@ def create_document_routes(
# Check if filename already exists in doc_status storage
existing_doc_data = await rag.doc_status.get_doc_by_file_path(safe_filename)
if existing_doc_data:
# Get document status information for error message
# Get document status and track_id from existing document
status = existing_doc_data.get("status", "unknown")
# Use `or ""` to handle both missing key and None value (e.g., legacy rows without track_id)
existing_track_id = existing_doc_data.get("track_id") or ""
return InsertResponse(
status="duplicated",
message=f"File '{safe_filename}' already exists in document storage (Status: {status}).",
track_id="",
track_id=existing_track_id,
)
file_path = doc_manager.input_dir / safe_filename
@ -2166,14 +2173,30 @@ def create_document_routes(
request.file_source
)
if existing_doc_data:
# Get document status information for error message
# Get document status and track_id from existing document
status = existing_doc_data.get("status", "unknown")
# Use `or ""` to handle both missing key and None value (e.g., legacy rows without track_id)
existing_track_id = existing_doc_data.get("track_id") or ""
return InsertResponse(
status="duplicated",
message=f"File source '{request.file_source}' already exists in document storage (Status: {status}).",
track_id="",
track_id=existing_track_id,
)
# Check if content already exists by computing content hash (doc_id)
sanitized_text = sanitize_text_for_encoding(request.text)
content_doc_id = compute_mdhash_id(sanitized_text, prefix="doc-")
existing_doc = await rag.doc_status.get_by_id(content_doc_id)
if existing_doc:
# Content already exists, return duplicated with existing track_id
status = existing_doc.get("status", "unknown")
existing_track_id = existing_doc.get("track_id") or ""
return InsertResponse(
status="duplicated",
message=f"Identical content already exists in document storage (doc_id: {content_doc_id}, Status: {status}).",
track_id=existing_track_id,
)
# Generate track_id for text insertion
track_id = generate_track_id("insert")
@ -2232,14 +2255,31 @@ def create_document_routes(
file_source
)
if existing_doc_data:
# Get document status information for error message
# Get document status and track_id from existing document
status = existing_doc_data.get("status", "unknown")
# Use `or ""` to handle both missing key and None value (e.g., legacy rows without track_id)
existing_track_id = existing_doc_data.get("track_id") or ""
return InsertResponse(
status="duplicated",
message=f"File source '{file_source}' already exists in document storage (Status: {status}).",
track_id="",
track_id=existing_track_id,
)
# Check if any content already exists by computing content hash (doc_id)
for text in request.texts:
sanitized_text = sanitize_text_for_encoding(text)
content_doc_id = compute_mdhash_id(sanitized_text, prefix="doc-")
existing_doc = await rag.doc_status.get_by_id(content_doc_id)
if existing_doc:
# Content already exists, return duplicated with existing track_id
status = existing_doc.get("status", "unknown")
existing_track_id = existing_doc.get("track_id") or ""
return InsertResponse(
status="duplicated",
message=f"Identical content already exists in document storage (doc_id: {content_doc_id}, Status: {status}).",
track_id=existing_track_id,
)
# Generate track_id for texts insertion
track_id = generate_track_id("insert")
@ -3058,29 +3098,27 @@ def create_document_routes(
This is useful for recovering from server crashes, network errors, LLM service
outages, or other temporary failures that caused document processing to fail.
The processing happens in the background and can be monitored using the
returned track_id or by checking the pipeline status.
The processing happens in the background and can be monitored by checking the
pipeline status. The reprocessed documents retain their original track_id from
initial upload, so use their original track_id to monitor progress.
Returns:
ReprocessResponse: Response with status, message, and track_id
ReprocessResponse: Response with status and message.
track_id is always empty string because reprocessed documents retain
their original track_id from initial upload.
Raises:
HTTPException: If an error occurs while initiating reprocessing (500).
"""
try:
# Generate track_id with "retry" prefix for retry operation
track_id = generate_track_id("retry")
# Start the reprocessing in the background
# Note: Reprocessed documents retain their original track_id from initial upload
background_tasks.add_task(rag.apipeline_process_enqueue_documents)
logger.info(
f"Reprocessing of failed documents initiated with track_id: {track_id}"
)
logger.info("Reprocessing of failed documents initiated")
return ReprocessResponse(
status="reprocessing_started",
message="Reprocessing of failed documents has been initiated in background",
track_id=track_id,
message="Reprocessing of failed documents has been initiated in background. Documents retain their original track_id.",
)
except Exception as e:

View file

@ -2,7 +2,7 @@ from __future__ import annotations
import os
import aiohttp
from typing import Any, List, Dict, Optional
from typing import Any, List, Dict, Optional, Tuple
from tenacity import (
retry,
stop_after_attempt,
@ -19,6 +19,158 @@ from dotenv import load_dotenv
load_dotenv(dotenv_path=".env", override=False)
def chunk_documents_for_rerank(
documents: List[str],
max_tokens: int = 480,
overlap_tokens: int = 32,
tokenizer_model: str = "gpt-4o-mini",
) -> Tuple[List[str], List[int]]:
"""
Chunk documents that exceed token limit for reranking.
Args:
documents: List of document strings to chunk
max_tokens: Maximum tokens per chunk (default 480 to leave margin for 512 limit)
overlap_tokens: Number of tokens to overlap between chunks
tokenizer_model: Model name for tiktoken tokenizer
Returns:
Tuple of (chunked_documents, original_doc_indices)
- chunked_documents: List of document chunks (may be more than input)
- original_doc_indices: Maps each chunk back to its original document index
"""
# Clamp overlap_tokens to ensure the loop always advances
# If overlap_tokens >= max_tokens, the chunking loop would hang
if overlap_tokens >= max_tokens:
original_overlap = overlap_tokens
# Ensure overlap is at least 1 token less than max to guarantee progress
# For very small max_tokens (e.g., 1), set overlap to 0
overlap_tokens = max(0, max_tokens - 1)
logger.warning(
f"overlap_tokens ({original_overlap}) must be less than max_tokens ({max_tokens}). "
f"Clamping to {overlap_tokens} to prevent infinite loop."
)
try:
from .utils import TiktokenTokenizer
tokenizer = TiktokenTokenizer(model_name=tokenizer_model)
except Exception as e:
logger.warning(
f"Failed to initialize tokenizer: {e}. Using character-based approximation."
)
# Fallback: approximate 1 token ≈ 4 characters
max_chars = max_tokens * 4
overlap_chars = overlap_tokens * 4
chunked_docs = []
doc_indices = []
for idx, doc in enumerate(documents):
if len(doc) <= max_chars:
chunked_docs.append(doc)
doc_indices.append(idx)
else:
# Split into overlapping chunks
start = 0
while start < len(doc):
end = min(start + max_chars, len(doc))
chunk = doc[start:end]
chunked_docs.append(chunk)
doc_indices.append(idx)
if end >= len(doc):
break
start = end - overlap_chars
return chunked_docs, doc_indices
# Use tokenizer for accurate chunking
chunked_docs = []
doc_indices = []
for idx, doc in enumerate(documents):
tokens = tokenizer.encode(doc)
if len(tokens) <= max_tokens:
# Document fits in one chunk
chunked_docs.append(doc)
doc_indices.append(idx)
else:
# Split into overlapping chunks
start = 0
while start < len(tokens):
end = min(start + max_tokens, len(tokens))
chunk_tokens = tokens[start:end]
chunk_text = tokenizer.decode(chunk_tokens)
chunked_docs.append(chunk_text)
doc_indices.append(idx)
if end >= len(tokens):
break
start = end - overlap_tokens
return chunked_docs, doc_indices
def aggregate_chunk_scores(
chunk_results: List[Dict[str, Any]],
doc_indices: List[int],
num_original_docs: int,
aggregation: str = "max",
) -> List[Dict[str, Any]]:
"""
Aggregate rerank scores from document chunks back to original documents.
Args:
chunk_results: Rerank results for chunks [{"index": chunk_idx, "relevance_score": score}, ...]
doc_indices: Maps each chunk index to original document index
num_original_docs: Total number of original documents
aggregation: Strategy for aggregating scores ("max", "mean", "first")
Returns:
List of results for original documents [{"index": doc_idx, "relevance_score": score}, ...]
"""
# Group scores by original document index
doc_scores: Dict[int, List[float]] = {i: [] for i in range(num_original_docs)}
for result in chunk_results:
chunk_idx = result["index"]
score = result["relevance_score"]
if 0 <= chunk_idx < len(doc_indices):
original_doc_idx = doc_indices[chunk_idx]
doc_scores[original_doc_idx].append(score)
# Aggregate scores
aggregated_results = []
for doc_idx, scores in doc_scores.items():
if not scores:
continue
if aggregation == "max":
final_score = max(scores)
elif aggregation == "mean":
final_score = sum(scores) / len(scores)
elif aggregation == "first":
final_score = scores[0]
else:
logger.warning(f"Unknown aggregation strategy: {aggregation}, using max")
final_score = max(scores)
aggregated_results.append(
{
"index": doc_idx,
"relevance_score": final_score,
}
)
# Sort by relevance score (descending)
aggregated_results.sort(key=lambda x: x["relevance_score"], reverse=True)
return aggregated_results
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=60),
@ -38,6 +190,8 @@ async def generic_rerank_api(
extra_body: Optional[Dict[str, Any]] = None,
response_format: str = "standard", # "standard" (Jina/Cohere) or "aliyun"
request_format: str = "standard", # "standard" (Jina/Cohere) or "aliyun"
enable_chunking: bool = False,
max_tokens_per_doc: int = 480,
) -> List[Dict[str, Any]]:
"""
Generic rerank API call for Jina/Cohere/Aliyun models.
@ -52,6 +206,9 @@ async def generic_rerank_api(
return_documents: Whether to return document text (Jina only)
extra_body: Additional body parameters
response_format: Response format type ("standard" for Jina/Cohere, "aliyun" for Aliyun)
request_format: Request format type
enable_chunking: Whether to chunk documents exceeding token limit
max_tokens_per_doc: Maximum tokens per document for chunking
Returns:
List of dictionary of ["index": int, "relevance_score": float]
@ -63,6 +220,27 @@ async def generic_rerank_api(
if api_key is not None:
headers["Authorization"] = f"Bearer {api_key}"
# Handle document chunking if enabled
original_documents = documents
doc_indices = None
original_top_n = top_n # Save original top_n for post-aggregation limiting
if enable_chunking:
documents, doc_indices = chunk_documents_for_rerank(
documents, max_tokens=max_tokens_per_doc
)
logger.debug(
f"Chunked {len(original_documents)} documents into {len(documents)} chunks"
)
# When chunking is enabled, disable top_n at API level to get all chunk scores
# This ensures proper document-level coverage after aggregation
# We'll apply top_n to aggregated document results instead
if top_n is not None:
logger.debug(
f"Chunking enabled: disabled API-level top_n={top_n} to ensure complete document coverage"
)
top_n = None
# Build request payload based on request format
if request_format == "aliyun":
# Aliyun format: nested input/parameters structure
@ -86,7 +264,7 @@ async def generic_rerank_api(
if extra_body:
payload["parameters"].update(extra_body)
else:
# Standard format for Jina/Cohere
# Standard format for Jina/Cohere/OpenAI
payload = {
"model": model,
"query": query,
@ -98,7 +276,7 @@ async def generic_rerank_api(
payload["top_n"] = top_n
# Only Jina API supports return_documents parameter
if return_documents is not None:
if return_documents is not None and response_format in ("standard",):
payload["return_documents"] = return_documents
# Add extra parameters
@ -147,7 +325,6 @@ async def generic_rerank_api(
f"Expected 'output.results' to be list, got {type(results)}: {results}"
)
results = []
elif response_format == "standard":
# Standard format: {"results": [...]}
results = response_json.get("results", [])
@ -158,16 +335,35 @@ async def generic_rerank_api(
results = []
else:
raise ValueError(f"Unsupported response format: {response_format}")
if not results:
logger.warning("Rerank API returned empty results")
return []
# Standardize return format
return [
standardized_results = [
{"index": result["index"], "relevance_score": result["relevance_score"]}
for result in results
]
# Aggregate chunk scores back to original documents if chunking was enabled
if enable_chunking and doc_indices:
standardized_results = aggregate_chunk_scores(
standardized_results,
doc_indices,
len(original_documents),
aggregation="max",
)
# Apply original top_n limit at document level (post-aggregation)
# This preserves document-level semantics: top_n limits documents, not chunks
if (
original_top_n is not None
and len(standardized_results) > original_top_n
):
standardized_results = standardized_results[:original_top_n]
return standardized_results
async def cohere_rerank(
query: str,
@ -177,21 +373,46 @@ async def cohere_rerank(
model: str = "rerank-v3.5",
base_url: str = "https://api.cohere.com/v2/rerank",
extra_body: Optional[Dict[str, Any]] = None,
enable_chunking: bool = False,
max_tokens_per_doc: int = 4096,
) -> List[Dict[str, Any]]:
"""
Rerank documents using Cohere API.
Supports both standard Cohere API and Cohere-compatible proxies
Args:
query: The search query
documents: List of strings to rerank
top_n: Number of top results to return
api_key: API key
model: rerank model name
api_key: API key for authentication
model: rerank model name (default: rerank-v3.5)
base_url: API endpoint
extra_body: Additional body for http request(reserved for extra params)
enable_chunking: Whether to chunk documents exceeding max_tokens_per_doc
max_tokens_per_doc: Maximum tokens per document (default: 4096 for Cohere v3.5)
Returns:
List of dictionary of ["index": int, "relevance_score": float]
Example:
>>> # Standard Cohere API
>>> results = await cohere_rerank(
... query="What is the meaning of life?",
... documents=["Doc1", "Doc2"],
... api_key="your-cohere-key"
... )
>>> # LiteLLM proxy with user authentication
>>> results = await cohere_rerank(
... query="What is vector search?",
... documents=["Doc1", "Doc2"],
... model="answerai-colbert-small-v1",
... base_url="https://llm-proxy.example.com/v2/rerank",
... api_key="your-proxy-key",
... enable_chunking=True,
... max_tokens_per_doc=480
... )
"""
if api_key is None:
api_key = os.getenv("COHERE_API_KEY") or os.getenv("RERANK_BINDING_API_KEY")
@ -206,6 +427,8 @@ async def cohere_rerank(
return_documents=None, # Cohere doesn't support this parameter
extra_body=extra_body,
response_format="standard",
enable_chunking=enable_chunking,
max_tokens_per_doc=max_tokens_per_doc,
)

File diff suppressed because it is too large Load diff

View file

@ -16,17 +16,17 @@
"preview-no-bun": "vite preview"
},
"dependencies": {
"@faker-js/faker": "^9.9.0",
"@faker-js/faker": "^10.1.0",
"@radix-ui/react-alert-dialog": "^1.1.15",
"@radix-ui/react-checkbox": "^1.3.3",
"@radix-ui/react-dialog": "^1.1.15",
"@radix-ui/react-hover-card": "^1.1.15",
"@radix-ui/react-popover": "^1.1.15",
"@radix-ui/react-progress": "^1.1.7",
"@radix-ui/react-progress": "^1.1.8",
"@radix-ui/react-scroll-area": "^1.2.10",
"@radix-ui/react-select": "^2.2.6",
"@radix-ui/react-separator": "^1.1.7",
"@radix-ui/react-slot": "^1.2.3",
"@radix-ui/react-separator": "^1.1.8",
"@radix-ui/react-slot": "^1.2.4",
"@radix-ui/react-tabs": "^1.1.13",
"@radix-ui/react-tooltip": "^1.2.8",
"@radix-ui/react-use-controllable-state": "^1.2.2",
@ -43,7 +43,7 @@
"@sigma/node-border": "^3.0.0",
"@tanstack/react-query": "^5.87.1",
"@tanstack/react-table": "^8.21.3",
"axios": "^1.12.2",
"axios": "^1.13.2",
"class-variance-authority": "^0.7.1",
"clsx": "^2.1.1",
"cmdk": "^1.1.1",
@ -53,21 +53,21 @@
"graphology-layout-force": "^0.2.4",
"graphology-layout-forceatlas2": "^0.10.1",
"graphology-layout-noverlap": "^0.4.2",
"i18next": "^24.2.3",
"katex": "^0.16.23",
"lucide-react": "^0.475.0",
"mermaid": "^11.12.0",
"i18next": "^25.6.3",
"katex": "^0.16.25",
"mermaid": "^11.12.1",
"lucide-react": "^0.554.0",
"minisearch": "^7.2.0",
"react": "^19.2.0",
"react-dom": "^19.2.0",
"react-dropzone": "^14.3.8",
"react-error-boundary": "^5.0.0",
"react-i18next": "^15.7.4",
"react-markdown": "^9.1.0",
"react-error-boundary": "^6.0.0",
"react-i18next": "^16.3.5",
"react-markdown": "^10.1.0",
"react-number-format": "^5.4.4",
"react-router-dom": "^7.9.4",
"react-router-dom": "^7.9.6",
"react-select": "^5.10.2",
"react-syntax-highlighter": "^15.6.6",
"react-syntax-highlighter": "^16.1.0",
"rehype-katex": "^7.0.1",
"rehype-raw": "^7.0.0",
"rehype-react": "^8.0.0",
@ -75,8 +75,8 @@
"remark-math": "^6.0.0",
"seedrandom": "^3.0.5",
"sigma": "^3.0.2",
"sonner": "^1.7.4",
"tailwind-merge": "^3.3.1",
"sonner": "^2.0.7",
"tailwind-merge": "^3.4.0",
"tailwind-scrollbar": "^4.0.2",
"typography": "^0.16.24",
"unist-util-visit": "^5.0.0",
@ -84,32 +84,32 @@
},
"devDependencies": {
"@biomejs/biome": "^1.9.3",
"@eslint/js": "^9.37.0",
"@stylistic/eslint-plugin-js": "^3.1.0",
"@eslint/js": "^9.39.1",
"@stylistic/eslint-plugin-js": "^4.4.1",
"@tailwindcss/typography": "^0.5.15",
"@tailwindcss/vite": "^4.1.14",
"@types/bun": "^1.2.23",
"@tailwindcss/vite": "^4.1.17",
"@types/bun": "^1.3.3",
"@types/katex": "^0.16.7",
"@types/node": "^22.18.9",
"@types/react": "^19.2.2",
"@types/react-dom": "^19.2.1",
"@types/node": "^24.10.1",
"@types/react": "^19.2.7",
"@types/react-dom": "^19.2.3",
"@types/react-i18next": "^8.1.0",
"@types/react-syntax-highlighter": "^15.5.13",
"@types/seedrandom": "^3.0.8",
"@vitejs/plugin-react-swc": "^3.11.0",
"eslint": "^9.37.0",
"@vitejs/plugin-react-swc": "^4.2.2",
"eslint": "^9.39.1",
"eslint-config-prettier": "^10.1.8",
"eslint-plugin-react": "^7.37.5",
"eslint-plugin-react-hooks": "^5.2.0",
"eslint-plugin-react-refresh": "^0.4.23",
"globals": "^15.15.0",
"eslint-plugin-react-hooks": "^7.0.1",
"eslint-plugin-react-refresh": "^0.4.24",
"globals": "^16.5.0",
"graphology-types": "^0.24.8",
"prettier": "^3.6.2",
"prettier-plugin-tailwindcss": "^0.6.14",
"tailwindcss": "^4.1.14",
"prettier-plugin-tailwindcss": "^0.7.1",
"typescript-eslint": "^8.48.0",
"tailwindcss": "^4.1.17",
"tailwindcss-animate": "^1.0.7",
"typescript": "~5.7.3",
"typescript-eslint": "^8.46.0",
"vite": "^6.3.6"
"typescript": "~5.9.3",
"vite": "^7.2.4"
}
}

View file

@ -153,7 +153,6 @@ export const ChatMessage = ({
setKatexPlugin(() => rehypeKatex)
} catch (error) {
console.error('Failed to load KaTeX plugin:', error)
// Set to null to ensure we don't try to use a failed plugin
setKatexPlugin(null)
}
}

View file

@ -0,0 +1,116 @@
"""
Test for overlap_tokens validation to prevent infinite loop.
This test validates the fix for the bug where overlap_tokens >= max_tokens
causes an infinite loop in the chunking function.
"""
import pytest
from lightrag.rerank import chunk_documents_for_rerank
@pytest.mark.offline
class TestOverlapValidation:
"""Test suite for overlap_tokens validation"""
def test_overlap_greater_than_max_tokens(self):
"""Test that overlap_tokens > max_tokens is clamped and doesn't hang"""
documents = [" ".join([f"word{i}" for i in range(100)])]
# This should clamp overlap_tokens to 29 (max_tokens - 1)
chunked_docs, doc_indices = chunk_documents_for_rerank(
documents, max_tokens=30, overlap_tokens=32
)
# Should complete without hanging
assert len(chunked_docs) > 0
assert all(idx == 0 for idx in doc_indices)
def test_overlap_equal_to_max_tokens(self):
"""Test that overlap_tokens == max_tokens is clamped and doesn't hang"""
documents = [" ".join([f"word{i}" for i in range(100)])]
# This should clamp overlap_tokens to 29 (max_tokens - 1)
chunked_docs, doc_indices = chunk_documents_for_rerank(
documents, max_tokens=30, overlap_tokens=30
)
# Should complete without hanging
assert len(chunked_docs) > 0
assert all(idx == 0 for idx in doc_indices)
def test_overlap_slightly_less_than_max_tokens(self):
"""Test that overlap_tokens < max_tokens works normally"""
documents = [" ".join([f"word{i}" for i in range(100)])]
# This should work without clamping
chunked_docs, doc_indices = chunk_documents_for_rerank(
documents, max_tokens=30, overlap_tokens=29
)
# Should complete successfully
assert len(chunked_docs) > 0
assert all(idx == 0 for idx in doc_indices)
def test_small_max_tokens_with_large_overlap(self):
"""Test edge case with very small max_tokens"""
documents = [" ".join([f"word{i}" for i in range(50)])]
# max_tokens=5, overlap_tokens=10 should clamp to 4
chunked_docs, doc_indices = chunk_documents_for_rerank(
documents, max_tokens=5, overlap_tokens=10
)
# Should complete without hanging
assert len(chunked_docs) > 0
assert all(idx == 0 for idx in doc_indices)
def test_multiple_documents_with_invalid_overlap(self):
"""Test multiple documents with overlap_tokens >= max_tokens"""
documents = [
" ".join([f"word{i}" for i in range(50)]),
"short document",
" ".join([f"word{i}" for i in range(75)]),
]
# overlap_tokens > max_tokens
chunked_docs, _ = chunk_documents_for_rerank(
documents, max_tokens=25, overlap_tokens=30
)
# Should complete successfully and chunk the long documents
assert len(chunked_docs) >= len(documents)
# Short document should not be chunked
assert "short document" in chunked_docs
def test_normal_operation_unaffected(self):
"""Test that normal cases continue to work correctly"""
documents = [
" ".join([f"word{i}" for i in range(100)]),
"short doc",
]
# Normal case: overlap_tokens (10) < max_tokens (50)
chunked_docs, doc_indices = chunk_documents_for_rerank(
documents, max_tokens=50, overlap_tokens=10
)
# Long document should be chunked, short one should not
assert len(chunked_docs) > 2 # At least 3 chunks (2 from long doc + 1 short)
assert "short doc" in chunked_docs
# Verify doc_indices maps correctly
assert doc_indices[-1] == 1 # Last chunk is from second document
def test_edge_case_max_tokens_one(self):
"""Test edge case where max_tokens=1"""
documents = [" ".join([f"word{i}" for i in range(20)])]
# max_tokens=1, overlap_tokens=5 should clamp to 0
chunked_docs, doc_indices = chunk_documents_for_rerank(
documents, max_tokens=1, overlap_tokens=5
)
# Should complete without hanging
assert len(chunked_docs) > 0
assert all(idx == 0 for idx in doc_indices)

View file

@ -0,0 +1,564 @@
"""
Unit tests for rerank document chunking functionality.
Tests the chunk_documents_for_rerank and aggregate_chunk_scores functions
in lightrag/rerank.py to ensure proper document splitting and score aggregation.
"""
import pytest
from unittest.mock import Mock, patch, AsyncMock
from lightrag.rerank import (
chunk_documents_for_rerank,
aggregate_chunk_scores,
cohere_rerank,
)
class TestChunkDocumentsForRerank:
"""Test suite for chunk_documents_for_rerank function"""
def test_no_chunking_needed_for_short_docs(self):
"""Documents shorter than max_tokens should not be chunked"""
documents = [
"Short doc 1",
"Short doc 2",
"Short doc 3",
]
chunked_docs, doc_indices = chunk_documents_for_rerank(
documents, max_tokens=100, overlap_tokens=10
)
# No chunking should occur
assert len(chunked_docs) == 3
assert chunked_docs == documents
assert doc_indices == [0, 1, 2]
def test_chunking_with_character_fallback(self):
"""Test chunking falls back to character-based when tokenizer unavailable"""
# Create a very long document that exceeds character limit
long_doc = "a" * 2000 # 2000 characters
documents = [long_doc, "short doc"]
with patch("lightrag.utils.TiktokenTokenizer", side_effect=ImportError):
chunked_docs, doc_indices = chunk_documents_for_rerank(
documents,
max_tokens=100, # 100 tokens = ~400 chars
overlap_tokens=10, # 10 tokens = ~40 chars
)
# First doc should be split into chunks, second doc stays whole
assert len(chunked_docs) > 2 # At least one chunk from first doc + second doc
assert chunked_docs[-1] == "short doc" # Last chunk is the short doc
# Verify doc_indices maps chunks to correct original document
assert doc_indices[-1] == 1 # Last chunk maps to document 1
def test_chunking_with_tiktoken_tokenizer(self):
"""Test chunking with actual tokenizer"""
# Create document with known token count
# Approximate: "word " = ~1 token, so 200 words ~ 200 tokens
long_doc = " ".join([f"word{i}" for i in range(200)])
documents = [long_doc, "short"]
chunked_docs, doc_indices = chunk_documents_for_rerank(
documents, max_tokens=50, overlap_tokens=10
)
# Long doc should be split, short doc should remain
assert len(chunked_docs) > 2
assert doc_indices[-1] == 1 # Last chunk is from second document
# Verify overlapping chunks contain overlapping content
if len(chunked_docs) > 2:
# Check that consecutive chunks from same doc have some overlap
for i in range(len(doc_indices) - 1):
if doc_indices[i] == doc_indices[i + 1] == 0:
# Both chunks from first doc, should have overlap
chunk1_words = chunked_docs[i].split()
chunk2_words = chunked_docs[i + 1].split()
# At least one word should be common due to overlap
assert any(word in chunk2_words for word in chunk1_words[-5:])
def test_empty_documents(self):
"""Test handling of empty document list"""
documents = []
chunked_docs, doc_indices = chunk_documents_for_rerank(documents)
assert chunked_docs == []
assert doc_indices == []
def test_single_document_chunking(self):
"""Test chunking of a single long document"""
# Create document with ~100 tokens
long_doc = " ".join([f"token{i}" for i in range(100)])
documents = [long_doc]
chunked_docs, doc_indices = chunk_documents_for_rerank(
documents, max_tokens=30, overlap_tokens=5
)
# Should create multiple chunks
assert len(chunked_docs) > 1
# All chunks should map to document 0
assert all(idx == 0 for idx in doc_indices)
class TestAggregateChunkScores:
"""Test suite for aggregate_chunk_scores function"""
def test_no_chunking_simple_aggregation(self):
"""Test aggregation when no chunking occurred (1:1 mapping)"""
chunk_results = [
{"index": 0, "relevance_score": 0.9},
{"index": 1, "relevance_score": 0.7},
{"index": 2, "relevance_score": 0.5},
]
doc_indices = [0, 1, 2] # 1:1 mapping
num_original_docs = 3
aggregated = aggregate_chunk_scores(
chunk_results, doc_indices, num_original_docs, aggregation="max"
)
# Results should be sorted by score
assert len(aggregated) == 3
assert aggregated[0]["index"] == 0
assert aggregated[0]["relevance_score"] == 0.9
assert aggregated[1]["index"] == 1
assert aggregated[1]["relevance_score"] == 0.7
assert aggregated[2]["index"] == 2
assert aggregated[2]["relevance_score"] == 0.5
def test_max_aggregation_with_chunks(self):
"""Test max aggregation strategy with multiple chunks per document"""
# 5 chunks: first 3 from doc 0, last 2 from doc 1
chunk_results = [
{"index": 0, "relevance_score": 0.5},
{"index": 1, "relevance_score": 0.8},
{"index": 2, "relevance_score": 0.6},
{"index": 3, "relevance_score": 0.7},
{"index": 4, "relevance_score": 0.4},
]
doc_indices = [0, 0, 0, 1, 1]
num_original_docs = 2
aggregated = aggregate_chunk_scores(
chunk_results, doc_indices, num_original_docs, aggregation="max"
)
# Should take max score for each document
assert len(aggregated) == 2
assert aggregated[0]["index"] == 0
assert aggregated[0]["relevance_score"] == 0.8 # max of 0.5, 0.8, 0.6
assert aggregated[1]["index"] == 1
assert aggregated[1]["relevance_score"] == 0.7 # max of 0.7, 0.4
def test_mean_aggregation_with_chunks(self):
"""Test mean aggregation strategy"""
chunk_results = [
{"index": 0, "relevance_score": 0.6},
{"index": 1, "relevance_score": 0.8},
{"index": 2, "relevance_score": 0.4},
]
doc_indices = [0, 0, 1] # First two chunks from doc 0, last from doc 1
num_original_docs = 2
aggregated = aggregate_chunk_scores(
chunk_results, doc_indices, num_original_docs, aggregation="mean"
)
assert len(aggregated) == 2
assert aggregated[0]["index"] == 0
assert aggregated[0]["relevance_score"] == pytest.approx(0.7) # (0.6 + 0.8) / 2
assert aggregated[1]["index"] == 1
assert aggregated[1]["relevance_score"] == 0.4
def test_first_aggregation_with_chunks(self):
"""Test first aggregation strategy"""
chunk_results = [
{"index": 0, "relevance_score": 0.6},
{"index": 1, "relevance_score": 0.8},
{"index": 2, "relevance_score": 0.4},
]
doc_indices = [0, 0, 1]
num_original_docs = 2
aggregated = aggregate_chunk_scores(
chunk_results, doc_indices, num_original_docs, aggregation="first"
)
assert len(aggregated) == 2
# First should use first score seen for each doc
assert aggregated[0]["index"] == 0
assert aggregated[0]["relevance_score"] == 0.6 # First score for doc 0
assert aggregated[1]["index"] == 1
assert aggregated[1]["relevance_score"] == 0.4
def test_empty_chunk_results(self):
"""Test handling of empty results"""
aggregated = aggregate_chunk_scores([], [], 3, aggregation="max")
assert aggregated == []
def test_documents_with_no_scores(self):
"""Test when some documents have no chunks/scores"""
chunk_results = [
{"index": 0, "relevance_score": 0.9},
{"index": 1, "relevance_score": 0.7},
]
doc_indices = [0, 0] # Both chunks from document 0
num_original_docs = 3 # But we have 3 documents total
aggregated = aggregate_chunk_scores(
chunk_results, doc_indices, num_original_docs, aggregation="max"
)
# Only doc 0 should appear in results
assert len(aggregated) == 1
assert aggregated[0]["index"] == 0
def test_unknown_aggregation_strategy(self):
"""Test that unknown strategy falls back to max"""
chunk_results = [
{"index": 0, "relevance_score": 0.6},
{"index": 1, "relevance_score": 0.8},
]
doc_indices = [0, 0]
num_original_docs = 1
# Use invalid strategy
aggregated = aggregate_chunk_scores(
chunk_results, doc_indices, num_original_docs, aggregation="invalid"
)
# Should fall back to max
assert aggregated[0]["relevance_score"] == 0.8
@pytest.mark.offline
class TestTopNWithChunking:
"""Tests for top_n behavior when chunking is enabled (Bug fix verification)"""
@pytest.mark.asyncio
async def test_top_n_limits_documents_not_chunks(self):
"""
Test that top_n correctly limits documents (not chunks) when chunking is enabled.
Bug scenario: 10 docs expand to 50 chunks. With old behavior, top_n=5 would
return scores for only 5 chunks (possibly all from 1-2 docs). After aggregation,
fewer than 5 documents would be returned.
Fixed behavior: top_n=5 should return exactly 5 documents after aggregation.
"""
# Setup: 5 documents, each producing multiple chunks when chunked
# Using small max_tokens to force chunking
long_docs = [" ".join([f"doc{i}_word{j}" for j in range(50)]) for i in range(5)]
query = "test query"
# First, determine how many chunks will be created by actual chunking
_, doc_indices = chunk_documents_for_rerank(
long_docs, max_tokens=50, overlap_tokens=10
)
num_chunks = len(doc_indices)
# Mock API returns scores for ALL chunks (simulating disabled API-level top_n)
# Give different scores to ensure doc 0 gets highest, doc 1 second, etc.
# Assign scores based on original document index (lower doc index = higher score)
mock_chunk_scores = []
for i in range(num_chunks):
original_doc = doc_indices[i]
# Higher score for lower doc index, with small variation per chunk
base_score = 0.9 - (original_doc * 0.1)
mock_chunk_scores.append({"index": i, "relevance_score": base_score})
mock_response = Mock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value={"results": mock_chunk_scores})
mock_response.request_info = None
mock_response.history = None
mock_response.headers = {}
mock_response.__aenter__ = AsyncMock(return_value=mock_response)
mock_response.__aexit__ = AsyncMock(return_value=None)
mock_session = Mock()
mock_session.post = Mock(return_value=mock_response)
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
mock_session.__aexit__ = AsyncMock(return_value=None)
with patch("lightrag.rerank.aiohttp.ClientSession", return_value=mock_session):
result = await cohere_rerank(
query=query,
documents=long_docs,
api_key="test-key",
base_url="http://test.com/rerank",
enable_chunking=True,
max_tokens_per_doc=50, # Match chunking above
top_n=3, # Request top 3 documents
)
# Verify: should get exactly 3 documents (not unlimited chunks)
assert len(result) == 3
# All results should have valid document indices (0-4)
assert all(0 <= r["index"] < 5 for r in result)
# Results should be sorted by score (descending)
assert all(
result[i]["relevance_score"] >= result[i + 1]["relevance_score"]
for i in range(len(result) - 1)
)
# The top 3 docs should be 0, 1, 2 (highest scores)
result_indices = [r["index"] for r in result]
assert set(result_indices) == {0, 1, 2}
@pytest.mark.asyncio
async def test_api_receives_no_top_n_when_chunking_enabled(self):
"""
Test that the API request does NOT include top_n when chunking is enabled.
This ensures all chunk scores are retrieved for proper aggregation.
"""
documents = [" ".join([f"word{i}" for i in range(100)]), "short doc"]
query = "test query"
captured_payload = {}
mock_response = Mock()
mock_response.status = 200
mock_response.json = AsyncMock(
return_value={
"results": [
{"index": 0, "relevance_score": 0.9},
{"index": 1, "relevance_score": 0.8},
{"index": 2, "relevance_score": 0.7},
]
}
)
mock_response.request_info = None
mock_response.history = None
mock_response.headers = {}
mock_response.__aenter__ = AsyncMock(return_value=mock_response)
mock_response.__aexit__ = AsyncMock(return_value=None)
def capture_post(*args, **kwargs):
captured_payload.update(kwargs.get("json", {}))
return mock_response
mock_session = Mock()
mock_session.post = Mock(side_effect=capture_post)
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
mock_session.__aexit__ = AsyncMock(return_value=None)
with patch("lightrag.rerank.aiohttp.ClientSession", return_value=mock_session):
await cohere_rerank(
query=query,
documents=documents,
api_key="test-key",
base_url="http://test.com/rerank",
enable_chunking=True,
max_tokens_per_doc=30,
top_n=1, # User wants top 1 document
)
# Verify: API payload should NOT have top_n (disabled for chunking)
assert "top_n" not in captured_payload
@pytest.mark.asyncio
async def test_top_n_not_modified_when_chunking_disabled(self):
"""
Test that top_n is passed through to API when chunking is disabled.
"""
documents = ["doc1", "doc2"]
query = "test query"
captured_payload = {}
mock_response = Mock()
mock_response.status = 200
mock_response.json = AsyncMock(
return_value={
"results": [
{"index": 0, "relevance_score": 0.9},
]
}
)
mock_response.request_info = None
mock_response.history = None
mock_response.headers = {}
mock_response.__aenter__ = AsyncMock(return_value=mock_response)
mock_response.__aexit__ = AsyncMock(return_value=None)
def capture_post(*args, **kwargs):
captured_payload.update(kwargs.get("json", {}))
return mock_response
mock_session = Mock()
mock_session.post = Mock(side_effect=capture_post)
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
mock_session.__aexit__ = AsyncMock(return_value=None)
with patch("lightrag.rerank.aiohttp.ClientSession", return_value=mock_session):
await cohere_rerank(
query=query,
documents=documents,
api_key="test-key",
base_url="http://test.com/rerank",
enable_chunking=False, # Chunking disabled
top_n=1,
)
# Verify: API payload should have top_n when chunking is disabled
assert captured_payload.get("top_n") == 1
@pytest.mark.offline
class TestCohereRerankChunking:
"""Integration tests for cohere_rerank with chunking enabled"""
@pytest.mark.asyncio
async def test_cohere_rerank_with_chunking_disabled(self):
"""Test that chunking can be disabled"""
documents = ["doc1", "doc2"]
query = "test query"
# Mock the generic_rerank_api
with patch(
"lightrag.rerank.generic_rerank_api", new_callable=AsyncMock
) as mock_api:
mock_api.return_value = [
{"index": 0, "relevance_score": 0.9},
{"index": 1, "relevance_score": 0.7},
]
result = await cohere_rerank(
query=query,
documents=documents,
api_key="test-key",
enable_chunking=False,
max_tokens_per_doc=100,
)
# Verify generic_rerank_api was called with correct parameters
mock_api.assert_called_once()
call_kwargs = mock_api.call_args[1]
assert call_kwargs["enable_chunking"] is False
assert call_kwargs["max_tokens_per_doc"] == 100
# Result should mirror mocked scores
assert len(result) == 2
assert result[0]["index"] == 0
assert result[0]["relevance_score"] == 0.9
assert result[1]["index"] == 1
assert result[1]["relevance_score"] == 0.7
@pytest.mark.asyncio
async def test_cohere_rerank_with_chunking_enabled(self):
"""Test that chunking parameters are passed through"""
documents = ["doc1", "doc2"]
query = "test query"
with patch(
"lightrag.rerank.generic_rerank_api", new_callable=AsyncMock
) as mock_api:
mock_api.return_value = [
{"index": 0, "relevance_score": 0.9},
{"index": 1, "relevance_score": 0.7},
]
result = await cohere_rerank(
query=query,
documents=documents,
api_key="test-key",
enable_chunking=True,
max_tokens_per_doc=480,
)
# Verify parameters were passed
call_kwargs = mock_api.call_args[1]
assert call_kwargs["enable_chunking"] is True
assert call_kwargs["max_tokens_per_doc"] == 480
# Result should mirror mocked scores
assert len(result) == 2
assert result[0]["index"] == 0
assert result[0]["relevance_score"] == 0.9
assert result[1]["index"] == 1
assert result[1]["relevance_score"] == 0.7
@pytest.mark.asyncio
async def test_cohere_rerank_default_parameters(self):
"""Test default parameter values for cohere_rerank"""
documents = ["doc1"]
query = "test"
with patch(
"lightrag.rerank.generic_rerank_api", new_callable=AsyncMock
) as mock_api:
mock_api.return_value = [{"index": 0, "relevance_score": 0.9}]
result = await cohere_rerank(
query=query, documents=documents, api_key="test-key"
)
# Verify default values
call_kwargs = mock_api.call_args[1]
assert call_kwargs["enable_chunking"] is False
assert call_kwargs["max_tokens_per_doc"] == 4096
assert call_kwargs["model"] == "rerank-v3.5"
# Result should mirror mocked scores
assert len(result) == 1
assert result[0]["index"] == 0
assert result[0]["relevance_score"] == 0.9
@pytest.mark.offline
class TestEndToEndChunking:
"""End-to-end tests for chunking workflow"""
@pytest.mark.asyncio
async def test_end_to_end_chunking_workflow(self):
"""Test complete chunking workflow from documents to aggregated results"""
# Create documents where first one needs chunking
long_doc = " ".join([f"word{i}" for i in range(100)])
documents = [long_doc, "short doc"]
query = "test query"
# Mock the HTTP call inside generic_rerank_api
mock_response = Mock()
mock_response.status = 200
mock_response.json = AsyncMock(
return_value={
"results": [
{"index": 0, "relevance_score": 0.5}, # chunk 0 from doc 0
{"index": 1, "relevance_score": 0.8}, # chunk 1 from doc 0
{"index": 2, "relevance_score": 0.6}, # chunk 2 from doc 0
{"index": 3, "relevance_score": 0.7}, # doc 1 (short)
]
}
)
mock_response.request_info = None
mock_response.history = None
mock_response.headers = {}
# Make mock_response an async context manager (for `async with session.post() as response`)
mock_response.__aenter__ = AsyncMock(return_value=mock_response)
mock_response.__aexit__ = AsyncMock(return_value=None)
mock_session = Mock()
# session.post() returns an async context manager, so return mock_response which is now one
mock_session.post = Mock(return_value=mock_response)
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
mock_session.__aexit__ = AsyncMock(return_value=None)
with patch("lightrag.rerank.aiohttp.ClientSession", return_value=mock_session):
result = await cohere_rerank(
query=query,
documents=documents,
api_key="test-key",
base_url="http://test.com/rerank",
enable_chunking=True,
max_tokens_per_doc=30, # Force chunking of long doc
)
# Should get 2 results (one per original document)
# The long doc's chunks should be aggregated
assert len(result) <= len(documents)
# Results should be sorted by score
assert all(
result[i]["relevance_score"] >= result[i + 1]["relevance_score"]
for i in range(len(result) - 1)
)