Update rerank_example & readme
This commit is contained in:
parent
7c882313bb
commit
9a9f0f2463
3 changed files with 129 additions and 89 deletions
|
|
@ -1,36 +1,24 @@
|
|||
# Rerank Integration in LightRAG
|
||||
# Rerank Integration Guide
|
||||
|
||||
This document explains how to configure and use the rerank functionality in LightRAG to improve retrieval quality.
|
||||
LightRAG supports reranking functionality to improve retrieval quality by re-ordering documents based on their relevance to the query. Reranking is now controlled per query via the `enable_rerank` parameter (default: True).
|
||||
|
||||
## Overview
|
||||
|
||||
Reranking is an optional feature that improves the quality of retrieved documents by re-ordering them based on their relevance to the query. This is particularly useful when you want higher precision in document retrieval across all query modes (naive, local, global, hybrid, mix).
|
||||
|
||||
## Architecture
|
||||
|
||||
The rerank integration follows a simplified design pattern:
|
||||
|
||||
- **Single Function Configuration**: All rerank settings (model, API keys, top_k, etc.) are contained within the rerank function
|
||||
- **Async Processing**: Non-blocking rerank operations
|
||||
- **Error Handling**: Graceful fallback to original results
|
||||
- **Optional Feature**: Can be enabled/disabled via configuration
|
||||
- **Code Reuse**: Single generic implementation for Jina/Cohere compatible APIs
|
||||
|
||||
## Configuration
|
||||
## Quick Start
|
||||
|
||||
### Environment Variables
|
||||
|
||||
Set this variable in your `.env` file or environment:
|
||||
Set these variables in your `.env` file or environment for rerank model configuration:
|
||||
|
||||
```bash
|
||||
# Enable/disable reranking
|
||||
ENABLE_RERANK=True
|
||||
# Rerank model configuration (required when enable_rerank=True in queries)
|
||||
RERANK_MODEL=BAAI/bge-reranker-v2-m3
|
||||
RERANK_BINDING_HOST=https://api.your-provider.com/v1/rerank
|
||||
RERANK_BINDING_API_KEY=your_api_key_here
|
||||
```
|
||||
|
||||
### Programmatic Configuration
|
||||
|
||||
```python
|
||||
from lightrag import LightRAG
|
||||
from lightrag import LightRAG, QueryParam
|
||||
from lightrag.rerank import custom_rerank, RerankModel
|
||||
|
||||
# Method 1: Using a custom rerank function with all settings included
|
||||
|
|
@ -49,8 +37,19 @@ rag = LightRAG(
|
|||
working_dir="./rag_storage",
|
||||
llm_model_func=your_llm_func,
|
||||
embedding_func=your_embedding_func,
|
||||
enable_rerank=True,
|
||||
rerank_model_func=my_rerank_func,
|
||||
rerank_model_func=my_rerank_func, # Configure rerank function
|
||||
)
|
||||
|
||||
# Query with rerank enabled (default)
|
||||
result = await rag.aquery(
|
||||
"your query",
|
||||
param=QueryParam(enable_rerank=True) # Control rerank per query
|
||||
)
|
||||
|
||||
# Query with rerank disabled
|
||||
result = await rag.aquery(
|
||||
"your query",
|
||||
param=QueryParam(enable_rerank=False)
|
||||
)
|
||||
|
||||
# Method 2: Using RerankModel wrapper
|
||||
|
|
@ -67,9 +66,17 @@ rag = LightRAG(
|
|||
working_dir="./rag_storage",
|
||||
llm_model_func=your_llm_func,
|
||||
embedding_func=your_embedding_func,
|
||||
enable_rerank=True,
|
||||
rerank_model_func=rerank_model.rerank,
|
||||
)
|
||||
|
||||
# Control rerank per query
|
||||
result = await rag.aquery(
|
||||
"your query",
|
||||
param=QueryParam(
|
||||
enable_rerank=True, # Enable rerank for this query
|
||||
chunk_top_k=5 # Number of chunks to keep after reranking
|
||||
)
|
||||
)
|
||||
```
|
||||
|
||||
## Supported Providers
|
||||
|
|
@ -164,7 +171,6 @@ async def main():
|
|||
working_dir="./rag_storage",
|
||||
llm_model_func=gpt_4o_mini_complete,
|
||||
embedding_func=openai_embedding,
|
||||
enable_rerank=True,
|
||||
rerank_model_func=my_rerank_func,
|
||||
)
|
||||
|
||||
|
|
@ -180,7 +186,7 @@ async def main():
|
|||
# Query with rerank (automatically applied)
|
||||
result = await rag.aquery(
|
||||
"Your question here",
|
||||
param=QueryParam(mode="hybrid", top_k=5) # This top_k is passed to rerank function
|
||||
param=QueryParam(enable_rerank=True) # This top_k is passed to rerank function
|
||||
)
|
||||
|
||||
print(result)
|
||||
|
|
|
|||
|
|
@ -9,7 +9,11 @@ Configuration Required:
|
|||
2. Set your embedding API key and base URL in embedding_func()
|
||||
3. Set your rerank API key and base URL in the rerank configuration
|
||||
4. Or use environment variables (.env file):
|
||||
- ENABLE_RERANK=True
|
||||
- RERANK_MODEL=your_rerank_model
|
||||
- RERANK_BINDING_HOST=your_rerank_endpoint
|
||||
- RERANK_BINDING_API_KEY=your_rerank_api_key
|
||||
|
||||
Note: Rerank is now controlled per query via the 'enable_rerank' parameter (default: True)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
|
@ -83,8 +87,7 @@ async def create_rag_with_rerank():
|
|||
max_token_size=8192,
|
||||
func=embedding_func,
|
||||
),
|
||||
# Simplified Rerank Configuration
|
||||
enable_rerank=True,
|
||||
# Rerank Configuration - provide the rerank function
|
||||
rerank_model_func=my_rerank_func,
|
||||
)
|
||||
|
||||
|
|
@ -120,7 +123,6 @@ async def create_rag_with_rerank_model():
|
|||
max_token_size=8192,
|
||||
func=embedding_func,
|
||||
),
|
||||
enable_rerank=True,
|
||||
rerank_model_func=rerank_model.rerank,
|
||||
)
|
||||
|
||||
|
|
@ -130,9 +132,9 @@ async def create_rag_with_rerank_model():
|
|||
return rag
|
||||
|
||||
|
||||
async def test_rerank_with_different_topk():
|
||||
async def test_rerank_with_different_settings():
|
||||
"""
|
||||
Test rerank functionality with different top_k settings
|
||||
Test rerank functionality with different enable_rerank settings
|
||||
"""
|
||||
print("🚀 Setting up LightRAG with Rerank functionality...")
|
||||
|
||||
|
|
@ -154,16 +156,41 @@ async def test_rerank_with_different_topk():
|
|||
print(f"\n🔍 Testing query: '{query}'")
|
||||
print("=" * 80)
|
||||
|
||||
# Test different top_k values to show parameter priority
|
||||
top_k_values = [2, 5, 10]
|
||||
# Test with rerank enabled (default)
|
||||
print("\n📊 Testing with enable_rerank=True (default):")
|
||||
result_with_rerank = await rag.aquery(
|
||||
query,
|
||||
param=QueryParam(
|
||||
mode="naive",
|
||||
top_k=10,
|
||||
chunk_top_k=5,
|
||||
enable_rerank=True, # Explicitly enable rerank
|
||||
),
|
||||
)
|
||||
print(f" Result length: {len(result_with_rerank)} characters")
|
||||
print(f" Preview: {result_with_rerank[:100]}...")
|
||||
|
||||
for top_k in top_k_values:
|
||||
print(f"\n📊 Testing with QueryParam(top_k={top_k}):")
|
||||
# Test with rerank disabled
|
||||
print("\n📊 Testing with enable_rerank=False:")
|
||||
result_without_rerank = await rag.aquery(
|
||||
query,
|
||||
param=QueryParam(
|
||||
mode="naive",
|
||||
top_k=10,
|
||||
chunk_top_k=5,
|
||||
enable_rerank=False, # Disable rerank
|
||||
),
|
||||
)
|
||||
print(f" Result length: {len(result_without_rerank)} characters")
|
||||
print(f" Preview: {result_without_rerank[:100]}...")
|
||||
|
||||
# Test naive mode with specific top_k
|
||||
result = await rag.aquery(query, param=QueryParam(mode="naive", top_k=top_k))
|
||||
print(f" Result length: {len(result)} characters")
|
||||
print(f" Preview: {result[:100]}...")
|
||||
# Test with default settings (enable_rerank defaults to True)
|
||||
print("\n📊 Testing with default settings (enable_rerank defaults to True):")
|
||||
result_default = await rag.aquery(
|
||||
query, param=QueryParam(mode="naive", top_k=10, chunk_top_k=5)
|
||||
)
|
||||
print(f" Result length: {len(result_default)} characters")
|
||||
print(f" Preview: {result_default[:100]}...")
|
||||
|
||||
|
||||
async def test_direct_rerank():
|
||||
|
|
@ -209,17 +236,21 @@ async def main():
|
|||
print("=" * 60)
|
||||
|
||||
try:
|
||||
# Test rerank with different top_k values
|
||||
await test_rerank_with_different_topk()
|
||||
# Test rerank with different enable_rerank settings
|
||||
await test_rerank_with_different_settings()
|
||||
|
||||
# Test direct rerank
|
||||
await test_direct_rerank()
|
||||
|
||||
print("\n✅ Example completed successfully!")
|
||||
print("\n💡 Key Points:")
|
||||
print(" ✓ All rerank configurations are contained within rerank_model_func")
|
||||
print(" ✓ Rerank improves document relevance ordering")
|
||||
print(" ✓ Configure API keys within your rerank function")
|
||||
print(" ✓ Rerank is now controlled per query via 'enable_rerank' parameter")
|
||||
print(" ✓ Default value for enable_rerank is True")
|
||||
print(" ✓ Rerank function is configured at LightRAG initialization")
|
||||
print(" ✓ Per-query enable_rerank setting overrides default behavior")
|
||||
print(
|
||||
" ✓ If enable_rerank=True but no rerank model is configured, a warning is issued"
|
||||
)
|
||||
print(" ✓ Monitor API usage and costs when using rerank services")
|
||||
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -10,55 +10,58 @@ from .utils import logger
|
|||
|
||||
class RerankModel(BaseModel):
|
||||
"""
|
||||
Pydantic model class for defining a custom rerank model.
|
||||
|
||||
This class provides a convenient wrapper for rerank functions, allowing you to
|
||||
encapsulate all rerank configurations (API keys, model settings, etc.) in one place.
|
||||
|
||||
Attributes:
|
||||
rerank_func (Callable[[Any], List[Dict]]): A callable function that reranks documents.
|
||||
The function should take query and documents as input and return reranked results.
|
||||
kwargs (Dict[str, Any]): A dictionary that contains the arguments to pass to the callable function.
|
||||
This should include all necessary configurations such as model name, API key, base_url, etc.
|
||||
Wrapper for rerank functions that can be used with LightRAG.
|
||||
|
||||
Example usage:
|
||||
Rerank model example with Jina:
|
||||
```python
|
||||
rerank_model = RerankModel(
|
||||
rerank_func=jina_rerank,
|
||||
kwargs={
|
||||
"model": "BAAI/bge-reranker-v2-m3",
|
||||
"api_key": "your_api_key_here",
|
||||
"base_url": "https://api.jina.ai/v1/rerank"
|
||||
}
|
||||
```python
|
||||
from lightrag.rerank import RerankModel, jina_rerank
|
||||
|
||||
# Create rerank model
|
||||
rerank_model = RerankModel(
|
||||
rerank_func=jina_rerank,
|
||||
kwargs={
|
||||
"model": "BAAI/bge-reranker-v2-m3",
|
||||
"api_key": "your_api_key_here",
|
||||
"base_url": "https://api.jina.ai/v1/rerank"
|
||||
}
|
||||
)
|
||||
|
||||
# Use in LightRAG
|
||||
rag = LightRAG(
|
||||
rerank_model_func=rerank_model.rerank,
|
||||
# ... other configurations
|
||||
)
|
||||
|
||||
# Query with rerank enabled (default)
|
||||
result = await rag.aquery(
|
||||
"your query",
|
||||
param=QueryParam(enable_rerank=True)
|
||||
)
|
||||
```
|
||||
|
||||
Or define a custom function directly:
|
||||
```python
|
||||
async def my_rerank_func(query: str, documents: list, top_k: int = None, **kwargs):
|
||||
return await jina_rerank(
|
||||
query=query,
|
||||
documents=documents,
|
||||
model="BAAI/bge-reranker-v2-m3",
|
||||
api_key="your_api_key_here",
|
||||
top_k=top_k or 10,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
# Use in LightRAG
|
||||
rag = LightRAG(
|
||||
enable_rerank=True,
|
||||
rerank_model_func=rerank_model.rerank,
|
||||
# ... other configurations
|
||||
)
|
||||
```
|
||||
rag = LightRAG(
|
||||
rerank_model_func=my_rerank_func,
|
||||
# ... other configurations
|
||||
)
|
||||
|
||||
Or define a custom function directly:
|
||||
```python
|
||||
async def my_rerank_func(query: str, documents: list, top_k: int = None, **kwargs):
|
||||
return await jina_rerank(
|
||||
query=query,
|
||||
documents=documents,
|
||||
model="BAAI/bge-reranker-v2-m3",
|
||||
api_key="your_api_key_here",
|
||||
top_k=top_k or 10,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
rag = LightRAG(
|
||||
enable_rerank=True,
|
||||
rerank_model_func=my_rerank_func,
|
||||
# ... other configurations
|
||||
)
|
||||
```
|
||||
# Control rerank per query
|
||||
result = await rag.aquery(
|
||||
"your query",
|
||||
param=QueryParam(enable_rerank=True) # Enable rerank for this query
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
rerank_func: Callable[[Any], List[Dict]]
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue