feat: Add multiple rerank provider support to LightRAG Server by adding new env vars and cli params

- Add --enable-rerank CLI argument and ENABLE_RERANK env var
- Simplify rerank configuration logic to only check enable flag and binding
- Update health endpoint to show enable_rerank and rerank_configured status
- Improve logging messages for rerank enable/disable states
- Maintain backward compatibility with default value True
This commit is contained in:
yangdx 2025-08-22 19:29:45 +08:00
parent 0019a3adc6
commit 580cb7906c
6 changed files with 368 additions and 568 deletions

View file

@ -1,281 +0,0 @@
# Rerank Integration Guide
LightRAG supports reranking functionality to improve retrieval quality by re-ordering documents based on their relevance to the query. Reranking is now controlled per query via the `enable_rerank` parameter (default: True).
## Quick Start
### Environment Variables
Set these variables in your `.env` file or environment for rerank model configuration:
```bash
# Rerank model configuration (required when enable_rerank=True in queries)
RERANK_MODEL=BAAI/bge-reranker-v2-m3
RERANK_BINDING_HOST=https://api.your-provider.com/v1/rerank
RERANK_BINDING_API_KEY=your_api_key_here
```
### Programmatic Configuration
```python
from lightrag import LightRAG, QueryParam
from lightrag.rerank import custom_rerank, RerankModel
# Method 1: Using a custom rerank function with all settings included
async def my_rerank_func(query: str, documents: list, top_n: int = None, **kwargs):
return await custom_rerank(
query=query,
documents=documents,
model="BAAI/bge-reranker-v2-m3",
base_url="https://api.your-provider.com/v1/rerank",
api_key="your_api_key_here",
top_n=top_n or 10, # Handle top_n within the function
**kwargs
)
rag = LightRAG(
working_dir="./rag_storage",
llm_model_func=your_llm_func,
embedding_func=your_embedding_func,
rerank_model_func=my_rerank_func, # Configure rerank function
)
# Query with rerank enabled (default)
result = await rag.aquery(
"your query",
param=QueryParam(enable_rerank=True) # Control rerank per query
)
# Query with rerank disabled
result = await rag.aquery(
"your query",
param=QueryParam(enable_rerank=False)
)
# Method 2: Using RerankModel wrapper
rerank_model = RerankModel(
rerank_func=custom_rerank,
kwargs={
"model": "BAAI/bge-reranker-v2-m3",
"base_url": "https://api.your-provider.com/v1/rerank",
"api_key": "your_api_key_here",
}
)
rag = LightRAG(
working_dir="./rag_storage",
llm_model_func=your_llm_func,
embedding_func=your_embedding_func,
rerank_model_func=rerank_model.rerank,
)
# Control rerank per query
result = await rag.aquery(
"your query",
param=QueryParam(
enable_rerank=True, # Enable rerank for this query
chunk_top_k=5 # Number of chunks to keep after reranking
)
)
```
## Supported Providers
### 1. Custom/Generic API (Recommended)
For Jina/Cohere compatible APIs:
```python
from lightrag.rerank import custom_rerank
# Your custom API endpoint
result = await custom_rerank(
query="your query",
documents=documents,
model="BAAI/bge-reranker-v2-m3",
base_url="https://api.your-provider.com/v1/rerank",
api_key="your_api_key_here",
top_n=10
)
```
### 2. Jina AI
```python
from lightrag.rerank import jina_rerank
result = await jina_rerank(
query="your query",
documents=documents,
model="BAAI/bge-reranker-v2-m3",
api_key="your_jina_api_key",
top_n=10
)
```
### 3. Cohere
```python
from lightrag.rerank import cohere_rerank
result = await cohere_rerank(
query="your query",
documents=documents,
model="rerank-english-v2.0",
api_key="your_cohere_api_key",
top_n=10
)
```
## Integration Points
Reranking is automatically applied at these key retrieval stages:
1. **Naive Mode**: After vector similarity search in `_get_vector_context`
2. **Local Mode**: After entity retrieval in `_get_node_data`
3. **Global Mode**: After relationship retrieval in `_get_edge_data`
4. **Hybrid/Mix Modes**: Applied to all relevant components
## Configuration Parameters
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `enable_rerank` | bool | False | Enable/disable reranking |
| `rerank_model_func` | callable | None | Custom rerank function containing all configurations (model, API keys, top_n, etc.) |
## Example Usage
### Basic Usage
```python
import asyncio
from lightrag import LightRAG, QueryParam
from lightrag.llm.openai import gpt_4o_mini_complete, openai_embedding
from lightrag.kg.shared_storage import initialize_pipeline_status
from lightrag.rerank import jina_rerank
async def my_rerank_func(query: str, documents: list, top_n: int = None, **kwargs):
"""Custom rerank function with all settings included"""
return await jina_rerank(
query=query,
documents=documents,
model="BAAI/bge-reranker-v2-m3",
api_key="your_jina_api_key_here",
top_n=top_n or 10, # Default top_n if not provided
**kwargs
)
async def main():
# Initialize with rerank enabled
rag = LightRAG(
working_dir="./rag_storage",
llm_model_func=gpt_4o_mini_complete,
embedding_func=openai_embedding,
rerank_model_func=my_rerank_func,
)
await rag.initialize_storages()
await initialize_pipeline_status()
# Insert documents
await rag.ainsert([
"Document 1 content...",
"Document 2 content...",
])
# Query with rerank (automatically applied)
result = await rag.aquery(
"Your question here",
param=QueryParam(enable_rerank=True) # This top_n is passed to rerank function
)
print(result)
asyncio.run(main())
```
### Direct Rerank Usage
```python
from lightrag.rerank import custom_rerank
async def test_rerank():
documents = [
{"content": "Text about topic A"},
{"content": "Text about topic B"},
{"content": "Text about topic C"},
]
reranked = await custom_rerank(
query="Tell me about topic A",
documents=documents,
model="BAAI/bge-reranker-v2-m3",
base_url="https://api.your-provider.com/v1/rerank",
api_key="your_api_key_here",
top_n=2
)
for doc in reranked:
print(f"Score: {doc.get('rerank_score')}, Content: {doc.get('content')}")
```
## Best Practices
1. **Self-Contained Functions**: Include all necessary configurations (API keys, models, top_n handling) within your rerank function
2. **Performance**: Use reranking selectively for better performance vs. quality tradeoff
3. **API Limits**: Monitor API usage and implement rate limiting within your rerank function
4. **Fallback**: Always handle rerank failures gracefully (returns original results)
5. **Top-n Handling**: Handle top_n parameter appropriately within your rerank function
6. **Cost Management**: Consider rerank API costs in your budget planning
## Troubleshooting
### Common Issues
1. **API Key Missing**: Ensure API keys are properly configured within your rerank function
2. **Network Issues**: Check API endpoints and network connectivity
3. **Model Errors**: Verify the rerank model name is supported by your API
4. **Document Format**: Ensure documents have `content` or `text` fields
### Debug Mode
Enable debug logging to see rerank operations:
```python
import logging
logging.getLogger("lightrag.rerank").setLevel(logging.DEBUG)
```
### Error Handling
The rerank integration includes automatic fallback:
```python
# If rerank fails, original documents are returned
# No exceptions are raised to the user
# Errors are logged for debugging
```
## API Compatibility
The generic rerank API expects this response format:
```json
{
"results": [
{
"index": 0,
"relevance_score": 0.95
},
{
"index": 2,
"relevance_score": 0.87
}
]
}
```
This is compatible with:
- Jina AI Rerank API
- Cohere Rerank API
- Custom APIs following the same format

View file

@ -85,16 +85,36 @@ ENABLE_LLM_CACHE=true
### If reranking is enabled, the impact of chunk selection strategies will be diminished. ### If reranking is enabled, the impact of chunk selection strategies will be diminished.
# KG_CHUNK_PICK_METHOD=VECTOR # KG_CHUNK_PICK_METHOD=VECTOR
#########################################################
### Reranking configuration ### Reranking configuration
### Reranker Set ENABLE_RERANK to true in reranking model is configed ### RERANK_BINDING type: cohere, jina, aliyun
# ENABLE_RERANK=True ### For rerank model deployed by vLLM use cohere binding
### Minimum rerank score for document chunk exclusion (set to 0.0 to keep all chunks, 0.6 or above if LLM is not strong enought) #########################################################
ENABLE_RERANK=False
RERANK_BINDING=cohere
### rerank score chunk filter(set to 0.0 to keep all chunks, 0.6 or above if LLM is not strong enought)
# MIN_RERANK_SCORE=0.0 # MIN_RERANK_SCORE=0.0
### Rerank model configuration (required when ENABLE_RERANK=True)
# RERANK_MODEL=jina-reranker-v2-base-multilingual ### For local deployment
# RERANK_MODEL=BAAI/bge-reranker-v2-m3
# RERANK_BINDING_HOST=http://localhost:8000
# RERANK_BINDING_API_KEY=your_rerank_api_key_here
### Default value for Cohere AI
# RERANK_MODEL=rerank-v3.5
# RERANK_BINDING_HOST=https://ai.znipower.com:5017/rerank
# RERANK_BINDING_API_KEY=your_rerank_api_key_here
### Default value for Jina AI
# RERANK_MODELjina-reranker-v2-base-multilingual
# RERANK_BINDING_HOST=https://api.jina.ai/v1/rerank # RERANK_BINDING_HOST=https://api.jina.ai/v1/rerank
# RERANK_BINDING_API_KEY=your_rerank_api_key_here # RERANK_BINDING_API_KEY=your_rerank_api_key_here
### Default value for Aliyun
# RERANK_MODEL=gte-rerank-v2
# RERANK_BINDING_HOST=https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank
# RERANK_BINDING_API_KEY=your_rerank_api_key_here
######################################## ########################################
### Document processing configuration ### Document processing configuration
######################################## ########################################

View file

@ -225,6 +225,19 @@ def parse_args() -> argparse.Namespace:
choices=["lollms", "ollama", "openai", "azure_openai", "aws_bedrock", "jina"], choices=["lollms", "ollama", "openai", "azure_openai", "aws_bedrock", "jina"],
help="Embedding binding type (default: from env or ollama)", help="Embedding binding type (default: from env or ollama)",
) )
parser.add_argument(
"--rerank-binding",
type=str,
default=get_env_value("RERANK_BINDING", "cohere"),
choices=["cohere", "jina", "aliyun"],
help="Rerank binding type (default: from env or cohere)",
)
parser.add_argument(
"--enable-rerank",
action="store_true",
default=get_env_value("ENABLE_RERANK", True, bool),
help="Enable rerank functionality (default: from env or True)",
)
# Conditionally add binding options defined in binding_options module # Conditionally add binding options defined in binding_options module
# This will add command line arguments for all binding options (e.g., --ollama-embedding-num_ctx) # This will add command line arguments for all binding options (e.g., --ollama-embedding-num_ctx)
@ -340,6 +353,7 @@ def parse_args() -> argparse.Namespace:
args.rerank_model = get_env_value("RERANK_MODEL", "BAAI/bge-reranker-v2-m3") args.rerank_model = get_env_value("RERANK_MODEL", "BAAI/bge-reranker-v2-m3")
args.rerank_binding_host = get_env_value("RERANK_BINDING_HOST", None) args.rerank_binding_host = get_env_value("RERANK_BINDING_HOST", None)
args.rerank_binding_api_key = get_env_value("RERANK_BINDING_API_KEY", None) args.rerank_binding_api_key = get_env_value("RERANK_BINDING_API_KEY", None)
# Note: rerank_binding is already set by argparse, no need to override from env
# Min rerank score configuration # Min rerank score configuration
args.min_rerank_score = get_env_value( args.min_rerank_score = get_env_value(

View file

@ -390,33 +390,44 @@ def create_app(args):
), ),
) )
# Configure rerank function if model and API are configured # Configure rerank function based on enable_rerank parameter
rerank_model_func = None rerank_model_func = None
if args.rerank_binding_api_key and args.rerank_binding_host: if args.enable_rerank and args.rerank_binding:
from lightrag.rerank import custom_rerank from lightrag.rerank import cohere_rerank, jina_rerank, ali_rerank
# Map rerank binding to corresponding function
rerank_functions = {
"cohere": cohere_rerank,
"jina": jina_rerank,
"aliyun": ali_rerank,
}
# Select the appropriate rerank function based on binding
selected_rerank_func = rerank_functions.get(args.rerank_binding)
if not selected_rerank_func:
logger.error(f"Unsupported rerank binding: {args.rerank_binding}")
raise ValueError(f"Unsupported rerank binding: {args.rerank_binding}")
async def server_rerank_func( async def server_rerank_func(
query: str, documents: list, top_n: int = None, **kwargs query: str, documents: list, top_n: int = None, extra_body: dict = None
): ):
"""Server rerank function with configuration from environment variables""" """Server rerank function with configuration from environment variables"""
return await custom_rerank( return await selected_rerank_func(
query=query, query=query,
documents=documents, documents=documents,
model=args.rerank_model, model=args.rerank_model,
base_url=args.rerank_binding_host, base_url=args.rerank_binding_host,
api_key=args.rerank_binding_api_key, api_key=args.rerank_binding_api_key,
top_n=top_n, top_n=top_n,
**kwargs, extra_body=extra_body,
) )
rerank_model_func = server_rerank_func rerank_model_func = server_rerank_func
logger.info( logger.info(
f"Rerank model configured: {args.rerank_model} (can be enabled per query)" f"Rerank enabled: {args.rerank_model} using {args.rerank_binding} provider"
) )
else: else:
logger.info( logger.info("Rerank disabled")
"Rerank model not configured. Set RERANK_BINDING_API_KEY and RERANK_BINDING_HOST to enable reranking."
)
# Create ollama_server_infos from command line arguments # Create ollama_server_infos from command line arguments
from lightrag.api.config import OllamaServerInfos from lightrag.api.config import OllamaServerInfos
@ -622,13 +633,15 @@ def create_app(args):
"enable_llm_cache": args.enable_llm_cache, "enable_llm_cache": args.enable_llm_cache,
"workspace": args.workspace, "workspace": args.workspace,
"max_graph_nodes": args.max_graph_nodes, "max_graph_nodes": args.max_graph_nodes,
# Rerank configuration (based on whether rerank model is configured) # Rerank configuration
"enable_rerank": rerank_model_func is not None, "enable_rerank": args.enable_rerank,
"rerank_model": args.rerank_model "rerank_configured": rerank_model_func is not None,
if rerank_model_func is not None "rerank_binding": args.rerank_binding
if args.enable_rerank
else None, else None,
"rerank_model": args.rerank_model if args.enable_rerank else None,
"rerank_binding_host": args.rerank_binding_host "rerank_binding_host": args.rerank_binding_host
if rerank_model_func is not None if args.enable_rerank
else None, else None,
# Environment variable status (requested configuration) # Environment variable status (requested configuration)
"summary_language": args.summary_language, "summary_language": args.summary_language,

View file

@ -2,270 +2,194 @@ from __future__ import annotations
import os import os
import aiohttp import aiohttp
from typing import Callable, Any, List, Dict, Optional from typing import Any, List, Dict, Optional
from pydantic import BaseModel, Field from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)
from .utils import logger from .utils import logger
from dotenv import load_dotenv
class RerankModel(BaseModel): # use the .env that is inside the current folder
""" # allows to use different .env file for each lightrag instance
Wrapper for rerank functions that can be used with LightRAG. # the OS environment variables take precedence over the .env file
load_dotenv(dotenv_path=".env", override=False)
Example usage:
```python
from lightrag.rerank import RerankModel, jina_rerank
# Create rerank model
rerank_model = RerankModel(
rerank_func=jina_rerank,
kwargs={
"model": "BAAI/bge-reranker-v2-m3",
"api_key": "your_api_key_here",
"base_url": "https://api.jina.ai/v1/rerank"
}
)
# Use in LightRAG
rag = LightRAG(
rerank_model_func=rerank_model.rerank,
# ... other configurations
)
# Query with rerank enabled (default)
result = await rag.aquery(
"your query",
param=QueryParam(enable_rerank=True)
)
```
Or define a custom function directly:
```python
async def my_rerank_func(query: str, documents: list, top_n: int = None, **kwargs):
return await jina_rerank(
query=query,
documents=documents,
model="BAAI/bge-reranker-v2-m3",
api_key="your_api_key_here",
top_n=top_n or 10,
**kwargs
)
rag = LightRAG(
rerank_model_func=my_rerank_func,
# ... other configurations
)
# Control rerank per query
result = await rag.aquery(
"your query",
param=QueryParam(enable_rerank=True) # Enable rerank for this query
)
```
"""
rerank_func: Callable[[Any], List[Dict]]
kwargs: Dict[str, Any] = Field(default_factory=dict)
async def rerank(
self,
query: str,
documents: List[Dict[str, Any]],
top_n: Optional[int] = None,
**extra_kwargs,
) -> List[Dict[str, Any]]:
"""Rerank documents using the configured model function."""
# Merge extra kwargs with model kwargs
kwargs = {**self.kwargs, **extra_kwargs}
return await self.rerank_func(
query=query, documents=documents, top_n=top_n, **kwargs
)
class MultiRerankModel(BaseModel):
"""Multiple rerank models for different modes/scenarios."""
# Primary rerank model (used if mode-specific models are not defined)
rerank_model: Optional[RerankModel] = None
# Mode-specific rerank models
entity_rerank_model: Optional[RerankModel] = None
relation_rerank_model: Optional[RerankModel] = None
chunk_rerank_model: Optional[RerankModel] = None
async def rerank(
self,
query: str,
documents: List[Dict[str, Any]],
mode: str = "default",
top_n: Optional[int] = None,
**kwargs,
) -> List[Dict[str, Any]]:
"""Rerank using the appropriate model based on mode."""
# Select model based on mode
if mode == "entity" and self.entity_rerank_model:
model = self.entity_rerank_model
elif mode == "relation" and self.relation_rerank_model:
model = self.relation_rerank_model
elif mode == "chunk" and self.chunk_rerank_model:
model = self.chunk_rerank_model
elif self.rerank_model:
model = self.rerank_model
else:
logger.warning(f"No rerank model available for mode: {mode}")
return documents
return await model.rerank(query, documents, top_n, **kwargs)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=60),
retry=(
retry_if_exception_type(aiohttp.ClientError)
| retry_if_exception_type(aiohttp.ClientResponseError)
),
)
async def generic_rerank_api( async def generic_rerank_api(
query: str, query: str,
documents: List[Dict[str, Any]], documents: List[str],
model: str, model: str,
base_url: str, base_url: str,
api_key: str, api_key: str,
top_n: Optional[int] = None, top_n: Optional[int] = None,
**kwargs, return_documents: Optional[bool] = None,
extra_body: Optional[Dict[str, Any]] = None,
response_format: str = "standard", # "standard" (Jina/Cohere) or "aliyun"
request_format: str = "standard", # "standard" (Jina/Cohere) or "aliyun"
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
""" """
Generic rerank function that works with Jina/Cohere compatible APIs. Generic rerank API call for Jina/Cohere/Aliyun models.
Args: Args:
query: The search query query: The search query
documents: List of documents to rerank documents: List of strings to rerank
model: Model identifier model: Model name to use
base_url: API endpoint URL base_url: API endpoint URL
api_key: API authentication key api_key: API key for authentication
top_n: Number of top results to return top_n: Number of top results to return
**kwargs: Additional API-specific parameters return_documents: Whether to return document text (Jina only)
extra_body: Additional body parameters
response_format: Response format type ("standard" for Jina/Cohere, "aliyun" for Aliyun)
Returns: Returns:
List of reranked documents with relevance scores List of dictionary of ["index": int, "relevance_score": float]
""" """
if not api_key: if not api_key:
logger.warning("No API key provided for rerank service") raise ValueError("API key is required")
return documents
if not documents: headers = {
return documents "Content-Type": "application/json",
"Authorization": f"Bearer {api_key}",
}
# Prepare documents for reranking - handle both text and dict formats # Build request payload based on request format
prepared_docs = [] if request_format == "aliyun":
for doc in documents: # Aliyun format: nested input/parameters structure
if isinstance(doc, dict): payload = {
# Use 'content' field if available, otherwise use 'text' or convert to string "model": model,
text = doc.get("content") or doc.get("text") or str(doc) "input": {
else: "query": query,
text = str(doc) "documents": documents,
prepared_docs.append(text) },
"parameters": {},
}
# Prepare request # Add optional parameters to parameters object
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"} if top_n is not None:
payload["parameters"]["top_n"] = top_n
data = {"model": model, "query": query, "documents": prepared_docs, **kwargs} if return_documents is not None:
payload["parameters"]["return_documents"] = return_documents
if top_n is not None: # Add extra parameters to parameters object
data["top_n"] = min(top_n, len(prepared_docs)) if extra_body:
payload["parameters"].update(extra_body)
else:
# Standard format for Jina/Cohere
payload = {
"model": model,
"query": query,
"documents": documents,
}
try: # Add optional parameters
async with aiohttp.ClientSession() as session: if top_n is not None:
async with session.post(base_url, headers=headers, json=data) as response: payload["top_n"] = top_n
if response.status != 200:
error_text = await response.text()
logger.error(f"Rerank API error {response.status}: {error_text}")
return documents
result = await response.json() # Only Jina API supports return_documents parameter
if return_documents is not None:
payload["return_documents"] = return_documents
# Extract reranked results # Add extra parameters
if "results" in result: if extra_body:
# Standard format: results contain index and relevance_score payload.update(extra_body)
reranked_docs = []
for item in result["results"]:
if "index" in item:
doc_idx = item["index"]
if 0 <= doc_idx < len(documents):
reranked_doc = documents[doc_idx].copy()
if "relevance_score" in item:
reranked_doc["rerank_score"] = item[
"relevance_score"
]
reranked_docs.append(reranked_doc)
return reranked_docs
else:
logger.warning("Unexpected rerank API response format")
return documents
except Exception as e: logger.debug(
logger.error(f"Error during reranking: {e}") f"Rerank request: {len(documents)} documents, model: {model}, format: {response_format}"
return documents
async def jina_rerank(
query: str,
documents: List[Dict[str, Any]],
model: str = "BAAI/bge-reranker-v2-m3",
top_n: Optional[int] = None,
base_url: str = "https://api.jina.ai/v1/rerank",
api_key: Optional[str] = None,
**kwargs,
) -> List[Dict[str, Any]]:
"""
Rerank documents using Jina AI API.
Args:
query: The search query
documents: List of documents to rerank
model: Jina rerank model name
top_n: Number of top results to return
base_url: Jina API endpoint
api_key: Jina API key
**kwargs: Additional parameters
Returns:
List of reranked documents with relevance scores
"""
if api_key is None:
api_key = os.getenv("JINA_API_KEY") or os.getenv("RERANK_API_KEY")
return await generic_rerank_api(
query=query,
documents=documents,
model=model,
base_url=base_url,
api_key=api_key,
top_n=top_n,
**kwargs,
) )
async with aiohttp.ClientSession() as session:
async with session.post(base_url, headers=headers, json=payload) as response:
if response.status != 200:
error_text = await response.text()
content_type = response.headers.get("content-type", "").lower()
is_html_error = (
error_text.strip().startswith("<!DOCTYPE html>")
or "text/html" in content_type
)
if is_html_error:
if response.status == 502:
clean_error = "Bad Gateway (502) - Rerank service temporarily unavailable. Please try again in a few minutes."
elif response.status == 503:
clean_error = "Service Unavailable (503) - Rerank service is temporarily overloaded. Please try again later."
elif response.status == 504:
clean_error = "Gateway Timeout (504) - Rerank service request timed out. Please try again."
else:
clean_error = f"HTTP {response.status} - Rerank service error. Please try again later."
else:
clean_error = error_text
logger.error(f"Rerank API error {response.status}: {clean_error}")
raise aiohttp.ClientResponseError(
request_info=response.request_info,
history=response.history,
status=response.status,
message=f"Rerank API error: {clean_error}",
)
response_json = await response.json()
# Handle different response formats
if response_format == "aliyun":
# Aliyun format: {"output": {"results": [...]}}
output = response_json.get("output", {})
results = output.get("results", [])
elif response_format == "standard":
# Standard format: {"results": [...]}
results = response_json.get("results", [])
else:
raise ValueError(f"Unsupported response format: {response_format}")
if not results:
logger.warning("Rerank API returned empty results")
return []
# Standardize return format
return [
{"index": result["index"], "relevance_score": result["relevance_score"]}
for result in results
]
async def cohere_rerank( async def cohere_rerank(
query: str, query: str,
documents: List[Dict[str, Any]], documents: List[str],
model: str = "rerank-english-v2.0",
top_n: Optional[int] = None, top_n: Optional[int] = None,
base_url: str = "https://api.cohere.ai/v1/rerank",
api_key: Optional[str] = None, api_key: Optional[str] = None,
**kwargs, model: str = "rerank-v3.5",
base_url: str = "https://ai.znipower.com:5017/rerank",
extra_body: Optional[Dict[str, Any]] = None,
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
""" """
Rerank documents using Cohere API. Rerank documents using Cohere API.
Args: Args:
query: The search query query: The search query
documents: List of documents to rerank documents: List of strings to rerank
model: Cohere rerank model name
top_n: Number of top results to return top_n: Number of top results to return
base_url: Cohere API endpoint api_key: API key
api_key: Cohere API key model: rerank model name
**kwargs: Additional parameters base_url: API endpoint
extra_body: Additional body for http request(reserved for extra params)
Returns: Returns:
List of reranked documents with relevance scores List of dictionary of ["index": int, "relevance_score": float]
""" """
if api_key is None: if api_key is None:
api_key = os.getenv("COHERE_API_KEY") or os.getenv("RERANK_API_KEY") api_key = os.getenv("COHERE_API_KEY") or os.getenv("RERANK_BINDING_API_KEY")
return await generic_rerank_api( return await generic_rerank_api(
query=query, query=query,
@ -274,24 +198,39 @@ async def cohere_rerank(
base_url=base_url, base_url=base_url,
api_key=api_key, api_key=api_key,
top_n=top_n, top_n=top_n,
**kwargs, return_documents=None, # Cohere doesn't support this parameter
extra_body=extra_body,
response_format="standard",
) )
# Convenience function for custom API endpoints async def jina_rerank(
async def custom_rerank(
query: str, query: str,
documents: List[Dict[str, Any]], documents: List[str],
model: str,
base_url: str,
api_key: str,
top_n: Optional[int] = None, top_n: Optional[int] = None,
**kwargs, api_key: Optional[str] = None,
model: str = "jina-reranker-v2-base-multilingual",
base_url: str = "https://api.jina.ai/v1/rerank",
extra_body: Optional[Dict[str, Any]] = None,
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
""" """
Rerank documents using a custom API endpoint. Rerank documents using Jina AI API.
This is useful for self-hosted or custom rerank services.
Args:
query: The search query
documents: List of strings to rerank
top_n: Number of top results to return
api_key: API key
model: rerank model name
base_url: API endpoint
extra_body: Additional body for http request(reserved for extra params)
Returns:
List of dictionary of ["index": int, "relevance_score": float]
""" """
if api_key is None:
api_key = os.getenv("JINA_API_KEY") or os.getenv("RERANK_BINDING_API_KEY")
return await generic_rerank_api( return await generic_rerank_api(
query=query, query=query,
documents=documents, documents=documents,
@ -299,26 +238,112 @@ async def custom_rerank(
base_url=base_url, base_url=base_url,
api_key=api_key, api_key=api_key,
top_n=top_n, top_n=top_n,
**kwargs, return_documents=False,
extra_body=extra_body,
response_format="standard",
) )
async def ali_rerank(
query: str,
documents: List[str],
top_n: Optional[int] = None,
api_key: Optional[str] = None,
model: str = "gte-rerank-v2",
base_url: str = "https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank",
extra_body: Optional[Dict[str, Any]] = None,
) -> List[Dict[str, Any]]:
"""
Rerank documents using Aliyun DashScope API.
Args:
query: The search query
documents: List of strings to rerank
top_n: Number of top results to return
api_key: Aliyun API key
model: rerank model name
base_url: API endpoint
extra_body: Additional body for http request(reserved for extra params)
Returns:
List of dictionary of ["index": int, "relevance_score": float]
"""
if api_key is None:
api_key = os.getenv("DASHSCOPE_API_KEY") or os.getenv("RERANK_BINDING_API_KEY")
return await generic_rerank_api(
query=query,
documents=documents,
model=model,
base_url=base_url,
api_key=api_key,
top_n=top_n,
return_documents=False, # Aliyun doesn't need this parameter
extra_body=extra_body,
response_format="aliyun",
request_format="aliyun",
)
"""Please run this test as a module:
python -m lightrag.rerank
"""
if __name__ == "__main__": if __name__ == "__main__":
import asyncio import asyncio
async def main(): async def main():
# Example usage # Example usage - documents should be strings, not dictionaries
docs = [ docs = [
{"content": "The capital of France is Paris."}, "The capital of France is Paris.",
{"content": "Tokyo is the capital of Japan."}, "Tokyo is the capital of Japan.",
{"content": "London is the capital of England."}, "London is the capital of England.",
] ]
query = "What is the capital of France?" query = "What is the capital of France?"
result = await jina_rerank( # Test Jina rerank
query=query, documents=docs, top_n=2, api_key="your-api-key-here" try:
) print("=== Jina Rerank ===")
print(result) result = await jina_rerank(
query=query,
documents=docs,
top_n=2,
)
print("Results:")
for item in result:
print(f"Index: {item['index']}, Score: {item['relevance_score']:.4f}")
print(f"Document: {docs[item['index']]}")
except Exception as e:
print(f"Jina Error: {e}")
# Test Cohere rerank
try:
print("\n=== Cohere Rerank ===")
result = await cohere_rerank(
query=query,
documents=docs,
top_n=2,
)
print("Results:")
for item in result:
print(f"Index: {item['index']}, Score: {item['relevance_score']:.4f}")
print(f"Document: {docs[item['index']]}")
except Exception as e:
print(f"Cohere Error: {e}")
# Test Aliyun rerank
try:
print("\n=== Aliyun Rerank ===")
result = await ali_rerank(
query=query,
documents=docs,
top_n=2,
)
print("Results:")
for item in result:
print(f"Index: {item['index']}, Score: {item['relevance_score']:.4f}")
print(f"Document: {docs[item['index']]}")
except Exception as e:
print(f"Aliyun Error: {e}")
asyncio.run(main()) asyncio.run(main())

View file

@ -1978,17 +1978,50 @@ async def apply_rerank_if_enabled(
return retrieved_docs return retrieved_docs
try: try:
# Apply reranking - let rerank_model_func handle top_k internally # Extract document content for reranking
reranked_docs = await rerank_func( document_texts = []
for doc in retrieved_docs:
# Try multiple possible content fields
content = (
doc.get("content")
or doc.get("text")
or doc.get("chunk_content")
or doc.get("document")
or str(doc)
)
document_texts.append(content)
# Call the new rerank function that returns index-based results
rerank_results = await rerank_func(
query=query, query=query,
documents=retrieved_docs, documents=document_texts,
top_n=top_n, top_n=top_n or len(retrieved_docs),
) )
if reranked_docs and len(reranked_docs) > 0:
if len(reranked_docs) > top_n: # Process rerank results based on return format
reranked_docs = reranked_docs[:top_n] if rerank_results and len(rerank_results) > 0:
logger.info(f"Successfully reranked: {len(retrieved_docs)} chunks") # Check if results are in the new index-based format
return reranked_docs if isinstance(rerank_results[0], dict) and "index" in rerank_results[0]:
# New format: [{"index": 0, "relevance_score": 0.85}, ...]
reranked_docs = []
for result in rerank_results:
index = result["index"]
relevance_score = result["relevance_score"]
# Get original document and add rerank score
if 0 <= index < len(retrieved_docs):
doc = retrieved_docs[index].copy()
doc["rerank_score"] = relevance_score
reranked_docs.append(doc)
logger.info(
f"Successfully reranked: {len(reranked_docs)} chunks from {len(retrieved_docs)} original chunks"
)
return reranked_docs
else:
# Legacy format: assume it's already reranked documents
logger.info(f"Using legacy rerank format: {len(rerank_results)} chunks")
return rerank_results[:top_n] if top_n else rerank_results
else: else:
logger.warning("Rerank returned empty results, using original chunks") logger.warning("Rerank returned empty results, using original chunks")
return retrieved_docs return retrieved_docs
@ -2027,13 +2060,6 @@ async def process_chunks_unified(
# 1. Apply reranking if enabled and query is provided # 1. Apply reranking if enabled and query is provided
if query_param.enable_rerank and query and unique_chunks: if query_param.enable_rerank and query and unique_chunks:
# 保存 chunk_id 字段,因为 rerank 可能会丢失这个字段
chunk_ids = {}
for chunk in unique_chunks:
chunk_id = chunk.get("chunk_id")
if chunk_id:
chunk_ids[id(chunk)] = chunk_id
rerank_top_k = query_param.chunk_top_k or len(unique_chunks) rerank_top_k = query_param.chunk_top_k or len(unique_chunks)
unique_chunks = await apply_rerank_if_enabled( unique_chunks = await apply_rerank_if_enabled(
query=query, query=query,
@ -2043,11 +2069,6 @@ async def process_chunks_unified(
top_n=rerank_top_k, top_n=rerank_top_k,
) )
# 恢复 chunk_id 字段
for chunk in unique_chunks:
if id(chunk) in chunk_ids:
chunk["chunk_id"] = chunk_ids[id(chunk)]
# 2. Filter by minimum rerank score if reranking is enabled # 2. Filter by minimum rerank score if reranking is enabled
if query_param.enable_rerank and unique_chunks: if query_param.enable_rerank and unique_chunks:
min_rerank_score = global_config.get("min_rerank_score", 0.5) min_rerank_score = global_config.get("min_rerank_score", 0.5)
@ -2095,13 +2116,6 @@ async def process_chunks_unified(
original_count = len(unique_chunks) original_count = len(unique_chunks)
# Keep chunk_id field, cause truncate_list_by_token_size will lose it
chunk_ids_map = {}
for i, chunk in enumerate(unique_chunks):
chunk_id = chunk.get("chunk_id")
if chunk_id:
chunk_ids_map[i] = chunk_id
unique_chunks = truncate_list_by_token_size( unique_chunks = truncate_list_by_token_size(
unique_chunks, unique_chunks,
key=lambda x: json.dumps(x, ensure_ascii=False), key=lambda x: json.dumps(x, ensure_ascii=False),
@ -2109,11 +2123,6 @@ async def process_chunks_unified(
tokenizer=tokenizer, tokenizer=tokenizer,
) )
# restore chunk_id feiled
for i, chunk in enumerate(unique_chunks):
if i in chunk_ids_map:
chunk["chunk_id"] = chunk_ids_map[i]
logger.debug( logger.debug(
f"Token truncation: {len(unique_chunks)} chunks from {original_count} " f"Token truncation: {len(unique_chunks)} chunks from {original_count} "
f"(chunk available tokens: {chunk_token_limit}, source: {source_type})" f"(chunk available tokens: {chunk_token_limit}, source: {source_type})"