Update rerank_example & readme

This commit is contained in:
zrguo 2025-07-15 12:17:27 +08:00
parent 7c882313bb
commit 9a9f0f2463
3 changed files with 129 additions and 89 deletions

View file

@ -1,36 +1,24 @@
# Rerank Integration in LightRAG # Rerank Integration Guide
This document explains how to configure and use the rerank functionality in LightRAG to improve retrieval quality. LightRAG supports reranking functionality to improve retrieval quality by re-ordering documents based on their relevance to the query. Reranking is now controlled per query via the `enable_rerank` parameter (default: True).
## Overview ## Quick Start
Reranking is an optional feature that improves the quality of retrieved documents by re-ordering them based on their relevance to the query. This is particularly useful when you want higher precision in document retrieval across all query modes (naive, local, global, hybrid, mix).
## Architecture
The rerank integration follows a simplified design pattern:
- **Single Function Configuration**: All rerank settings (model, API keys, top_k, etc.) are contained within the rerank function
- **Async Processing**: Non-blocking rerank operations
- **Error Handling**: Graceful fallback to original results
- **Optional Feature**: Can be enabled/disabled via configuration
- **Code Reuse**: Single generic implementation for Jina/Cohere compatible APIs
## Configuration
### Environment Variables ### Environment Variables
Set this variable in your `.env` file or environment: Set these variables in your `.env` file or environment for rerank model configuration:
```bash ```bash
# Enable/disable reranking # Rerank model configuration (required when enable_rerank=True in queries)
ENABLE_RERANK=True RERANK_MODEL=BAAI/bge-reranker-v2-m3
RERANK_BINDING_HOST=https://api.your-provider.com/v1/rerank
RERANK_BINDING_API_KEY=your_api_key_here
``` ```
### Programmatic Configuration ### Programmatic Configuration
```python ```python
from lightrag import LightRAG from lightrag import LightRAG, QueryParam
from lightrag.rerank import custom_rerank, RerankModel from lightrag.rerank import custom_rerank, RerankModel
# Method 1: Using a custom rerank function with all settings included # Method 1: Using a custom rerank function with all settings included
@ -49,8 +37,19 @@ rag = LightRAG(
working_dir="./rag_storage", working_dir="./rag_storage",
llm_model_func=your_llm_func, llm_model_func=your_llm_func,
embedding_func=your_embedding_func, embedding_func=your_embedding_func,
enable_rerank=True, rerank_model_func=my_rerank_func, # Configure rerank function
rerank_model_func=my_rerank_func, )
# Query with rerank enabled (default)
result = await rag.aquery(
"your query",
param=QueryParam(enable_rerank=True) # Control rerank per query
)
# Query with rerank disabled
result = await rag.aquery(
"your query",
param=QueryParam(enable_rerank=False)
) )
# Method 2: Using RerankModel wrapper # Method 2: Using RerankModel wrapper
@ -67,9 +66,17 @@ rag = LightRAG(
working_dir="./rag_storage", working_dir="./rag_storage",
llm_model_func=your_llm_func, llm_model_func=your_llm_func,
embedding_func=your_embedding_func, embedding_func=your_embedding_func,
enable_rerank=True,
rerank_model_func=rerank_model.rerank, rerank_model_func=rerank_model.rerank,
) )
# Control rerank per query
result = await rag.aquery(
"your query",
param=QueryParam(
enable_rerank=True, # Enable rerank for this query
chunk_top_k=5 # Number of chunks to keep after reranking
)
)
``` ```
## Supported Providers ## Supported Providers
@ -164,7 +171,6 @@ async def main():
working_dir="./rag_storage", working_dir="./rag_storage",
llm_model_func=gpt_4o_mini_complete, llm_model_func=gpt_4o_mini_complete,
embedding_func=openai_embedding, embedding_func=openai_embedding,
enable_rerank=True,
rerank_model_func=my_rerank_func, rerank_model_func=my_rerank_func,
) )
@ -180,7 +186,7 @@ async def main():
# Query with rerank (automatically applied) # Query with rerank (automatically applied)
result = await rag.aquery( result = await rag.aquery(
"Your question here", "Your question here",
param=QueryParam(mode="hybrid", top_k=5) # This top_k is passed to rerank function param=QueryParam(enable_rerank=True) # This top_k is passed to rerank function
) )
print(result) print(result)

View file

@ -9,7 +9,11 @@ Configuration Required:
2. Set your embedding API key and base URL in embedding_func() 2. Set your embedding API key and base URL in embedding_func()
3. Set your rerank API key and base URL in the rerank configuration 3. Set your rerank API key and base URL in the rerank configuration
4. Or use environment variables (.env file): 4. Or use environment variables (.env file):
- ENABLE_RERANK=True - RERANK_MODEL=your_rerank_model
- RERANK_BINDING_HOST=your_rerank_endpoint
- RERANK_BINDING_API_KEY=your_rerank_api_key
Note: Rerank is now controlled per query via the 'enable_rerank' parameter (default: True)
""" """
import asyncio import asyncio
@ -83,8 +87,7 @@ async def create_rag_with_rerank():
max_token_size=8192, max_token_size=8192,
func=embedding_func, func=embedding_func,
), ),
# Simplified Rerank Configuration # Rerank Configuration - provide the rerank function
enable_rerank=True,
rerank_model_func=my_rerank_func, rerank_model_func=my_rerank_func,
) )
@ -120,7 +123,6 @@ async def create_rag_with_rerank_model():
max_token_size=8192, max_token_size=8192,
func=embedding_func, func=embedding_func,
), ),
enable_rerank=True,
rerank_model_func=rerank_model.rerank, rerank_model_func=rerank_model.rerank,
) )
@ -130,9 +132,9 @@ async def create_rag_with_rerank_model():
return rag return rag
async def test_rerank_with_different_topk(): async def test_rerank_with_different_settings():
""" """
Test rerank functionality with different top_k settings Test rerank functionality with different enable_rerank settings
""" """
print("🚀 Setting up LightRAG with Rerank functionality...") print("🚀 Setting up LightRAG with Rerank functionality...")
@ -154,16 +156,41 @@ async def test_rerank_with_different_topk():
print(f"\n🔍 Testing query: '{query}'") print(f"\n🔍 Testing query: '{query}'")
print("=" * 80) print("=" * 80)
# Test different top_k values to show parameter priority # Test with rerank enabled (default)
top_k_values = [2, 5, 10] print("\n📊 Testing with enable_rerank=True (default):")
result_with_rerank = await rag.aquery(
query,
param=QueryParam(
mode="naive",
top_k=10,
chunk_top_k=5,
enable_rerank=True, # Explicitly enable rerank
),
)
print(f" Result length: {len(result_with_rerank)} characters")
print(f" Preview: {result_with_rerank[:100]}...")
for top_k in top_k_values: # Test with rerank disabled
print(f"\n📊 Testing with QueryParam(top_k={top_k}):") print("\n📊 Testing with enable_rerank=False:")
result_without_rerank = await rag.aquery(
query,
param=QueryParam(
mode="naive",
top_k=10,
chunk_top_k=5,
enable_rerank=False, # Disable rerank
),
)
print(f" Result length: {len(result_without_rerank)} characters")
print(f" Preview: {result_without_rerank[:100]}...")
# Test naive mode with specific top_k # Test with default settings (enable_rerank defaults to True)
result = await rag.aquery(query, param=QueryParam(mode="naive", top_k=top_k)) print("\n📊 Testing with default settings (enable_rerank defaults to True):")
print(f" Result length: {len(result)} characters") result_default = await rag.aquery(
print(f" Preview: {result[:100]}...") query, param=QueryParam(mode="naive", top_k=10, chunk_top_k=5)
)
print(f" Result length: {len(result_default)} characters")
print(f" Preview: {result_default[:100]}...")
async def test_direct_rerank(): async def test_direct_rerank():
@ -209,17 +236,21 @@ async def main():
print("=" * 60) print("=" * 60)
try: try:
# Test rerank with different top_k values # Test rerank with different enable_rerank settings
await test_rerank_with_different_topk() await test_rerank_with_different_settings()
# Test direct rerank # Test direct rerank
await test_direct_rerank() await test_direct_rerank()
print("\n✅ Example completed successfully!") print("\n✅ Example completed successfully!")
print("\n💡 Key Points:") print("\n💡 Key Points:")
print(" ✓ All rerank configurations are contained within rerank_model_func") print(" ✓ Rerank is now controlled per query via 'enable_rerank' parameter")
print(" ✓ Rerank improves document relevance ordering") print(" ✓ Default value for enable_rerank is True")
print(" ✓ Configure API keys within your rerank function") print(" ✓ Rerank function is configured at LightRAG initialization")
print(" ✓ Per-query enable_rerank setting overrides default behavior")
print(
" ✓ If enable_rerank=True but no rerank model is configured, a warning is issued"
)
print(" ✓ Monitor API usage and costs when using rerank services") print(" ✓ Monitor API usage and costs when using rerank services")
except Exception as e: except Exception as e:

View file

@ -10,55 +10,58 @@ from .utils import logger
class RerankModel(BaseModel): class RerankModel(BaseModel):
""" """
Pydantic model class for defining a custom rerank model. Wrapper for rerank functions that can be used with LightRAG.
This class provides a convenient wrapper for rerank functions, allowing you to
encapsulate all rerank configurations (API keys, model settings, etc.) in one place.
Attributes:
rerank_func (Callable[[Any], List[Dict]]): A callable function that reranks documents.
The function should take query and documents as input and return reranked results.
kwargs (Dict[str, Any]): A dictionary that contains the arguments to pass to the callable function.
This should include all necessary configurations such as model name, API key, base_url, etc.
Example usage: Example usage:
Rerank model example with Jina: ```python
```python from lightrag.rerank import RerankModel, jina_rerank
rerank_model = RerankModel(
rerank_func=jina_rerank, # Create rerank model
kwargs={ rerank_model = RerankModel(
"model": "BAAI/bge-reranker-v2-m3", rerank_func=jina_rerank,
"api_key": "your_api_key_here", kwargs={
"base_url": "https://api.jina.ai/v1/rerank" "model": "BAAI/bge-reranker-v2-m3",
} "api_key": "your_api_key_here",
"base_url": "https://api.jina.ai/v1/rerank"
}
)
# Use in LightRAG
rag = LightRAG(
rerank_model_func=rerank_model.rerank,
# ... other configurations
)
# Query with rerank enabled (default)
result = await rag.aquery(
"your query",
param=QueryParam(enable_rerank=True)
)
```
Or define a custom function directly:
```python
async def my_rerank_func(query: str, documents: list, top_k: int = None, **kwargs):
return await jina_rerank(
query=query,
documents=documents,
model="BAAI/bge-reranker-v2-m3",
api_key="your_api_key_here",
top_k=top_k or 10,
**kwargs
) )
# Use in LightRAG rag = LightRAG(
rag = LightRAG( rerank_model_func=my_rerank_func,
enable_rerank=True, # ... other configurations
rerank_model_func=rerank_model.rerank, )
# ... other configurations
)
```
Or define a custom function directly: # Control rerank per query
```python result = await rag.aquery(
async def my_rerank_func(query: str, documents: list, top_k: int = None, **kwargs): "your query",
return await jina_rerank( param=QueryParam(enable_rerank=True) # Enable rerank for this query
query=query, )
documents=documents, ```
model="BAAI/bge-reranker-v2-m3",
api_key="your_api_key_here",
top_k=top_k or 10,
**kwargs
)
rag = LightRAG(
enable_rerank=True,
rerank_model_func=my_rerank_func,
# ... other configurations
)
```
""" """
rerank_func: Callable[[Any], List[Dict]] rerank_func: Callable[[Any], List[Dict]]