feat: Add multiple rerank provider support to LightRAG Server by adding new env vars and cli params
- Add --enable-rerank CLI argument and ENABLE_RERANK env var - Simplify rerank configuration logic to only check enable flag and binding - Update health endpoint to show enable_rerank and rerank_configured status - Improve logging messages for rerank enable/disable states - Maintain backward compatibility with default value True
This commit is contained in:
parent
0019a3adc6
commit
580cb7906c
6 changed files with 368 additions and 568 deletions
|
|
@ -1,281 +0,0 @@
|
||||||
# Rerank Integration Guide
|
|
||||||
|
|
||||||
LightRAG supports reranking functionality to improve retrieval quality by re-ordering documents based on their relevance to the query. Reranking is now controlled per query via the `enable_rerank` parameter (default: True).
|
|
||||||
|
|
||||||
## Quick Start
|
|
||||||
|
|
||||||
### Environment Variables
|
|
||||||
|
|
||||||
Set these variables in your `.env` file or environment for rerank model configuration:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Rerank model configuration (required when enable_rerank=True in queries)
|
|
||||||
RERANK_MODEL=BAAI/bge-reranker-v2-m3
|
|
||||||
RERANK_BINDING_HOST=https://api.your-provider.com/v1/rerank
|
|
||||||
RERANK_BINDING_API_KEY=your_api_key_here
|
|
||||||
```
|
|
||||||
|
|
||||||
### Programmatic Configuration
|
|
||||||
|
|
||||||
```python
|
|
||||||
from lightrag import LightRAG, QueryParam
|
|
||||||
from lightrag.rerank import custom_rerank, RerankModel
|
|
||||||
|
|
||||||
# Method 1: Using a custom rerank function with all settings included
|
|
||||||
async def my_rerank_func(query: str, documents: list, top_n: int = None, **kwargs):
|
|
||||||
return await custom_rerank(
|
|
||||||
query=query,
|
|
||||||
documents=documents,
|
|
||||||
model="BAAI/bge-reranker-v2-m3",
|
|
||||||
base_url="https://api.your-provider.com/v1/rerank",
|
|
||||||
api_key="your_api_key_here",
|
|
||||||
top_n=top_n or 10, # Handle top_n within the function
|
|
||||||
**kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
rag = LightRAG(
|
|
||||||
working_dir="./rag_storage",
|
|
||||||
llm_model_func=your_llm_func,
|
|
||||||
embedding_func=your_embedding_func,
|
|
||||||
rerank_model_func=my_rerank_func, # Configure rerank function
|
|
||||||
)
|
|
||||||
|
|
||||||
# Query with rerank enabled (default)
|
|
||||||
result = await rag.aquery(
|
|
||||||
"your query",
|
|
||||||
param=QueryParam(enable_rerank=True) # Control rerank per query
|
|
||||||
)
|
|
||||||
|
|
||||||
# Query with rerank disabled
|
|
||||||
result = await rag.aquery(
|
|
||||||
"your query",
|
|
||||||
param=QueryParam(enable_rerank=False)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Method 2: Using RerankModel wrapper
|
|
||||||
rerank_model = RerankModel(
|
|
||||||
rerank_func=custom_rerank,
|
|
||||||
kwargs={
|
|
||||||
"model": "BAAI/bge-reranker-v2-m3",
|
|
||||||
"base_url": "https://api.your-provider.com/v1/rerank",
|
|
||||||
"api_key": "your_api_key_here",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
rag = LightRAG(
|
|
||||||
working_dir="./rag_storage",
|
|
||||||
llm_model_func=your_llm_func,
|
|
||||||
embedding_func=your_embedding_func,
|
|
||||||
rerank_model_func=rerank_model.rerank,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Control rerank per query
|
|
||||||
result = await rag.aquery(
|
|
||||||
"your query",
|
|
||||||
param=QueryParam(
|
|
||||||
enable_rerank=True, # Enable rerank for this query
|
|
||||||
chunk_top_k=5 # Number of chunks to keep after reranking
|
|
||||||
)
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
## Supported Providers
|
|
||||||
|
|
||||||
### 1. Custom/Generic API (Recommended)
|
|
||||||
|
|
||||||
For Jina/Cohere compatible APIs:
|
|
||||||
|
|
||||||
```python
|
|
||||||
from lightrag.rerank import custom_rerank
|
|
||||||
|
|
||||||
# Your custom API endpoint
|
|
||||||
result = await custom_rerank(
|
|
||||||
query="your query",
|
|
||||||
documents=documents,
|
|
||||||
model="BAAI/bge-reranker-v2-m3",
|
|
||||||
base_url="https://api.your-provider.com/v1/rerank",
|
|
||||||
api_key="your_api_key_here",
|
|
||||||
top_n=10
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
### 2. Jina AI
|
|
||||||
|
|
||||||
```python
|
|
||||||
from lightrag.rerank import jina_rerank
|
|
||||||
|
|
||||||
result = await jina_rerank(
|
|
||||||
query="your query",
|
|
||||||
documents=documents,
|
|
||||||
model="BAAI/bge-reranker-v2-m3",
|
|
||||||
api_key="your_jina_api_key",
|
|
||||||
top_n=10
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
### 3. Cohere
|
|
||||||
|
|
||||||
```python
|
|
||||||
from lightrag.rerank import cohere_rerank
|
|
||||||
|
|
||||||
result = await cohere_rerank(
|
|
||||||
query="your query",
|
|
||||||
documents=documents,
|
|
||||||
model="rerank-english-v2.0",
|
|
||||||
api_key="your_cohere_api_key",
|
|
||||||
top_n=10
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
## Integration Points
|
|
||||||
|
|
||||||
Reranking is automatically applied at these key retrieval stages:
|
|
||||||
|
|
||||||
1. **Naive Mode**: After vector similarity search in `_get_vector_context`
|
|
||||||
2. **Local Mode**: After entity retrieval in `_get_node_data`
|
|
||||||
3. **Global Mode**: After relationship retrieval in `_get_edge_data`
|
|
||||||
4. **Hybrid/Mix Modes**: Applied to all relevant components
|
|
||||||
|
|
||||||
## Configuration Parameters
|
|
||||||
|
|
||||||
| Parameter | Type | Default | Description |
|
|
||||||
|-----------|------|---------|-------------|
|
|
||||||
| `enable_rerank` | bool | False | Enable/disable reranking |
|
|
||||||
| `rerank_model_func` | callable | None | Custom rerank function containing all configurations (model, API keys, top_n, etc.) |
|
|
||||||
|
|
||||||
## Example Usage
|
|
||||||
|
|
||||||
### Basic Usage
|
|
||||||
|
|
||||||
```python
|
|
||||||
import asyncio
|
|
||||||
from lightrag import LightRAG, QueryParam
|
|
||||||
from lightrag.llm.openai import gpt_4o_mini_complete, openai_embedding
|
|
||||||
from lightrag.kg.shared_storage import initialize_pipeline_status
|
|
||||||
from lightrag.rerank import jina_rerank
|
|
||||||
|
|
||||||
async def my_rerank_func(query: str, documents: list, top_n: int = None, **kwargs):
|
|
||||||
"""Custom rerank function with all settings included"""
|
|
||||||
return await jina_rerank(
|
|
||||||
query=query,
|
|
||||||
documents=documents,
|
|
||||||
model="BAAI/bge-reranker-v2-m3",
|
|
||||||
api_key="your_jina_api_key_here",
|
|
||||||
top_n=top_n or 10, # Default top_n if not provided
|
|
||||||
**kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
async def main():
|
|
||||||
# Initialize with rerank enabled
|
|
||||||
rag = LightRAG(
|
|
||||||
working_dir="./rag_storage",
|
|
||||||
llm_model_func=gpt_4o_mini_complete,
|
|
||||||
embedding_func=openai_embedding,
|
|
||||||
rerank_model_func=my_rerank_func,
|
|
||||||
)
|
|
||||||
|
|
||||||
await rag.initialize_storages()
|
|
||||||
await initialize_pipeline_status()
|
|
||||||
|
|
||||||
# Insert documents
|
|
||||||
await rag.ainsert([
|
|
||||||
"Document 1 content...",
|
|
||||||
"Document 2 content...",
|
|
||||||
])
|
|
||||||
|
|
||||||
# Query with rerank (automatically applied)
|
|
||||||
result = await rag.aquery(
|
|
||||||
"Your question here",
|
|
||||||
param=QueryParam(enable_rerank=True) # This top_n is passed to rerank function
|
|
||||||
)
|
|
||||||
|
|
||||||
print(result)
|
|
||||||
|
|
||||||
asyncio.run(main())
|
|
||||||
```
|
|
||||||
|
|
||||||
### Direct Rerank Usage
|
|
||||||
|
|
||||||
```python
|
|
||||||
from lightrag.rerank import custom_rerank
|
|
||||||
|
|
||||||
async def test_rerank():
|
|
||||||
documents = [
|
|
||||||
{"content": "Text about topic A"},
|
|
||||||
{"content": "Text about topic B"},
|
|
||||||
{"content": "Text about topic C"},
|
|
||||||
]
|
|
||||||
|
|
||||||
reranked = await custom_rerank(
|
|
||||||
query="Tell me about topic A",
|
|
||||||
documents=documents,
|
|
||||||
model="BAAI/bge-reranker-v2-m3",
|
|
||||||
base_url="https://api.your-provider.com/v1/rerank",
|
|
||||||
api_key="your_api_key_here",
|
|
||||||
top_n=2
|
|
||||||
)
|
|
||||||
|
|
||||||
for doc in reranked:
|
|
||||||
print(f"Score: {doc.get('rerank_score')}, Content: {doc.get('content')}")
|
|
||||||
```
|
|
||||||
|
|
||||||
## Best Practices
|
|
||||||
|
|
||||||
1. **Self-Contained Functions**: Include all necessary configurations (API keys, models, top_n handling) within your rerank function
|
|
||||||
2. **Performance**: Use reranking selectively for better performance vs. quality tradeoff
|
|
||||||
3. **API Limits**: Monitor API usage and implement rate limiting within your rerank function
|
|
||||||
4. **Fallback**: Always handle rerank failures gracefully (returns original results)
|
|
||||||
5. **Top-n Handling**: Handle top_n parameter appropriately within your rerank function
|
|
||||||
6. **Cost Management**: Consider rerank API costs in your budget planning
|
|
||||||
|
|
||||||
## Troubleshooting
|
|
||||||
|
|
||||||
### Common Issues
|
|
||||||
|
|
||||||
1. **API Key Missing**: Ensure API keys are properly configured within your rerank function
|
|
||||||
2. **Network Issues**: Check API endpoints and network connectivity
|
|
||||||
3. **Model Errors**: Verify the rerank model name is supported by your API
|
|
||||||
4. **Document Format**: Ensure documents have `content` or `text` fields
|
|
||||||
|
|
||||||
### Debug Mode
|
|
||||||
|
|
||||||
Enable debug logging to see rerank operations:
|
|
||||||
|
|
||||||
```python
|
|
||||||
import logging
|
|
||||||
logging.getLogger("lightrag.rerank").setLevel(logging.DEBUG)
|
|
||||||
```
|
|
||||||
|
|
||||||
### Error Handling
|
|
||||||
|
|
||||||
The rerank integration includes automatic fallback:
|
|
||||||
|
|
||||||
```python
|
|
||||||
# If rerank fails, original documents are returned
|
|
||||||
# No exceptions are raised to the user
|
|
||||||
# Errors are logged for debugging
|
|
||||||
```
|
|
||||||
|
|
||||||
## API Compatibility
|
|
||||||
|
|
||||||
The generic rerank API expects this response format:
|
|
||||||
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"results": [
|
|
||||||
{
|
|
||||||
"index": 0,
|
|
||||||
"relevance_score": 0.95
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"index": 2,
|
|
||||||
"relevance_score": 0.87
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
This is compatible with:
|
|
||||||
- Jina AI Rerank API
|
|
||||||
- Cohere Rerank API
|
|
||||||
- Custom APIs following the same format
|
|
||||||
30
env.example
30
env.example
|
|
@ -85,16 +85,36 @@ ENABLE_LLM_CACHE=true
|
||||||
### If reranking is enabled, the impact of chunk selection strategies will be diminished.
|
### If reranking is enabled, the impact of chunk selection strategies will be diminished.
|
||||||
# KG_CHUNK_PICK_METHOD=VECTOR
|
# KG_CHUNK_PICK_METHOD=VECTOR
|
||||||
|
|
||||||
|
#########################################################
|
||||||
### Reranking configuration
|
### Reranking configuration
|
||||||
### Reranker Set ENABLE_RERANK to true in reranking model is configed
|
### RERANK_BINDING type: cohere, jina, aliyun
|
||||||
# ENABLE_RERANK=True
|
### For rerank model deployed by vLLM use cohere binding
|
||||||
### Minimum rerank score for document chunk exclusion (set to 0.0 to keep all chunks, 0.6 or above if LLM is not strong enought)
|
#########################################################
|
||||||
|
ENABLE_RERANK=False
|
||||||
|
RERANK_BINDING=cohere
|
||||||
|
### rerank score chunk filter(set to 0.0 to keep all chunks, 0.6 or above if LLM is not strong enought)
|
||||||
# MIN_RERANK_SCORE=0.0
|
# MIN_RERANK_SCORE=0.0
|
||||||
### Rerank model configuration (required when ENABLE_RERANK=True)
|
|
||||||
# RERANK_MODEL=jina-reranker-v2-base-multilingual
|
### For local deployment
|
||||||
|
# RERANK_MODEL=BAAI/bge-reranker-v2-m3
|
||||||
|
# RERANK_BINDING_HOST=http://localhost:8000
|
||||||
|
# RERANK_BINDING_API_KEY=your_rerank_api_key_here
|
||||||
|
|
||||||
|
### Default value for Cohere AI
|
||||||
|
# RERANK_MODEL=rerank-v3.5
|
||||||
|
# RERANK_BINDING_HOST=https://ai.znipower.com:5017/rerank
|
||||||
|
# RERANK_BINDING_API_KEY=your_rerank_api_key_here
|
||||||
|
|
||||||
|
### Default value for Jina AI
|
||||||
|
# RERANK_MODELjina-reranker-v2-base-multilingual
|
||||||
# RERANK_BINDING_HOST=https://api.jina.ai/v1/rerank
|
# RERANK_BINDING_HOST=https://api.jina.ai/v1/rerank
|
||||||
# RERANK_BINDING_API_KEY=your_rerank_api_key_here
|
# RERANK_BINDING_API_KEY=your_rerank_api_key_here
|
||||||
|
|
||||||
|
### Default value for Aliyun
|
||||||
|
# RERANK_MODEL=gte-rerank-v2
|
||||||
|
# RERANK_BINDING_HOST=https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank
|
||||||
|
# RERANK_BINDING_API_KEY=your_rerank_api_key_here
|
||||||
|
|
||||||
########################################
|
########################################
|
||||||
### Document processing configuration
|
### Document processing configuration
|
||||||
########################################
|
########################################
|
||||||
|
|
|
||||||
|
|
@ -225,6 +225,19 @@ def parse_args() -> argparse.Namespace:
|
||||||
choices=["lollms", "ollama", "openai", "azure_openai", "aws_bedrock", "jina"],
|
choices=["lollms", "ollama", "openai", "azure_openai", "aws_bedrock", "jina"],
|
||||||
help="Embedding binding type (default: from env or ollama)",
|
help="Embedding binding type (default: from env or ollama)",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--rerank-binding",
|
||||||
|
type=str,
|
||||||
|
default=get_env_value("RERANK_BINDING", "cohere"),
|
||||||
|
choices=["cohere", "jina", "aliyun"],
|
||||||
|
help="Rerank binding type (default: from env or cohere)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable-rerank",
|
||||||
|
action="store_true",
|
||||||
|
default=get_env_value("ENABLE_RERANK", True, bool),
|
||||||
|
help="Enable rerank functionality (default: from env or True)",
|
||||||
|
)
|
||||||
|
|
||||||
# Conditionally add binding options defined in binding_options module
|
# Conditionally add binding options defined in binding_options module
|
||||||
# This will add command line arguments for all binding options (e.g., --ollama-embedding-num_ctx)
|
# This will add command line arguments for all binding options (e.g., --ollama-embedding-num_ctx)
|
||||||
|
|
@ -340,6 +353,7 @@ def parse_args() -> argparse.Namespace:
|
||||||
args.rerank_model = get_env_value("RERANK_MODEL", "BAAI/bge-reranker-v2-m3")
|
args.rerank_model = get_env_value("RERANK_MODEL", "BAAI/bge-reranker-v2-m3")
|
||||||
args.rerank_binding_host = get_env_value("RERANK_BINDING_HOST", None)
|
args.rerank_binding_host = get_env_value("RERANK_BINDING_HOST", None)
|
||||||
args.rerank_binding_api_key = get_env_value("RERANK_BINDING_API_KEY", None)
|
args.rerank_binding_api_key = get_env_value("RERANK_BINDING_API_KEY", None)
|
||||||
|
# Note: rerank_binding is already set by argparse, no need to override from env
|
||||||
|
|
||||||
# Min rerank score configuration
|
# Min rerank score configuration
|
||||||
args.min_rerank_score = get_env_value(
|
args.min_rerank_score = get_env_value(
|
||||||
|
|
|
||||||
|
|
@ -390,33 +390,44 @@ def create_app(args):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Configure rerank function if model and API are configured
|
# Configure rerank function based on enable_rerank parameter
|
||||||
rerank_model_func = None
|
rerank_model_func = None
|
||||||
if args.rerank_binding_api_key and args.rerank_binding_host:
|
if args.enable_rerank and args.rerank_binding:
|
||||||
from lightrag.rerank import custom_rerank
|
from lightrag.rerank import cohere_rerank, jina_rerank, ali_rerank
|
||||||
|
|
||||||
|
# Map rerank binding to corresponding function
|
||||||
|
rerank_functions = {
|
||||||
|
"cohere": cohere_rerank,
|
||||||
|
"jina": jina_rerank,
|
||||||
|
"aliyun": ali_rerank,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Select the appropriate rerank function based on binding
|
||||||
|
selected_rerank_func = rerank_functions.get(args.rerank_binding)
|
||||||
|
if not selected_rerank_func:
|
||||||
|
logger.error(f"Unsupported rerank binding: {args.rerank_binding}")
|
||||||
|
raise ValueError(f"Unsupported rerank binding: {args.rerank_binding}")
|
||||||
|
|
||||||
async def server_rerank_func(
|
async def server_rerank_func(
|
||||||
query: str, documents: list, top_n: int = None, **kwargs
|
query: str, documents: list, top_n: int = None, extra_body: dict = None
|
||||||
):
|
):
|
||||||
"""Server rerank function with configuration from environment variables"""
|
"""Server rerank function with configuration from environment variables"""
|
||||||
return await custom_rerank(
|
return await selected_rerank_func(
|
||||||
query=query,
|
query=query,
|
||||||
documents=documents,
|
documents=documents,
|
||||||
model=args.rerank_model,
|
model=args.rerank_model,
|
||||||
base_url=args.rerank_binding_host,
|
base_url=args.rerank_binding_host,
|
||||||
api_key=args.rerank_binding_api_key,
|
api_key=args.rerank_binding_api_key,
|
||||||
top_n=top_n,
|
top_n=top_n,
|
||||||
**kwargs,
|
extra_body=extra_body,
|
||||||
)
|
)
|
||||||
|
|
||||||
rerank_model_func = server_rerank_func
|
rerank_model_func = server_rerank_func
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Rerank model configured: {args.rerank_model} (can be enabled per query)"
|
f"Rerank enabled: {args.rerank_model} using {args.rerank_binding} provider"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.info(
|
logger.info("Rerank disabled")
|
||||||
"Rerank model not configured. Set RERANK_BINDING_API_KEY and RERANK_BINDING_HOST to enable reranking."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create ollama_server_infos from command line arguments
|
# Create ollama_server_infos from command line arguments
|
||||||
from lightrag.api.config import OllamaServerInfos
|
from lightrag.api.config import OllamaServerInfos
|
||||||
|
|
@ -622,13 +633,15 @@ def create_app(args):
|
||||||
"enable_llm_cache": args.enable_llm_cache,
|
"enable_llm_cache": args.enable_llm_cache,
|
||||||
"workspace": args.workspace,
|
"workspace": args.workspace,
|
||||||
"max_graph_nodes": args.max_graph_nodes,
|
"max_graph_nodes": args.max_graph_nodes,
|
||||||
# Rerank configuration (based on whether rerank model is configured)
|
# Rerank configuration
|
||||||
"enable_rerank": rerank_model_func is not None,
|
"enable_rerank": args.enable_rerank,
|
||||||
"rerank_model": args.rerank_model
|
"rerank_configured": rerank_model_func is not None,
|
||||||
if rerank_model_func is not None
|
"rerank_binding": args.rerank_binding
|
||||||
|
if args.enable_rerank
|
||||||
else None,
|
else None,
|
||||||
|
"rerank_model": args.rerank_model if args.enable_rerank else None,
|
||||||
"rerank_binding_host": args.rerank_binding_host
|
"rerank_binding_host": args.rerank_binding_host
|
||||||
if rerank_model_func is not None
|
if args.enable_rerank
|
||||||
else None,
|
else None,
|
||||||
# Environment variable status (requested configuration)
|
# Environment variable status (requested configuration)
|
||||||
"summary_language": args.summary_language,
|
"summary_language": args.summary_language,
|
||||||
|
|
|
||||||
|
|
@ -2,270 +2,194 @@ from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from typing import Callable, Any, List, Dict, Optional
|
from typing import Any, List, Dict, Optional
|
||||||
from pydantic import BaseModel, Field
|
from tenacity import (
|
||||||
|
retry,
|
||||||
|
stop_after_attempt,
|
||||||
|
wait_exponential,
|
||||||
|
retry_if_exception_type,
|
||||||
|
)
|
||||||
from .utils import logger
|
from .utils import logger
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
class RerankModel(BaseModel):
|
# use the .env that is inside the current folder
|
||||||
"""
|
# allows to use different .env file for each lightrag instance
|
||||||
Wrapper for rerank functions that can be used with LightRAG.
|
# the OS environment variables take precedence over the .env file
|
||||||
|
load_dotenv(dotenv_path=".env", override=False)
|
||||||
Example usage:
|
|
||||||
```python
|
|
||||||
from lightrag.rerank import RerankModel, jina_rerank
|
|
||||||
|
|
||||||
# Create rerank model
|
|
||||||
rerank_model = RerankModel(
|
|
||||||
rerank_func=jina_rerank,
|
|
||||||
kwargs={
|
|
||||||
"model": "BAAI/bge-reranker-v2-m3",
|
|
||||||
"api_key": "your_api_key_here",
|
|
||||||
"base_url": "https://api.jina.ai/v1/rerank"
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Use in LightRAG
|
|
||||||
rag = LightRAG(
|
|
||||||
rerank_model_func=rerank_model.rerank,
|
|
||||||
# ... other configurations
|
|
||||||
)
|
|
||||||
|
|
||||||
# Query with rerank enabled (default)
|
|
||||||
result = await rag.aquery(
|
|
||||||
"your query",
|
|
||||||
param=QueryParam(enable_rerank=True)
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
Or define a custom function directly:
|
|
||||||
```python
|
|
||||||
async def my_rerank_func(query: str, documents: list, top_n: int = None, **kwargs):
|
|
||||||
return await jina_rerank(
|
|
||||||
query=query,
|
|
||||||
documents=documents,
|
|
||||||
model="BAAI/bge-reranker-v2-m3",
|
|
||||||
api_key="your_api_key_here",
|
|
||||||
top_n=top_n or 10,
|
|
||||||
**kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
rag = LightRAG(
|
|
||||||
rerank_model_func=my_rerank_func,
|
|
||||||
# ... other configurations
|
|
||||||
)
|
|
||||||
|
|
||||||
# Control rerank per query
|
|
||||||
result = await rag.aquery(
|
|
||||||
"your query",
|
|
||||||
param=QueryParam(enable_rerank=True) # Enable rerank for this query
|
|
||||||
)
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
|
|
||||||
rerank_func: Callable[[Any], List[Dict]]
|
|
||||||
kwargs: Dict[str, Any] = Field(default_factory=dict)
|
|
||||||
|
|
||||||
async def rerank(
|
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
documents: List[Dict[str, Any]],
|
|
||||||
top_n: Optional[int] = None,
|
|
||||||
**extra_kwargs,
|
|
||||||
) -> List[Dict[str, Any]]:
|
|
||||||
"""Rerank documents using the configured model function."""
|
|
||||||
# Merge extra kwargs with model kwargs
|
|
||||||
kwargs = {**self.kwargs, **extra_kwargs}
|
|
||||||
return await self.rerank_func(
|
|
||||||
query=query, documents=documents, top_n=top_n, **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class MultiRerankModel(BaseModel):
|
|
||||||
"""Multiple rerank models for different modes/scenarios."""
|
|
||||||
|
|
||||||
# Primary rerank model (used if mode-specific models are not defined)
|
|
||||||
rerank_model: Optional[RerankModel] = None
|
|
||||||
|
|
||||||
# Mode-specific rerank models
|
|
||||||
entity_rerank_model: Optional[RerankModel] = None
|
|
||||||
relation_rerank_model: Optional[RerankModel] = None
|
|
||||||
chunk_rerank_model: Optional[RerankModel] = None
|
|
||||||
|
|
||||||
async def rerank(
|
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
documents: List[Dict[str, Any]],
|
|
||||||
mode: str = "default",
|
|
||||||
top_n: Optional[int] = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> List[Dict[str, Any]]:
|
|
||||||
"""Rerank using the appropriate model based on mode."""
|
|
||||||
|
|
||||||
# Select model based on mode
|
|
||||||
if mode == "entity" and self.entity_rerank_model:
|
|
||||||
model = self.entity_rerank_model
|
|
||||||
elif mode == "relation" and self.relation_rerank_model:
|
|
||||||
model = self.relation_rerank_model
|
|
||||||
elif mode == "chunk" and self.chunk_rerank_model:
|
|
||||||
model = self.chunk_rerank_model
|
|
||||||
elif self.rerank_model:
|
|
||||||
model = self.rerank_model
|
|
||||||
else:
|
|
||||||
logger.warning(f"No rerank model available for mode: {mode}")
|
|
||||||
return documents
|
|
||||||
|
|
||||||
return await model.rerank(query, documents, top_n, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
|
@retry(
|
||||||
|
stop=stop_after_attempt(3),
|
||||||
|
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||||
|
retry=(
|
||||||
|
retry_if_exception_type(aiohttp.ClientError)
|
||||||
|
| retry_if_exception_type(aiohttp.ClientResponseError)
|
||||||
|
),
|
||||||
|
)
|
||||||
async def generic_rerank_api(
|
async def generic_rerank_api(
|
||||||
query: str,
|
query: str,
|
||||||
documents: List[Dict[str, Any]],
|
documents: List[str],
|
||||||
model: str,
|
model: str,
|
||||||
base_url: str,
|
base_url: str,
|
||||||
api_key: str,
|
api_key: str,
|
||||||
top_n: Optional[int] = None,
|
top_n: Optional[int] = None,
|
||||||
**kwargs,
|
return_documents: Optional[bool] = None,
|
||||||
|
extra_body: Optional[Dict[str, Any]] = None,
|
||||||
|
response_format: str = "standard", # "standard" (Jina/Cohere) or "aliyun"
|
||||||
|
request_format: str = "standard", # "standard" (Jina/Cohere) or "aliyun"
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Generic rerank function that works with Jina/Cohere compatible APIs.
|
Generic rerank API call for Jina/Cohere/Aliyun models.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query: The search query
|
query: The search query
|
||||||
documents: List of documents to rerank
|
documents: List of strings to rerank
|
||||||
model: Model identifier
|
model: Model name to use
|
||||||
base_url: API endpoint URL
|
base_url: API endpoint URL
|
||||||
api_key: API authentication key
|
api_key: API key for authentication
|
||||||
top_n: Number of top results to return
|
top_n: Number of top results to return
|
||||||
**kwargs: Additional API-specific parameters
|
return_documents: Whether to return document text (Jina only)
|
||||||
|
extra_body: Additional body parameters
|
||||||
|
response_format: Response format type ("standard" for Jina/Cohere, "aliyun" for Aliyun)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of reranked documents with relevance scores
|
List of dictionary of ["index": int, "relevance_score": float]
|
||||||
"""
|
"""
|
||||||
if not api_key:
|
if not api_key:
|
||||||
logger.warning("No API key provided for rerank service")
|
raise ValueError("API key is required")
|
||||||
return documents
|
|
||||||
|
|
||||||
if not documents:
|
headers = {
|
||||||
return documents
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {api_key}",
|
||||||
|
}
|
||||||
|
|
||||||
# Prepare documents for reranking - handle both text and dict formats
|
# Build request payload based on request format
|
||||||
prepared_docs = []
|
if request_format == "aliyun":
|
||||||
for doc in documents:
|
# Aliyun format: nested input/parameters structure
|
||||||
if isinstance(doc, dict):
|
payload = {
|
||||||
# Use 'content' field if available, otherwise use 'text' or convert to string
|
"model": model,
|
||||||
text = doc.get("content") or doc.get("text") or str(doc)
|
"input": {
|
||||||
else:
|
"query": query,
|
||||||
text = str(doc)
|
"documents": documents,
|
||||||
prepared_docs.append(text)
|
},
|
||||||
|
"parameters": {},
|
||||||
|
}
|
||||||
|
|
||||||
# Prepare request
|
# Add optional parameters to parameters object
|
||||||
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
|
if top_n is not None:
|
||||||
|
payload["parameters"]["top_n"] = top_n
|
||||||
|
|
||||||
data = {"model": model, "query": query, "documents": prepared_docs, **kwargs}
|
if return_documents is not None:
|
||||||
|
payload["parameters"]["return_documents"] = return_documents
|
||||||
|
|
||||||
if top_n is not None:
|
# Add extra parameters to parameters object
|
||||||
data["top_n"] = min(top_n, len(prepared_docs))
|
if extra_body:
|
||||||
|
payload["parameters"].update(extra_body)
|
||||||
|
else:
|
||||||
|
# Standard format for Jina/Cohere
|
||||||
|
payload = {
|
||||||
|
"model": model,
|
||||||
|
"query": query,
|
||||||
|
"documents": documents,
|
||||||
|
}
|
||||||
|
|
||||||
try:
|
# Add optional parameters
|
||||||
async with aiohttp.ClientSession() as session:
|
if top_n is not None:
|
||||||
async with session.post(base_url, headers=headers, json=data) as response:
|
payload["top_n"] = top_n
|
||||||
if response.status != 200:
|
|
||||||
error_text = await response.text()
|
|
||||||
logger.error(f"Rerank API error {response.status}: {error_text}")
|
|
||||||
return documents
|
|
||||||
|
|
||||||
result = await response.json()
|
# Only Jina API supports return_documents parameter
|
||||||
|
if return_documents is not None:
|
||||||
|
payload["return_documents"] = return_documents
|
||||||
|
|
||||||
# Extract reranked results
|
# Add extra parameters
|
||||||
if "results" in result:
|
if extra_body:
|
||||||
# Standard format: results contain index and relevance_score
|
payload.update(extra_body)
|
||||||
reranked_docs = []
|
|
||||||
for item in result["results"]:
|
|
||||||
if "index" in item:
|
|
||||||
doc_idx = item["index"]
|
|
||||||
if 0 <= doc_idx < len(documents):
|
|
||||||
reranked_doc = documents[doc_idx].copy()
|
|
||||||
if "relevance_score" in item:
|
|
||||||
reranked_doc["rerank_score"] = item[
|
|
||||||
"relevance_score"
|
|
||||||
]
|
|
||||||
reranked_docs.append(reranked_doc)
|
|
||||||
return reranked_docs
|
|
||||||
else:
|
|
||||||
logger.warning("Unexpected rerank API response format")
|
|
||||||
return documents
|
|
||||||
|
|
||||||
except Exception as e:
|
logger.debug(
|
||||||
logger.error(f"Error during reranking: {e}")
|
f"Rerank request: {len(documents)} documents, model: {model}, format: {response_format}"
|
||||||
return documents
|
|
||||||
|
|
||||||
|
|
||||||
async def jina_rerank(
|
|
||||||
query: str,
|
|
||||||
documents: List[Dict[str, Any]],
|
|
||||||
model: str = "BAAI/bge-reranker-v2-m3",
|
|
||||||
top_n: Optional[int] = None,
|
|
||||||
base_url: str = "https://api.jina.ai/v1/rerank",
|
|
||||||
api_key: Optional[str] = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> List[Dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
Rerank documents using Jina AI API.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: The search query
|
|
||||||
documents: List of documents to rerank
|
|
||||||
model: Jina rerank model name
|
|
||||||
top_n: Number of top results to return
|
|
||||||
base_url: Jina API endpoint
|
|
||||||
api_key: Jina API key
|
|
||||||
**kwargs: Additional parameters
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of reranked documents with relevance scores
|
|
||||||
"""
|
|
||||||
if api_key is None:
|
|
||||||
api_key = os.getenv("JINA_API_KEY") or os.getenv("RERANK_API_KEY")
|
|
||||||
|
|
||||||
return await generic_rerank_api(
|
|
||||||
query=query,
|
|
||||||
documents=documents,
|
|
||||||
model=model,
|
|
||||||
base_url=base_url,
|
|
||||||
api_key=api_key,
|
|
||||||
top_n=top_n,
|
|
||||||
**kwargs,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(base_url, headers=headers, json=payload) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
error_text = await response.text()
|
||||||
|
content_type = response.headers.get("content-type", "").lower()
|
||||||
|
is_html_error = (
|
||||||
|
error_text.strip().startswith("<!DOCTYPE html>")
|
||||||
|
or "text/html" in content_type
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_html_error:
|
||||||
|
if response.status == 502:
|
||||||
|
clean_error = "Bad Gateway (502) - Rerank service temporarily unavailable. Please try again in a few minutes."
|
||||||
|
elif response.status == 503:
|
||||||
|
clean_error = "Service Unavailable (503) - Rerank service is temporarily overloaded. Please try again later."
|
||||||
|
elif response.status == 504:
|
||||||
|
clean_error = "Gateway Timeout (504) - Rerank service request timed out. Please try again."
|
||||||
|
else:
|
||||||
|
clean_error = f"HTTP {response.status} - Rerank service error. Please try again later."
|
||||||
|
else:
|
||||||
|
clean_error = error_text
|
||||||
|
|
||||||
|
logger.error(f"Rerank API error {response.status}: {clean_error}")
|
||||||
|
raise aiohttp.ClientResponseError(
|
||||||
|
request_info=response.request_info,
|
||||||
|
history=response.history,
|
||||||
|
status=response.status,
|
||||||
|
message=f"Rerank API error: {clean_error}",
|
||||||
|
)
|
||||||
|
|
||||||
|
response_json = await response.json()
|
||||||
|
|
||||||
|
# Handle different response formats
|
||||||
|
if response_format == "aliyun":
|
||||||
|
# Aliyun format: {"output": {"results": [...]}}
|
||||||
|
output = response_json.get("output", {})
|
||||||
|
results = output.get("results", [])
|
||||||
|
elif response_format == "standard":
|
||||||
|
# Standard format: {"results": [...]}
|
||||||
|
results = response_json.get("results", [])
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported response format: {response_format}")
|
||||||
|
|
||||||
|
if not results:
|
||||||
|
logger.warning("Rerank API returned empty results")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Standardize return format
|
||||||
|
return [
|
||||||
|
{"index": result["index"], "relevance_score": result["relevance_score"]}
|
||||||
|
for result in results
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
async def cohere_rerank(
|
async def cohere_rerank(
|
||||||
query: str,
|
query: str,
|
||||||
documents: List[Dict[str, Any]],
|
documents: List[str],
|
||||||
model: str = "rerank-english-v2.0",
|
|
||||||
top_n: Optional[int] = None,
|
top_n: Optional[int] = None,
|
||||||
base_url: str = "https://api.cohere.ai/v1/rerank",
|
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
**kwargs,
|
model: str = "rerank-v3.5",
|
||||||
|
base_url: str = "https://ai.znipower.com:5017/rerank",
|
||||||
|
extra_body: Optional[Dict[str, Any]] = None,
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Rerank documents using Cohere API.
|
Rerank documents using Cohere API.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query: The search query
|
query: The search query
|
||||||
documents: List of documents to rerank
|
documents: List of strings to rerank
|
||||||
model: Cohere rerank model name
|
|
||||||
top_n: Number of top results to return
|
top_n: Number of top results to return
|
||||||
base_url: Cohere API endpoint
|
api_key: API key
|
||||||
api_key: Cohere API key
|
model: rerank model name
|
||||||
**kwargs: Additional parameters
|
base_url: API endpoint
|
||||||
|
extra_body: Additional body for http request(reserved for extra params)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of reranked documents with relevance scores
|
List of dictionary of ["index": int, "relevance_score": float]
|
||||||
"""
|
"""
|
||||||
if api_key is None:
|
if api_key is None:
|
||||||
api_key = os.getenv("COHERE_API_KEY") or os.getenv("RERANK_API_KEY")
|
api_key = os.getenv("COHERE_API_KEY") or os.getenv("RERANK_BINDING_API_KEY")
|
||||||
|
|
||||||
return await generic_rerank_api(
|
return await generic_rerank_api(
|
||||||
query=query,
|
query=query,
|
||||||
|
|
@ -274,24 +198,39 @@ async def cohere_rerank(
|
||||||
base_url=base_url,
|
base_url=base_url,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
top_n=top_n,
|
top_n=top_n,
|
||||||
**kwargs,
|
return_documents=None, # Cohere doesn't support this parameter
|
||||||
|
extra_body=extra_body,
|
||||||
|
response_format="standard",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Convenience function for custom API endpoints
|
async def jina_rerank(
|
||||||
async def custom_rerank(
|
|
||||||
query: str,
|
query: str,
|
||||||
documents: List[Dict[str, Any]],
|
documents: List[str],
|
||||||
model: str,
|
|
||||||
base_url: str,
|
|
||||||
api_key: str,
|
|
||||||
top_n: Optional[int] = None,
|
top_n: Optional[int] = None,
|
||||||
**kwargs,
|
api_key: Optional[str] = None,
|
||||||
|
model: str = "jina-reranker-v2-base-multilingual",
|
||||||
|
base_url: str = "https://api.jina.ai/v1/rerank",
|
||||||
|
extra_body: Optional[Dict[str, Any]] = None,
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Rerank documents using a custom API endpoint.
|
Rerank documents using Jina AI API.
|
||||||
This is useful for self-hosted or custom rerank services.
|
|
||||||
|
Args:
|
||||||
|
query: The search query
|
||||||
|
documents: List of strings to rerank
|
||||||
|
top_n: Number of top results to return
|
||||||
|
api_key: API key
|
||||||
|
model: rerank model name
|
||||||
|
base_url: API endpoint
|
||||||
|
extra_body: Additional body for http request(reserved for extra params)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of dictionary of ["index": int, "relevance_score": float]
|
||||||
"""
|
"""
|
||||||
|
if api_key is None:
|
||||||
|
api_key = os.getenv("JINA_API_KEY") or os.getenv("RERANK_BINDING_API_KEY")
|
||||||
|
|
||||||
return await generic_rerank_api(
|
return await generic_rerank_api(
|
||||||
query=query,
|
query=query,
|
||||||
documents=documents,
|
documents=documents,
|
||||||
|
|
@ -299,26 +238,112 @@ async def custom_rerank(
|
||||||
base_url=base_url,
|
base_url=base_url,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
top_n=top_n,
|
top_n=top_n,
|
||||||
**kwargs,
|
return_documents=False,
|
||||||
|
extra_body=extra_body,
|
||||||
|
response_format="standard",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def ali_rerank(
|
||||||
|
query: str,
|
||||||
|
documents: List[str],
|
||||||
|
top_n: Optional[int] = None,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
model: str = "gte-rerank-v2",
|
||||||
|
base_url: str = "https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank",
|
||||||
|
extra_body: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Rerank documents using Aliyun DashScope API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: The search query
|
||||||
|
documents: List of strings to rerank
|
||||||
|
top_n: Number of top results to return
|
||||||
|
api_key: Aliyun API key
|
||||||
|
model: rerank model name
|
||||||
|
base_url: API endpoint
|
||||||
|
extra_body: Additional body for http request(reserved for extra params)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of dictionary of ["index": int, "relevance_score": float]
|
||||||
|
"""
|
||||||
|
if api_key is None:
|
||||||
|
api_key = os.getenv("DASHSCOPE_API_KEY") or os.getenv("RERANK_BINDING_API_KEY")
|
||||||
|
|
||||||
|
return await generic_rerank_api(
|
||||||
|
query=query,
|
||||||
|
documents=documents,
|
||||||
|
model=model,
|
||||||
|
base_url=base_url,
|
||||||
|
api_key=api_key,
|
||||||
|
top_n=top_n,
|
||||||
|
return_documents=False, # Aliyun doesn't need this parameter
|
||||||
|
extra_body=extra_body,
|
||||||
|
response_format="aliyun",
|
||||||
|
request_format="aliyun",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
"""Please run this test as a module:
|
||||||
|
python -m lightrag.rerank
|
||||||
|
"""
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
# Example usage
|
# Example usage - documents should be strings, not dictionaries
|
||||||
docs = [
|
docs = [
|
||||||
{"content": "The capital of France is Paris."},
|
"The capital of France is Paris.",
|
||||||
{"content": "Tokyo is the capital of Japan."},
|
"Tokyo is the capital of Japan.",
|
||||||
{"content": "London is the capital of England."},
|
"London is the capital of England.",
|
||||||
]
|
]
|
||||||
|
|
||||||
query = "What is the capital of France?"
|
query = "What is the capital of France?"
|
||||||
|
|
||||||
result = await jina_rerank(
|
# Test Jina rerank
|
||||||
query=query, documents=docs, top_n=2, api_key="your-api-key-here"
|
try:
|
||||||
)
|
print("=== Jina Rerank ===")
|
||||||
print(result)
|
result = await jina_rerank(
|
||||||
|
query=query,
|
||||||
|
documents=docs,
|
||||||
|
top_n=2,
|
||||||
|
)
|
||||||
|
print("Results:")
|
||||||
|
for item in result:
|
||||||
|
print(f"Index: {item['index']}, Score: {item['relevance_score']:.4f}")
|
||||||
|
print(f"Document: {docs[item['index']]}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Jina Error: {e}")
|
||||||
|
|
||||||
|
# Test Cohere rerank
|
||||||
|
try:
|
||||||
|
print("\n=== Cohere Rerank ===")
|
||||||
|
result = await cohere_rerank(
|
||||||
|
query=query,
|
||||||
|
documents=docs,
|
||||||
|
top_n=2,
|
||||||
|
)
|
||||||
|
print("Results:")
|
||||||
|
for item in result:
|
||||||
|
print(f"Index: {item['index']}, Score: {item['relevance_score']:.4f}")
|
||||||
|
print(f"Document: {docs[item['index']]}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Cohere Error: {e}")
|
||||||
|
|
||||||
|
# Test Aliyun rerank
|
||||||
|
try:
|
||||||
|
print("\n=== Aliyun Rerank ===")
|
||||||
|
result = await ali_rerank(
|
||||||
|
query=query,
|
||||||
|
documents=docs,
|
||||||
|
top_n=2,
|
||||||
|
)
|
||||||
|
print("Results:")
|
||||||
|
for item in result:
|
||||||
|
print(f"Index: {item['index']}, Score: {item['relevance_score']:.4f}")
|
||||||
|
print(f"Document: {docs[item['index']]}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Aliyun Error: {e}")
|
||||||
|
|
||||||
asyncio.run(main())
|
asyncio.run(main())
|
||||||
|
|
|
||||||
|
|
@ -1978,17 +1978,50 @@ async def apply_rerank_if_enabled(
|
||||||
return retrieved_docs
|
return retrieved_docs
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Apply reranking - let rerank_model_func handle top_k internally
|
# Extract document content for reranking
|
||||||
reranked_docs = await rerank_func(
|
document_texts = []
|
||||||
|
for doc in retrieved_docs:
|
||||||
|
# Try multiple possible content fields
|
||||||
|
content = (
|
||||||
|
doc.get("content")
|
||||||
|
or doc.get("text")
|
||||||
|
or doc.get("chunk_content")
|
||||||
|
or doc.get("document")
|
||||||
|
or str(doc)
|
||||||
|
)
|
||||||
|
document_texts.append(content)
|
||||||
|
|
||||||
|
# Call the new rerank function that returns index-based results
|
||||||
|
rerank_results = await rerank_func(
|
||||||
query=query,
|
query=query,
|
||||||
documents=retrieved_docs,
|
documents=document_texts,
|
||||||
top_n=top_n,
|
top_n=top_n or len(retrieved_docs),
|
||||||
)
|
)
|
||||||
if reranked_docs and len(reranked_docs) > 0:
|
|
||||||
if len(reranked_docs) > top_n:
|
# Process rerank results based on return format
|
||||||
reranked_docs = reranked_docs[:top_n]
|
if rerank_results and len(rerank_results) > 0:
|
||||||
logger.info(f"Successfully reranked: {len(retrieved_docs)} chunks")
|
# Check if results are in the new index-based format
|
||||||
return reranked_docs
|
if isinstance(rerank_results[0], dict) and "index" in rerank_results[0]:
|
||||||
|
# New format: [{"index": 0, "relevance_score": 0.85}, ...]
|
||||||
|
reranked_docs = []
|
||||||
|
for result in rerank_results:
|
||||||
|
index = result["index"]
|
||||||
|
relevance_score = result["relevance_score"]
|
||||||
|
|
||||||
|
# Get original document and add rerank score
|
||||||
|
if 0 <= index < len(retrieved_docs):
|
||||||
|
doc = retrieved_docs[index].copy()
|
||||||
|
doc["rerank_score"] = relevance_score
|
||||||
|
reranked_docs.append(doc)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Successfully reranked: {len(reranked_docs)} chunks from {len(retrieved_docs)} original chunks"
|
||||||
|
)
|
||||||
|
return reranked_docs
|
||||||
|
else:
|
||||||
|
# Legacy format: assume it's already reranked documents
|
||||||
|
logger.info(f"Using legacy rerank format: {len(rerank_results)} chunks")
|
||||||
|
return rerank_results[:top_n] if top_n else rerank_results
|
||||||
else:
|
else:
|
||||||
logger.warning("Rerank returned empty results, using original chunks")
|
logger.warning("Rerank returned empty results, using original chunks")
|
||||||
return retrieved_docs
|
return retrieved_docs
|
||||||
|
|
@ -2027,13 +2060,6 @@ async def process_chunks_unified(
|
||||||
|
|
||||||
# 1. Apply reranking if enabled and query is provided
|
# 1. Apply reranking if enabled and query is provided
|
||||||
if query_param.enable_rerank and query and unique_chunks:
|
if query_param.enable_rerank and query and unique_chunks:
|
||||||
# 保存 chunk_id 字段,因为 rerank 可能会丢失这个字段
|
|
||||||
chunk_ids = {}
|
|
||||||
for chunk in unique_chunks:
|
|
||||||
chunk_id = chunk.get("chunk_id")
|
|
||||||
if chunk_id:
|
|
||||||
chunk_ids[id(chunk)] = chunk_id
|
|
||||||
|
|
||||||
rerank_top_k = query_param.chunk_top_k or len(unique_chunks)
|
rerank_top_k = query_param.chunk_top_k or len(unique_chunks)
|
||||||
unique_chunks = await apply_rerank_if_enabled(
|
unique_chunks = await apply_rerank_if_enabled(
|
||||||
query=query,
|
query=query,
|
||||||
|
|
@ -2043,11 +2069,6 @@ async def process_chunks_unified(
|
||||||
top_n=rerank_top_k,
|
top_n=rerank_top_k,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 恢复 chunk_id 字段
|
|
||||||
for chunk in unique_chunks:
|
|
||||||
if id(chunk) in chunk_ids:
|
|
||||||
chunk["chunk_id"] = chunk_ids[id(chunk)]
|
|
||||||
|
|
||||||
# 2. Filter by minimum rerank score if reranking is enabled
|
# 2. Filter by minimum rerank score if reranking is enabled
|
||||||
if query_param.enable_rerank and unique_chunks:
|
if query_param.enable_rerank and unique_chunks:
|
||||||
min_rerank_score = global_config.get("min_rerank_score", 0.5)
|
min_rerank_score = global_config.get("min_rerank_score", 0.5)
|
||||||
|
|
@ -2095,13 +2116,6 @@ async def process_chunks_unified(
|
||||||
|
|
||||||
original_count = len(unique_chunks)
|
original_count = len(unique_chunks)
|
||||||
|
|
||||||
# Keep chunk_id field, cause truncate_list_by_token_size will lose it
|
|
||||||
chunk_ids_map = {}
|
|
||||||
for i, chunk in enumerate(unique_chunks):
|
|
||||||
chunk_id = chunk.get("chunk_id")
|
|
||||||
if chunk_id:
|
|
||||||
chunk_ids_map[i] = chunk_id
|
|
||||||
|
|
||||||
unique_chunks = truncate_list_by_token_size(
|
unique_chunks = truncate_list_by_token_size(
|
||||||
unique_chunks,
|
unique_chunks,
|
||||||
key=lambda x: json.dumps(x, ensure_ascii=False),
|
key=lambda x: json.dumps(x, ensure_ascii=False),
|
||||||
|
|
@ -2109,11 +2123,6 @@ async def process_chunks_unified(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
)
|
)
|
||||||
|
|
||||||
# restore chunk_id feiled
|
|
||||||
for i, chunk in enumerate(unique_chunks):
|
|
||||||
if i in chunk_ids_map:
|
|
||||||
chunk["chunk_id"] = chunk_ids_map[i]
|
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Token truncation: {len(unique_chunks)} chunks from {original_count} "
|
f"Token truncation: {len(unique_chunks)} chunks from {original_count} "
|
||||||
f"(chunk available tokens: {chunk_token_limit}, source: {source_type})"
|
f"(chunk available tokens: {chunk_token_limit}, source: {source_type})"
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue