Merge branch 'optimize-reranker'

This commit is contained in:
yangdx 2025-08-23 01:08:50 +08:00
commit 3d5eeedc36
10 changed files with 397 additions and 689 deletions

View file

@ -1,281 +0,0 @@
# Rerank Integration Guide
LightRAG supports reranking functionality to improve retrieval quality by re-ordering documents based on their relevance to the query. Reranking is now controlled per query via the `enable_rerank` parameter (default: True).
## Quick Start
### Environment Variables
Set these variables in your `.env` file or environment for rerank model configuration:
```bash
# Rerank model configuration (required when enable_rerank=True in queries)
RERANK_MODEL=BAAI/bge-reranker-v2-m3
RERANK_BINDING_HOST=https://api.your-provider.com/v1/rerank
RERANK_BINDING_API_KEY=your_api_key_here
```
### Programmatic Configuration
```python
from lightrag import LightRAG, QueryParam
from lightrag.rerank import custom_rerank, RerankModel
# Method 1: Using a custom rerank function with all settings included
async def my_rerank_func(query: str, documents: list, top_n: int = None, **kwargs):
return await custom_rerank(
query=query,
documents=documents,
model="BAAI/bge-reranker-v2-m3",
base_url="https://api.your-provider.com/v1/rerank",
api_key="your_api_key_here",
top_n=top_n or 10, # Handle top_n within the function
**kwargs
)
rag = LightRAG(
working_dir="./rag_storage",
llm_model_func=your_llm_func,
embedding_func=your_embedding_func,
rerank_model_func=my_rerank_func, # Configure rerank function
)
# Query with rerank enabled (default)
result = await rag.aquery(
"your query",
param=QueryParam(enable_rerank=True) # Control rerank per query
)
# Query with rerank disabled
result = await rag.aquery(
"your query",
param=QueryParam(enable_rerank=False)
)
# Method 2: Using RerankModel wrapper
rerank_model = RerankModel(
rerank_func=custom_rerank,
kwargs={
"model": "BAAI/bge-reranker-v2-m3",
"base_url": "https://api.your-provider.com/v1/rerank",
"api_key": "your_api_key_here",
}
)
rag = LightRAG(
working_dir="./rag_storage",
llm_model_func=your_llm_func,
embedding_func=your_embedding_func,
rerank_model_func=rerank_model.rerank,
)
# Control rerank per query
result = await rag.aquery(
"your query",
param=QueryParam(
enable_rerank=True, # Enable rerank for this query
chunk_top_k=5 # Number of chunks to keep after reranking
)
)
```
## Supported Providers
### 1. Custom/Generic API (Recommended)
For Jina/Cohere compatible APIs:
```python
from lightrag.rerank import custom_rerank
# Your custom API endpoint
result = await custom_rerank(
query="your query",
documents=documents,
model="BAAI/bge-reranker-v2-m3",
base_url="https://api.your-provider.com/v1/rerank",
api_key="your_api_key_here",
top_n=10
)
```
### 2. Jina AI
```python
from lightrag.rerank import jina_rerank
result = await jina_rerank(
query="your query",
documents=documents,
model="BAAI/bge-reranker-v2-m3",
api_key="your_jina_api_key",
top_n=10
)
```
### 3. Cohere
```python
from lightrag.rerank import cohere_rerank
result = await cohere_rerank(
query="your query",
documents=documents,
model="rerank-english-v2.0",
api_key="your_cohere_api_key",
top_n=10
)
```
## Integration Points
Reranking is automatically applied at these key retrieval stages:
1. **Naive Mode**: After vector similarity search in `_get_vector_context`
2. **Local Mode**: After entity retrieval in `_get_node_data`
3. **Global Mode**: After relationship retrieval in `_get_edge_data`
4. **Hybrid/Mix Modes**: Applied to all relevant components
## Configuration Parameters
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `enable_rerank` | bool | False | Enable/disable reranking |
| `rerank_model_func` | callable | None | Custom rerank function containing all configurations (model, API keys, top_n, etc.) |
## Example Usage
### Basic Usage
```python
import asyncio
from lightrag import LightRAG, QueryParam
from lightrag.llm.openai import gpt_4o_mini_complete, openai_embedding
from lightrag.kg.shared_storage import initialize_pipeline_status
from lightrag.rerank import jina_rerank
async def my_rerank_func(query: str, documents: list, top_n: int = None, **kwargs):
"""Custom rerank function with all settings included"""
return await jina_rerank(
query=query,
documents=documents,
model="BAAI/bge-reranker-v2-m3",
api_key="your_jina_api_key_here",
top_n=top_n or 10, # Default top_n if not provided
**kwargs
)
async def main():
# Initialize with rerank enabled
rag = LightRAG(
working_dir="./rag_storage",
llm_model_func=gpt_4o_mini_complete,
embedding_func=openai_embedding,
rerank_model_func=my_rerank_func,
)
await rag.initialize_storages()
await initialize_pipeline_status()
# Insert documents
await rag.ainsert([
"Document 1 content...",
"Document 2 content...",
])
# Query with rerank (automatically applied)
result = await rag.aquery(
"Your question here",
param=QueryParam(enable_rerank=True) # This top_n is passed to rerank function
)
print(result)
asyncio.run(main())
```
### Direct Rerank Usage
```python
from lightrag.rerank import custom_rerank
async def test_rerank():
documents = [
{"content": "Text about topic A"},
{"content": "Text about topic B"},
{"content": "Text about topic C"},
]
reranked = await custom_rerank(
query="Tell me about topic A",
documents=documents,
model="BAAI/bge-reranker-v2-m3",
base_url="https://api.your-provider.com/v1/rerank",
api_key="your_api_key_here",
top_n=2
)
for doc in reranked:
print(f"Score: {doc.get('rerank_score')}, Content: {doc.get('content')}")
```
## Best Practices
1. **Self-Contained Functions**: Include all necessary configurations (API keys, models, top_n handling) within your rerank function
2. **Performance**: Use reranking selectively for better performance vs. quality tradeoff
3. **API Limits**: Monitor API usage and implement rate limiting within your rerank function
4. **Fallback**: Always handle rerank failures gracefully (returns original results)
5. **Top-n Handling**: Handle top_n parameter appropriately within your rerank function
6. **Cost Management**: Consider rerank API costs in your budget planning
## Troubleshooting
### Common Issues
1. **API Key Missing**: Ensure API keys are properly configured within your rerank function
2. **Network Issues**: Check API endpoints and network connectivity
3. **Model Errors**: Verify the rerank model name is supported by your API
4. **Document Format**: Ensure documents have `content` or `text` fields
### Debug Mode
Enable debug logging to see rerank operations:
```python
import logging
logging.getLogger("lightrag.rerank").setLevel(logging.DEBUG)
```
### Error Handling
The rerank integration includes automatic fallback:
```python
# If rerank fails, original documents are returned
# No exceptions are raised to the user
# Errors are logged for debugging
```
## API Compatibility
The generic rerank API expects this response format:
```json
{
"results": [
{
"index": 0,
"relevance_score": 0.95
},
{
"index": 2,
"relevance_score": 0.87
}
]
}
```
This is compatible with:
- Jina AI Rerank API
- Cohere Rerank API
- Custom APIs following the same format

View file

@ -85,16 +85,36 @@ ENABLE_LLM_CACHE=true
### If reranking is enabled, the impact of chunk selection strategies will be diminished.
# KG_CHUNK_PICK_METHOD=VECTOR
#########################################################
### Reranking configuration
### Reranker Set ENABLE_RERANK to true in reranking model is configed
# ENABLE_RERANK=True
### Minimum rerank score for document chunk exclusion (set to 0.0 to keep all chunks, 0.6 or above if LLM is not strong enought)
### RERANK_BINDING type: cohere, jina, aliyun
### For rerank model deployed by vLLM use cohere binding
#########################################################
ENABLE_RERANK=False
RERANK_BINDING=cohere
### rerank score chunk filter(set to 0.0 to keep all chunks, 0.6 or above if LLM is not strong enought)
# MIN_RERANK_SCORE=0.0
### Rerank model configuration (required when ENABLE_RERANK=True)
# RERANK_MODEL=jina-reranker-v2-base-multilingual
### For local deployment
# RERANK_MODEL=BAAI/bge-reranker-v2-m3
# RERANK_BINDING_HOST=http://localhost:8000
# RERANK_BINDING_API_KEY=your_rerank_api_key_here
### Default value for Cohere AI
# RERANK_MODEL=rerank-v3.5
# RERANK_BINDING_HOST=https://ai.znipower.com:5017/rerank
# RERANK_BINDING_API_KEY=your_rerank_api_key_here
### Default value for Jina AI
# RERANK_MODELjina-reranker-v2-base-multilingual
# RERANK_BINDING_HOST=https://api.jina.ai/v1/rerank
# RERANK_BINDING_API_KEY=your_rerank_api_key_here
### Default value for Aliyun
# RERANK_MODEL=gte-rerank-v2
# RERANK_BINDING_HOST=https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank
# RERANK_BINDING_API_KEY=your_rerank_api_key_here
########################################
### Document processing configuration
########################################

View file

@ -35,6 +35,7 @@ from lightrag.constants import (
DEFAULT_EMBEDDING_BATCH_NUM,
DEFAULT_OLLAMA_MODEL_NAME,
DEFAULT_OLLAMA_MODEL_TAG,
DEFAULT_RERANK_BINDING,
)
# use the .env that is inside the current folder
@ -76,9 +77,7 @@ def parse_args() -> argparse.Namespace:
argparse.Namespace: Parsed arguments
"""
parser = argparse.ArgumentParser(
description="LightRAG FastAPI Server with separate working and input directories"
)
parser = argparse.ArgumentParser(description="LightRAG API Server")
# Server configuration
parser.add_argument(
@ -225,6 +224,19 @@ def parse_args() -> argparse.Namespace:
choices=["lollms", "ollama", "openai", "azure_openai", "aws_bedrock", "jina"],
help="Embedding binding type (default: from env or ollama)",
)
parser.add_argument(
"--rerank-binding",
type=str,
default=get_env_value("RERANK_BINDING", DEFAULT_RERANK_BINDING),
choices=["cohere", "jina", "aliyun"],
help=f"Rerank binding type (default: from env or {DEFAULT_RERANK_BINDING})",
)
parser.add_argument(
"--enable-rerank",
action="store_true",
default=get_env_value("ENABLE_RERANK", False, bool),
help="Enable rerank functionality (default: from env or disalbed)",
)
# Conditionally add binding options defined in binding_options module
# This will add command line arguments for all binding options (e.g., --ollama-embedding-num_ctx)
@ -337,9 +349,10 @@ def parse_args() -> argparse.Namespace:
args.jwt_algorithm = get_env_value("JWT_ALGORITHM", "HS256")
# Rerank model configuration
args.rerank_model = get_env_value("RERANK_MODEL", "BAAI/bge-reranker-v2-m3")
args.rerank_model = get_env_value("RERANK_MODEL", None)
args.rerank_binding_host = get_env_value("RERANK_BINDING_HOST", None)
args.rerank_binding_api_key = get_env_value("RERANK_BINDING_API_KEY", None)
# Note: rerank_binding is already set by argparse, no need to override from env
# Min rerank score configuration
args.min_rerank_score = get_env_value(

View file

@ -11,6 +11,7 @@ import signal
import sys
import uvicorn
import pipmaster as pm
import inspect
from fastapi.staticfiles import StaticFiles
from fastapi.responses import RedirectResponse
from pathlib import Path
@ -390,33 +391,60 @@ def create_app(args):
),
)
# Configure rerank function if model and API are configured
# Configure rerank function based on enable_rerank parameter
rerank_model_func = None
if args.rerank_binding_api_key and args.rerank_binding_host:
from lightrag.rerank import custom_rerank
if args.enable_rerank and args.rerank_binding:
from lightrag.rerank import cohere_rerank, jina_rerank, ali_rerank
# Map rerank binding to corresponding function
rerank_functions = {
"cohere": cohere_rerank,
"jina": jina_rerank,
"aliyun": ali_rerank,
}
# Select the appropriate rerank function based on binding
selected_rerank_func = rerank_functions.get(args.rerank_binding)
if not selected_rerank_func:
logger.error(f"Unsupported rerank binding: {args.rerank_binding}")
raise ValueError(f"Unsupported rerank binding: {args.rerank_binding}")
# Get default values from selected_rerank_func if args values are None
if args.rerank_model is None or args.rerank_binding_host is None:
sig = inspect.signature(selected_rerank_func)
# Set default model if args.rerank_model is None
if args.rerank_model is None and "model" in sig.parameters:
default_model = sig.parameters["model"].default
if default_model != inspect.Parameter.empty:
args.rerank_model = default_model
# Set default base_url if args.rerank_binding_host is None
if args.rerank_binding_host is None and "base_url" in sig.parameters:
default_base_url = sig.parameters["base_url"].default
if default_base_url != inspect.Parameter.empty:
args.rerank_binding_host = default_base_url
async def server_rerank_func(
query: str, documents: list, top_n: int = None, **kwargs
query: str, documents: list, top_n: int = None, extra_body: dict = None
):
"""Server rerank function with configuration from environment variables"""
return await custom_rerank(
return await selected_rerank_func(
query=query,
documents=documents,
top_n=top_n,
api_key=args.rerank_binding_api_key,
model=args.rerank_model,
base_url=args.rerank_binding_host,
api_key=args.rerank_binding_api_key,
top_n=top_n,
**kwargs,
extra_body=extra_body,
)
rerank_model_func = server_rerank_func
logger.info(
f"Rerank model configured: {args.rerank_model} (can be enabled per query)"
f"Reranking is enabled: {args.rerank_model or 'default model'} using {args.rerank_binding} provider"
)
else:
logger.info(
"Rerank model not configured. Set RERANK_BINDING_API_KEY and RERANK_BINDING_HOST to enable reranking."
)
logger.info("Reranking is disabled")
# Create ollama_server_infos from command line arguments
from lightrag.api.config import OllamaServerInfos
@ -622,13 +650,14 @@ def create_app(args):
"enable_llm_cache": args.enable_llm_cache,
"workspace": args.workspace,
"max_graph_nodes": args.max_graph_nodes,
# Rerank configuration (based on whether rerank model is configured)
"enable_rerank": rerank_model_func is not None,
"rerank_model": args.rerank_model
if rerank_model_func is not None
# Rerank configuration
"enable_rerank": args.enable_rerank,
"rerank_binding": args.rerank_binding
if args.enable_rerank
else None,
"rerank_model": args.rerank_model if args.enable_rerank else None,
"rerank_binding_host": args.rerank_binding_host
if rerank_model_func is not None
if args.enable_rerank
else None,
# Environment variable status (requested configuration)
"summary_language": args.summary_language,

View file

@ -22,7 +22,6 @@ from .constants import (
DEFAULT_MAX_RELATION_TOKENS,
DEFAULT_MAX_TOTAL_TOKENS,
DEFAULT_HISTORY_TURNS,
DEFAULT_ENABLE_RERANK,
DEFAULT_OLLAMA_MODEL_NAME,
DEFAULT_OLLAMA_MODEL_TAG,
DEFAULT_OLLAMA_MODEL_SIZE,
@ -158,9 +157,7 @@ class QueryParam:
If proivded, this will be use instead of the default vaulue from prompt template.
"""
enable_rerank: bool = (
os.getenv("ENABLE_RERANK", str(DEFAULT_ENABLE_RERANK).lower()).lower() == "true"
)
enable_rerank: bool = os.getenv("ENABLE_RERANK", "false").lower() == "true"
"""Enable reranking for retrieved text chunks. If True but no rerank model is configured, a warning will be issued.
Default is True to enable reranking when rerank model is available.
"""

View file

@ -32,8 +32,8 @@ DEFAULT_KG_CHUNK_PICK_METHOD = "VECTOR"
DEFAULT_HISTORY_TURNS = 0
# Rerank configuration defaults
DEFAULT_ENABLE_RERANK = True
DEFAULT_MIN_RERANK_SCORE = 0.0
DEFAULT_RERANK_BINDING = "cohere"
# File path configuration for vector and graph database(Should not be changed, used in Milvus Schema)
DEFAULT_MAX_FILE_PATH_LENGTH = 32768

View file

@ -525,14 +525,6 @@ class LightRAG:
)
)
# Init Rerank
if self.rerank_model_func:
logger.info("Rerank model initialized for improved retrieval quality")
else:
logger.warning(
"Rerank is enabled but no rerank_model_func provided. Reranking will be skipped."
)
self._storages_status = StoragesStatus.CREATED
async def initialize_storages(self):

View file

@ -1,101 +0,0 @@
from __future__ import annotations
from typing import Callable, Any
from pydantic import BaseModel, Field
class Model(BaseModel):
"""
This is a Pydantic model class named 'Model' that is used to define a custom language model.
Attributes:
gen_func (Callable[[Any], str]): A callable function that generates the response from the language model.
The function should take any argument and return a string.
kwargs (Dict[str, Any]): A dictionary that contains the arguments to pass to the callable function.
This could include parameters such as the model name, API key, etc.
Example usage:
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_1"]})
In this example, 'openai_complete_if_cache' is the callable function that generates the response from the OpenAI model.
The 'kwargs' dictionary contains the model name and API key to be passed to the function.
"""
gen_func: Callable[[Any], str] = Field(
...,
description="A function that generates the response from the llm. The response must be a string",
)
kwargs: dict[str, Any] = Field(
...,
description="The arguments to pass to the callable function. Eg. the api key, model name, etc",
)
class Config:
arbitrary_types_allowed = True
class MultiModel:
"""
Distributes the load across multiple language models. Useful for circumventing low rate limits with certain api providers especially if you are on the free tier.
Could also be used for spliting across diffrent models or providers.
Attributes:
models (List[Model]): A list of language models to be used.
Usage example:
```python
models = [
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_1"]}),
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_2"]}),
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_3"]}),
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_4"]}),
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_5"]}),
]
multi_model = MultiModel(models)
rag = LightRAG(
llm_model_func=multi_model.llm_model_func
/ ..other args
)
```
"""
def __init__(self, models: list[Model]):
self._models = models
self._current_model = 0
def _next_model(self):
self._current_model = (self._current_model + 1) % len(self._models)
return self._models[self._current_model]
async def llm_model_func(
self,
prompt: str,
system_prompt: str | None = None,
history_messages: list[dict[str, Any]] = [],
**kwargs: Any,
) -> str:
kwargs.pop("model", None) # stop from overwriting the custom model name
kwargs.pop("keyword_extraction", None)
kwargs.pop("mode", None)
next_model = self._next_model()
args = dict(
prompt=prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
**next_model.kwargs,
)
return await next_model.gen_func(**args)
if __name__ == "__main__":
import asyncio
async def main():
from lightrag.llm.openai import gpt_4o_mini_complete
result = await gpt_4o_mini_complete("How are you?")
print(result)
asyncio.run(main())

View file

@ -2,270 +2,199 @@ from __future__ import annotations
import os
import aiohttp
from typing import Callable, Any, List, Dict, Optional
from pydantic import BaseModel, Field
from typing import Any, List, Dict, Optional
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)
from .utils import logger
from dotenv import load_dotenv
class RerankModel(BaseModel):
"""
Wrapper for rerank functions that can be used with LightRAG.
Example usage:
```python
from lightrag.rerank import RerankModel, jina_rerank
# Create rerank model
rerank_model = RerankModel(
rerank_func=jina_rerank,
kwargs={
"model": "BAAI/bge-reranker-v2-m3",
"api_key": "your_api_key_here",
"base_url": "https://api.jina.ai/v1/rerank"
}
)
# Use in LightRAG
rag = LightRAG(
rerank_model_func=rerank_model.rerank,
# ... other configurations
)
# Query with rerank enabled (default)
result = await rag.aquery(
"your query",
param=QueryParam(enable_rerank=True)
)
```
Or define a custom function directly:
```python
async def my_rerank_func(query: str, documents: list, top_n: int = None, **kwargs):
return await jina_rerank(
query=query,
documents=documents,
model="BAAI/bge-reranker-v2-m3",
api_key="your_api_key_here",
top_n=top_n or 10,
**kwargs
)
rag = LightRAG(
rerank_model_func=my_rerank_func,
# ... other configurations
)
# Control rerank per query
result = await rag.aquery(
"your query",
param=QueryParam(enable_rerank=True) # Enable rerank for this query
)
```
"""
rerank_func: Callable[[Any], List[Dict]]
kwargs: Dict[str, Any] = Field(default_factory=dict)
async def rerank(
self,
query: str,
documents: List[Dict[str, Any]],
top_n: Optional[int] = None,
**extra_kwargs,
) -> List[Dict[str, Any]]:
"""Rerank documents using the configured model function."""
# Merge extra kwargs with model kwargs
kwargs = {**self.kwargs, **extra_kwargs}
return await self.rerank_func(
query=query, documents=documents, top_n=top_n, **kwargs
)
class MultiRerankModel(BaseModel):
"""Multiple rerank models for different modes/scenarios."""
# Primary rerank model (used if mode-specific models are not defined)
rerank_model: Optional[RerankModel] = None
# Mode-specific rerank models
entity_rerank_model: Optional[RerankModel] = None
relation_rerank_model: Optional[RerankModel] = None
chunk_rerank_model: Optional[RerankModel] = None
async def rerank(
self,
query: str,
documents: List[Dict[str, Any]],
mode: str = "default",
top_n: Optional[int] = None,
**kwargs,
) -> List[Dict[str, Any]]:
"""Rerank using the appropriate model based on mode."""
# Select model based on mode
if mode == "entity" and self.entity_rerank_model:
model = self.entity_rerank_model
elif mode == "relation" and self.relation_rerank_model:
model = self.relation_rerank_model
elif mode == "chunk" and self.chunk_rerank_model:
model = self.chunk_rerank_model
elif self.rerank_model:
model = self.rerank_model
else:
logger.warning(f"No rerank model available for mode: {mode}")
return documents
return await model.rerank(query, documents, top_n, **kwargs)
# use the .env that is inside the current folder
# allows to use different .env file for each lightrag instance
# the OS environment variables take precedence over the .env file
load_dotenv(dotenv_path=".env", override=False)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=60),
retry=(
retry_if_exception_type(aiohttp.ClientError)
| retry_if_exception_type(aiohttp.ClientResponseError)
),
)
async def generic_rerank_api(
query: str,
documents: List[Dict[str, Any]],
documents: List[str],
model: str,
base_url: str,
api_key: str,
api_key: Optional[str],
top_n: Optional[int] = None,
**kwargs,
return_documents: Optional[bool] = None,
extra_body: Optional[Dict[str, Any]] = None,
response_format: str = "standard", # "standard" (Jina/Cohere) or "aliyun"
request_format: str = "standard", # "standard" (Jina/Cohere) or "aliyun"
) -> List[Dict[str, Any]]:
"""
Generic rerank function that works with Jina/Cohere compatible APIs.
Generic rerank API call for Jina/Cohere/Aliyun models.
Args:
query: The search query
documents: List of documents to rerank
model: Model identifier
documents: List of strings to rerank
model: Model name to use
base_url: API endpoint URL
api_key: API authentication key
api_key: API key for authentication
top_n: Number of top results to return
**kwargs: Additional API-specific parameters
return_documents: Whether to return document text (Jina only)
extra_body: Additional body parameters
response_format: Response format type ("standard" for Jina/Cohere, "aliyun" for Aliyun)
Returns:
List of reranked documents with relevance scores
List of dictionary of ["index": int, "relevance_score": float]
"""
if not api_key:
logger.warning("No API key provided for rerank service")
return documents
if not base_url:
raise ValueError("Base URL is required")
if not documents:
return documents
headers = {"Content-Type": "application/json"}
if api_key is not None:
headers["Authorization"] = f"Bearer {api_key}"
# Prepare documents for reranking - handle both text and dict formats
prepared_docs = []
for doc in documents:
if isinstance(doc, dict):
# Use 'content' field if available, otherwise use 'text' or convert to string
text = doc.get("content") or doc.get("text") or str(doc)
else:
text = str(doc)
prepared_docs.append(text)
# Build request payload based on request format
if request_format == "aliyun":
# Aliyun format: nested input/parameters structure
payload = {
"model": model,
"input": {
"query": query,
"documents": documents,
},
"parameters": {},
}
# Prepare request
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
# Add optional parameters to parameters object
if top_n is not None:
payload["parameters"]["top_n"] = top_n
data = {"model": model, "query": query, "documents": prepared_docs, **kwargs}
if return_documents is not None:
payload["parameters"]["return_documents"] = return_documents
if top_n is not None:
data["top_n"] = min(top_n, len(prepared_docs))
# Add extra parameters to parameters object
if extra_body:
payload["parameters"].update(extra_body)
else:
# Standard format for Jina/Cohere
payload = {
"model": model,
"query": query,
"documents": documents,
}
try:
async with aiohttp.ClientSession() as session:
async with session.post(base_url, headers=headers, json=data) as response:
if response.status != 200:
error_text = await response.text()
logger.error(f"Rerank API error {response.status}: {error_text}")
return documents
# Add optional parameters
if top_n is not None:
payload["top_n"] = top_n
result = await response.json()
# Only Jina API supports return_documents parameter
if return_documents is not None:
payload["return_documents"] = return_documents
# Extract reranked results
if "results" in result:
# Standard format: results contain index and relevance_score
reranked_docs = []
for item in result["results"]:
if "index" in item:
doc_idx = item["index"]
if 0 <= doc_idx < len(documents):
reranked_doc = documents[doc_idx].copy()
if "relevance_score" in item:
reranked_doc["rerank_score"] = item[
"relevance_score"
]
reranked_docs.append(reranked_doc)
return reranked_docs
else:
logger.warning("Unexpected rerank API response format")
return documents
# Add extra parameters
if extra_body:
payload.update(extra_body)
except Exception as e:
logger.error(f"Error during reranking: {e}")
return documents
async def jina_rerank(
query: str,
documents: List[Dict[str, Any]],
model: str = "BAAI/bge-reranker-v2-m3",
top_n: Optional[int] = None,
base_url: str = "https://api.jina.ai/v1/rerank",
api_key: Optional[str] = None,
**kwargs,
) -> List[Dict[str, Any]]:
"""
Rerank documents using Jina AI API.
Args:
query: The search query
documents: List of documents to rerank
model: Jina rerank model name
top_n: Number of top results to return
base_url: Jina API endpoint
api_key: Jina API key
**kwargs: Additional parameters
Returns:
List of reranked documents with relevance scores
"""
if api_key is None:
api_key = os.getenv("JINA_API_KEY") or os.getenv("RERANK_API_KEY")
return await generic_rerank_api(
query=query,
documents=documents,
model=model,
base_url=base_url,
api_key=api_key,
top_n=top_n,
**kwargs,
logger.debug(
f"Rerank request: {len(documents)} documents, model: {model}, format: {response_format}"
)
async with aiohttp.ClientSession() as session:
async with session.post(base_url, headers=headers, json=payload) as response:
if response.status != 200:
error_text = await response.text()
content_type = response.headers.get("content-type", "").lower()
is_html_error = (
error_text.strip().startswith("<!DOCTYPE html>")
or "text/html" in content_type
)
if is_html_error:
if response.status == 502:
clean_error = "Bad Gateway (502) - Rerank service temporarily unavailable. Please try again in a few minutes."
elif response.status == 503:
clean_error = "Service Unavailable (503) - Rerank service is temporarily overloaded. Please try again later."
elif response.status == 504:
clean_error = "Gateway Timeout (504) - Rerank service request timed out. Please try again."
else:
clean_error = f"HTTP {response.status} - Rerank service error. Please try again later."
else:
clean_error = error_text
logger.error(f"Rerank API error {response.status}: {clean_error}")
raise aiohttp.ClientResponseError(
request_info=response.request_info,
history=response.history,
status=response.status,
message=f"Rerank API error: {clean_error}",
)
response_json = await response.json()
if response_format == "aliyun":
# Aliyun format: {"output": {"results": [...]}}
results = response_json.get("output", {}).get("results", [])
if not isinstance(results, list):
logger.warning(
f"Expected 'output.results' to be list, got {type(results)}: {results}"
)
results = []
elif response_format == "standard":
# Standard format: {"results": [...]}
results = response_json.get("results", [])
if not isinstance(results, list):
logger.warning(
f"Expected 'results' to be list, got {type(results)}: {results}"
)
results = []
else:
raise ValueError(f"Unsupported response format: {response_format}")
if not results:
logger.warning("Rerank API returned empty results")
return []
# Standardize return format
return [
{"index": result["index"], "relevance_score": result["relevance_score"]}
for result in results
]
async def cohere_rerank(
query: str,
documents: List[Dict[str, Any]],
model: str = "rerank-english-v2.0",
documents: List[str],
top_n: Optional[int] = None,
base_url: str = "https://api.cohere.ai/v1/rerank",
api_key: Optional[str] = None,
**kwargs,
model: str = "rerank-v3.5",
base_url: str = "https://api.cohere.com/v2/rerank",
extra_body: Optional[Dict[str, Any]] = None,
) -> List[Dict[str, Any]]:
"""
Rerank documents using Cohere API.
Args:
query: The search query
documents: List of documents to rerank
model: Cohere rerank model name
documents: List of strings to rerank
top_n: Number of top results to return
base_url: Cohere API endpoint
api_key: Cohere API key
**kwargs: Additional parameters
api_key: API key
model: rerank model name
base_url: API endpoint
extra_body: Additional body for http request(reserved for extra params)
Returns:
List of reranked documents with relevance scores
List of dictionary of ["index": int, "relevance_score": float]
"""
if api_key is None:
api_key = os.getenv("COHERE_API_KEY") or os.getenv("RERANK_API_KEY")
api_key = os.getenv("COHERE_API_KEY") or os.getenv("RERANK_BINDING_API_KEY")
return await generic_rerank_api(
query=query,
@ -274,24 +203,39 @@ async def cohere_rerank(
base_url=base_url,
api_key=api_key,
top_n=top_n,
**kwargs,
return_documents=None, # Cohere doesn't support this parameter
extra_body=extra_body,
response_format="standard",
)
# Convenience function for custom API endpoints
async def custom_rerank(
async def jina_rerank(
query: str,
documents: List[Dict[str, Any]],
model: str,
base_url: str,
api_key: str,
documents: List[str],
top_n: Optional[int] = None,
**kwargs,
api_key: Optional[str] = None,
model: str = "jina-reranker-v2-base-multilingual",
base_url: str = "https://api.jina.ai/v1/rerank",
extra_body: Optional[Dict[str, Any]] = None,
) -> List[Dict[str, Any]]:
"""
Rerank documents using a custom API endpoint.
This is useful for self-hosted or custom rerank services.
Rerank documents using Jina AI API.
Args:
query: The search query
documents: List of strings to rerank
top_n: Number of top results to return
api_key: API key
model: rerank model name
base_url: API endpoint
extra_body: Additional body for http request(reserved for extra params)
Returns:
List of dictionary of ["index": int, "relevance_score": float]
"""
if api_key is None:
api_key = os.getenv("JINA_API_KEY") or os.getenv("RERANK_BINDING_API_KEY")
return await generic_rerank_api(
query=query,
documents=documents,
@ -299,26 +243,112 @@ async def custom_rerank(
base_url=base_url,
api_key=api_key,
top_n=top_n,
**kwargs,
return_documents=False,
extra_body=extra_body,
response_format="standard",
)
async def ali_rerank(
query: str,
documents: List[str],
top_n: Optional[int] = None,
api_key: Optional[str] = None,
model: str = "gte-rerank-v2",
base_url: str = "https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank",
extra_body: Optional[Dict[str, Any]] = None,
) -> List[Dict[str, Any]]:
"""
Rerank documents using Aliyun DashScope API.
Args:
query: The search query
documents: List of strings to rerank
top_n: Number of top results to return
api_key: Aliyun API key
model: rerank model name
base_url: API endpoint
extra_body: Additional body for http request(reserved for extra params)
Returns:
List of dictionary of ["index": int, "relevance_score": float]
"""
if api_key is None:
api_key = os.getenv("DASHSCOPE_API_KEY") or os.getenv("RERANK_BINDING_API_KEY")
return await generic_rerank_api(
query=query,
documents=documents,
model=model,
base_url=base_url,
api_key=api_key,
top_n=top_n,
return_documents=False, # Aliyun doesn't need this parameter
extra_body=extra_body,
response_format="aliyun",
request_format="aliyun",
)
"""Please run this test as a module:
python -m lightrag.rerank
"""
if __name__ == "__main__":
import asyncio
async def main():
# Example usage
# Example usage - documents should be strings, not dictionaries
docs = [
{"content": "The capital of France is Paris."},
{"content": "Tokyo is the capital of Japan."},
{"content": "London is the capital of England."},
"The capital of France is Paris.",
"Tokyo is the capital of Japan.",
"London is the capital of England.",
]
query = "What is the capital of France?"
result = await jina_rerank(
query=query, documents=docs, top_n=2, api_key="your-api-key-here"
)
print(result)
# Test Jina rerank
try:
print("=== Jina Rerank ===")
result = await jina_rerank(
query=query,
documents=docs,
top_n=2,
)
print("Results:")
for item in result:
print(f"Index: {item['index']}, Score: {item['relevance_score']:.4f}")
print(f"Document: {docs[item['index']]}")
except Exception as e:
print(f"Jina Error: {e}")
# Test Cohere rerank
try:
print("\n=== Cohere Rerank ===")
result = await cohere_rerank(
query=query,
documents=docs,
top_n=2,
)
print("Results:")
for item in result:
print(f"Index: {item['index']}, Score: {item['relevance_score']:.4f}")
print(f"Document: {docs[item['index']]}")
except Exception as e:
print(f"Cohere Error: {e}")
# Test Aliyun rerank
try:
print("\n=== Aliyun Rerank ===")
result = await ali_rerank(
query=query,
documents=docs,
top_n=2,
)
print("Results:")
for item in result:
print(f"Index: {item['index']}, Score: {item['relevance_score']:.4f}")
print(f"Document: {docs[item['index']]}")
except Exception as e:
print(f"Aliyun Error: {e}")
asyncio.run(main())

View file

@ -1978,17 +1978,50 @@ async def apply_rerank_if_enabled(
return retrieved_docs
try:
# Apply reranking - let rerank_model_func handle top_k internally
reranked_docs = await rerank_func(
# Extract document content for reranking
document_texts = []
for doc in retrieved_docs:
# Try multiple possible content fields
content = (
doc.get("content")
or doc.get("text")
or doc.get("chunk_content")
or doc.get("document")
or str(doc)
)
document_texts.append(content)
# Call the new rerank function that returns index-based results
rerank_results = await rerank_func(
query=query,
documents=retrieved_docs,
documents=document_texts,
top_n=top_n,
)
if reranked_docs and len(reranked_docs) > 0:
if len(reranked_docs) > top_n:
reranked_docs = reranked_docs[:top_n]
logger.info(f"Successfully reranked: {len(retrieved_docs)} chunks")
return reranked_docs
# Process rerank results based on return format
if rerank_results and len(rerank_results) > 0:
# Check if results are in the new index-based format
if isinstance(rerank_results[0], dict) and "index" in rerank_results[0]:
# New format: [{"index": 0, "relevance_score": 0.85}, ...]
reranked_docs = []
for result in rerank_results:
index = result["index"]
relevance_score = result["relevance_score"]
# Get original document and add rerank score
if 0 <= index < len(retrieved_docs):
doc = retrieved_docs[index].copy()
doc["rerank_score"] = relevance_score
reranked_docs.append(doc)
logger.info(
f"Successfully reranked: {len(reranked_docs)} chunks from {len(retrieved_docs)} original chunks"
)
return reranked_docs
else:
# Legacy format: assume it's already reranked documents
logger.info(f"Using legacy rerank format: {len(rerank_results)} chunks")
return rerank_results[:top_n] if top_n else rerank_results
else:
logger.warning("Rerank returned empty results, using original chunks")
return retrieved_docs
@ -2027,13 +2060,6 @@ async def process_chunks_unified(
# 1. Apply reranking if enabled and query is provided
if query_param.enable_rerank and query and unique_chunks:
# 保存 chunk_id 字段,因为 rerank 可能会丢失这个字段
chunk_ids = {}
for chunk in unique_chunks:
chunk_id = chunk.get("chunk_id")
if chunk_id:
chunk_ids[id(chunk)] = chunk_id
rerank_top_k = query_param.chunk_top_k or len(unique_chunks)
unique_chunks = await apply_rerank_if_enabled(
query=query,
@ -2043,11 +2069,6 @@ async def process_chunks_unified(
top_n=rerank_top_k,
)
# 恢复 chunk_id 字段
for chunk in unique_chunks:
if id(chunk) in chunk_ids:
chunk["chunk_id"] = chunk_ids[id(chunk)]
# 2. Filter by minimum rerank score if reranking is enabled
if query_param.enable_rerank and unique_chunks:
min_rerank_score = global_config.get("min_rerank_score", 0.5)
@ -2095,13 +2116,6 @@ async def process_chunks_unified(
original_count = len(unique_chunks)
# Keep chunk_id field, cause truncate_list_by_token_size will lose it
chunk_ids_map = {}
for i, chunk in enumerate(unique_chunks):
chunk_id = chunk.get("chunk_id")
if chunk_id:
chunk_ids_map[i] = chunk_id
unique_chunks = truncate_list_by_token_size(
unique_chunks,
key=lambda x: json.dumps(x, ensure_ascii=False),
@ -2109,11 +2123,6 @@ async def process_chunks_unified(
tokenizer=tokenizer,
)
# restore chunk_id feiled
for i, chunk in enumerate(unique_chunks):
if i in chunk_ids_map:
chunk["chunk_id"] = chunk_ids_map[i]
logger.debug(
f"Token truncation: {len(unique_chunks)} chunks from {original_count} "
f"(chunk available tokens: {chunk_token_limit}, source: {source_type})"