chore: sync with upstream (#4)
* chore: sync with upstream - Cohere rerank improvements - Content deduplication - Dependency updates * fix: address CodeRabbit review feedback - Harden env parsing for RERANK_MAX_TOKENS_PER_DOC with try/except - Add @pytest.mark.offline to test_overlap_validation - Remove unused doc_indices variable
This commit is contained in:
parent
99f950671e
commit
9bae6267f6
13 changed files with 2135 additions and 940 deletions
206
.github/dependabot.yml
vendored
Normal file
206
.github/dependabot.yml
vendored
Normal file
|
|
@ -0,0 +1,206 @@
|
|||
# Keep GitHub Actions up to date with GitHub's Dependabot...
|
||||
# https://docs.github.com/en/code-security/dependabot/working-with-dependabot/keeping-your-actions-up-to-date-with-dependabot
|
||||
# https://docs.github.com/en/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file#package-ecosystem
|
||||
version: 2
|
||||
updates:
|
||||
# ============================================================
|
||||
# GitHub Actions
|
||||
# PR Strategy:
|
||||
# - All updates (major/minor/patch): Grouped into a single PR
|
||||
# ============================================================
|
||||
- package-ecosystem: github-actions
|
||||
directory: /
|
||||
groups:
|
||||
github-actions:
|
||||
patterns:
|
||||
- "*" # Group all Actions updates into a single larger pull request
|
||||
schedule:
|
||||
interval: weekly
|
||||
day: monday
|
||||
time: "02:00"
|
||||
timezone: "Asia/Shanghai"
|
||||
labels:
|
||||
- "dependencies"
|
||||
- "github-actions"
|
||||
open-pull-requests-limit: 2
|
||||
|
||||
# ============================================================
|
||||
# Python (pip) Dependencies
|
||||
# PR Strategy:
|
||||
# - Major updates: Individual PR per package (except numpy which is ignored)
|
||||
# - Minor updates: Grouped by category (llm-providers, storage, etc.)
|
||||
# - Patch updates: Grouped by category
|
||||
# ============================================================
|
||||
- package-ecosystem: "pip"
|
||||
directory: "/"
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
day: "wednesday"
|
||||
time: "02:00"
|
||||
timezone: "Asia/Shanghai"
|
||||
cooldown:
|
||||
default-days: 5
|
||||
semver-major-days: 30
|
||||
semver-minor-days: 7
|
||||
semver-patch-days: 3
|
||||
groups:
|
||||
# Core dependencies - LLM providers and embeddings
|
||||
llm-providers:
|
||||
patterns:
|
||||
- "openai"
|
||||
- "anthropic"
|
||||
- "google-*"
|
||||
- "boto3"
|
||||
- "botocore"
|
||||
- "ollama"
|
||||
update-types:
|
||||
- "minor"
|
||||
- "patch"
|
||||
# Storage backends
|
||||
storage:
|
||||
patterns:
|
||||
- "neo4j"
|
||||
- "pymongo"
|
||||
- "redis"
|
||||
- "psycopg*"
|
||||
- "asyncpg"
|
||||
- "milvus*"
|
||||
- "qdrant*"
|
||||
update-types:
|
||||
- "minor"
|
||||
- "patch"
|
||||
# Data processing and ML
|
||||
data-processing:
|
||||
patterns:
|
||||
- "numpy"
|
||||
- "scipy"
|
||||
- "pandas"
|
||||
- "tiktoken"
|
||||
- "transformers"
|
||||
- "torch*"
|
||||
update-types:
|
||||
- "minor"
|
||||
- "patch"
|
||||
# Web framework and API
|
||||
web-framework:
|
||||
patterns:
|
||||
- "fastapi"
|
||||
- "uvicorn"
|
||||
- "gunicorn"
|
||||
- "starlette"
|
||||
- "pydantic*"
|
||||
update-types:
|
||||
- "minor"
|
||||
- "patch"
|
||||
# Development and testing tools
|
||||
dev-tools:
|
||||
patterns:
|
||||
- "pytest*"
|
||||
- "ruff"
|
||||
- "pre-commit"
|
||||
- "black"
|
||||
- "mypy"
|
||||
update-types:
|
||||
- "minor"
|
||||
- "patch"
|
||||
# Minor and patch updates for everything else
|
||||
python-minor-patch:
|
||||
patterns:
|
||||
- "*"
|
||||
update-types:
|
||||
- "minor"
|
||||
- "patch"
|
||||
ignore:
|
||||
- dependency-name: "numpy"
|
||||
update-types:
|
||||
- "version-update:semver-major"
|
||||
labels:
|
||||
- "dependencies"
|
||||
- "python"
|
||||
open-pull-requests-limit: 5
|
||||
|
||||
# ============================================================
|
||||
# Frontend (bun) Dependencies
|
||||
# PR Strategy:
|
||||
# - Major updates: Individual PR per package
|
||||
# - Minor updates: Grouped by category (react, ui-components, etc.)
|
||||
# - Patch updates: Grouped by category
|
||||
# ============================================================
|
||||
- package-ecosystem: "bun"
|
||||
directory: "/lightrag_webui"
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
day: "friday"
|
||||
time: "02:00"
|
||||
timezone: "Asia/Shanghai"
|
||||
cooldown:
|
||||
default-days: 5
|
||||
semver-major-days: 30
|
||||
semver-minor-days: 7
|
||||
semver-patch-days: 3
|
||||
groups:
|
||||
# React ecosystem
|
||||
react:
|
||||
patterns:
|
||||
- "react"
|
||||
- "react-dom"
|
||||
- "react-router*"
|
||||
- "@types/react*"
|
||||
update-types:
|
||||
- "minor"
|
||||
- "patch"
|
||||
# UI components and styling
|
||||
ui-components:
|
||||
patterns:
|
||||
- "@radix-ui/*"
|
||||
- "tailwind*"
|
||||
- "@tailwindcss/*"
|
||||
- "lucide-react"
|
||||
- "class-variance-authority"
|
||||
- "clsx"
|
||||
update-types:
|
||||
- "minor"
|
||||
- "patch"
|
||||
# Graph visualization
|
||||
graph-viz:
|
||||
patterns:
|
||||
- "sigma"
|
||||
- "@sigma/*"
|
||||
- "graphology*"
|
||||
update-types:
|
||||
- "minor"
|
||||
- "patch"
|
||||
# Build tools and dev dependencies
|
||||
build-tools:
|
||||
patterns:
|
||||
- "vite"
|
||||
- "@vitejs/*"
|
||||
- "typescript"
|
||||
- "eslint*"
|
||||
- "@eslint/*"
|
||||
- "typescript-eslint"
|
||||
- "prettier"
|
||||
- "prettier-*"
|
||||
- "@types/bun"
|
||||
update-types:
|
||||
- "minor"
|
||||
- "patch"
|
||||
# Content rendering libraries (math, diagrams, etc.)
|
||||
content-rendering:
|
||||
patterns:
|
||||
- "katex"
|
||||
- "mermaid"
|
||||
update-types:
|
||||
- "minor"
|
||||
- "patch"
|
||||
# All other minor and patch updates
|
||||
frontend-minor-patch:
|
||||
patterns:
|
||||
- "*"
|
||||
update-types:
|
||||
- "minor"
|
||||
- "patch"
|
||||
labels:
|
||||
- "dependencies"
|
||||
- "frontend"
|
||||
open-pull-requests-limit: 5
|
||||
6
.github/workflows/docker-publish.yml
vendored
6
.github/workflows/docker-publish.yml
vendored
|
|
@ -12,7 +12,9 @@ jobs:
|
|||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
|
@ -25,7 +27,7 @@ jobs:
|
|||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Build and push Docker image
|
||||
uses: docker/build-push-action@v5
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
file: ./Dockerfile
|
||||
|
|
|
|||
|
|
@ -102,6 +102,9 @@ RERANK_BINDING=null
|
|||
# RERANK_MODEL=rerank-v3.5
|
||||
# RERANK_BINDING_HOST=https://api.cohere.com/v2/rerank
|
||||
# RERANK_BINDING_API_KEY=your_rerank_api_key_here
|
||||
### Cohere rerank chunking configuration (useful for models with token limits like ColBERT)
|
||||
# RERANK_ENABLE_CHUNKING=true
|
||||
# RERANK_MAX_TOKENS_PER_DOC=480
|
||||
|
||||
### Default value for Jina AI
|
||||
# RERANK_MODEL=jina-reranker-v2-base-multilingual
|
||||
|
|
|
|||
|
|
@ -15,9 +15,12 @@ Configuration Required:
|
|||
EMBEDDING_BINDING_HOST
|
||||
EMBEDDING_BINDING_API_KEY
|
||||
3. Set your vLLM deployed AI rerank model setting with env vars:
|
||||
RERANK_MODEL
|
||||
RERANK_BINDING_HOST
|
||||
RERANK_BINDING=cohere
|
||||
RERANK_MODEL (e.g., answerai-colbert-small-v1 or rerank-v3.5)
|
||||
RERANK_BINDING_HOST (e.g., https://api.cohere.com/v2/rerank or LiteLLM proxy)
|
||||
RERANK_BINDING_API_KEY
|
||||
RERANK_ENABLE_CHUNKING=true (optional, for models with token limits)
|
||||
RERANK_MAX_TOKENS_PER_DOC=480 (optional, default 4096)
|
||||
|
||||
Note: Rerank is controlled per query via the 'enable_rerank' parameter (default: True)
|
||||
"""
|
||||
|
|
@ -66,9 +69,11 @@ async def embedding_func(texts: list[str]) -> np.ndarray:
|
|||
|
||||
rerank_model_func = partial(
|
||||
cohere_rerank,
|
||||
model=os.getenv("RERANK_MODEL"),
|
||||
model=os.getenv("RERANK_MODEL", "rerank-v3.5"),
|
||||
api_key=os.getenv("RERANK_BINDING_API_KEY"),
|
||||
base_url=os.getenv("RERANK_BINDING_HOST"),
|
||||
base_url=os.getenv("RERANK_BINDING_HOST", "https://api.cohere.com/v2/rerank"),
|
||||
enable_chunking=os.getenv("RERANK_ENABLE_CHUNKING", "false").lower() == "true",
|
||||
max_tokens_per_doc=int(os.getenv("RERANK_MAX_TOKENS_PER_DOC", "4096")),
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
__api_version__ = "0258"
|
||||
__api_version__ = "0259"
|
||||
|
|
|
|||
|
|
@ -1023,15 +1023,30 @@ def create_app(args):
|
|||
query: str, documents: list, top_n: int = None, extra_body: dict = None
|
||||
):
|
||||
"""Server rerank function with configuration from environment variables"""
|
||||
return await selected_rerank_func(
|
||||
query=query,
|
||||
documents=documents,
|
||||
top_n=top_n,
|
||||
api_key=args.rerank_binding_api_key,
|
||||
model=args.rerank_model,
|
||||
base_url=args.rerank_binding_host,
|
||||
extra_body=extra_body,
|
||||
)
|
||||
# Prepare kwargs for rerank function
|
||||
kwargs = {
|
||||
"query": query,
|
||||
"documents": documents,
|
||||
"top_n": top_n,
|
||||
"api_key": args.rerank_binding_api_key,
|
||||
"model": args.rerank_model,
|
||||
"base_url": args.rerank_binding_host,
|
||||
}
|
||||
|
||||
# Add Cohere-specific parameters if using cohere binding
|
||||
if args.rerank_binding == "cohere":
|
||||
# Enable chunking if configured (useful for models with token limits like ColBERT)
|
||||
kwargs["enable_chunking"] = (
|
||||
os.getenv("RERANK_ENABLE_CHUNKING", "false").lower() == "true"
|
||||
)
|
||||
try:
|
||||
kwargs["max_tokens_per_doc"] = int(
|
||||
os.getenv("RERANK_MAX_TOKENS_PER_DOC", "4096")
|
||||
)
|
||||
except ValueError:
|
||||
kwargs["max_tokens_per_doc"] = 4096
|
||||
|
||||
return await selected_rerank_func(**kwargs, extra_body=extra_body)
|
||||
|
||||
rerank_model_func = server_rerank_func
|
||||
logger.info(
|
||||
|
|
|
|||
|
|
@ -24,7 +24,11 @@ from pydantic import BaseModel, Field, field_validator
|
|||
|
||||
from lightrag import LightRAG
|
||||
from lightrag.base import DeletionResult, DocProcessingStatus, DocStatus
|
||||
from lightrag.utils import generate_track_id
|
||||
from lightrag.utils import (
|
||||
generate_track_id,
|
||||
compute_mdhash_id,
|
||||
sanitize_text_for_encoding,
|
||||
)
|
||||
from lightrag.api.utils_api import get_combined_auth_dependency
|
||||
from ..config import global_args
|
||||
|
||||
|
|
@ -159,7 +163,7 @@ class ReprocessResponse(BaseModel):
|
|||
Attributes:
|
||||
status: Status of the reprocessing operation
|
||||
message: Message describing the operation result
|
||||
track_id: Tracking ID for monitoring reprocessing progress
|
||||
track_id: Always empty string. Reprocessed documents retain their original track_id.
|
||||
"""
|
||||
|
||||
status: Literal["reprocessing_started"] = Field(
|
||||
|
|
@ -167,7 +171,8 @@ class ReprocessResponse(BaseModel):
|
|||
)
|
||||
message: str = Field(description="Human-readable message describing the operation")
|
||||
track_id: str = Field(
|
||||
description="Tracking ID for monitoring reprocessing progress"
|
||||
default="",
|
||||
description="Always empty string. Reprocessed documents retain their original track_id from initial upload.",
|
||||
)
|
||||
|
||||
class Config:
|
||||
|
|
@ -175,7 +180,7 @@ class ReprocessResponse(BaseModel):
|
|||
"example": {
|
||||
"status": "reprocessing_started",
|
||||
"message": "Reprocessing of failed documents has been initiated in background",
|
||||
"track_id": "retry_20250729_170612_def456",
|
||||
"track_id": "",
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -2097,12 +2102,14 @@ def create_document_routes(
|
|||
# Check if filename already exists in doc_status storage
|
||||
existing_doc_data = await rag.doc_status.get_doc_by_file_path(safe_filename)
|
||||
if existing_doc_data:
|
||||
# Get document status information for error message
|
||||
# Get document status and track_id from existing document
|
||||
status = existing_doc_data.get("status", "unknown")
|
||||
# Use `or ""` to handle both missing key and None value (e.g., legacy rows without track_id)
|
||||
existing_track_id = existing_doc_data.get("track_id") or ""
|
||||
return InsertResponse(
|
||||
status="duplicated",
|
||||
message=f"File '{safe_filename}' already exists in document storage (Status: {status}).",
|
||||
track_id="",
|
||||
track_id=existing_track_id,
|
||||
)
|
||||
|
||||
file_path = doc_manager.input_dir / safe_filename
|
||||
|
|
@ -2166,14 +2173,30 @@ def create_document_routes(
|
|||
request.file_source
|
||||
)
|
||||
if existing_doc_data:
|
||||
# Get document status information for error message
|
||||
# Get document status and track_id from existing document
|
||||
status = existing_doc_data.get("status", "unknown")
|
||||
# Use `or ""` to handle both missing key and None value (e.g., legacy rows without track_id)
|
||||
existing_track_id = existing_doc_data.get("track_id") or ""
|
||||
return InsertResponse(
|
||||
status="duplicated",
|
||||
message=f"File source '{request.file_source}' already exists in document storage (Status: {status}).",
|
||||
track_id="",
|
||||
track_id=existing_track_id,
|
||||
)
|
||||
|
||||
# Check if content already exists by computing content hash (doc_id)
|
||||
sanitized_text = sanitize_text_for_encoding(request.text)
|
||||
content_doc_id = compute_mdhash_id(sanitized_text, prefix="doc-")
|
||||
existing_doc = await rag.doc_status.get_by_id(content_doc_id)
|
||||
if existing_doc:
|
||||
# Content already exists, return duplicated with existing track_id
|
||||
status = existing_doc.get("status", "unknown")
|
||||
existing_track_id = existing_doc.get("track_id") or ""
|
||||
return InsertResponse(
|
||||
status="duplicated",
|
||||
message=f"Identical content already exists in document storage (doc_id: {content_doc_id}, Status: {status}).",
|
||||
track_id=existing_track_id,
|
||||
)
|
||||
|
||||
# Generate track_id for text insertion
|
||||
track_id = generate_track_id("insert")
|
||||
|
||||
|
|
@ -2232,14 +2255,31 @@ def create_document_routes(
|
|||
file_source
|
||||
)
|
||||
if existing_doc_data:
|
||||
# Get document status information for error message
|
||||
# Get document status and track_id from existing document
|
||||
status = existing_doc_data.get("status", "unknown")
|
||||
# Use `or ""` to handle both missing key and None value (e.g., legacy rows without track_id)
|
||||
existing_track_id = existing_doc_data.get("track_id") or ""
|
||||
return InsertResponse(
|
||||
status="duplicated",
|
||||
message=f"File source '{file_source}' already exists in document storage (Status: {status}).",
|
||||
track_id="",
|
||||
track_id=existing_track_id,
|
||||
)
|
||||
|
||||
# Check if any content already exists by computing content hash (doc_id)
|
||||
for text in request.texts:
|
||||
sanitized_text = sanitize_text_for_encoding(text)
|
||||
content_doc_id = compute_mdhash_id(sanitized_text, prefix="doc-")
|
||||
existing_doc = await rag.doc_status.get_by_id(content_doc_id)
|
||||
if existing_doc:
|
||||
# Content already exists, return duplicated with existing track_id
|
||||
status = existing_doc.get("status", "unknown")
|
||||
existing_track_id = existing_doc.get("track_id") or ""
|
||||
return InsertResponse(
|
||||
status="duplicated",
|
||||
message=f"Identical content already exists in document storage (doc_id: {content_doc_id}, Status: {status}).",
|
||||
track_id=existing_track_id,
|
||||
)
|
||||
|
||||
# Generate track_id for texts insertion
|
||||
track_id = generate_track_id("insert")
|
||||
|
||||
|
|
@ -3058,29 +3098,27 @@ def create_document_routes(
|
|||
This is useful for recovering from server crashes, network errors, LLM service
|
||||
outages, or other temporary failures that caused document processing to fail.
|
||||
|
||||
The processing happens in the background and can be monitored using the
|
||||
returned track_id or by checking the pipeline status.
|
||||
The processing happens in the background and can be monitored by checking the
|
||||
pipeline status. The reprocessed documents retain their original track_id from
|
||||
initial upload, so use their original track_id to monitor progress.
|
||||
|
||||
Returns:
|
||||
ReprocessResponse: Response with status, message, and track_id
|
||||
ReprocessResponse: Response with status and message.
|
||||
track_id is always empty string because reprocessed documents retain
|
||||
their original track_id from initial upload.
|
||||
|
||||
Raises:
|
||||
HTTPException: If an error occurs while initiating reprocessing (500).
|
||||
"""
|
||||
try:
|
||||
# Generate track_id with "retry" prefix for retry operation
|
||||
track_id = generate_track_id("retry")
|
||||
|
||||
# Start the reprocessing in the background
|
||||
# Note: Reprocessed documents retain their original track_id from initial upload
|
||||
background_tasks.add_task(rag.apipeline_process_enqueue_documents)
|
||||
logger.info(
|
||||
f"Reprocessing of failed documents initiated with track_id: {track_id}"
|
||||
)
|
||||
logger.info("Reprocessing of failed documents initiated")
|
||||
|
||||
return ReprocessResponse(
|
||||
status="reprocessing_started",
|
||||
message="Reprocessing of failed documents has been initiated in background",
|
||||
track_id=track_id,
|
||||
message="Reprocessing of failed documents has been initiated in background. Documents retain their original track_id.",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
|||
|
||||
import os
|
||||
import aiohttp
|
||||
from typing import Any, List, Dict, Optional
|
||||
from typing import Any, List, Dict, Optional, Tuple
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
|
|
@ -19,6 +19,158 @@ from dotenv import load_dotenv
|
|||
load_dotenv(dotenv_path=".env", override=False)
|
||||
|
||||
|
||||
def chunk_documents_for_rerank(
|
||||
documents: List[str],
|
||||
max_tokens: int = 480,
|
||||
overlap_tokens: int = 32,
|
||||
tokenizer_model: str = "gpt-4o-mini",
|
||||
) -> Tuple[List[str], List[int]]:
|
||||
"""
|
||||
Chunk documents that exceed token limit for reranking.
|
||||
|
||||
Args:
|
||||
documents: List of document strings to chunk
|
||||
max_tokens: Maximum tokens per chunk (default 480 to leave margin for 512 limit)
|
||||
overlap_tokens: Number of tokens to overlap between chunks
|
||||
tokenizer_model: Model name for tiktoken tokenizer
|
||||
|
||||
Returns:
|
||||
Tuple of (chunked_documents, original_doc_indices)
|
||||
- chunked_documents: List of document chunks (may be more than input)
|
||||
- original_doc_indices: Maps each chunk back to its original document index
|
||||
"""
|
||||
# Clamp overlap_tokens to ensure the loop always advances
|
||||
# If overlap_tokens >= max_tokens, the chunking loop would hang
|
||||
if overlap_tokens >= max_tokens:
|
||||
original_overlap = overlap_tokens
|
||||
# Ensure overlap is at least 1 token less than max to guarantee progress
|
||||
# For very small max_tokens (e.g., 1), set overlap to 0
|
||||
overlap_tokens = max(0, max_tokens - 1)
|
||||
logger.warning(
|
||||
f"overlap_tokens ({original_overlap}) must be less than max_tokens ({max_tokens}). "
|
||||
f"Clamping to {overlap_tokens} to prevent infinite loop."
|
||||
)
|
||||
|
||||
try:
|
||||
from .utils import TiktokenTokenizer
|
||||
|
||||
tokenizer = TiktokenTokenizer(model_name=tokenizer_model)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to initialize tokenizer: {e}. Using character-based approximation."
|
||||
)
|
||||
# Fallback: approximate 1 token ≈ 4 characters
|
||||
max_chars = max_tokens * 4
|
||||
overlap_chars = overlap_tokens * 4
|
||||
|
||||
chunked_docs = []
|
||||
doc_indices = []
|
||||
|
||||
for idx, doc in enumerate(documents):
|
||||
if len(doc) <= max_chars:
|
||||
chunked_docs.append(doc)
|
||||
doc_indices.append(idx)
|
||||
else:
|
||||
# Split into overlapping chunks
|
||||
start = 0
|
||||
while start < len(doc):
|
||||
end = min(start + max_chars, len(doc))
|
||||
chunk = doc[start:end]
|
||||
chunked_docs.append(chunk)
|
||||
doc_indices.append(idx)
|
||||
|
||||
if end >= len(doc):
|
||||
break
|
||||
start = end - overlap_chars
|
||||
|
||||
return chunked_docs, doc_indices
|
||||
|
||||
# Use tokenizer for accurate chunking
|
||||
chunked_docs = []
|
||||
doc_indices = []
|
||||
|
||||
for idx, doc in enumerate(documents):
|
||||
tokens = tokenizer.encode(doc)
|
||||
|
||||
if len(tokens) <= max_tokens:
|
||||
# Document fits in one chunk
|
||||
chunked_docs.append(doc)
|
||||
doc_indices.append(idx)
|
||||
else:
|
||||
# Split into overlapping chunks
|
||||
start = 0
|
||||
while start < len(tokens):
|
||||
end = min(start + max_tokens, len(tokens))
|
||||
chunk_tokens = tokens[start:end]
|
||||
chunk_text = tokenizer.decode(chunk_tokens)
|
||||
chunked_docs.append(chunk_text)
|
||||
doc_indices.append(idx)
|
||||
|
||||
if end >= len(tokens):
|
||||
break
|
||||
start = end - overlap_tokens
|
||||
|
||||
return chunked_docs, doc_indices
|
||||
|
||||
|
||||
def aggregate_chunk_scores(
|
||||
chunk_results: List[Dict[str, Any]],
|
||||
doc_indices: List[int],
|
||||
num_original_docs: int,
|
||||
aggregation: str = "max",
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Aggregate rerank scores from document chunks back to original documents.
|
||||
|
||||
Args:
|
||||
chunk_results: Rerank results for chunks [{"index": chunk_idx, "relevance_score": score}, ...]
|
||||
doc_indices: Maps each chunk index to original document index
|
||||
num_original_docs: Total number of original documents
|
||||
aggregation: Strategy for aggregating scores ("max", "mean", "first")
|
||||
|
||||
Returns:
|
||||
List of results for original documents [{"index": doc_idx, "relevance_score": score}, ...]
|
||||
"""
|
||||
# Group scores by original document index
|
||||
doc_scores: Dict[int, List[float]] = {i: [] for i in range(num_original_docs)}
|
||||
|
||||
for result in chunk_results:
|
||||
chunk_idx = result["index"]
|
||||
score = result["relevance_score"]
|
||||
|
||||
if 0 <= chunk_idx < len(doc_indices):
|
||||
original_doc_idx = doc_indices[chunk_idx]
|
||||
doc_scores[original_doc_idx].append(score)
|
||||
|
||||
# Aggregate scores
|
||||
aggregated_results = []
|
||||
for doc_idx, scores in doc_scores.items():
|
||||
if not scores:
|
||||
continue
|
||||
|
||||
if aggregation == "max":
|
||||
final_score = max(scores)
|
||||
elif aggregation == "mean":
|
||||
final_score = sum(scores) / len(scores)
|
||||
elif aggregation == "first":
|
||||
final_score = scores[0]
|
||||
else:
|
||||
logger.warning(f"Unknown aggregation strategy: {aggregation}, using max")
|
||||
final_score = max(scores)
|
||||
|
||||
aggregated_results.append(
|
||||
{
|
||||
"index": doc_idx,
|
||||
"relevance_score": final_score,
|
||||
}
|
||||
)
|
||||
|
||||
# Sort by relevance score (descending)
|
||||
aggregated_results.sort(key=lambda x: x["relevance_score"], reverse=True)
|
||||
|
||||
return aggregated_results
|
||||
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||
|
|
@ -38,6 +190,8 @@ async def generic_rerank_api(
|
|||
extra_body: Optional[Dict[str, Any]] = None,
|
||||
response_format: str = "standard", # "standard" (Jina/Cohere) or "aliyun"
|
||||
request_format: str = "standard", # "standard" (Jina/Cohere) or "aliyun"
|
||||
enable_chunking: bool = False,
|
||||
max_tokens_per_doc: int = 480,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Generic rerank API call for Jina/Cohere/Aliyun models.
|
||||
|
|
@ -52,6 +206,9 @@ async def generic_rerank_api(
|
|||
return_documents: Whether to return document text (Jina only)
|
||||
extra_body: Additional body parameters
|
||||
response_format: Response format type ("standard" for Jina/Cohere, "aliyun" for Aliyun)
|
||||
request_format: Request format type
|
||||
enable_chunking: Whether to chunk documents exceeding token limit
|
||||
max_tokens_per_doc: Maximum tokens per document for chunking
|
||||
|
||||
Returns:
|
||||
List of dictionary of ["index": int, "relevance_score": float]
|
||||
|
|
@ -63,6 +220,27 @@ async def generic_rerank_api(
|
|||
if api_key is not None:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
# Handle document chunking if enabled
|
||||
original_documents = documents
|
||||
doc_indices = None
|
||||
original_top_n = top_n # Save original top_n for post-aggregation limiting
|
||||
|
||||
if enable_chunking:
|
||||
documents, doc_indices = chunk_documents_for_rerank(
|
||||
documents, max_tokens=max_tokens_per_doc
|
||||
)
|
||||
logger.debug(
|
||||
f"Chunked {len(original_documents)} documents into {len(documents)} chunks"
|
||||
)
|
||||
# When chunking is enabled, disable top_n at API level to get all chunk scores
|
||||
# This ensures proper document-level coverage after aggregation
|
||||
# We'll apply top_n to aggregated document results instead
|
||||
if top_n is not None:
|
||||
logger.debug(
|
||||
f"Chunking enabled: disabled API-level top_n={top_n} to ensure complete document coverage"
|
||||
)
|
||||
top_n = None
|
||||
|
||||
# Build request payload based on request format
|
||||
if request_format == "aliyun":
|
||||
# Aliyun format: nested input/parameters structure
|
||||
|
|
@ -86,7 +264,7 @@ async def generic_rerank_api(
|
|||
if extra_body:
|
||||
payload["parameters"].update(extra_body)
|
||||
else:
|
||||
# Standard format for Jina/Cohere
|
||||
# Standard format for Jina/Cohere/OpenAI
|
||||
payload = {
|
||||
"model": model,
|
||||
"query": query,
|
||||
|
|
@ -98,7 +276,7 @@ async def generic_rerank_api(
|
|||
payload["top_n"] = top_n
|
||||
|
||||
# Only Jina API supports return_documents parameter
|
||||
if return_documents is not None:
|
||||
if return_documents is not None and response_format in ("standard",):
|
||||
payload["return_documents"] = return_documents
|
||||
|
||||
# Add extra parameters
|
||||
|
|
@ -147,7 +325,6 @@ async def generic_rerank_api(
|
|||
f"Expected 'output.results' to be list, got {type(results)}: {results}"
|
||||
)
|
||||
results = []
|
||||
|
||||
elif response_format == "standard":
|
||||
# Standard format: {"results": [...]}
|
||||
results = response_json.get("results", [])
|
||||
|
|
@ -158,16 +335,35 @@ async def generic_rerank_api(
|
|||
results = []
|
||||
else:
|
||||
raise ValueError(f"Unsupported response format: {response_format}")
|
||||
|
||||
if not results:
|
||||
logger.warning("Rerank API returned empty results")
|
||||
return []
|
||||
|
||||
# Standardize return format
|
||||
return [
|
||||
standardized_results = [
|
||||
{"index": result["index"], "relevance_score": result["relevance_score"]}
|
||||
for result in results
|
||||
]
|
||||
|
||||
# Aggregate chunk scores back to original documents if chunking was enabled
|
||||
if enable_chunking and doc_indices:
|
||||
standardized_results = aggregate_chunk_scores(
|
||||
standardized_results,
|
||||
doc_indices,
|
||||
len(original_documents),
|
||||
aggregation="max",
|
||||
)
|
||||
# Apply original top_n limit at document level (post-aggregation)
|
||||
# This preserves document-level semantics: top_n limits documents, not chunks
|
||||
if (
|
||||
original_top_n is not None
|
||||
and len(standardized_results) > original_top_n
|
||||
):
|
||||
standardized_results = standardized_results[:original_top_n]
|
||||
|
||||
return standardized_results
|
||||
|
||||
|
||||
async def cohere_rerank(
|
||||
query: str,
|
||||
|
|
@ -177,21 +373,46 @@ async def cohere_rerank(
|
|||
model: str = "rerank-v3.5",
|
||||
base_url: str = "https://api.cohere.com/v2/rerank",
|
||||
extra_body: Optional[Dict[str, Any]] = None,
|
||||
enable_chunking: bool = False,
|
||||
max_tokens_per_doc: int = 4096,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Rerank documents using Cohere API.
|
||||
|
||||
Supports both standard Cohere API and Cohere-compatible proxies
|
||||
|
||||
Args:
|
||||
query: The search query
|
||||
documents: List of strings to rerank
|
||||
top_n: Number of top results to return
|
||||
api_key: API key
|
||||
model: rerank model name
|
||||
api_key: API key for authentication
|
||||
model: rerank model name (default: rerank-v3.5)
|
||||
base_url: API endpoint
|
||||
extra_body: Additional body for http request(reserved for extra params)
|
||||
enable_chunking: Whether to chunk documents exceeding max_tokens_per_doc
|
||||
max_tokens_per_doc: Maximum tokens per document (default: 4096 for Cohere v3.5)
|
||||
|
||||
Returns:
|
||||
List of dictionary of ["index": int, "relevance_score": float]
|
||||
|
||||
Example:
|
||||
>>> # Standard Cohere API
|
||||
>>> results = await cohere_rerank(
|
||||
... query="What is the meaning of life?",
|
||||
... documents=["Doc1", "Doc2"],
|
||||
... api_key="your-cohere-key"
|
||||
... )
|
||||
|
||||
>>> # LiteLLM proxy with user authentication
|
||||
>>> results = await cohere_rerank(
|
||||
... query="What is vector search?",
|
||||
... documents=["Doc1", "Doc2"],
|
||||
... model="answerai-colbert-small-v1",
|
||||
... base_url="https://llm-proxy.example.com/v2/rerank",
|
||||
... api_key="your-proxy-key",
|
||||
... enable_chunking=True,
|
||||
... max_tokens_per_doc=480
|
||||
... )
|
||||
"""
|
||||
if api_key is None:
|
||||
api_key = os.getenv("COHERE_API_KEY") or os.getenv("RERANK_BINDING_API_KEY")
|
||||
|
|
@ -206,6 +427,8 @@ async def cohere_rerank(
|
|||
return_documents=None, # Cohere doesn't support this parameter
|
||||
extra_body=extra_body,
|
||||
response_format="standard",
|
||||
enable_chunking=enable_chunking,
|
||||
max_tokens_per_doc=max_tokens_per_doc,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -16,17 +16,17 @@
|
|||
"preview-no-bun": "vite preview"
|
||||
},
|
||||
"dependencies": {
|
||||
"@faker-js/faker": "^9.9.0",
|
||||
"@faker-js/faker": "^10.1.0",
|
||||
"@radix-ui/react-alert-dialog": "^1.1.15",
|
||||
"@radix-ui/react-checkbox": "^1.3.3",
|
||||
"@radix-ui/react-dialog": "^1.1.15",
|
||||
"@radix-ui/react-hover-card": "^1.1.15",
|
||||
"@radix-ui/react-popover": "^1.1.15",
|
||||
"@radix-ui/react-progress": "^1.1.7",
|
||||
"@radix-ui/react-progress": "^1.1.8",
|
||||
"@radix-ui/react-scroll-area": "^1.2.10",
|
||||
"@radix-ui/react-select": "^2.2.6",
|
||||
"@radix-ui/react-separator": "^1.1.7",
|
||||
"@radix-ui/react-slot": "^1.2.3",
|
||||
"@radix-ui/react-separator": "^1.1.8",
|
||||
"@radix-ui/react-slot": "^1.2.4",
|
||||
"@radix-ui/react-tabs": "^1.1.13",
|
||||
"@radix-ui/react-tooltip": "^1.2.8",
|
||||
"@radix-ui/react-use-controllable-state": "^1.2.2",
|
||||
|
|
@ -43,7 +43,7 @@
|
|||
"@sigma/node-border": "^3.0.0",
|
||||
"@tanstack/react-query": "^5.87.1",
|
||||
"@tanstack/react-table": "^8.21.3",
|
||||
"axios": "^1.12.2",
|
||||
"axios": "^1.13.2",
|
||||
"class-variance-authority": "^0.7.1",
|
||||
"clsx": "^2.1.1",
|
||||
"cmdk": "^1.1.1",
|
||||
|
|
@ -53,21 +53,21 @@
|
|||
"graphology-layout-force": "^0.2.4",
|
||||
"graphology-layout-forceatlas2": "^0.10.1",
|
||||
"graphology-layout-noverlap": "^0.4.2",
|
||||
"i18next": "^24.2.3",
|
||||
"katex": "^0.16.23",
|
||||
"lucide-react": "^0.475.0",
|
||||
"mermaid": "^11.12.0",
|
||||
"i18next": "^25.6.3",
|
||||
"katex": "^0.16.25",
|
||||
"mermaid": "^11.12.1",
|
||||
"lucide-react": "^0.554.0",
|
||||
"minisearch": "^7.2.0",
|
||||
"react": "^19.2.0",
|
||||
"react-dom": "^19.2.0",
|
||||
"react-dropzone": "^14.3.8",
|
||||
"react-error-boundary": "^5.0.0",
|
||||
"react-i18next": "^15.7.4",
|
||||
"react-markdown": "^9.1.0",
|
||||
"react-error-boundary": "^6.0.0",
|
||||
"react-i18next": "^16.3.5",
|
||||
"react-markdown": "^10.1.0",
|
||||
"react-number-format": "^5.4.4",
|
||||
"react-router-dom": "^7.9.4",
|
||||
"react-router-dom": "^7.9.6",
|
||||
"react-select": "^5.10.2",
|
||||
"react-syntax-highlighter": "^15.6.6",
|
||||
"react-syntax-highlighter": "^16.1.0",
|
||||
"rehype-katex": "^7.0.1",
|
||||
"rehype-raw": "^7.0.0",
|
||||
"rehype-react": "^8.0.0",
|
||||
|
|
@ -75,8 +75,8 @@
|
|||
"remark-math": "^6.0.0",
|
||||
"seedrandom": "^3.0.5",
|
||||
"sigma": "^3.0.2",
|
||||
"sonner": "^1.7.4",
|
||||
"tailwind-merge": "^3.3.1",
|
||||
"sonner": "^2.0.7",
|
||||
"tailwind-merge": "^3.4.0",
|
||||
"tailwind-scrollbar": "^4.0.2",
|
||||
"typography": "^0.16.24",
|
||||
"unist-util-visit": "^5.0.0",
|
||||
|
|
@ -84,32 +84,32 @@
|
|||
},
|
||||
"devDependencies": {
|
||||
"@biomejs/biome": "^1.9.3",
|
||||
"@eslint/js": "^9.37.0",
|
||||
"@stylistic/eslint-plugin-js": "^3.1.0",
|
||||
"@eslint/js": "^9.39.1",
|
||||
"@stylistic/eslint-plugin-js": "^4.4.1",
|
||||
"@tailwindcss/typography": "^0.5.15",
|
||||
"@tailwindcss/vite": "^4.1.14",
|
||||
"@types/bun": "^1.2.23",
|
||||
"@tailwindcss/vite": "^4.1.17",
|
||||
"@types/bun": "^1.3.3",
|
||||
"@types/katex": "^0.16.7",
|
||||
"@types/node": "^22.18.9",
|
||||
"@types/react": "^19.2.2",
|
||||
"@types/react-dom": "^19.2.1",
|
||||
"@types/node": "^24.10.1",
|
||||
"@types/react": "^19.2.7",
|
||||
"@types/react-dom": "^19.2.3",
|
||||
"@types/react-i18next": "^8.1.0",
|
||||
"@types/react-syntax-highlighter": "^15.5.13",
|
||||
"@types/seedrandom": "^3.0.8",
|
||||
"@vitejs/plugin-react-swc": "^3.11.0",
|
||||
"eslint": "^9.37.0",
|
||||
"@vitejs/plugin-react-swc": "^4.2.2",
|
||||
"eslint": "^9.39.1",
|
||||
"eslint-config-prettier": "^10.1.8",
|
||||
"eslint-plugin-react": "^7.37.5",
|
||||
"eslint-plugin-react-hooks": "^5.2.0",
|
||||
"eslint-plugin-react-refresh": "^0.4.23",
|
||||
"globals": "^15.15.0",
|
||||
"eslint-plugin-react-hooks": "^7.0.1",
|
||||
"eslint-plugin-react-refresh": "^0.4.24",
|
||||
"globals": "^16.5.0",
|
||||
"graphology-types": "^0.24.8",
|
||||
"prettier": "^3.6.2",
|
||||
"prettier-plugin-tailwindcss": "^0.6.14",
|
||||
"tailwindcss": "^4.1.14",
|
||||
"prettier-plugin-tailwindcss": "^0.7.1",
|
||||
"typescript-eslint": "^8.48.0",
|
||||
"tailwindcss": "^4.1.17",
|
||||
"tailwindcss-animate": "^1.0.7",
|
||||
"typescript": "~5.7.3",
|
||||
"typescript-eslint": "^8.46.0",
|
||||
"vite": "^6.3.6"
|
||||
"typescript": "~5.9.3",
|
||||
"vite": "^7.2.4"
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -153,7 +153,6 @@ export const ChatMessage = ({
|
|||
setKatexPlugin(() => rehypeKatex)
|
||||
} catch (error) {
|
||||
console.error('Failed to load KaTeX plugin:', error)
|
||||
// Set to null to ensure we don't try to use a failed plugin
|
||||
setKatexPlugin(null)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
116
tests/test_overlap_validation.py
Normal file
116
tests/test_overlap_validation.py
Normal file
|
|
@ -0,0 +1,116 @@
|
|||
"""
|
||||
Test for overlap_tokens validation to prevent infinite loop.
|
||||
|
||||
This test validates the fix for the bug where overlap_tokens >= max_tokens
|
||||
causes an infinite loop in the chunking function.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from lightrag.rerank import chunk_documents_for_rerank
|
||||
|
||||
|
||||
@pytest.mark.offline
|
||||
class TestOverlapValidation:
|
||||
"""Test suite for overlap_tokens validation"""
|
||||
|
||||
def test_overlap_greater_than_max_tokens(self):
|
||||
"""Test that overlap_tokens > max_tokens is clamped and doesn't hang"""
|
||||
documents = [" ".join([f"word{i}" for i in range(100)])]
|
||||
|
||||
# This should clamp overlap_tokens to 29 (max_tokens - 1)
|
||||
chunked_docs, doc_indices = chunk_documents_for_rerank(
|
||||
documents, max_tokens=30, overlap_tokens=32
|
||||
)
|
||||
|
||||
# Should complete without hanging
|
||||
assert len(chunked_docs) > 0
|
||||
assert all(idx == 0 for idx in doc_indices)
|
||||
|
||||
def test_overlap_equal_to_max_tokens(self):
|
||||
"""Test that overlap_tokens == max_tokens is clamped and doesn't hang"""
|
||||
documents = [" ".join([f"word{i}" for i in range(100)])]
|
||||
|
||||
# This should clamp overlap_tokens to 29 (max_tokens - 1)
|
||||
chunked_docs, doc_indices = chunk_documents_for_rerank(
|
||||
documents, max_tokens=30, overlap_tokens=30
|
||||
)
|
||||
|
||||
# Should complete without hanging
|
||||
assert len(chunked_docs) > 0
|
||||
assert all(idx == 0 for idx in doc_indices)
|
||||
|
||||
def test_overlap_slightly_less_than_max_tokens(self):
|
||||
"""Test that overlap_tokens < max_tokens works normally"""
|
||||
documents = [" ".join([f"word{i}" for i in range(100)])]
|
||||
|
||||
# This should work without clamping
|
||||
chunked_docs, doc_indices = chunk_documents_for_rerank(
|
||||
documents, max_tokens=30, overlap_tokens=29
|
||||
)
|
||||
|
||||
# Should complete successfully
|
||||
assert len(chunked_docs) > 0
|
||||
assert all(idx == 0 for idx in doc_indices)
|
||||
|
||||
def test_small_max_tokens_with_large_overlap(self):
|
||||
"""Test edge case with very small max_tokens"""
|
||||
documents = [" ".join([f"word{i}" for i in range(50)])]
|
||||
|
||||
# max_tokens=5, overlap_tokens=10 should clamp to 4
|
||||
chunked_docs, doc_indices = chunk_documents_for_rerank(
|
||||
documents, max_tokens=5, overlap_tokens=10
|
||||
)
|
||||
|
||||
# Should complete without hanging
|
||||
assert len(chunked_docs) > 0
|
||||
assert all(idx == 0 for idx in doc_indices)
|
||||
|
||||
def test_multiple_documents_with_invalid_overlap(self):
|
||||
"""Test multiple documents with overlap_tokens >= max_tokens"""
|
||||
documents = [
|
||||
" ".join([f"word{i}" for i in range(50)]),
|
||||
"short document",
|
||||
" ".join([f"word{i}" for i in range(75)]),
|
||||
]
|
||||
|
||||
# overlap_tokens > max_tokens
|
||||
chunked_docs, _ = chunk_documents_for_rerank(
|
||||
documents, max_tokens=25, overlap_tokens=30
|
||||
)
|
||||
|
||||
# Should complete successfully and chunk the long documents
|
||||
assert len(chunked_docs) >= len(documents)
|
||||
# Short document should not be chunked
|
||||
assert "short document" in chunked_docs
|
||||
|
||||
def test_normal_operation_unaffected(self):
|
||||
"""Test that normal cases continue to work correctly"""
|
||||
documents = [
|
||||
" ".join([f"word{i}" for i in range(100)]),
|
||||
"short doc",
|
||||
]
|
||||
|
||||
# Normal case: overlap_tokens (10) < max_tokens (50)
|
||||
chunked_docs, doc_indices = chunk_documents_for_rerank(
|
||||
documents, max_tokens=50, overlap_tokens=10
|
||||
)
|
||||
|
||||
# Long document should be chunked, short one should not
|
||||
assert len(chunked_docs) > 2 # At least 3 chunks (2 from long doc + 1 short)
|
||||
assert "short doc" in chunked_docs
|
||||
# Verify doc_indices maps correctly
|
||||
assert doc_indices[-1] == 1 # Last chunk is from second document
|
||||
|
||||
def test_edge_case_max_tokens_one(self):
|
||||
"""Test edge case where max_tokens=1"""
|
||||
documents = [" ".join([f"word{i}" for i in range(20)])]
|
||||
|
||||
# max_tokens=1, overlap_tokens=5 should clamp to 0
|
||||
chunked_docs, doc_indices = chunk_documents_for_rerank(
|
||||
documents, max_tokens=1, overlap_tokens=5
|
||||
)
|
||||
|
||||
# Should complete without hanging
|
||||
assert len(chunked_docs) > 0
|
||||
assert all(idx == 0 for idx in doc_indices)
|
||||
564
tests/test_rerank_chunking.py
Normal file
564
tests/test_rerank_chunking.py
Normal file
|
|
@ -0,0 +1,564 @@
|
|||
"""
|
||||
Unit tests for rerank document chunking functionality.
|
||||
|
||||
Tests the chunk_documents_for_rerank and aggregate_chunk_scores functions
|
||||
in lightrag/rerank.py to ensure proper document splitting and score aggregation.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch, AsyncMock
|
||||
from lightrag.rerank import (
|
||||
chunk_documents_for_rerank,
|
||||
aggregate_chunk_scores,
|
||||
cohere_rerank,
|
||||
)
|
||||
|
||||
|
||||
class TestChunkDocumentsForRerank:
|
||||
"""Test suite for chunk_documents_for_rerank function"""
|
||||
|
||||
def test_no_chunking_needed_for_short_docs(self):
|
||||
"""Documents shorter than max_tokens should not be chunked"""
|
||||
documents = [
|
||||
"Short doc 1",
|
||||
"Short doc 2",
|
||||
"Short doc 3",
|
||||
]
|
||||
|
||||
chunked_docs, doc_indices = chunk_documents_for_rerank(
|
||||
documents, max_tokens=100, overlap_tokens=10
|
||||
)
|
||||
|
||||
# No chunking should occur
|
||||
assert len(chunked_docs) == 3
|
||||
assert chunked_docs == documents
|
||||
assert doc_indices == [0, 1, 2]
|
||||
|
||||
def test_chunking_with_character_fallback(self):
|
||||
"""Test chunking falls back to character-based when tokenizer unavailable"""
|
||||
# Create a very long document that exceeds character limit
|
||||
long_doc = "a" * 2000 # 2000 characters
|
||||
documents = [long_doc, "short doc"]
|
||||
|
||||
with patch("lightrag.utils.TiktokenTokenizer", side_effect=ImportError):
|
||||
chunked_docs, doc_indices = chunk_documents_for_rerank(
|
||||
documents,
|
||||
max_tokens=100, # 100 tokens = ~400 chars
|
||||
overlap_tokens=10, # 10 tokens = ~40 chars
|
||||
)
|
||||
|
||||
# First doc should be split into chunks, second doc stays whole
|
||||
assert len(chunked_docs) > 2 # At least one chunk from first doc + second doc
|
||||
assert chunked_docs[-1] == "short doc" # Last chunk is the short doc
|
||||
# Verify doc_indices maps chunks to correct original document
|
||||
assert doc_indices[-1] == 1 # Last chunk maps to document 1
|
||||
|
||||
def test_chunking_with_tiktoken_tokenizer(self):
|
||||
"""Test chunking with actual tokenizer"""
|
||||
# Create document with known token count
|
||||
# Approximate: "word " = ~1 token, so 200 words ~ 200 tokens
|
||||
long_doc = " ".join([f"word{i}" for i in range(200)])
|
||||
documents = [long_doc, "short"]
|
||||
|
||||
chunked_docs, doc_indices = chunk_documents_for_rerank(
|
||||
documents, max_tokens=50, overlap_tokens=10
|
||||
)
|
||||
|
||||
# Long doc should be split, short doc should remain
|
||||
assert len(chunked_docs) > 2
|
||||
assert doc_indices[-1] == 1 # Last chunk is from second document
|
||||
|
||||
# Verify overlapping chunks contain overlapping content
|
||||
if len(chunked_docs) > 2:
|
||||
# Check that consecutive chunks from same doc have some overlap
|
||||
for i in range(len(doc_indices) - 1):
|
||||
if doc_indices[i] == doc_indices[i + 1] == 0:
|
||||
# Both chunks from first doc, should have overlap
|
||||
chunk1_words = chunked_docs[i].split()
|
||||
chunk2_words = chunked_docs[i + 1].split()
|
||||
# At least one word should be common due to overlap
|
||||
assert any(word in chunk2_words for word in chunk1_words[-5:])
|
||||
|
||||
def test_empty_documents(self):
|
||||
"""Test handling of empty document list"""
|
||||
documents = []
|
||||
chunked_docs, doc_indices = chunk_documents_for_rerank(documents)
|
||||
|
||||
assert chunked_docs == []
|
||||
assert doc_indices == []
|
||||
|
||||
def test_single_document_chunking(self):
|
||||
"""Test chunking of a single long document"""
|
||||
# Create document with ~100 tokens
|
||||
long_doc = " ".join([f"token{i}" for i in range(100)])
|
||||
documents = [long_doc]
|
||||
|
||||
chunked_docs, doc_indices = chunk_documents_for_rerank(
|
||||
documents, max_tokens=30, overlap_tokens=5
|
||||
)
|
||||
|
||||
# Should create multiple chunks
|
||||
assert len(chunked_docs) > 1
|
||||
# All chunks should map to document 0
|
||||
assert all(idx == 0 for idx in doc_indices)
|
||||
|
||||
|
||||
class TestAggregateChunkScores:
|
||||
"""Test suite for aggregate_chunk_scores function"""
|
||||
|
||||
def test_no_chunking_simple_aggregation(self):
|
||||
"""Test aggregation when no chunking occurred (1:1 mapping)"""
|
||||
chunk_results = [
|
||||
{"index": 0, "relevance_score": 0.9},
|
||||
{"index": 1, "relevance_score": 0.7},
|
||||
{"index": 2, "relevance_score": 0.5},
|
||||
]
|
||||
doc_indices = [0, 1, 2] # 1:1 mapping
|
||||
num_original_docs = 3
|
||||
|
||||
aggregated = aggregate_chunk_scores(
|
||||
chunk_results, doc_indices, num_original_docs, aggregation="max"
|
||||
)
|
||||
|
||||
# Results should be sorted by score
|
||||
assert len(aggregated) == 3
|
||||
assert aggregated[0]["index"] == 0
|
||||
assert aggregated[0]["relevance_score"] == 0.9
|
||||
assert aggregated[1]["index"] == 1
|
||||
assert aggregated[1]["relevance_score"] == 0.7
|
||||
assert aggregated[2]["index"] == 2
|
||||
assert aggregated[2]["relevance_score"] == 0.5
|
||||
|
||||
def test_max_aggregation_with_chunks(self):
|
||||
"""Test max aggregation strategy with multiple chunks per document"""
|
||||
# 5 chunks: first 3 from doc 0, last 2 from doc 1
|
||||
chunk_results = [
|
||||
{"index": 0, "relevance_score": 0.5},
|
||||
{"index": 1, "relevance_score": 0.8},
|
||||
{"index": 2, "relevance_score": 0.6},
|
||||
{"index": 3, "relevance_score": 0.7},
|
||||
{"index": 4, "relevance_score": 0.4},
|
||||
]
|
||||
doc_indices = [0, 0, 0, 1, 1]
|
||||
num_original_docs = 2
|
||||
|
||||
aggregated = aggregate_chunk_scores(
|
||||
chunk_results, doc_indices, num_original_docs, aggregation="max"
|
||||
)
|
||||
|
||||
# Should take max score for each document
|
||||
assert len(aggregated) == 2
|
||||
assert aggregated[0]["index"] == 0
|
||||
assert aggregated[0]["relevance_score"] == 0.8 # max of 0.5, 0.8, 0.6
|
||||
assert aggregated[1]["index"] == 1
|
||||
assert aggregated[1]["relevance_score"] == 0.7 # max of 0.7, 0.4
|
||||
|
||||
def test_mean_aggregation_with_chunks(self):
|
||||
"""Test mean aggregation strategy"""
|
||||
chunk_results = [
|
||||
{"index": 0, "relevance_score": 0.6},
|
||||
{"index": 1, "relevance_score": 0.8},
|
||||
{"index": 2, "relevance_score": 0.4},
|
||||
]
|
||||
doc_indices = [0, 0, 1] # First two chunks from doc 0, last from doc 1
|
||||
num_original_docs = 2
|
||||
|
||||
aggregated = aggregate_chunk_scores(
|
||||
chunk_results, doc_indices, num_original_docs, aggregation="mean"
|
||||
)
|
||||
|
||||
assert len(aggregated) == 2
|
||||
assert aggregated[0]["index"] == 0
|
||||
assert aggregated[0]["relevance_score"] == pytest.approx(0.7) # (0.6 + 0.8) / 2
|
||||
assert aggregated[1]["index"] == 1
|
||||
assert aggregated[1]["relevance_score"] == 0.4
|
||||
|
||||
def test_first_aggregation_with_chunks(self):
|
||||
"""Test first aggregation strategy"""
|
||||
chunk_results = [
|
||||
{"index": 0, "relevance_score": 0.6},
|
||||
{"index": 1, "relevance_score": 0.8},
|
||||
{"index": 2, "relevance_score": 0.4},
|
||||
]
|
||||
doc_indices = [0, 0, 1]
|
||||
num_original_docs = 2
|
||||
|
||||
aggregated = aggregate_chunk_scores(
|
||||
chunk_results, doc_indices, num_original_docs, aggregation="first"
|
||||
)
|
||||
|
||||
assert len(aggregated) == 2
|
||||
# First should use first score seen for each doc
|
||||
assert aggregated[0]["index"] == 0
|
||||
assert aggregated[0]["relevance_score"] == 0.6 # First score for doc 0
|
||||
assert aggregated[1]["index"] == 1
|
||||
assert aggregated[1]["relevance_score"] == 0.4
|
||||
|
||||
def test_empty_chunk_results(self):
|
||||
"""Test handling of empty results"""
|
||||
aggregated = aggregate_chunk_scores([], [], 3, aggregation="max")
|
||||
assert aggregated == []
|
||||
|
||||
def test_documents_with_no_scores(self):
|
||||
"""Test when some documents have no chunks/scores"""
|
||||
chunk_results = [
|
||||
{"index": 0, "relevance_score": 0.9},
|
||||
{"index": 1, "relevance_score": 0.7},
|
||||
]
|
||||
doc_indices = [0, 0] # Both chunks from document 0
|
||||
num_original_docs = 3 # But we have 3 documents total
|
||||
|
||||
aggregated = aggregate_chunk_scores(
|
||||
chunk_results, doc_indices, num_original_docs, aggregation="max"
|
||||
)
|
||||
|
||||
# Only doc 0 should appear in results
|
||||
assert len(aggregated) == 1
|
||||
assert aggregated[0]["index"] == 0
|
||||
|
||||
def test_unknown_aggregation_strategy(self):
|
||||
"""Test that unknown strategy falls back to max"""
|
||||
chunk_results = [
|
||||
{"index": 0, "relevance_score": 0.6},
|
||||
{"index": 1, "relevance_score": 0.8},
|
||||
]
|
||||
doc_indices = [0, 0]
|
||||
num_original_docs = 1
|
||||
|
||||
# Use invalid strategy
|
||||
aggregated = aggregate_chunk_scores(
|
||||
chunk_results, doc_indices, num_original_docs, aggregation="invalid"
|
||||
)
|
||||
|
||||
# Should fall back to max
|
||||
assert aggregated[0]["relevance_score"] == 0.8
|
||||
|
||||
|
||||
@pytest.mark.offline
|
||||
class TestTopNWithChunking:
|
||||
"""Tests for top_n behavior when chunking is enabled (Bug fix verification)"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_top_n_limits_documents_not_chunks(self):
|
||||
"""
|
||||
Test that top_n correctly limits documents (not chunks) when chunking is enabled.
|
||||
|
||||
Bug scenario: 10 docs expand to 50 chunks. With old behavior, top_n=5 would
|
||||
return scores for only 5 chunks (possibly all from 1-2 docs). After aggregation,
|
||||
fewer than 5 documents would be returned.
|
||||
|
||||
Fixed behavior: top_n=5 should return exactly 5 documents after aggregation.
|
||||
"""
|
||||
# Setup: 5 documents, each producing multiple chunks when chunked
|
||||
# Using small max_tokens to force chunking
|
||||
long_docs = [" ".join([f"doc{i}_word{j}" for j in range(50)]) for i in range(5)]
|
||||
query = "test query"
|
||||
|
||||
# First, determine how many chunks will be created by actual chunking
|
||||
_, doc_indices = chunk_documents_for_rerank(
|
||||
long_docs, max_tokens=50, overlap_tokens=10
|
||||
)
|
||||
num_chunks = len(doc_indices)
|
||||
|
||||
# Mock API returns scores for ALL chunks (simulating disabled API-level top_n)
|
||||
# Give different scores to ensure doc 0 gets highest, doc 1 second, etc.
|
||||
# Assign scores based on original document index (lower doc index = higher score)
|
||||
mock_chunk_scores = []
|
||||
for i in range(num_chunks):
|
||||
original_doc = doc_indices[i]
|
||||
# Higher score for lower doc index, with small variation per chunk
|
||||
base_score = 0.9 - (original_doc * 0.1)
|
||||
mock_chunk_scores.append({"index": i, "relevance_score": base_score})
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value={"results": mock_chunk_scores})
|
||||
mock_response.request_info = None
|
||||
mock_response.history = None
|
||||
mock_response.headers = {}
|
||||
mock_response.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_response.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
mock_session = Mock()
|
||||
mock_session.post = Mock(return_value=mock_response)
|
||||
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_session.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
with patch("lightrag.rerank.aiohttp.ClientSession", return_value=mock_session):
|
||||
result = await cohere_rerank(
|
||||
query=query,
|
||||
documents=long_docs,
|
||||
api_key="test-key",
|
||||
base_url="http://test.com/rerank",
|
||||
enable_chunking=True,
|
||||
max_tokens_per_doc=50, # Match chunking above
|
||||
top_n=3, # Request top 3 documents
|
||||
)
|
||||
|
||||
# Verify: should get exactly 3 documents (not unlimited chunks)
|
||||
assert len(result) == 3
|
||||
# All results should have valid document indices (0-4)
|
||||
assert all(0 <= r["index"] < 5 for r in result)
|
||||
# Results should be sorted by score (descending)
|
||||
assert all(
|
||||
result[i]["relevance_score"] >= result[i + 1]["relevance_score"]
|
||||
for i in range(len(result) - 1)
|
||||
)
|
||||
# The top 3 docs should be 0, 1, 2 (highest scores)
|
||||
result_indices = [r["index"] for r in result]
|
||||
assert set(result_indices) == {0, 1, 2}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_receives_no_top_n_when_chunking_enabled(self):
|
||||
"""
|
||||
Test that the API request does NOT include top_n when chunking is enabled.
|
||||
|
||||
This ensures all chunk scores are retrieved for proper aggregation.
|
||||
"""
|
||||
documents = [" ".join([f"word{i}" for i in range(100)]), "short doc"]
|
||||
query = "test query"
|
||||
|
||||
captured_payload = {}
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(
|
||||
return_value={
|
||||
"results": [
|
||||
{"index": 0, "relevance_score": 0.9},
|
||||
{"index": 1, "relevance_score": 0.8},
|
||||
{"index": 2, "relevance_score": 0.7},
|
||||
]
|
||||
}
|
||||
)
|
||||
mock_response.request_info = None
|
||||
mock_response.history = None
|
||||
mock_response.headers = {}
|
||||
mock_response.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_response.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
def capture_post(*args, **kwargs):
|
||||
captured_payload.update(kwargs.get("json", {}))
|
||||
return mock_response
|
||||
|
||||
mock_session = Mock()
|
||||
mock_session.post = Mock(side_effect=capture_post)
|
||||
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_session.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
with patch("lightrag.rerank.aiohttp.ClientSession", return_value=mock_session):
|
||||
await cohere_rerank(
|
||||
query=query,
|
||||
documents=documents,
|
||||
api_key="test-key",
|
||||
base_url="http://test.com/rerank",
|
||||
enable_chunking=True,
|
||||
max_tokens_per_doc=30,
|
||||
top_n=1, # User wants top 1 document
|
||||
)
|
||||
|
||||
# Verify: API payload should NOT have top_n (disabled for chunking)
|
||||
assert "top_n" not in captured_payload
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_top_n_not_modified_when_chunking_disabled(self):
|
||||
"""
|
||||
Test that top_n is passed through to API when chunking is disabled.
|
||||
"""
|
||||
documents = ["doc1", "doc2"]
|
||||
query = "test query"
|
||||
|
||||
captured_payload = {}
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(
|
||||
return_value={
|
||||
"results": [
|
||||
{"index": 0, "relevance_score": 0.9},
|
||||
]
|
||||
}
|
||||
)
|
||||
mock_response.request_info = None
|
||||
mock_response.history = None
|
||||
mock_response.headers = {}
|
||||
mock_response.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_response.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
def capture_post(*args, **kwargs):
|
||||
captured_payload.update(kwargs.get("json", {}))
|
||||
return mock_response
|
||||
|
||||
mock_session = Mock()
|
||||
mock_session.post = Mock(side_effect=capture_post)
|
||||
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_session.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
with patch("lightrag.rerank.aiohttp.ClientSession", return_value=mock_session):
|
||||
await cohere_rerank(
|
||||
query=query,
|
||||
documents=documents,
|
||||
api_key="test-key",
|
||||
base_url="http://test.com/rerank",
|
||||
enable_chunking=False, # Chunking disabled
|
||||
top_n=1,
|
||||
)
|
||||
|
||||
# Verify: API payload should have top_n when chunking is disabled
|
||||
assert captured_payload.get("top_n") == 1
|
||||
|
||||
|
||||
@pytest.mark.offline
|
||||
class TestCohereRerankChunking:
|
||||
"""Integration tests for cohere_rerank with chunking enabled"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cohere_rerank_with_chunking_disabled(self):
|
||||
"""Test that chunking can be disabled"""
|
||||
documents = ["doc1", "doc2"]
|
||||
query = "test query"
|
||||
|
||||
# Mock the generic_rerank_api
|
||||
with patch(
|
||||
"lightrag.rerank.generic_rerank_api", new_callable=AsyncMock
|
||||
) as mock_api:
|
||||
mock_api.return_value = [
|
||||
{"index": 0, "relevance_score": 0.9},
|
||||
{"index": 1, "relevance_score": 0.7},
|
||||
]
|
||||
|
||||
result = await cohere_rerank(
|
||||
query=query,
|
||||
documents=documents,
|
||||
api_key="test-key",
|
||||
enable_chunking=False,
|
||||
max_tokens_per_doc=100,
|
||||
)
|
||||
|
||||
# Verify generic_rerank_api was called with correct parameters
|
||||
mock_api.assert_called_once()
|
||||
call_kwargs = mock_api.call_args[1]
|
||||
assert call_kwargs["enable_chunking"] is False
|
||||
assert call_kwargs["max_tokens_per_doc"] == 100
|
||||
# Result should mirror mocked scores
|
||||
assert len(result) == 2
|
||||
assert result[0]["index"] == 0
|
||||
assert result[0]["relevance_score"] == 0.9
|
||||
assert result[1]["index"] == 1
|
||||
assert result[1]["relevance_score"] == 0.7
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cohere_rerank_with_chunking_enabled(self):
|
||||
"""Test that chunking parameters are passed through"""
|
||||
documents = ["doc1", "doc2"]
|
||||
query = "test query"
|
||||
|
||||
with patch(
|
||||
"lightrag.rerank.generic_rerank_api", new_callable=AsyncMock
|
||||
) as mock_api:
|
||||
mock_api.return_value = [
|
||||
{"index": 0, "relevance_score": 0.9},
|
||||
{"index": 1, "relevance_score": 0.7},
|
||||
]
|
||||
|
||||
result = await cohere_rerank(
|
||||
query=query,
|
||||
documents=documents,
|
||||
api_key="test-key",
|
||||
enable_chunking=True,
|
||||
max_tokens_per_doc=480,
|
||||
)
|
||||
|
||||
# Verify parameters were passed
|
||||
call_kwargs = mock_api.call_args[1]
|
||||
assert call_kwargs["enable_chunking"] is True
|
||||
assert call_kwargs["max_tokens_per_doc"] == 480
|
||||
# Result should mirror mocked scores
|
||||
assert len(result) == 2
|
||||
assert result[0]["index"] == 0
|
||||
assert result[0]["relevance_score"] == 0.9
|
||||
assert result[1]["index"] == 1
|
||||
assert result[1]["relevance_score"] == 0.7
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cohere_rerank_default_parameters(self):
|
||||
"""Test default parameter values for cohere_rerank"""
|
||||
documents = ["doc1"]
|
||||
query = "test"
|
||||
|
||||
with patch(
|
||||
"lightrag.rerank.generic_rerank_api", new_callable=AsyncMock
|
||||
) as mock_api:
|
||||
mock_api.return_value = [{"index": 0, "relevance_score": 0.9}]
|
||||
|
||||
result = await cohere_rerank(
|
||||
query=query, documents=documents, api_key="test-key"
|
||||
)
|
||||
|
||||
# Verify default values
|
||||
call_kwargs = mock_api.call_args[1]
|
||||
assert call_kwargs["enable_chunking"] is False
|
||||
assert call_kwargs["max_tokens_per_doc"] == 4096
|
||||
assert call_kwargs["model"] == "rerank-v3.5"
|
||||
# Result should mirror mocked scores
|
||||
assert len(result) == 1
|
||||
assert result[0]["index"] == 0
|
||||
assert result[0]["relevance_score"] == 0.9
|
||||
|
||||
|
||||
@pytest.mark.offline
|
||||
class TestEndToEndChunking:
|
||||
"""End-to-end tests for chunking workflow"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_end_to_end_chunking_workflow(self):
|
||||
"""Test complete chunking workflow from documents to aggregated results"""
|
||||
# Create documents where first one needs chunking
|
||||
long_doc = " ".join([f"word{i}" for i in range(100)])
|
||||
documents = [long_doc, "short doc"]
|
||||
query = "test query"
|
||||
|
||||
# Mock the HTTP call inside generic_rerank_api
|
||||
mock_response = Mock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(
|
||||
return_value={
|
||||
"results": [
|
||||
{"index": 0, "relevance_score": 0.5}, # chunk 0 from doc 0
|
||||
{"index": 1, "relevance_score": 0.8}, # chunk 1 from doc 0
|
||||
{"index": 2, "relevance_score": 0.6}, # chunk 2 from doc 0
|
||||
{"index": 3, "relevance_score": 0.7}, # doc 1 (short)
|
||||
]
|
||||
}
|
||||
)
|
||||
mock_response.request_info = None
|
||||
mock_response.history = None
|
||||
mock_response.headers = {}
|
||||
# Make mock_response an async context manager (for `async with session.post() as response`)
|
||||
mock_response.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_response.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
mock_session = Mock()
|
||||
# session.post() returns an async context manager, so return mock_response which is now one
|
||||
mock_session.post = Mock(return_value=mock_response)
|
||||
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_session.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
with patch("lightrag.rerank.aiohttp.ClientSession", return_value=mock_session):
|
||||
result = await cohere_rerank(
|
||||
query=query,
|
||||
documents=documents,
|
||||
api_key="test-key",
|
||||
base_url="http://test.com/rerank",
|
||||
enable_chunking=True,
|
||||
max_tokens_per_doc=30, # Force chunking of long doc
|
||||
)
|
||||
|
||||
# Should get 2 results (one per original document)
|
||||
# The long doc's chunks should be aggregated
|
||||
assert len(result) <= len(documents)
|
||||
# Results should be sorted by score
|
||||
assert all(
|
||||
result[i]["relevance_score"] >= result[i + 1]["relevance_score"]
|
||||
for i in range(len(result) - 1)
|
||||
)
|
||||
Loading…
Add table
Reference in a new issue