chore: sync with upstream (#4)
* chore: sync with upstream - Cohere rerank improvements - Content deduplication - Dependency updates * fix: address CodeRabbit review feedback - Harden env parsing for RERANK_MAX_TOKENS_PER_DOC with try/except - Add @pytest.mark.offline to test_overlap_validation - Remove unused doc_indices variable
This commit is contained in:
parent
99f950671e
commit
9bae6267f6
13 changed files with 2135 additions and 940 deletions
206
.github/dependabot.yml
vendored
Normal file
206
.github/dependabot.yml
vendored
Normal file
|
|
@ -0,0 +1,206 @@
|
||||||
|
# Keep GitHub Actions up to date with GitHub's Dependabot...
|
||||||
|
# https://docs.github.com/en/code-security/dependabot/working-with-dependabot/keeping-your-actions-up-to-date-with-dependabot
|
||||||
|
# https://docs.github.com/en/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file#package-ecosystem
|
||||||
|
version: 2
|
||||||
|
updates:
|
||||||
|
# ============================================================
|
||||||
|
# GitHub Actions
|
||||||
|
# PR Strategy:
|
||||||
|
# - All updates (major/minor/patch): Grouped into a single PR
|
||||||
|
# ============================================================
|
||||||
|
- package-ecosystem: github-actions
|
||||||
|
directory: /
|
||||||
|
groups:
|
||||||
|
github-actions:
|
||||||
|
patterns:
|
||||||
|
- "*" # Group all Actions updates into a single larger pull request
|
||||||
|
schedule:
|
||||||
|
interval: weekly
|
||||||
|
day: monday
|
||||||
|
time: "02:00"
|
||||||
|
timezone: "Asia/Shanghai"
|
||||||
|
labels:
|
||||||
|
- "dependencies"
|
||||||
|
- "github-actions"
|
||||||
|
open-pull-requests-limit: 2
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Python (pip) Dependencies
|
||||||
|
# PR Strategy:
|
||||||
|
# - Major updates: Individual PR per package (except numpy which is ignored)
|
||||||
|
# - Minor updates: Grouped by category (llm-providers, storage, etc.)
|
||||||
|
# - Patch updates: Grouped by category
|
||||||
|
# ============================================================
|
||||||
|
- package-ecosystem: "pip"
|
||||||
|
directory: "/"
|
||||||
|
schedule:
|
||||||
|
interval: "weekly"
|
||||||
|
day: "wednesday"
|
||||||
|
time: "02:00"
|
||||||
|
timezone: "Asia/Shanghai"
|
||||||
|
cooldown:
|
||||||
|
default-days: 5
|
||||||
|
semver-major-days: 30
|
||||||
|
semver-minor-days: 7
|
||||||
|
semver-patch-days: 3
|
||||||
|
groups:
|
||||||
|
# Core dependencies - LLM providers and embeddings
|
||||||
|
llm-providers:
|
||||||
|
patterns:
|
||||||
|
- "openai"
|
||||||
|
- "anthropic"
|
||||||
|
- "google-*"
|
||||||
|
- "boto3"
|
||||||
|
- "botocore"
|
||||||
|
- "ollama"
|
||||||
|
update-types:
|
||||||
|
- "minor"
|
||||||
|
- "patch"
|
||||||
|
# Storage backends
|
||||||
|
storage:
|
||||||
|
patterns:
|
||||||
|
- "neo4j"
|
||||||
|
- "pymongo"
|
||||||
|
- "redis"
|
||||||
|
- "psycopg*"
|
||||||
|
- "asyncpg"
|
||||||
|
- "milvus*"
|
||||||
|
- "qdrant*"
|
||||||
|
update-types:
|
||||||
|
- "minor"
|
||||||
|
- "patch"
|
||||||
|
# Data processing and ML
|
||||||
|
data-processing:
|
||||||
|
patterns:
|
||||||
|
- "numpy"
|
||||||
|
- "scipy"
|
||||||
|
- "pandas"
|
||||||
|
- "tiktoken"
|
||||||
|
- "transformers"
|
||||||
|
- "torch*"
|
||||||
|
update-types:
|
||||||
|
- "minor"
|
||||||
|
- "patch"
|
||||||
|
# Web framework and API
|
||||||
|
web-framework:
|
||||||
|
patterns:
|
||||||
|
- "fastapi"
|
||||||
|
- "uvicorn"
|
||||||
|
- "gunicorn"
|
||||||
|
- "starlette"
|
||||||
|
- "pydantic*"
|
||||||
|
update-types:
|
||||||
|
- "minor"
|
||||||
|
- "patch"
|
||||||
|
# Development and testing tools
|
||||||
|
dev-tools:
|
||||||
|
patterns:
|
||||||
|
- "pytest*"
|
||||||
|
- "ruff"
|
||||||
|
- "pre-commit"
|
||||||
|
- "black"
|
||||||
|
- "mypy"
|
||||||
|
update-types:
|
||||||
|
- "minor"
|
||||||
|
- "patch"
|
||||||
|
# Minor and patch updates for everything else
|
||||||
|
python-minor-patch:
|
||||||
|
patterns:
|
||||||
|
- "*"
|
||||||
|
update-types:
|
||||||
|
- "minor"
|
||||||
|
- "patch"
|
||||||
|
ignore:
|
||||||
|
- dependency-name: "numpy"
|
||||||
|
update-types:
|
||||||
|
- "version-update:semver-major"
|
||||||
|
labels:
|
||||||
|
- "dependencies"
|
||||||
|
- "python"
|
||||||
|
open-pull-requests-limit: 5
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Frontend (bun) Dependencies
|
||||||
|
# PR Strategy:
|
||||||
|
# - Major updates: Individual PR per package
|
||||||
|
# - Minor updates: Grouped by category (react, ui-components, etc.)
|
||||||
|
# - Patch updates: Grouped by category
|
||||||
|
# ============================================================
|
||||||
|
- package-ecosystem: "bun"
|
||||||
|
directory: "/lightrag_webui"
|
||||||
|
schedule:
|
||||||
|
interval: "weekly"
|
||||||
|
day: "friday"
|
||||||
|
time: "02:00"
|
||||||
|
timezone: "Asia/Shanghai"
|
||||||
|
cooldown:
|
||||||
|
default-days: 5
|
||||||
|
semver-major-days: 30
|
||||||
|
semver-minor-days: 7
|
||||||
|
semver-patch-days: 3
|
||||||
|
groups:
|
||||||
|
# React ecosystem
|
||||||
|
react:
|
||||||
|
patterns:
|
||||||
|
- "react"
|
||||||
|
- "react-dom"
|
||||||
|
- "react-router*"
|
||||||
|
- "@types/react*"
|
||||||
|
update-types:
|
||||||
|
- "minor"
|
||||||
|
- "patch"
|
||||||
|
# UI components and styling
|
||||||
|
ui-components:
|
||||||
|
patterns:
|
||||||
|
- "@radix-ui/*"
|
||||||
|
- "tailwind*"
|
||||||
|
- "@tailwindcss/*"
|
||||||
|
- "lucide-react"
|
||||||
|
- "class-variance-authority"
|
||||||
|
- "clsx"
|
||||||
|
update-types:
|
||||||
|
- "minor"
|
||||||
|
- "patch"
|
||||||
|
# Graph visualization
|
||||||
|
graph-viz:
|
||||||
|
patterns:
|
||||||
|
- "sigma"
|
||||||
|
- "@sigma/*"
|
||||||
|
- "graphology*"
|
||||||
|
update-types:
|
||||||
|
- "minor"
|
||||||
|
- "patch"
|
||||||
|
# Build tools and dev dependencies
|
||||||
|
build-tools:
|
||||||
|
patterns:
|
||||||
|
- "vite"
|
||||||
|
- "@vitejs/*"
|
||||||
|
- "typescript"
|
||||||
|
- "eslint*"
|
||||||
|
- "@eslint/*"
|
||||||
|
- "typescript-eslint"
|
||||||
|
- "prettier"
|
||||||
|
- "prettier-*"
|
||||||
|
- "@types/bun"
|
||||||
|
update-types:
|
||||||
|
- "minor"
|
||||||
|
- "patch"
|
||||||
|
# Content rendering libraries (math, diagrams, etc.)
|
||||||
|
content-rendering:
|
||||||
|
patterns:
|
||||||
|
- "katex"
|
||||||
|
- "mermaid"
|
||||||
|
update-types:
|
||||||
|
- "minor"
|
||||||
|
- "patch"
|
||||||
|
# All other minor and patch updates
|
||||||
|
frontend-minor-patch:
|
||||||
|
patterns:
|
||||||
|
- "*"
|
||||||
|
update-types:
|
||||||
|
- "minor"
|
||||||
|
- "patch"
|
||||||
|
labels:
|
||||||
|
- "dependencies"
|
||||||
|
- "frontend"
|
||||||
|
open-pull-requests-limit: 5
|
||||||
6
.github/workflows/docker-publish.yml
vendored
6
.github/workflows/docker-publish.yml
vendored
|
|
@ -12,7 +12,9 @@ jobs:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v6
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
|
||||||
- name: Set up Docker Buildx
|
- name: Set up Docker Buildx
|
||||||
uses: docker/setup-buildx-action@v3
|
uses: docker/setup-buildx-action@v3
|
||||||
|
|
@ -25,7 +27,7 @@ jobs:
|
||||||
password: ${{ secrets.GITHUB_TOKEN }}
|
password: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
|
||||||
- name: Build and push Docker image
|
- name: Build and push Docker image
|
||||||
uses: docker/build-push-action@v5
|
uses: docker/build-push-action@v6
|
||||||
with:
|
with:
|
||||||
context: .
|
context: .
|
||||||
file: ./Dockerfile
|
file: ./Dockerfile
|
||||||
|
|
|
||||||
|
|
@ -102,6 +102,9 @@ RERANK_BINDING=null
|
||||||
# RERANK_MODEL=rerank-v3.5
|
# RERANK_MODEL=rerank-v3.5
|
||||||
# RERANK_BINDING_HOST=https://api.cohere.com/v2/rerank
|
# RERANK_BINDING_HOST=https://api.cohere.com/v2/rerank
|
||||||
# RERANK_BINDING_API_KEY=your_rerank_api_key_here
|
# RERANK_BINDING_API_KEY=your_rerank_api_key_here
|
||||||
|
### Cohere rerank chunking configuration (useful for models with token limits like ColBERT)
|
||||||
|
# RERANK_ENABLE_CHUNKING=true
|
||||||
|
# RERANK_MAX_TOKENS_PER_DOC=480
|
||||||
|
|
||||||
### Default value for Jina AI
|
### Default value for Jina AI
|
||||||
# RERANK_MODEL=jina-reranker-v2-base-multilingual
|
# RERANK_MODEL=jina-reranker-v2-base-multilingual
|
||||||
|
|
|
||||||
|
|
@ -15,9 +15,12 @@ Configuration Required:
|
||||||
EMBEDDING_BINDING_HOST
|
EMBEDDING_BINDING_HOST
|
||||||
EMBEDDING_BINDING_API_KEY
|
EMBEDDING_BINDING_API_KEY
|
||||||
3. Set your vLLM deployed AI rerank model setting with env vars:
|
3. Set your vLLM deployed AI rerank model setting with env vars:
|
||||||
RERANK_MODEL
|
RERANK_BINDING=cohere
|
||||||
RERANK_BINDING_HOST
|
RERANK_MODEL (e.g., answerai-colbert-small-v1 or rerank-v3.5)
|
||||||
|
RERANK_BINDING_HOST (e.g., https://api.cohere.com/v2/rerank or LiteLLM proxy)
|
||||||
RERANK_BINDING_API_KEY
|
RERANK_BINDING_API_KEY
|
||||||
|
RERANK_ENABLE_CHUNKING=true (optional, for models with token limits)
|
||||||
|
RERANK_MAX_TOKENS_PER_DOC=480 (optional, default 4096)
|
||||||
|
|
||||||
Note: Rerank is controlled per query via the 'enable_rerank' parameter (default: True)
|
Note: Rerank is controlled per query via the 'enable_rerank' parameter (default: True)
|
||||||
"""
|
"""
|
||||||
|
|
@ -66,9 +69,11 @@ async def embedding_func(texts: list[str]) -> np.ndarray:
|
||||||
|
|
||||||
rerank_model_func = partial(
|
rerank_model_func = partial(
|
||||||
cohere_rerank,
|
cohere_rerank,
|
||||||
model=os.getenv("RERANK_MODEL"),
|
model=os.getenv("RERANK_MODEL", "rerank-v3.5"),
|
||||||
api_key=os.getenv("RERANK_BINDING_API_KEY"),
|
api_key=os.getenv("RERANK_BINDING_API_KEY"),
|
||||||
base_url=os.getenv("RERANK_BINDING_HOST"),
|
base_url=os.getenv("RERANK_BINDING_HOST", "https://api.cohere.com/v2/rerank"),
|
||||||
|
enable_chunking=os.getenv("RERANK_ENABLE_CHUNKING", "false").lower() == "true",
|
||||||
|
max_tokens_per_doc=int(os.getenv("RERANK_MAX_TOKENS_PER_DOC", "4096")),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1 +1 @@
|
||||||
__api_version__ = "0258"
|
__api_version__ = "0259"
|
||||||
|
|
|
||||||
|
|
@ -1023,15 +1023,30 @@ def create_app(args):
|
||||||
query: str, documents: list, top_n: int = None, extra_body: dict = None
|
query: str, documents: list, top_n: int = None, extra_body: dict = None
|
||||||
):
|
):
|
||||||
"""Server rerank function with configuration from environment variables"""
|
"""Server rerank function with configuration from environment variables"""
|
||||||
return await selected_rerank_func(
|
# Prepare kwargs for rerank function
|
||||||
query=query,
|
kwargs = {
|
||||||
documents=documents,
|
"query": query,
|
||||||
top_n=top_n,
|
"documents": documents,
|
||||||
api_key=args.rerank_binding_api_key,
|
"top_n": top_n,
|
||||||
model=args.rerank_model,
|
"api_key": args.rerank_binding_api_key,
|
||||||
base_url=args.rerank_binding_host,
|
"model": args.rerank_model,
|
||||||
extra_body=extra_body,
|
"base_url": args.rerank_binding_host,
|
||||||
)
|
}
|
||||||
|
|
||||||
|
# Add Cohere-specific parameters if using cohere binding
|
||||||
|
if args.rerank_binding == "cohere":
|
||||||
|
# Enable chunking if configured (useful for models with token limits like ColBERT)
|
||||||
|
kwargs["enable_chunking"] = (
|
||||||
|
os.getenv("RERANK_ENABLE_CHUNKING", "false").lower() == "true"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
kwargs["max_tokens_per_doc"] = int(
|
||||||
|
os.getenv("RERANK_MAX_TOKENS_PER_DOC", "4096")
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
kwargs["max_tokens_per_doc"] = 4096
|
||||||
|
|
||||||
|
return await selected_rerank_func(**kwargs, extra_body=extra_body)
|
||||||
|
|
||||||
rerank_model_func = server_rerank_func
|
rerank_model_func = server_rerank_func
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|
|
||||||
|
|
@ -24,7 +24,11 @@ from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
from lightrag import LightRAG
|
from lightrag import LightRAG
|
||||||
from lightrag.base import DeletionResult, DocProcessingStatus, DocStatus
|
from lightrag.base import DeletionResult, DocProcessingStatus, DocStatus
|
||||||
from lightrag.utils import generate_track_id
|
from lightrag.utils import (
|
||||||
|
generate_track_id,
|
||||||
|
compute_mdhash_id,
|
||||||
|
sanitize_text_for_encoding,
|
||||||
|
)
|
||||||
from lightrag.api.utils_api import get_combined_auth_dependency
|
from lightrag.api.utils_api import get_combined_auth_dependency
|
||||||
from ..config import global_args
|
from ..config import global_args
|
||||||
|
|
||||||
|
|
@ -159,7 +163,7 @@ class ReprocessResponse(BaseModel):
|
||||||
Attributes:
|
Attributes:
|
||||||
status: Status of the reprocessing operation
|
status: Status of the reprocessing operation
|
||||||
message: Message describing the operation result
|
message: Message describing the operation result
|
||||||
track_id: Tracking ID for monitoring reprocessing progress
|
track_id: Always empty string. Reprocessed documents retain their original track_id.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
status: Literal["reprocessing_started"] = Field(
|
status: Literal["reprocessing_started"] = Field(
|
||||||
|
|
@ -167,7 +171,8 @@ class ReprocessResponse(BaseModel):
|
||||||
)
|
)
|
||||||
message: str = Field(description="Human-readable message describing the operation")
|
message: str = Field(description="Human-readable message describing the operation")
|
||||||
track_id: str = Field(
|
track_id: str = Field(
|
||||||
description="Tracking ID for monitoring reprocessing progress"
|
default="",
|
||||||
|
description="Always empty string. Reprocessed documents retain their original track_id from initial upload.",
|
||||||
)
|
)
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
|
|
@ -175,7 +180,7 @@ class ReprocessResponse(BaseModel):
|
||||||
"example": {
|
"example": {
|
||||||
"status": "reprocessing_started",
|
"status": "reprocessing_started",
|
||||||
"message": "Reprocessing of failed documents has been initiated in background",
|
"message": "Reprocessing of failed documents has been initiated in background",
|
||||||
"track_id": "retry_20250729_170612_def456",
|
"track_id": "",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -2097,12 +2102,14 @@ def create_document_routes(
|
||||||
# Check if filename already exists in doc_status storage
|
# Check if filename already exists in doc_status storage
|
||||||
existing_doc_data = await rag.doc_status.get_doc_by_file_path(safe_filename)
|
existing_doc_data = await rag.doc_status.get_doc_by_file_path(safe_filename)
|
||||||
if existing_doc_data:
|
if existing_doc_data:
|
||||||
# Get document status information for error message
|
# Get document status and track_id from existing document
|
||||||
status = existing_doc_data.get("status", "unknown")
|
status = existing_doc_data.get("status", "unknown")
|
||||||
|
# Use `or ""` to handle both missing key and None value (e.g., legacy rows without track_id)
|
||||||
|
existing_track_id = existing_doc_data.get("track_id") or ""
|
||||||
return InsertResponse(
|
return InsertResponse(
|
||||||
status="duplicated",
|
status="duplicated",
|
||||||
message=f"File '{safe_filename}' already exists in document storage (Status: {status}).",
|
message=f"File '{safe_filename}' already exists in document storage (Status: {status}).",
|
||||||
track_id="",
|
track_id=existing_track_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
file_path = doc_manager.input_dir / safe_filename
|
file_path = doc_manager.input_dir / safe_filename
|
||||||
|
|
@ -2166,14 +2173,30 @@ def create_document_routes(
|
||||||
request.file_source
|
request.file_source
|
||||||
)
|
)
|
||||||
if existing_doc_data:
|
if existing_doc_data:
|
||||||
# Get document status information for error message
|
# Get document status and track_id from existing document
|
||||||
status = existing_doc_data.get("status", "unknown")
|
status = existing_doc_data.get("status", "unknown")
|
||||||
|
# Use `or ""` to handle both missing key and None value (e.g., legacy rows without track_id)
|
||||||
|
existing_track_id = existing_doc_data.get("track_id") or ""
|
||||||
return InsertResponse(
|
return InsertResponse(
|
||||||
status="duplicated",
|
status="duplicated",
|
||||||
message=f"File source '{request.file_source}' already exists in document storage (Status: {status}).",
|
message=f"File source '{request.file_source}' already exists in document storage (Status: {status}).",
|
||||||
track_id="",
|
track_id=existing_track_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Check if content already exists by computing content hash (doc_id)
|
||||||
|
sanitized_text = sanitize_text_for_encoding(request.text)
|
||||||
|
content_doc_id = compute_mdhash_id(sanitized_text, prefix="doc-")
|
||||||
|
existing_doc = await rag.doc_status.get_by_id(content_doc_id)
|
||||||
|
if existing_doc:
|
||||||
|
# Content already exists, return duplicated with existing track_id
|
||||||
|
status = existing_doc.get("status", "unknown")
|
||||||
|
existing_track_id = existing_doc.get("track_id") or ""
|
||||||
|
return InsertResponse(
|
||||||
|
status="duplicated",
|
||||||
|
message=f"Identical content already exists in document storage (doc_id: {content_doc_id}, Status: {status}).",
|
||||||
|
track_id=existing_track_id,
|
||||||
|
)
|
||||||
|
|
||||||
# Generate track_id for text insertion
|
# Generate track_id for text insertion
|
||||||
track_id = generate_track_id("insert")
|
track_id = generate_track_id("insert")
|
||||||
|
|
||||||
|
|
@ -2232,14 +2255,31 @@ def create_document_routes(
|
||||||
file_source
|
file_source
|
||||||
)
|
)
|
||||||
if existing_doc_data:
|
if existing_doc_data:
|
||||||
# Get document status information for error message
|
# Get document status and track_id from existing document
|
||||||
status = existing_doc_data.get("status", "unknown")
|
status = existing_doc_data.get("status", "unknown")
|
||||||
|
# Use `or ""` to handle both missing key and None value (e.g., legacy rows without track_id)
|
||||||
|
existing_track_id = existing_doc_data.get("track_id") or ""
|
||||||
return InsertResponse(
|
return InsertResponse(
|
||||||
status="duplicated",
|
status="duplicated",
|
||||||
message=f"File source '{file_source}' already exists in document storage (Status: {status}).",
|
message=f"File source '{file_source}' already exists in document storage (Status: {status}).",
|
||||||
track_id="",
|
track_id=existing_track_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Check if any content already exists by computing content hash (doc_id)
|
||||||
|
for text in request.texts:
|
||||||
|
sanitized_text = sanitize_text_for_encoding(text)
|
||||||
|
content_doc_id = compute_mdhash_id(sanitized_text, prefix="doc-")
|
||||||
|
existing_doc = await rag.doc_status.get_by_id(content_doc_id)
|
||||||
|
if existing_doc:
|
||||||
|
# Content already exists, return duplicated with existing track_id
|
||||||
|
status = existing_doc.get("status", "unknown")
|
||||||
|
existing_track_id = existing_doc.get("track_id") or ""
|
||||||
|
return InsertResponse(
|
||||||
|
status="duplicated",
|
||||||
|
message=f"Identical content already exists in document storage (doc_id: {content_doc_id}, Status: {status}).",
|
||||||
|
track_id=existing_track_id,
|
||||||
|
)
|
||||||
|
|
||||||
# Generate track_id for texts insertion
|
# Generate track_id for texts insertion
|
||||||
track_id = generate_track_id("insert")
|
track_id = generate_track_id("insert")
|
||||||
|
|
||||||
|
|
@ -3058,29 +3098,27 @@ def create_document_routes(
|
||||||
This is useful for recovering from server crashes, network errors, LLM service
|
This is useful for recovering from server crashes, network errors, LLM service
|
||||||
outages, or other temporary failures that caused document processing to fail.
|
outages, or other temporary failures that caused document processing to fail.
|
||||||
|
|
||||||
The processing happens in the background and can be monitored using the
|
The processing happens in the background and can be monitored by checking the
|
||||||
returned track_id or by checking the pipeline status.
|
pipeline status. The reprocessed documents retain their original track_id from
|
||||||
|
initial upload, so use their original track_id to monitor progress.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ReprocessResponse: Response with status, message, and track_id
|
ReprocessResponse: Response with status and message.
|
||||||
|
track_id is always empty string because reprocessed documents retain
|
||||||
|
their original track_id from initial upload.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
HTTPException: If an error occurs while initiating reprocessing (500).
|
HTTPException: If an error occurs while initiating reprocessing (500).
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Generate track_id with "retry" prefix for retry operation
|
|
||||||
track_id = generate_track_id("retry")
|
|
||||||
|
|
||||||
# Start the reprocessing in the background
|
# Start the reprocessing in the background
|
||||||
|
# Note: Reprocessed documents retain their original track_id from initial upload
|
||||||
background_tasks.add_task(rag.apipeline_process_enqueue_documents)
|
background_tasks.add_task(rag.apipeline_process_enqueue_documents)
|
||||||
logger.info(
|
logger.info("Reprocessing of failed documents initiated")
|
||||||
f"Reprocessing of failed documents initiated with track_id: {track_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return ReprocessResponse(
|
return ReprocessResponse(
|
||||||
status="reprocessing_started",
|
status="reprocessing_started",
|
||||||
message="Reprocessing of failed documents has been initiated in background",
|
message="Reprocessing of failed documents has been initiated in background. Documents retain their original track_id.",
|
||||||
track_id=track_id,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from typing import Any, List, Dict, Optional
|
from typing import Any, List, Dict, Optional, Tuple
|
||||||
from tenacity import (
|
from tenacity import (
|
||||||
retry,
|
retry,
|
||||||
stop_after_attempt,
|
stop_after_attempt,
|
||||||
|
|
@ -19,6 +19,158 @@ from dotenv import load_dotenv
|
||||||
load_dotenv(dotenv_path=".env", override=False)
|
load_dotenv(dotenv_path=".env", override=False)
|
||||||
|
|
||||||
|
|
||||||
|
def chunk_documents_for_rerank(
|
||||||
|
documents: List[str],
|
||||||
|
max_tokens: int = 480,
|
||||||
|
overlap_tokens: int = 32,
|
||||||
|
tokenizer_model: str = "gpt-4o-mini",
|
||||||
|
) -> Tuple[List[str], List[int]]:
|
||||||
|
"""
|
||||||
|
Chunk documents that exceed token limit for reranking.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
documents: List of document strings to chunk
|
||||||
|
max_tokens: Maximum tokens per chunk (default 480 to leave margin for 512 limit)
|
||||||
|
overlap_tokens: Number of tokens to overlap between chunks
|
||||||
|
tokenizer_model: Model name for tiktoken tokenizer
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (chunked_documents, original_doc_indices)
|
||||||
|
- chunked_documents: List of document chunks (may be more than input)
|
||||||
|
- original_doc_indices: Maps each chunk back to its original document index
|
||||||
|
"""
|
||||||
|
# Clamp overlap_tokens to ensure the loop always advances
|
||||||
|
# If overlap_tokens >= max_tokens, the chunking loop would hang
|
||||||
|
if overlap_tokens >= max_tokens:
|
||||||
|
original_overlap = overlap_tokens
|
||||||
|
# Ensure overlap is at least 1 token less than max to guarantee progress
|
||||||
|
# For very small max_tokens (e.g., 1), set overlap to 0
|
||||||
|
overlap_tokens = max(0, max_tokens - 1)
|
||||||
|
logger.warning(
|
||||||
|
f"overlap_tokens ({original_overlap}) must be less than max_tokens ({max_tokens}). "
|
||||||
|
f"Clamping to {overlap_tokens} to prevent infinite loop."
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from .utils import TiktokenTokenizer
|
||||||
|
|
||||||
|
tokenizer = TiktokenTokenizer(model_name=tokenizer_model)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to initialize tokenizer: {e}. Using character-based approximation."
|
||||||
|
)
|
||||||
|
# Fallback: approximate 1 token ≈ 4 characters
|
||||||
|
max_chars = max_tokens * 4
|
||||||
|
overlap_chars = overlap_tokens * 4
|
||||||
|
|
||||||
|
chunked_docs = []
|
||||||
|
doc_indices = []
|
||||||
|
|
||||||
|
for idx, doc in enumerate(documents):
|
||||||
|
if len(doc) <= max_chars:
|
||||||
|
chunked_docs.append(doc)
|
||||||
|
doc_indices.append(idx)
|
||||||
|
else:
|
||||||
|
# Split into overlapping chunks
|
||||||
|
start = 0
|
||||||
|
while start < len(doc):
|
||||||
|
end = min(start + max_chars, len(doc))
|
||||||
|
chunk = doc[start:end]
|
||||||
|
chunked_docs.append(chunk)
|
||||||
|
doc_indices.append(idx)
|
||||||
|
|
||||||
|
if end >= len(doc):
|
||||||
|
break
|
||||||
|
start = end - overlap_chars
|
||||||
|
|
||||||
|
return chunked_docs, doc_indices
|
||||||
|
|
||||||
|
# Use tokenizer for accurate chunking
|
||||||
|
chunked_docs = []
|
||||||
|
doc_indices = []
|
||||||
|
|
||||||
|
for idx, doc in enumerate(documents):
|
||||||
|
tokens = tokenizer.encode(doc)
|
||||||
|
|
||||||
|
if len(tokens) <= max_tokens:
|
||||||
|
# Document fits in one chunk
|
||||||
|
chunked_docs.append(doc)
|
||||||
|
doc_indices.append(idx)
|
||||||
|
else:
|
||||||
|
# Split into overlapping chunks
|
||||||
|
start = 0
|
||||||
|
while start < len(tokens):
|
||||||
|
end = min(start + max_tokens, len(tokens))
|
||||||
|
chunk_tokens = tokens[start:end]
|
||||||
|
chunk_text = tokenizer.decode(chunk_tokens)
|
||||||
|
chunked_docs.append(chunk_text)
|
||||||
|
doc_indices.append(idx)
|
||||||
|
|
||||||
|
if end >= len(tokens):
|
||||||
|
break
|
||||||
|
start = end - overlap_tokens
|
||||||
|
|
||||||
|
return chunked_docs, doc_indices
|
||||||
|
|
||||||
|
|
||||||
|
def aggregate_chunk_scores(
|
||||||
|
chunk_results: List[Dict[str, Any]],
|
||||||
|
doc_indices: List[int],
|
||||||
|
num_original_docs: int,
|
||||||
|
aggregation: str = "max",
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Aggregate rerank scores from document chunks back to original documents.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chunk_results: Rerank results for chunks [{"index": chunk_idx, "relevance_score": score}, ...]
|
||||||
|
doc_indices: Maps each chunk index to original document index
|
||||||
|
num_original_docs: Total number of original documents
|
||||||
|
aggregation: Strategy for aggregating scores ("max", "mean", "first")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of results for original documents [{"index": doc_idx, "relevance_score": score}, ...]
|
||||||
|
"""
|
||||||
|
# Group scores by original document index
|
||||||
|
doc_scores: Dict[int, List[float]] = {i: [] for i in range(num_original_docs)}
|
||||||
|
|
||||||
|
for result in chunk_results:
|
||||||
|
chunk_idx = result["index"]
|
||||||
|
score = result["relevance_score"]
|
||||||
|
|
||||||
|
if 0 <= chunk_idx < len(doc_indices):
|
||||||
|
original_doc_idx = doc_indices[chunk_idx]
|
||||||
|
doc_scores[original_doc_idx].append(score)
|
||||||
|
|
||||||
|
# Aggregate scores
|
||||||
|
aggregated_results = []
|
||||||
|
for doc_idx, scores in doc_scores.items():
|
||||||
|
if not scores:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if aggregation == "max":
|
||||||
|
final_score = max(scores)
|
||||||
|
elif aggregation == "mean":
|
||||||
|
final_score = sum(scores) / len(scores)
|
||||||
|
elif aggregation == "first":
|
||||||
|
final_score = scores[0]
|
||||||
|
else:
|
||||||
|
logger.warning(f"Unknown aggregation strategy: {aggregation}, using max")
|
||||||
|
final_score = max(scores)
|
||||||
|
|
||||||
|
aggregated_results.append(
|
||||||
|
{
|
||||||
|
"index": doc_idx,
|
||||||
|
"relevance_score": final_score,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Sort by relevance score (descending)
|
||||||
|
aggregated_results.sort(key=lambda x: x["relevance_score"], reverse=True)
|
||||||
|
|
||||||
|
return aggregated_results
|
||||||
|
|
||||||
|
|
||||||
@retry(
|
@retry(
|
||||||
stop=stop_after_attempt(3),
|
stop=stop_after_attempt(3),
|
||||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||||
|
|
@ -38,6 +190,8 @@ async def generic_rerank_api(
|
||||||
extra_body: Optional[Dict[str, Any]] = None,
|
extra_body: Optional[Dict[str, Any]] = None,
|
||||||
response_format: str = "standard", # "standard" (Jina/Cohere) or "aliyun"
|
response_format: str = "standard", # "standard" (Jina/Cohere) or "aliyun"
|
||||||
request_format: str = "standard", # "standard" (Jina/Cohere) or "aliyun"
|
request_format: str = "standard", # "standard" (Jina/Cohere) or "aliyun"
|
||||||
|
enable_chunking: bool = False,
|
||||||
|
max_tokens_per_doc: int = 480,
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Generic rerank API call for Jina/Cohere/Aliyun models.
|
Generic rerank API call for Jina/Cohere/Aliyun models.
|
||||||
|
|
@ -52,6 +206,9 @@ async def generic_rerank_api(
|
||||||
return_documents: Whether to return document text (Jina only)
|
return_documents: Whether to return document text (Jina only)
|
||||||
extra_body: Additional body parameters
|
extra_body: Additional body parameters
|
||||||
response_format: Response format type ("standard" for Jina/Cohere, "aliyun" for Aliyun)
|
response_format: Response format type ("standard" for Jina/Cohere, "aliyun" for Aliyun)
|
||||||
|
request_format: Request format type
|
||||||
|
enable_chunking: Whether to chunk documents exceeding token limit
|
||||||
|
max_tokens_per_doc: Maximum tokens per document for chunking
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of dictionary of ["index": int, "relevance_score": float]
|
List of dictionary of ["index": int, "relevance_score": float]
|
||||||
|
|
@ -63,6 +220,27 @@ async def generic_rerank_api(
|
||||||
if api_key is not None:
|
if api_key is not None:
|
||||||
headers["Authorization"] = f"Bearer {api_key}"
|
headers["Authorization"] = f"Bearer {api_key}"
|
||||||
|
|
||||||
|
# Handle document chunking if enabled
|
||||||
|
original_documents = documents
|
||||||
|
doc_indices = None
|
||||||
|
original_top_n = top_n # Save original top_n for post-aggregation limiting
|
||||||
|
|
||||||
|
if enable_chunking:
|
||||||
|
documents, doc_indices = chunk_documents_for_rerank(
|
||||||
|
documents, max_tokens=max_tokens_per_doc
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
f"Chunked {len(original_documents)} documents into {len(documents)} chunks"
|
||||||
|
)
|
||||||
|
# When chunking is enabled, disable top_n at API level to get all chunk scores
|
||||||
|
# This ensures proper document-level coverage after aggregation
|
||||||
|
# We'll apply top_n to aggregated document results instead
|
||||||
|
if top_n is not None:
|
||||||
|
logger.debug(
|
||||||
|
f"Chunking enabled: disabled API-level top_n={top_n} to ensure complete document coverage"
|
||||||
|
)
|
||||||
|
top_n = None
|
||||||
|
|
||||||
# Build request payload based on request format
|
# Build request payload based on request format
|
||||||
if request_format == "aliyun":
|
if request_format == "aliyun":
|
||||||
# Aliyun format: nested input/parameters structure
|
# Aliyun format: nested input/parameters structure
|
||||||
|
|
@ -86,7 +264,7 @@ async def generic_rerank_api(
|
||||||
if extra_body:
|
if extra_body:
|
||||||
payload["parameters"].update(extra_body)
|
payload["parameters"].update(extra_body)
|
||||||
else:
|
else:
|
||||||
# Standard format for Jina/Cohere
|
# Standard format for Jina/Cohere/OpenAI
|
||||||
payload = {
|
payload = {
|
||||||
"model": model,
|
"model": model,
|
||||||
"query": query,
|
"query": query,
|
||||||
|
|
@ -98,7 +276,7 @@ async def generic_rerank_api(
|
||||||
payload["top_n"] = top_n
|
payload["top_n"] = top_n
|
||||||
|
|
||||||
# Only Jina API supports return_documents parameter
|
# Only Jina API supports return_documents parameter
|
||||||
if return_documents is not None:
|
if return_documents is not None and response_format in ("standard",):
|
||||||
payload["return_documents"] = return_documents
|
payload["return_documents"] = return_documents
|
||||||
|
|
||||||
# Add extra parameters
|
# Add extra parameters
|
||||||
|
|
@ -147,7 +325,6 @@ async def generic_rerank_api(
|
||||||
f"Expected 'output.results' to be list, got {type(results)}: {results}"
|
f"Expected 'output.results' to be list, got {type(results)}: {results}"
|
||||||
)
|
)
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
elif response_format == "standard":
|
elif response_format == "standard":
|
||||||
# Standard format: {"results": [...]}
|
# Standard format: {"results": [...]}
|
||||||
results = response_json.get("results", [])
|
results = response_json.get("results", [])
|
||||||
|
|
@ -158,16 +335,35 @@ async def generic_rerank_api(
|
||||||
results = []
|
results = []
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported response format: {response_format}")
|
raise ValueError(f"Unsupported response format: {response_format}")
|
||||||
|
|
||||||
if not results:
|
if not results:
|
||||||
logger.warning("Rerank API returned empty results")
|
logger.warning("Rerank API returned empty results")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# Standardize return format
|
# Standardize return format
|
||||||
return [
|
standardized_results = [
|
||||||
{"index": result["index"], "relevance_score": result["relevance_score"]}
|
{"index": result["index"], "relevance_score": result["relevance_score"]}
|
||||||
for result in results
|
for result in results
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Aggregate chunk scores back to original documents if chunking was enabled
|
||||||
|
if enable_chunking and doc_indices:
|
||||||
|
standardized_results = aggregate_chunk_scores(
|
||||||
|
standardized_results,
|
||||||
|
doc_indices,
|
||||||
|
len(original_documents),
|
||||||
|
aggregation="max",
|
||||||
|
)
|
||||||
|
# Apply original top_n limit at document level (post-aggregation)
|
||||||
|
# This preserves document-level semantics: top_n limits documents, not chunks
|
||||||
|
if (
|
||||||
|
original_top_n is not None
|
||||||
|
and len(standardized_results) > original_top_n
|
||||||
|
):
|
||||||
|
standardized_results = standardized_results[:original_top_n]
|
||||||
|
|
||||||
|
return standardized_results
|
||||||
|
|
||||||
|
|
||||||
async def cohere_rerank(
|
async def cohere_rerank(
|
||||||
query: str,
|
query: str,
|
||||||
|
|
@ -177,21 +373,46 @@ async def cohere_rerank(
|
||||||
model: str = "rerank-v3.5",
|
model: str = "rerank-v3.5",
|
||||||
base_url: str = "https://api.cohere.com/v2/rerank",
|
base_url: str = "https://api.cohere.com/v2/rerank",
|
||||||
extra_body: Optional[Dict[str, Any]] = None,
|
extra_body: Optional[Dict[str, Any]] = None,
|
||||||
|
enable_chunking: bool = False,
|
||||||
|
max_tokens_per_doc: int = 4096,
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Rerank documents using Cohere API.
|
Rerank documents using Cohere API.
|
||||||
|
|
||||||
|
Supports both standard Cohere API and Cohere-compatible proxies
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query: The search query
|
query: The search query
|
||||||
documents: List of strings to rerank
|
documents: List of strings to rerank
|
||||||
top_n: Number of top results to return
|
top_n: Number of top results to return
|
||||||
api_key: API key
|
api_key: API key for authentication
|
||||||
model: rerank model name
|
model: rerank model name (default: rerank-v3.5)
|
||||||
base_url: API endpoint
|
base_url: API endpoint
|
||||||
extra_body: Additional body for http request(reserved for extra params)
|
extra_body: Additional body for http request(reserved for extra params)
|
||||||
|
enable_chunking: Whether to chunk documents exceeding max_tokens_per_doc
|
||||||
|
max_tokens_per_doc: Maximum tokens per document (default: 4096 for Cohere v3.5)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of dictionary of ["index": int, "relevance_score": float]
|
List of dictionary of ["index": int, "relevance_score": float]
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> # Standard Cohere API
|
||||||
|
>>> results = await cohere_rerank(
|
||||||
|
... query="What is the meaning of life?",
|
||||||
|
... documents=["Doc1", "Doc2"],
|
||||||
|
... api_key="your-cohere-key"
|
||||||
|
... )
|
||||||
|
|
||||||
|
>>> # LiteLLM proxy with user authentication
|
||||||
|
>>> results = await cohere_rerank(
|
||||||
|
... query="What is vector search?",
|
||||||
|
... documents=["Doc1", "Doc2"],
|
||||||
|
... model="answerai-colbert-small-v1",
|
||||||
|
... base_url="https://llm-proxy.example.com/v2/rerank",
|
||||||
|
... api_key="your-proxy-key",
|
||||||
|
... enable_chunking=True,
|
||||||
|
... max_tokens_per_doc=480
|
||||||
|
... )
|
||||||
"""
|
"""
|
||||||
if api_key is None:
|
if api_key is None:
|
||||||
api_key = os.getenv("COHERE_API_KEY") or os.getenv("RERANK_BINDING_API_KEY")
|
api_key = os.getenv("COHERE_API_KEY") or os.getenv("RERANK_BINDING_API_KEY")
|
||||||
|
|
@ -206,6 +427,8 @@ async def cohere_rerank(
|
||||||
return_documents=None, # Cohere doesn't support this parameter
|
return_documents=None, # Cohere doesn't support this parameter
|
||||||
extra_body=extra_body,
|
extra_body=extra_body,
|
||||||
response_format="standard",
|
response_format="standard",
|
||||||
|
enable_chunking=enable_chunking,
|
||||||
|
max_tokens_per_doc=max_tokens_per_doc,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load diff
|
|
@ -16,17 +16,17 @@
|
||||||
"preview-no-bun": "vite preview"
|
"preview-no-bun": "vite preview"
|
||||||
},
|
},
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"@faker-js/faker": "^9.9.0",
|
"@faker-js/faker": "^10.1.0",
|
||||||
"@radix-ui/react-alert-dialog": "^1.1.15",
|
"@radix-ui/react-alert-dialog": "^1.1.15",
|
||||||
"@radix-ui/react-checkbox": "^1.3.3",
|
"@radix-ui/react-checkbox": "^1.3.3",
|
||||||
"@radix-ui/react-dialog": "^1.1.15",
|
"@radix-ui/react-dialog": "^1.1.15",
|
||||||
"@radix-ui/react-hover-card": "^1.1.15",
|
"@radix-ui/react-hover-card": "^1.1.15",
|
||||||
"@radix-ui/react-popover": "^1.1.15",
|
"@radix-ui/react-popover": "^1.1.15",
|
||||||
"@radix-ui/react-progress": "^1.1.7",
|
"@radix-ui/react-progress": "^1.1.8",
|
||||||
"@radix-ui/react-scroll-area": "^1.2.10",
|
"@radix-ui/react-scroll-area": "^1.2.10",
|
||||||
"@radix-ui/react-select": "^2.2.6",
|
"@radix-ui/react-select": "^2.2.6",
|
||||||
"@radix-ui/react-separator": "^1.1.7",
|
"@radix-ui/react-separator": "^1.1.8",
|
||||||
"@radix-ui/react-slot": "^1.2.3",
|
"@radix-ui/react-slot": "^1.2.4",
|
||||||
"@radix-ui/react-tabs": "^1.1.13",
|
"@radix-ui/react-tabs": "^1.1.13",
|
||||||
"@radix-ui/react-tooltip": "^1.2.8",
|
"@radix-ui/react-tooltip": "^1.2.8",
|
||||||
"@radix-ui/react-use-controllable-state": "^1.2.2",
|
"@radix-ui/react-use-controllable-state": "^1.2.2",
|
||||||
|
|
@ -43,7 +43,7 @@
|
||||||
"@sigma/node-border": "^3.0.0",
|
"@sigma/node-border": "^3.0.0",
|
||||||
"@tanstack/react-query": "^5.87.1",
|
"@tanstack/react-query": "^5.87.1",
|
||||||
"@tanstack/react-table": "^8.21.3",
|
"@tanstack/react-table": "^8.21.3",
|
||||||
"axios": "^1.12.2",
|
"axios": "^1.13.2",
|
||||||
"class-variance-authority": "^0.7.1",
|
"class-variance-authority": "^0.7.1",
|
||||||
"clsx": "^2.1.1",
|
"clsx": "^2.1.1",
|
||||||
"cmdk": "^1.1.1",
|
"cmdk": "^1.1.1",
|
||||||
|
|
@ -53,21 +53,21 @@
|
||||||
"graphology-layout-force": "^0.2.4",
|
"graphology-layout-force": "^0.2.4",
|
||||||
"graphology-layout-forceatlas2": "^0.10.1",
|
"graphology-layout-forceatlas2": "^0.10.1",
|
||||||
"graphology-layout-noverlap": "^0.4.2",
|
"graphology-layout-noverlap": "^0.4.2",
|
||||||
"i18next": "^24.2.3",
|
"i18next": "^25.6.3",
|
||||||
"katex": "^0.16.23",
|
"katex": "^0.16.25",
|
||||||
"lucide-react": "^0.475.0",
|
"mermaid": "^11.12.1",
|
||||||
"mermaid": "^11.12.0",
|
"lucide-react": "^0.554.0",
|
||||||
"minisearch": "^7.2.0",
|
"minisearch": "^7.2.0",
|
||||||
"react": "^19.2.0",
|
"react": "^19.2.0",
|
||||||
"react-dom": "^19.2.0",
|
"react-dom": "^19.2.0",
|
||||||
"react-dropzone": "^14.3.8",
|
"react-dropzone": "^14.3.8",
|
||||||
"react-error-boundary": "^5.0.0",
|
"react-error-boundary": "^6.0.0",
|
||||||
"react-i18next": "^15.7.4",
|
"react-i18next": "^16.3.5",
|
||||||
"react-markdown": "^9.1.0",
|
"react-markdown": "^10.1.0",
|
||||||
"react-number-format": "^5.4.4",
|
"react-number-format": "^5.4.4",
|
||||||
"react-router-dom": "^7.9.4",
|
"react-router-dom": "^7.9.6",
|
||||||
"react-select": "^5.10.2",
|
"react-select": "^5.10.2",
|
||||||
"react-syntax-highlighter": "^15.6.6",
|
"react-syntax-highlighter": "^16.1.0",
|
||||||
"rehype-katex": "^7.0.1",
|
"rehype-katex": "^7.0.1",
|
||||||
"rehype-raw": "^7.0.0",
|
"rehype-raw": "^7.0.0",
|
||||||
"rehype-react": "^8.0.0",
|
"rehype-react": "^8.0.0",
|
||||||
|
|
@ -75,8 +75,8 @@
|
||||||
"remark-math": "^6.0.0",
|
"remark-math": "^6.0.0",
|
||||||
"seedrandom": "^3.0.5",
|
"seedrandom": "^3.0.5",
|
||||||
"sigma": "^3.0.2",
|
"sigma": "^3.0.2",
|
||||||
"sonner": "^1.7.4",
|
"sonner": "^2.0.7",
|
||||||
"tailwind-merge": "^3.3.1",
|
"tailwind-merge": "^3.4.0",
|
||||||
"tailwind-scrollbar": "^4.0.2",
|
"tailwind-scrollbar": "^4.0.2",
|
||||||
"typography": "^0.16.24",
|
"typography": "^0.16.24",
|
||||||
"unist-util-visit": "^5.0.0",
|
"unist-util-visit": "^5.0.0",
|
||||||
|
|
@ -84,32 +84,32 @@
|
||||||
},
|
},
|
||||||
"devDependencies": {
|
"devDependencies": {
|
||||||
"@biomejs/biome": "^1.9.3",
|
"@biomejs/biome": "^1.9.3",
|
||||||
"@eslint/js": "^9.37.0",
|
"@eslint/js": "^9.39.1",
|
||||||
"@stylistic/eslint-plugin-js": "^3.1.0",
|
"@stylistic/eslint-plugin-js": "^4.4.1",
|
||||||
"@tailwindcss/typography": "^0.5.15",
|
"@tailwindcss/typography": "^0.5.15",
|
||||||
"@tailwindcss/vite": "^4.1.14",
|
"@tailwindcss/vite": "^4.1.17",
|
||||||
"@types/bun": "^1.2.23",
|
"@types/bun": "^1.3.3",
|
||||||
"@types/katex": "^0.16.7",
|
"@types/katex": "^0.16.7",
|
||||||
"@types/node": "^22.18.9",
|
"@types/node": "^24.10.1",
|
||||||
"@types/react": "^19.2.2",
|
"@types/react": "^19.2.7",
|
||||||
"@types/react-dom": "^19.2.1",
|
"@types/react-dom": "^19.2.3",
|
||||||
"@types/react-i18next": "^8.1.0",
|
"@types/react-i18next": "^8.1.0",
|
||||||
"@types/react-syntax-highlighter": "^15.5.13",
|
"@types/react-syntax-highlighter": "^15.5.13",
|
||||||
"@types/seedrandom": "^3.0.8",
|
"@types/seedrandom": "^3.0.8",
|
||||||
"@vitejs/plugin-react-swc": "^3.11.0",
|
"@vitejs/plugin-react-swc": "^4.2.2",
|
||||||
"eslint": "^9.37.0",
|
"eslint": "^9.39.1",
|
||||||
"eslint-config-prettier": "^10.1.8",
|
"eslint-config-prettier": "^10.1.8",
|
||||||
"eslint-plugin-react": "^7.37.5",
|
"eslint-plugin-react": "^7.37.5",
|
||||||
"eslint-plugin-react-hooks": "^5.2.0",
|
"eslint-plugin-react-hooks": "^7.0.1",
|
||||||
"eslint-plugin-react-refresh": "^0.4.23",
|
"eslint-plugin-react-refresh": "^0.4.24",
|
||||||
"globals": "^15.15.0",
|
"globals": "^16.5.0",
|
||||||
"graphology-types": "^0.24.8",
|
"graphology-types": "^0.24.8",
|
||||||
"prettier": "^3.6.2",
|
"prettier": "^3.6.2",
|
||||||
"prettier-plugin-tailwindcss": "^0.6.14",
|
"prettier-plugin-tailwindcss": "^0.7.1",
|
||||||
"tailwindcss": "^4.1.14",
|
"typescript-eslint": "^8.48.0",
|
||||||
|
"tailwindcss": "^4.1.17",
|
||||||
"tailwindcss-animate": "^1.0.7",
|
"tailwindcss-animate": "^1.0.7",
|
||||||
"typescript": "~5.7.3",
|
"typescript": "~5.9.3",
|
||||||
"typescript-eslint": "^8.46.0",
|
"vite": "^7.2.4"
|
||||||
"vite": "^6.3.6"
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -153,7 +153,6 @@ export const ChatMessage = ({
|
||||||
setKatexPlugin(() => rehypeKatex)
|
setKatexPlugin(() => rehypeKatex)
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('Failed to load KaTeX plugin:', error)
|
console.error('Failed to load KaTeX plugin:', error)
|
||||||
// Set to null to ensure we don't try to use a failed plugin
|
|
||||||
setKatexPlugin(null)
|
setKatexPlugin(null)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
116
tests/test_overlap_validation.py
Normal file
116
tests/test_overlap_validation.py
Normal file
|
|
@ -0,0 +1,116 @@
|
||||||
|
"""
|
||||||
|
Test for overlap_tokens validation to prevent infinite loop.
|
||||||
|
|
||||||
|
This test validates the fix for the bug where overlap_tokens >= max_tokens
|
||||||
|
causes an infinite loop in the chunking function.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from lightrag.rerank import chunk_documents_for_rerank
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.offline
|
||||||
|
class TestOverlapValidation:
|
||||||
|
"""Test suite for overlap_tokens validation"""
|
||||||
|
|
||||||
|
def test_overlap_greater_than_max_tokens(self):
|
||||||
|
"""Test that overlap_tokens > max_tokens is clamped and doesn't hang"""
|
||||||
|
documents = [" ".join([f"word{i}" for i in range(100)])]
|
||||||
|
|
||||||
|
# This should clamp overlap_tokens to 29 (max_tokens - 1)
|
||||||
|
chunked_docs, doc_indices = chunk_documents_for_rerank(
|
||||||
|
documents, max_tokens=30, overlap_tokens=32
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should complete without hanging
|
||||||
|
assert len(chunked_docs) > 0
|
||||||
|
assert all(idx == 0 for idx in doc_indices)
|
||||||
|
|
||||||
|
def test_overlap_equal_to_max_tokens(self):
|
||||||
|
"""Test that overlap_tokens == max_tokens is clamped and doesn't hang"""
|
||||||
|
documents = [" ".join([f"word{i}" for i in range(100)])]
|
||||||
|
|
||||||
|
# This should clamp overlap_tokens to 29 (max_tokens - 1)
|
||||||
|
chunked_docs, doc_indices = chunk_documents_for_rerank(
|
||||||
|
documents, max_tokens=30, overlap_tokens=30
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should complete without hanging
|
||||||
|
assert len(chunked_docs) > 0
|
||||||
|
assert all(idx == 0 for idx in doc_indices)
|
||||||
|
|
||||||
|
def test_overlap_slightly_less_than_max_tokens(self):
|
||||||
|
"""Test that overlap_tokens < max_tokens works normally"""
|
||||||
|
documents = [" ".join([f"word{i}" for i in range(100)])]
|
||||||
|
|
||||||
|
# This should work without clamping
|
||||||
|
chunked_docs, doc_indices = chunk_documents_for_rerank(
|
||||||
|
documents, max_tokens=30, overlap_tokens=29
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should complete successfully
|
||||||
|
assert len(chunked_docs) > 0
|
||||||
|
assert all(idx == 0 for idx in doc_indices)
|
||||||
|
|
||||||
|
def test_small_max_tokens_with_large_overlap(self):
|
||||||
|
"""Test edge case with very small max_tokens"""
|
||||||
|
documents = [" ".join([f"word{i}" for i in range(50)])]
|
||||||
|
|
||||||
|
# max_tokens=5, overlap_tokens=10 should clamp to 4
|
||||||
|
chunked_docs, doc_indices = chunk_documents_for_rerank(
|
||||||
|
documents, max_tokens=5, overlap_tokens=10
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should complete without hanging
|
||||||
|
assert len(chunked_docs) > 0
|
||||||
|
assert all(idx == 0 for idx in doc_indices)
|
||||||
|
|
||||||
|
def test_multiple_documents_with_invalid_overlap(self):
|
||||||
|
"""Test multiple documents with overlap_tokens >= max_tokens"""
|
||||||
|
documents = [
|
||||||
|
" ".join([f"word{i}" for i in range(50)]),
|
||||||
|
"short document",
|
||||||
|
" ".join([f"word{i}" for i in range(75)]),
|
||||||
|
]
|
||||||
|
|
||||||
|
# overlap_tokens > max_tokens
|
||||||
|
chunked_docs, _ = chunk_documents_for_rerank(
|
||||||
|
documents, max_tokens=25, overlap_tokens=30
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should complete successfully and chunk the long documents
|
||||||
|
assert len(chunked_docs) >= len(documents)
|
||||||
|
# Short document should not be chunked
|
||||||
|
assert "short document" in chunked_docs
|
||||||
|
|
||||||
|
def test_normal_operation_unaffected(self):
|
||||||
|
"""Test that normal cases continue to work correctly"""
|
||||||
|
documents = [
|
||||||
|
" ".join([f"word{i}" for i in range(100)]),
|
||||||
|
"short doc",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Normal case: overlap_tokens (10) < max_tokens (50)
|
||||||
|
chunked_docs, doc_indices = chunk_documents_for_rerank(
|
||||||
|
documents, max_tokens=50, overlap_tokens=10
|
||||||
|
)
|
||||||
|
|
||||||
|
# Long document should be chunked, short one should not
|
||||||
|
assert len(chunked_docs) > 2 # At least 3 chunks (2 from long doc + 1 short)
|
||||||
|
assert "short doc" in chunked_docs
|
||||||
|
# Verify doc_indices maps correctly
|
||||||
|
assert doc_indices[-1] == 1 # Last chunk is from second document
|
||||||
|
|
||||||
|
def test_edge_case_max_tokens_one(self):
|
||||||
|
"""Test edge case where max_tokens=1"""
|
||||||
|
documents = [" ".join([f"word{i}" for i in range(20)])]
|
||||||
|
|
||||||
|
# max_tokens=1, overlap_tokens=5 should clamp to 0
|
||||||
|
chunked_docs, doc_indices = chunk_documents_for_rerank(
|
||||||
|
documents, max_tokens=1, overlap_tokens=5
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should complete without hanging
|
||||||
|
assert len(chunked_docs) > 0
|
||||||
|
assert all(idx == 0 for idx in doc_indices)
|
||||||
564
tests/test_rerank_chunking.py
Normal file
564
tests/test_rerank_chunking.py
Normal file
|
|
@ -0,0 +1,564 @@
|
||||||
|
"""
|
||||||
|
Unit tests for rerank document chunking functionality.
|
||||||
|
|
||||||
|
Tests the chunk_documents_for_rerank and aggregate_chunk_scores functions
|
||||||
|
in lightrag/rerank.py to ensure proper document splitting and score aggregation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import Mock, patch, AsyncMock
|
||||||
|
from lightrag.rerank import (
|
||||||
|
chunk_documents_for_rerank,
|
||||||
|
aggregate_chunk_scores,
|
||||||
|
cohere_rerank,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestChunkDocumentsForRerank:
|
||||||
|
"""Test suite for chunk_documents_for_rerank function"""
|
||||||
|
|
||||||
|
def test_no_chunking_needed_for_short_docs(self):
|
||||||
|
"""Documents shorter than max_tokens should not be chunked"""
|
||||||
|
documents = [
|
||||||
|
"Short doc 1",
|
||||||
|
"Short doc 2",
|
||||||
|
"Short doc 3",
|
||||||
|
]
|
||||||
|
|
||||||
|
chunked_docs, doc_indices = chunk_documents_for_rerank(
|
||||||
|
documents, max_tokens=100, overlap_tokens=10
|
||||||
|
)
|
||||||
|
|
||||||
|
# No chunking should occur
|
||||||
|
assert len(chunked_docs) == 3
|
||||||
|
assert chunked_docs == documents
|
||||||
|
assert doc_indices == [0, 1, 2]
|
||||||
|
|
||||||
|
def test_chunking_with_character_fallback(self):
|
||||||
|
"""Test chunking falls back to character-based when tokenizer unavailable"""
|
||||||
|
# Create a very long document that exceeds character limit
|
||||||
|
long_doc = "a" * 2000 # 2000 characters
|
||||||
|
documents = [long_doc, "short doc"]
|
||||||
|
|
||||||
|
with patch("lightrag.utils.TiktokenTokenizer", side_effect=ImportError):
|
||||||
|
chunked_docs, doc_indices = chunk_documents_for_rerank(
|
||||||
|
documents,
|
||||||
|
max_tokens=100, # 100 tokens = ~400 chars
|
||||||
|
overlap_tokens=10, # 10 tokens = ~40 chars
|
||||||
|
)
|
||||||
|
|
||||||
|
# First doc should be split into chunks, second doc stays whole
|
||||||
|
assert len(chunked_docs) > 2 # At least one chunk from first doc + second doc
|
||||||
|
assert chunked_docs[-1] == "short doc" # Last chunk is the short doc
|
||||||
|
# Verify doc_indices maps chunks to correct original document
|
||||||
|
assert doc_indices[-1] == 1 # Last chunk maps to document 1
|
||||||
|
|
||||||
|
def test_chunking_with_tiktoken_tokenizer(self):
|
||||||
|
"""Test chunking with actual tokenizer"""
|
||||||
|
# Create document with known token count
|
||||||
|
# Approximate: "word " = ~1 token, so 200 words ~ 200 tokens
|
||||||
|
long_doc = " ".join([f"word{i}" for i in range(200)])
|
||||||
|
documents = [long_doc, "short"]
|
||||||
|
|
||||||
|
chunked_docs, doc_indices = chunk_documents_for_rerank(
|
||||||
|
documents, max_tokens=50, overlap_tokens=10
|
||||||
|
)
|
||||||
|
|
||||||
|
# Long doc should be split, short doc should remain
|
||||||
|
assert len(chunked_docs) > 2
|
||||||
|
assert doc_indices[-1] == 1 # Last chunk is from second document
|
||||||
|
|
||||||
|
# Verify overlapping chunks contain overlapping content
|
||||||
|
if len(chunked_docs) > 2:
|
||||||
|
# Check that consecutive chunks from same doc have some overlap
|
||||||
|
for i in range(len(doc_indices) - 1):
|
||||||
|
if doc_indices[i] == doc_indices[i + 1] == 0:
|
||||||
|
# Both chunks from first doc, should have overlap
|
||||||
|
chunk1_words = chunked_docs[i].split()
|
||||||
|
chunk2_words = chunked_docs[i + 1].split()
|
||||||
|
# At least one word should be common due to overlap
|
||||||
|
assert any(word in chunk2_words for word in chunk1_words[-5:])
|
||||||
|
|
||||||
|
def test_empty_documents(self):
|
||||||
|
"""Test handling of empty document list"""
|
||||||
|
documents = []
|
||||||
|
chunked_docs, doc_indices = chunk_documents_for_rerank(documents)
|
||||||
|
|
||||||
|
assert chunked_docs == []
|
||||||
|
assert doc_indices == []
|
||||||
|
|
||||||
|
def test_single_document_chunking(self):
|
||||||
|
"""Test chunking of a single long document"""
|
||||||
|
# Create document with ~100 tokens
|
||||||
|
long_doc = " ".join([f"token{i}" for i in range(100)])
|
||||||
|
documents = [long_doc]
|
||||||
|
|
||||||
|
chunked_docs, doc_indices = chunk_documents_for_rerank(
|
||||||
|
documents, max_tokens=30, overlap_tokens=5
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should create multiple chunks
|
||||||
|
assert len(chunked_docs) > 1
|
||||||
|
# All chunks should map to document 0
|
||||||
|
assert all(idx == 0 for idx in doc_indices)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAggregateChunkScores:
|
||||||
|
"""Test suite for aggregate_chunk_scores function"""
|
||||||
|
|
||||||
|
def test_no_chunking_simple_aggregation(self):
|
||||||
|
"""Test aggregation when no chunking occurred (1:1 mapping)"""
|
||||||
|
chunk_results = [
|
||||||
|
{"index": 0, "relevance_score": 0.9},
|
||||||
|
{"index": 1, "relevance_score": 0.7},
|
||||||
|
{"index": 2, "relevance_score": 0.5},
|
||||||
|
]
|
||||||
|
doc_indices = [0, 1, 2] # 1:1 mapping
|
||||||
|
num_original_docs = 3
|
||||||
|
|
||||||
|
aggregated = aggregate_chunk_scores(
|
||||||
|
chunk_results, doc_indices, num_original_docs, aggregation="max"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Results should be sorted by score
|
||||||
|
assert len(aggregated) == 3
|
||||||
|
assert aggregated[0]["index"] == 0
|
||||||
|
assert aggregated[0]["relevance_score"] == 0.9
|
||||||
|
assert aggregated[1]["index"] == 1
|
||||||
|
assert aggregated[1]["relevance_score"] == 0.7
|
||||||
|
assert aggregated[2]["index"] == 2
|
||||||
|
assert aggregated[2]["relevance_score"] == 0.5
|
||||||
|
|
||||||
|
def test_max_aggregation_with_chunks(self):
|
||||||
|
"""Test max aggregation strategy with multiple chunks per document"""
|
||||||
|
# 5 chunks: first 3 from doc 0, last 2 from doc 1
|
||||||
|
chunk_results = [
|
||||||
|
{"index": 0, "relevance_score": 0.5},
|
||||||
|
{"index": 1, "relevance_score": 0.8},
|
||||||
|
{"index": 2, "relevance_score": 0.6},
|
||||||
|
{"index": 3, "relevance_score": 0.7},
|
||||||
|
{"index": 4, "relevance_score": 0.4},
|
||||||
|
]
|
||||||
|
doc_indices = [0, 0, 0, 1, 1]
|
||||||
|
num_original_docs = 2
|
||||||
|
|
||||||
|
aggregated = aggregate_chunk_scores(
|
||||||
|
chunk_results, doc_indices, num_original_docs, aggregation="max"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should take max score for each document
|
||||||
|
assert len(aggregated) == 2
|
||||||
|
assert aggregated[0]["index"] == 0
|
||||||
|
assert aggregated[0]["relevance_score"] == 0.8 # max of 0.5, 0.8, 0.6
|
||||||
|
assert aggregated[1]["index"] == 1
|
||||||
|
assert aggregated[1]["relevance_score"] == 0.7 # max of 0.7, 0.4
|
||||||
|
|
||||||
|
def test_mean_aggregation_with_chunks(self):
|
||||||
|
"""Test mean aggregation strategy"""
|
||||||
|
chunk_results = [
|
||||||
|
{"index": 0, "relevance_score": 0.6},
|
||||||
|
{"index": 1, "relevance_score": 0.8},
|
||||||
|
{"index": 2, "relevance_score": 0.4},
|
||||||
|
]
|
||||||
|
doc_indices = [0, 0, 1] # First two chunks from doc 0, last from doc 1
|
||||||
|
num_original_docs = 2
|
||||||
|
|
||||||
|
aggregated = aggregate_chunk_scores(
|
||||||
|
chunk_results, doc_indices, num_original_docs, aggregation="mean"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(aggregated) == 2
|
||||||
|
assert aggregated[0]["index"] == 0
|
||||||
|
assert aggregated[0]["relevance_score"] == pytest.approx(0.7) # (0.6 + 0.8) / 2
|
||||||
|
assert aggregated[1]["index"] == 1
|
||||||
|
assert aggregated[1]["relevance_score"] == 0.4
|
||||||
|
|
||||||
|
def test_first_aggregation_with_chunks(self):
|
||||||
|
"""Test first aggregation strategy"""
|
||||||
|
chunk_results = [
|
||||||
|
{"index": 0, "relevance_score": 0.6},
|
||||||
|
{"index": 1, "relevance_score": 0.8},
|
||||||
|
{"index": 2, "relevance_score": 0.4},
|
||||||
|
]
|
||||||
|
doc_indices = [0, 0, 1]
|
||||||
|
num_original_docs = 2
|
||||||
|
|
||||||
|
aggregated = aggregate_chunk_scores(
|
||||||
|
chunk_results, doc_indices, num_original_docs, aggregation="first"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(aggregated) == 2
|
||||||
|
# First should use first score seen for each doc
|
||||||
|
assert aggregated[0]["index"] == 0
|
||||||
|
assert aggregated[0]["relevance_score"] == 0.6 # First score for doc 0
|
||||||
|
assert aggregated[1]["index"] == 1
|
||||||
|
assert aggregated[1]["relevance_score"] == 0.4
|
||||||
|
|
||||||
|
def test_empty_chunk_results(self):
|
||||||
|
"""Test handling of empty results"""
|
||||||
|
aggregated = aggregate_chunk_scores([], [], 3, aggregation="max")
|
||||||
|
assert aggregated == []
|
||||||
|
|
||||||
|
def test_documents_with_no_scores(self):
|
||||||
|
"""Test when some documents have no chunks/scores"""
|
||||||
|
chunk_results = [
|
||||||
|
{"index": 0, "relevance_score": 0.9},
|
||||||
|
{"index": 1, "relevance_score": 0.7},
|
||||||
|
]
|
||||||
|
doc_indices = [0, 0] # Both chunks from document 0
|
||||||
|
num_original_docs = 3 # But we have 3 documents total
|
||||||
|
|
||||||
|
aggregated = aggregate_chunk_scores(
|
||||||
|
chunk_results, doc_indices, num_original_docs, aggregation="max"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Only doc 0 should appear in results
|
||||||
|
assert len(aggregated) == 1
|
||||||
|
assert aggregated[0]["index"] == 0
|
||||||
|
|
||||||
|
def test_unknown_aggregation_strategy(self):
|
||||||
|
"""Test that unknown strategy falls back to max"""
|
||||||
|
chunk_results = [
|
||||||
|
{"index": 0, "relevance_score": 0.6},
|
||||||
|
{"index": 1, "relevance_score": 0.8},
|
||||||
|
]
|
||||||
|
doc_indices = [0, 0]
|
||||||
|
num_original_docs = 1
|
||||||
|
|
||||||
|
# Use invalid strategy
|
||||||
|
aggregated = aggregate_chunk_scores(
|
||||||
|
chunk_results, doc_indices, num_original_docs, aggregation="invalid"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should fall back to max
|
||||||
|
assert aggregated[0]["relevance_score"] == 0.8
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.offline
|
||||||
|
class TestTopNWithChunking:
|
||||||
|
"""Tests for top_n behavior when chunking is enabled (Bug fix verification)"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_top_n_limits_documents_not_chunks(self):
|
||||||
|
"""
|
||||||
|
Test that top_n correctly limits documents (not chunks) when chunking is enabled.
|
||||||
|
|
||||||
|
Bug scenario: 10 docs expand to 50 chunks. With old behavior, top_n=5 would
|
||||||
|
return scores for only 5 chunks (possibly all from 1-2 docs). After aggregation,
|
||||||
|
fewer than 5 documents would be returned.
|
||||||
|
|
||||||
|
Fixed behavior: top_n=5 should return exactly 5 documents after aggregation.
|
||||||
|
"""
|
||||||
|
# Setup: 5 documents, each producing multiple chunks when chunked
|
||||||
|
# Using small max_tokens to force chunking
|
||||||
|
long_docs = [" ".join([f"doc{i}_word{j}" for j in range(50)]) for i in range(5)]
|
||||||
|
query = "test query"
|
||||||
|
|
||||||
|
# First, determine how many chunks will be created by actual chunking
|
||||||
|
_, doc_indices = chunk_documents_for_rerank(
|
||||||
|
long_docs, max_tokens=50, overlap_tokens=10
|
||||||
|
)
|
||||||
|
num_chunks = len(doc_indices)
|
||||||
|
|
||||||
|
# Mock API returns scores for ALL chunks (simulating disabled API-level top_n)
|
||||||
|
# Give different scores to ensure doc 0 gets highest, doc 1 second, etc.
|
||||||
|
# Assign scores based on original document index (lower doc index = higher score)
|
||||||
|
mock_chunk_scores = []
|
||||||
|
for i in range(num_chunks):
|
||||||
|
original_doc = doc_indices[i]
|
||||||
|
# Higher score for lower doc index, with small variation per chunk
|
||||||
|
base_score = 0.9 - (original_doc * 0.1)
|
||||||
|
mock_chunk_scores.append({"index": i, "relevance_score": base_score})
|
||||||
|
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.status = 200
|
||||||
|
mock_response.json = AsyncMock(return_value={"results": mock_chunk_scores})
|
||||||
|
mock_response.request_info = None
|
||||||
|
mock_response.history = None
|
||||||
|
mock_response.headers = {}
|
||||||
|
mock_response.__aenter__ = AsyncMock(return_value=mock_response)
|
||||||
|
mock_response.__aexit__ = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
mock_session = Mock()
|
||||||
|
mock_session.post = Mock(return_value=mock_response)
|
||||||
|
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
|
||||||
|
mock_session.__aexit__ = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
with patch("lightrag.rerank.aiohttp.ClientSession", return_value=mock_session):
|
||||||
|
result = await cohere_rerank(
|
||||||
|
query=query,
|
||||||
|
documents=long_docs,
|
||||||
|
api_key="test-key",
|
||||||
|
base_url="http://test.com/rerank",
|
||||||
|
enable_chunking=True,
|
||||||
|
max_tokens_per_doc=50, # Match chunking above
|
||||||
|
top_n=3, # Request top 3 documents
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify: should get exactly 3 documents (not unlimited chunks)
|
||||||
|
assert len(result) == 3
|
||||||
|
# All results should have valid document indices (0-4)
|
||||||
|
assert all(0 <= r["index"] < 5 for r in result)
|
||||||
|
# Results should be sorted by score (descending)
|
||||||
|
assert all(
|
||||||
|
result[i]["relevance_score"] >= result[i + 1]["relevance_score"]
|
||||||
|
for i in range(len(result) - 1)
|
||||||
|
)
|
||||||
|
# The top 3 docs should be 0, 1, 2 (highest scores)
|
||||||
|
result_indices = [r["index"] for r in result]
|
||||||
|
assert set(result_indices) == {0, 1, 2}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_api_receives_no_top_n_when_chunking_enabled(self):
|
||||||
|
"""
|
||||||
|
Test that the API request does NOT include top_n when chunking is enabled.
|
||||||
|
|
||||||
|
This ensures all chunk scores are retrieved for proper aggregation.
|
||||||
|
"""
|
||||||
|
documents = [" ".join([f"word{i}" for i in range(100)]), "short doc"]
|
||||||
|
query = "test query"
|
||||||
|
|
||||||
|
captured_payload = {}
|
||||||
|
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.status = 200
|
||||||
|
mock_response.json = AsyncMock(
|
||||||
|
return_value={
|
||||||
|
"results": [
|
||||||
|
{"index": 0, "relevance_score": 0.9},
|
||||||
|
{"index": 1, "relevance_score": 0.8},
|
||||||
|
{"index": 2, "relevance_score": 0.7},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
mock_response.request_info = None
|
||||||
|
mock_response.history = None
|
||||||
|
mock_response.headers = {}
|
||||||
|
mock_response.__aenter__ = AsyncMock(return_value=mock_response)
|
||||||
|
mock_response.__aexit__ = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
def capture_post(*args, **kwargs):
|
||||||
|
captured_payload.update(kwargs.get("json", {}))
|
||||||
|
return mock_response
|
||||||
|
|
||||||
|
mock_session = Mock()
|
||||||
|
mock_session.post = Mock(side_effect=capture_post)
|
||||||
|
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
|
||||||
|
mock_session.__aexit__ = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
with patch("lightrag.rerank.aiohttp.ClientSession", return_value=mock_session):
|
||||||
|
await cohere_rerank(
|
||||||
|
query=query,
|
||||||
|
documents=documents,
|
||||||
|
api_key="test-key",
|
||||||
|
base_url="http://test.com/rerank",
|
||||||
|
enable_chunking=True,
|
||||||
|
max_tokens_per_doc=30,
|
||||||
|
top_n=1, # User wants top 1 document
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify: API payload should NOT have top_n (disabled for chunking)
|
||||||
|
assert "top_n" not in captured_payload
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_top_n_not_modified_when_chunking_disabled(self):
|
||||||
|
"""
|
||||||
|
Test that top_n is passed through to API when chunking is disabled.
|
||||||
|
"""
|
||||||
|
documents = ["doc1", "doc2"]
|
||||||
|
query = "test query"
|
||||||
|
|
||||||
|
captured_payload = {}
|
||||||
|
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.status = 200
|
||||||
|
mock_response.json = AsyncMock(
|
||||||
|
return_value={
|
||||||
|
"results": [
|
||||||
|
{"index": 0, "relevance_score": 0.9},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
mock_response.request_info = None
|
||||||
|
mock_response.history = None
|
||||||
|
mock_response.headers = {}
|
||||||
|
mock_response.__aenter__ = AsyncMock(return_value=mock_response)
|
||||||
|
mock_response.__aexit__ = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
def capture_post(*args, **kwargs):
|
||||||
|
captured_payload.update(kwargs.get("json", {}))
|
||||||
|
return mock_response
|
||||||
|
|
||||||
|
mock_session = Mock()
|
||||||
|
mock_session.post = Mock(side_effect=capture_post)
|
||||||
|
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
|
||||||
|
mock_session.__aexit__ = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
with patch("lightrag.rerank.aiohttp.ClientSession", return_value=mock_session):
|
||||||
|
await cohere_rerank(
|
||||||
|
query=query,
|
||||||
|
documents=documents,
|
||||||
|
api_key="test-key",
|
||||||
|
base_url="http://test.com/rerank",
|
||||||
|
enable_chunking=False, # Chunking disabled
|
||||||
|
top_n=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify: API payload should have top_n when chunking is disabled
|
||||||
|
assert captured_payload.get("top_n") == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.offline
|
||||||
|
class TestCohereRerankChunking:
|
||||||
|
"""Integration tests for cohere_rerank with chunking enabled"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cohere_rerank_with_chunking_disabled(self):
|
||||||
|
"""Test that chunking can be disabled"""
|
||||||
|
documents = ["doc1", "doc2"]
|
||||||
|
query = "test query"
|
||||||
|
|
||||||
|
# Mock the generic_rerank_api
|
||||||
|
with patch(
|
||||||
|
"lightrag.rerank.generic_rerank_api", new_callable=AsyncMock
|
||||||
|
) as mock_api:
|
||||||
|
mock_api.return_value = [
|
||||||
|
{"index": 0, "relevance_score": 0.9},
|
||||||
|
{"index": 1, "relevance_score": 0.7},
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await cohere_rerank(
|
||||||
|
query=query,
|
||||||
|
documents=documents,
|
||||||
|
api_key="test-key",
|
||||||
|
enable_chunking=False,
|
||||||
|
max_tokens_per_doc=100,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify generic_rerank_api was called with correct parameters
|
||||||
|
mock_api.assert_called_once()
|
||||||
|
call_kwargs = mock_api.call_args[1]
|
||||||
|
assert call_kwargs["enable_chunking"] is False
|
||||||
|
assert call_kwargs["max_tokens_per_doc"] == 100
|
||||||
|
# Result should mirror mocked scores
|
||||||
|
assert len(result) == 2
|
||||||
|
assert result[0]["index"] == 0
|
||||||
|
assert result[0]["relevance_score"] == 0.9
|
||||||
|
assert result[1]["index"] == 1
|
||||||
|
assert result[1]["relevance_score"] == 0.7
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cohere_rerank_with_chunking_enabled(self):
|
||||||
|
"""Test that chunking parameters are passed through"""
|
||||||
|
documents = ["doc1", "doc2"]
|
||||||
|
query = "test query"
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"lightrag.rerank.generic_rerank_api", new_callable=AsyncMock
|
||||||
|
) as mock_api:
|
||||||
|
mock_api.return_value = [
|
||||||
|
{"index": 0, "relevance_score": 0.9},
|
||||||
|
{"index": 1, "relevance_score": 0.7},
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await cohere_rerank(
|
||||||
|
query=query,
|
||||||
|
documents=documents,
|
||||||
|
api_key="test-key",
|
||||||
|
enable_chunking=True,
|
||||||
|
max_tokens_per_doc=480,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify parameters were passed
|
||||||
|
call_kwargs = mock_api.call_args[1]
|
||||||
|
assert call_kwargs["enable_chunking"] is True
|
||||||
|
assert call_kwargs["max_tokens_per_doc"] == 480
|
||||||
|
# Result should mirror mocked scores
|
||||||
|
assert len(result) == 2
|
||||||
|
assert result[0]["index"] == 0
|
||||||
|
assert result[0]["relevance_score"] == 0.9
|
||||||
|
assert result[1]["index"] == 1
|
||||||
|
assert result[1]["relevance_score"] == 0.7
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cohere_rerank_default_parameters(self):
|
||||||
|
"""Test default parameter values for cohere_rerank"""
|
||||||
|
documents = ["doc1"]
|
||||||
|
query = "test"
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"lightrag.rerank.generic_rerank_api", new_callable=AsyncMock
|
||||||
|
) as mock_api:
|
||||||
|
mock_api.return_value = [{"index": 0, "relevance_score": 0.9}]
|
||||||
|
|
||||||
|
result = await cohere_rerank(
|
||||||
|
query=query, documents=documents, api_key="test-key"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify default values
|
||||||
|
call_kwargs = mock_api.call_args[1]
|
||||||
|
assert call_kwargs["enable_chunking"] is False
|
||||||
|
assert call_kwargs["max_tokens_per_doc"] == 4096
|
||||||
|
assert call_kwargs["model"] == "rerank-v3.5"
|
||||||
|
# Result should mirror mocked scores
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0]["index"] == 0
|
||||||
|
assert result[0]["relevance_score"] == 0.9
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.offline
|
||||||
|
class TestEndToEndChunking:
|
||||||
|
"""End-to-end tests for chunking workflow"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_end_to_end_chunking_workflow(self):
|
||||||
|
"""Test complete chunking workflow from documents to aggregated results"""
|
||||||
|
# Create documents where first one needs chunking
|
||||||
|
long_doc = " ".join([f"word{i}" for i in range(100)])
|
||||||
|
documents = [long_doc, "short doc"]
|
||||||
|
query = "test query"
|
||||||
|
|
||||||
|
# Mock the HTTP call inside generic_rerank_api
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.status = 200
|
||||||
|
mock_response.json = AsyncMock(
|
||||||
|
return_value={
|
||||||
|
"results": [
|
||||||
|
{"index": 0, "relevance_score": 0.5}, # chunk 0 from doc 0
|
||||||
|
{"index": 1, "relevance_score": 0.8}, # chunk 1 from doc 0
|
||||||
|
{"index": 2, "relevance_score": 0.6}, # chunk 2 from doc 0
|
||||||
|
{"index": 3, "relevance_score": 0.7}, # doc 1 (short)
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
mock_response.request_info = None
|
||||||
|
mock_response.history = None
|
||||||
|
mock_response.headers = {}
|
||||||
|
# Make mock_response an async context manager (for `async with session.post() as response`)
|
||||||
|
mock_response.__aenter__ = AsyncMock(return_value=mock_response)
|
||||||
|
mock_response.__aexit__ = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
mock_session = Mock()
|
||||||
|
# session.post() returns an async context manager, so return mock_response which is now one
|
||||||
|
mock_session.post = Mock(return_value=mock_response)
|
||||||
|
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
|
||||||
|
mock_session.__aexit__ = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
with patch("lightrag.rerank.aiohttp.ClientSession", return_value=mock_session):
|
||||||
|
result = await cohere_rerank(
|
||||||
|
query=query,
|
||||||
|
documents=documents,
|
||||||
|
api_key="test-key",
|
||||||
|
base_url="http://test.com/rerank",
|
||||||
|
enable_chunking=True,
|
||||||
|
max_tokens_per_doc=30, # Force chunking of long doc
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should get 2 results (one per original document)
|
||||||
|
# The long doc's chunks should be aggregated
|
||||||
|
assert len(result) <= len(documents)
|
||||||
|
# Results should be sorted by score
|
||||||
|
assert all(
|
||||||
|
result[i]["relevance_score"] >= result[i + 1]["relevance_score"]
|
||||||
|
for i in range(len(result) - 1)
|
||||||
|
)
|
||||||
Loading…
Add table
Reference in a new issue