Merge pull request #1753 from HKUDS/rerank

Add rerank optional for chunks
This commit is contained in:
Daniel.y 2025-07-09 16:06:55 +08:00 committed by GitHub
commit ba0cffd853
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 1240 additions and 161 deletions

View file

@ -294,6 +294,16 @@ class QueryParam:
top_k: int = int(os.getenv("TOP_K", "60")) top_k: int = int(os.getenv("TOP_K", "60"))
"""Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode.""" """Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode."""
chunk_top_k: int = int(os.getenv("CHUNK_TOP_K", "5"))
"""Number of text chunks to retrieve initially from vector search.
If None, defaults to top_k value.
"""
chunk_rerank_top_k: int = int(os.getenv("CHUNK_RERANK_TOP_K", "5"))
"""Number of text chunks to keep after reranking.
If None, keeps all chunks returned from initial retrieval.
"""
max_token_for_text_unit: int = int(os.getenv("MAX_TOKEN_TEXT_CHUNK", "4000")) max_token_for_text_unit: int = int(os.getenv("MAX_TOKEN_TEXT_CHUNK", "4000"))
"""Maximum number of tokens allowed for each retrieved text chunk.""" """Maximum number of tokens allowed for each retrieved text chunk."""
@ -849,6 +859,18 @@ rag = LightRAG(
</details> </details>
### LightRAG实例间的数据隔离
通过 workspace 参数可以不同实现不同LightRAG实例之间的存储数据隔离。LightRAG在初始化后workspace就已经确定之后修改workspace是无效的。下面是不同类型的存储实现工作空间的方式
- **对于本地基于文件的数据库,数据隔离通过工作空间子目录实现:** JsonKVStorage, JsonDocStatusStorage, NetworkXStorage, NanoVectorDBStorage, FaissVectorDBStorage。
- **对于将数据存储在集合collection中的数据库通过在集合名称前添加工作空间前缀来实现** RedisKVStorage, RedisDocStatusStorage, MilvusVectorDBStorage, QdrantVectorDBStorage, MongoKVStorage, MongoDocStatusStorage, MongoVectorDBStorage, MongoGraphStorage, PGGraphStorage。
- **对于关系型数据库,数据隔离通过向表中添加 `workspace` 字段进行数据的逻辑隔离:** PGKVStorage, PGVectorStorage, PGDocStatusStorage。
* **对于Neo4j图数据库通过label来实现数据的逻辑隔离**Neo4JStorage
为了保持对遗留数据的兼容在未配置工作空间时PostgreSQL的默认工作空间为`default`Neo4j的默认工作空间为`base`。对于所有的外部存储,系统都提供了专用的工作空间环境变量,用于覆盖公共的 `WORKSPACE`环境变量配置。这些适用于指定存储类型的工作空间环境变量为:`REDIS_WORKSPACE`, `MILVUS_WORKSPACE`, `QDRANT_WORKSPACE`, `MONGODB_WORKSPACE`, `POSTGRES_WORKSPACE`, `NEO4J_WORKSPACE`
## 编辑实体和关系 ## 编辑实体和关系
LightRAG现在支持全面的知识图谱管理功能允许您在知识图谱中创建、编辑和删除实体和关系。 LightRAG现在支持全面的知识图谱管理功能允许您在知识图谱中创建、编辑和删除实体和关系。

View file

@ -153,7 +153,7 @@ curl https://raw.githubusercontent.com/gusye1234/nano-graphrag/main/tests/mock_d
python examples/lightrag_openai_demo.py python examples/lightrag_openai_demo.py
``` ```
For a streaming response implementation example, please see `examples/lightrag_openai_compatible_demo.py`. Prior to execution, ensure you modify the sample codes LLM and embedding configurations accordingly. For a streaming response implementation example, please see `examples/lightrag_openai_compatible_demo.py`. Prior to execution, ensure you modify the sample code's LLM and embedding configurations accordingly.
**Note 1**: When running the demo program, please be aware that different test scripts may use different embedding models. If you switch to a different embedding model, you must clear the data directory (`./dickens`); otherwise, the program may encounter errors. If you wish to retain the LLM cache, you can preserve the `kv_store_llm_response_cache.json` file while clearing the data directory. **Note 1**: When running the demo program, please be aware that different test scripts may use different embedding models. If you switch to a different embedding model, you must clear the data directory (`./dickens`); otherwise, the program may encounter errors. If you wish to retain the LLM cache, you can preserve the `kv_store_llm_response_cache.json` file while clearing the data directory.
@ -239,6 +239,7 @@ A full list of LightRAG init parameters:
| **Parameter** | **Type** | **Explanation** | **Default** | | **Parameter** | **Type** | **Explanation** | **Default** |
|--------------|----------|-----------------|-------------| |--------------|----------|-----------------|-------------|
| **working_dir** | `str` | Directory where the cache will be stored | `lightrag_cache+timestamp` | | **working_dir** | `str` | Directory where the cache will be stored | `lightrag_cache+timestamp` |
| **workspace** | str | Workspace name for data isolation between different LightRAG Instances | |
| **kv_storage** | `str` | Storage type for documents and text chunks. Supported types: `JsonKVStorage`,`PGKVStorage`,`RedisKVStorage`,`MongoKVStorage` | `JsonKVStorage` | | **kv_storage** | `str` | Storage type for documents and text chunks. Supported types: `JsonKVStorage`,`PGKVStorage`,`RedisKVStorage`,`MongoKVStorage` | `JsonKVStorage` |
| **vector_storage** | `str` | Storage type for embedding vectors. Supported types: `NanoVectorDBStorage`,`PGVectorStorage`,`MilvusVectorDBStorage`,`ChromaVectorDBStorage`,`FaissVectorDBStorage`,`MongoVectorDBStorage`,`QdrantVectorDBStorage` | `NanoVectorDBStorage` | | **vector_storage** | `str` | Storage type for embedding vectors. Supported types: `NanoVectorDBStorage`,`PGVectorStorage`,`MilvusVectorDBStorage`,`ChromaVectorDBStorage`,`FaissVectorDBStorage`,`MongoVectorDBStorage`,`QdrantVectorDBStorage` | `NanoVectorDBStorage` |
| **graph_storage** | `str` | Storage type for graph edges and nodes. Supported types: `NetworkXStorage`,`Neo4JStorage`,`PGGraphStorage`,`AGEStorage` | `NetworkXStorage` | | **graph_storage** | `str` | Storage type for graph edges and nodes. Supported types: `NetworkXStorage`,`Neo4JStorage`,`PGGraphStorage`,`AGEStorage` | `NetworkXStorage` |
@ -300,6 +301,16 @@ class QueryParam:
top_k: int = int(os.getenv("TOP_K", "60")) top_k: int = int(os.getenv("TOP_K", "60"))
"""Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode.""" """Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode."""
chunk_top_k: int = int(os.getenv("CHUNK_TOP_K", "5"))
"""Number of text chunks to retrieve initially from vector search.
If None, defaults to top_k value.
"""
chunk_rerank_top_k: int = int(os.getenv("CHUNK_RERANK_TOP_K", "5"))
"""Number of text chunks to keep after reranking.
If None, keeps all chunks returned from initial retrieval.
"""
max_token_for_text_unit: int = int(os.getenv("MAX_TOKEN_TEXT_CHUNK", "4000")) max_token_for_text_unit: int = int(os.getenv("MAX_TOKEN_TEXT_CHUNK", "4000"))
"""Maximum number of tokens allowed for each retrieved text chunk.""" """Maximum number of tokens allowed for each retrieved text chunk."""
@ -895,6 +906,17 @@ async def initialize_rag():
</details> </details>
### Data Isolation Between LightRAG Instances
The `workspace` parameter ensures data isolation between different LightRAG instances. Once initialized, the `workspace` is immutable and cannot be changed.Here is how workspaces are implemented for different types of storage:
- **For local file-based databases, data isolation is achieved through workspace subdirectories:** `JsonKVStorage`, `JsonDocStatusStorage`, `NetworkXStorage`, `NanoVectorDBStorage`, `FaissVectorDBStorage`.
- **For databases that store data in collections, it's done by adding a workspace prefix to the collection name:** `RedisKVStorage`, `RedisDocStatusStorage`, `MilvusVectorDBStorage`, `QdrantVectorDBStorage`, `MongoKVStorage`, `MongoDocStatusStorage`, `MongoVectorDBStorage`, `MongoGraphStorage`, `PGGraphStorage`.
- **For relational databases, data isolation is achieved by adding a `workspace` field to the tables for logical data separation:** `PGKVStorage`, `PGVectorStorage`, `PGDocStatusStorage`.
- **For the Neo4j graph database, logical data isolation is achieved through labels:** `Neo4JStorage`
To maintain compatibility with legacy data, the default workspace for PostgreSQL is `default` and for Neo4j is `base` when no workspace is configured. For all external storages, the system provides dedicated workspace environment variables to override the common `WORKSPACE` environment variable configuration. These storage-specific workspace environment variables are: `REDIS_WORKSPACE`, `MILVUS_WORKSPACE`, `QDRANT_WORKSPACE`, `MONGODB_WORKSPACE`, `POSTGRES_WORKSPACE`, `NEO4J_WORKSPACE`.
## Edit Entities and Relations ## Edit Entities and Relations
LightRAG now supports comprehensive knowledge graph management capabilities, allowing you to create, edit, and delete entities and relationships within your knowledge graph. LightRAG now supports comprehensive knowledge graph management capabilities, allowing you to create, edit, and delete entities and relationships within your knowledge graph.

275
docs/rerank_integration.md Normal file
View file

@ -0,0 +1,275 @@
# Rerank Integration in LightRAG
This document explains how to configure and use the rerank functionality in LightRAG to improve retrieval quality.
## Overview
Reranking is an optional feature that improves the quality of retrieved documents by re-ordering them based on their relevance to the query. This is particularly useful when you want higher precision in document retrieval across all query modes (naive, local, global, hybrid, mix).
## Architecture
The rerank integration follows a simplified design pattern:
- **Single Function Configuration**: All rerank settings (model, API keys, top_k, etc.) are contained within the rerank function
- **Async Processing**: Non-blocking rerank operations
- **Error Handling**: Graceful fallback to original results
- **Optional Feature**: Can be enabled/disabled via configuration
- **Code Reuse**: Single generic implementation for Jina/Cohere compatible APIs
## Configuration
### Environment Variables
Set this variable in your `.env` file or environment:
```bash
# Enable/disable reranking
ENABLE_RERANK=True
```
### Programmatic Configuration
```python
from lightrag import LightRAG
from lightrag.rerank import custom_rerank, RerankModel
# Method 1: Using a custom rerank function with all settings included
async def my_rerank_func(query: str, documents: list, top_k: int = None, **kwargs):
return await custom_rerank(
query=query,
documents=documents,
model="BAAI/bge-reranker-v2-m3",
base_url="https://api.your-provider.com/v1/rerank",
api_key="your_api_key_here",
top_k=top_k or 10, # Handle top_k within the function
**kwargs
)
rag = LightRAG(
working_dir="./rag_storage",
llm_model_func=your_llm_func,
embedding_func=your_embedding_func,
enable_rerank=True,
rerank_model_func=my_rerank_func,
)
# Method 2: Using RerankModel wrapper
rerank_model = RerankModel(
rerank_func=custom_rerank,
kwargs={
"model": "BAAI/bge-reranker-v2-m3",
"base_url": "https://api.your-provider.com/v1/rerank",
"api_key": "your_api_key_here",
}
)
rag = LightRAG(
working_dir="./rag_storage",
llm_model_func=your_llm_func,
embedding_func=your_embedding_func,
enable_rerank=True,
rerank_model_func=rerank_model.rerank,
)
```
## Supported Providers
### 1. Custom/Generic API (Recommended)
For Jina/Cohere compatible APIs:
```python
from lightrag.rerank import custom_rerank
# Your custom API endpoint
result = await custom_rerank(
query="your query",
documents=documents,
model="BAAI/bge-reranker-v2-m3",
base_url="https://api.your-provider.com/v1/rerank",
api_key="your_api_key_here",
top_k=10
)
```
### 2. Jina AI
```python
from lightrag.rerank import jina_rerank
result = await jina_rerank(
query="your query",
documents=documents,
model="BAAI/bge-reranker-v2-m3",
api_key="your_jina_api_key",
top_k=10
)
```
### 3. Cohere
```python
from lightrag.rerank import cohere_rerank
result = await cohere_rerank(
query="your query",
documents=documents,
model="rerank-english-v2.0",
api_key="your_cohere_api_key",
top_k=10
)
```
## Integration Points
Reranking is automatically applied at these key retrieval stages:
1. **Naive Mode**: After vector similarity search in `_get_vector_context`
2. **Local Mode**: After entity retrieval in `_get_node_data`
3. **Global Mode**: After relationship retrieval in `_get_edge_data`
4. **Hybrid/Mix Modes**: Applied to all relevant components
## Configuration Parameters
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `enable_rerank` | bool | False | Enable/disable reranking |
| `rerank_model_func` | callable | None | Custom rerank function containing all configurations (model, API keys, top_k, etc.) |
## Example Usage
### Basic Usage
```python
import asyncio
from lightrag import LightRAG, QueryParam
from lightrag.llm.openai import gpt_4o_mini_complete, openai_embedding
from lightrag.kg.shared_storage import initialize_pipeline_status
from lightrag.rerank import jina_rerank
async def my_rerank_func(query: str, documents: list, top_k: int = None, **kwargs):
"""Custom rerank function with all settings included"""
return await jina_rerank(
query=query,
documents=documents,
model="BAAI/bge-reranker-v2-m3",
api_key="your_jina_api_key_here",
top_k=top_k or 10, # Default top_k if not provided
**kwargs
)
async def main():
# Initialize with rerank enabled
rag = LightRAG(
working_dir="./rag_storage",
llm_model_func=gpt_4o_mini_complete,
embedding_func=openai_embedding,
enable_rerank=True,
rerank_model_func=my_rerank_func,
)
await rag.initialize_storages()
await initialize_pipeline_status()
# Insert documents
await rag.ainsert([
"Document 1 content...",
"Document 2 content...",
])
# Query with rerank (automatically applied)
result = await rag.aquery(
"Your question here",
param=QueryParam(mode="hybrid", top_k=5) # This top_k is passed to rerank function
)
print(result)
asyncio.run(main())
```
### Direct Rerank Usage
```python
from lightrag.rerank import custom_rerank
async def test_rerank():
documents = [
{"content": "Text about topic A"},
{"content": "Text about topic B"},
{"content": "Text about topic C"},
]
reranked = await custom_rerank(
query="Tell me about topic A",
documents=documents,
model="BAAI/bge-reranker-v2-m3",
base_url="https://api.your-provider.com/v1/rerank",
api_key="your_api_key_here",
top_k=2
)
for doc in reranked:
print(f"Score: {doc.get('rerank_score')}, Content: {doc.get('content')}")
```
## Best Practices
1. **Self-Contained Functions**: Include all necessary configurations (API keys, models, top_k handling) within your rerank function
2. **Performance**: Use reranking selectively for better performance vs. quality tradeoff
3. **API Limits**: Monitor API usage and implement rate limiting within your rerank function
4. **Fallback**: Always handle rerank failures gracefully (returns original results)
5. **Top-k Handling**: Handle top_k parameter appropriately within your rerank function
6. **Cost Management**: Consider rerank API costs in your budget planning
## Troubleshooting
### Common Issues
1. **API Key Missing**: Ensure API keys are properly configured within your rerank function
2. **Network Issues**: Check API endpoints and network connectivity
3. **Model Errors**: Verify the rerank model name is supported by your API
4. **Document Format**: Ensure documents have `content` or `text` fields
### Debug Mode
Enable debug logging to see rerank operations:
```python
import logging
logging.getLogger("lightrag.rerank").setLevel(logging.DEBUG)
```
### Error Handling
The rerank integration includes automatic fallback:
```python
# If rerank fails, original documents are returned
# No exceptions are raised to the user
# Errors are logged for debugging
```
## API Compatibility
The generic rerank API expects this response format:
```json
{
"results": [
{
"index": 0,
"relevance_score": 0.95
},
{
"index": 2,
"relevance_score": 0.87
}
]
}
```
This is compatible with:
- Jina AI Rerank API
- Cohere Rerank API
- Custom APIs following the same format

View file

@ -42,13 +42,31 @@ OLLAMA_EMULATING_MODEL_TAG=latest
### Logfile location (defaults to current working directory) ### Logfile location (defaults to current working directory)
# LOG_DIR=/path/to/log/directory # LOG_DIR=/path/to/log/directory
### Settings for RAG query ### RAG Configuration
### Chunk size for document splitting, 500~1500 is recommended
# CHUNK_SIZE=1200
# CHUNK_OVERLAP_SIZE=100
# MAX_TOKEN_SUMMARY=500
### RAG Query Configuration
# HISTORY_TURNS=3 # HISTORY_TURNS=3
# COSINE_THRESHOLD=0.2 # MAX_TOKEN_TEXT_CHUNK=6000
# TOP_K=60
# MAX_TOKEN_TEXT_CHUNK=4000
# MAX_TOKEN_RELATION_DESC=4000 # MAX_TOKEN_RELATION_DESC=4000
# MAX_TOKEN_ENTITY_DESC=4000 # MAX_TOKEN_ENTITY_DESC=4000
# COSINE_THRESHOLD=0.2
### Number of entities or relations to retrieve from KG
# TOP_K=60
### Number of text chunks to retrieve initially from vector search
# CHUNK_TOP_K=5
### Rerank Configuration
# ENABLE_RERANK=False
### Number of text chunks to keep after reranking (should be <= CHUNK_TOP_K)
# CHUNK_RERANK_TOP_K=5
### Rerank model configuration (required when ENABLE_RERANK=True)
# RERANK_MODEL=BAAI/bge-reranker-v2-m3
# RERANK_BINDING_HOST=https://api.your-rerank-provider.com/v1/rerank
# RERANK_BINDING_API_KEY=your_rerank_api_key_here
### Entity and relation summarization configuration ### Entity and relation summarization configuration
### Language: English, Chinese, French, German ... ### Language: English, Chinese, French, German ...
@ -62,9 +80,6 @@ SUMMARY_LANGUAGE=English
### Number of parallel processing documents(Less than MAX_ASYNC/2 is recommended) ### Number of parallel processing documents(Less than MAX_ASYNC/2 is recommended)
# MAX_PARALLEL_INSERT=2 # MAX_PARALLEL_INSERT=2
### Chunk size for document splitting, 500~1500 is recommended
# CHUNK_SIZE=1200
# CHUNK_OVERLAP_SIZE=100
### LLM Configuration ### LLM Configuration
ENABLE_LLM_CACHE=true ENABLE_LLM_CACHE=true

233
examples/rerank_example.py Normal file
View file

@ -0,0 +1,233 @@
"""
LightRAG Rerank Integration Example
This example demonstrates how to use rerank functionality with LightRAG
to improve retrieval quality across different query modes.
Configuration Required:
1. Set your LLM API key and base URL in llm_model_func()
2. Set your embedding API key and base URL in embedding_func()
3. Set your rerank API key and base URL in the rerank configuration
4. Or use environment variables (.env file):
- ENABLE_RERANK=True
"""
import asyncio
import os
import numpy as np
from lightrag import LightRAG, QueryParam
from lightrag.rerank import custom_rerank, RerankModel
from lightrag.llm.openai import openai_complete_if_cache, openai_embed
from lightrag.utils import EmbeddingFunc, setup_logger
from lightrag.kg.shared_storage import initialize_pipeline_status
# Set up your working directory
WORKING_DIR = "./test_rerank"
setup_logger("test_rerank")
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
async def llm_model_func(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
return await openai_complete_if_cache(
"gpt-4o-mini",
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
api_key="your_llm_api_key_here",
base_url="https://api.your-llm-provider.com/v1",
**kwargs,
)
async def embedding_func(texts: list[str]) -> np.ndarray:
return await openai_embed(
texts,
model="text-embedding-3-large",
api_key="your_embedding_api_key_here",
base_url="https://api.your-embedding-provider.com/v1",
)
async def my_rerank_func(query: str, documents: list, top_k: int = None, **kwargs):
"""Custom rerank function with all settings included"""
return await custom_rerank(
query=query,
documents=documents,
model="BAAI/bge-reranker-v2-m3",
base_url="https://api.your-rerank-provider.com/v1/rerank",
api_key="your_rerank_api_key_here",
top_k=top_k or 10, # Default top_k if not provided
**kwargs,
)
async def create_rag_with_rerank():
"""Create LightRAG instance with rerank configuration"""
# Get embedding dimension
test_embedding = await embedding_func(["test"])
embedding_dim = test_embedding.shape[1]
print(f"Detected embedding dimension: {embedding_dim}")
# Method 1: Using custom rerank function
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=embedding_dim,
max_token_size=8192,
func=embedding_func,
),
# Simplified Rerank Configuration
enable_rerank=True,
rerank_model_func=my_rerank_func,
)
await rag.initialize_storages()
await initialize_pipeline_status()
return rag
async def create_rag_with_rerank_model():
"""Alternative: Create LightRAG instance using RerankModel wrapper"""
# Get embedding dimension
test_embedding = await embedding_func(["test"])
embedding_dim = test_embedding.shape[1]
print(f"Detected embedding dimension: {embedding_dim}")
# Method 2: Using RerankModel wrapper
rerank_model = RerankModel(
rerank_func=custom_rerank,
kwargs={
"model": "BAAI/bge-reranker-v2-m3",
"base_url": "https://api.your-rerank-provider.com/v1/rerank",
"api_key": "your_rerank_api_key_here",
},
)
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=embedding_dim,
max_token_size=8192,
func=embedding_func,
),
enable_rerank=True,
rerank_model_func=rerank_model.rerank,
)
await rag.initialize_storages()
await initialize_pipeline_status()
return rag
async def test_rerank_with_different_topk():
"""
Test rerank functionality with different top_k settings
"""
print("🚀 Setting up LightRAG with Rerank functionality...")
rag = await create_rag_with_rerank()
# Insert sample documents
sample_docs = [
"Reranking improves retrieval quality by re-ordering documents based on relevance.",
"LightRAG is a powerful retrieval-augmented generation system with multiple query modes.",
"Vector databases enable efficient similarity search in high-dimensional embedding spaces.",
"Natural language processing has evolved with large language models and transformers.",
"Machine learning algorithms can learn patterns from data without explicit programming.",
]
print("📄 Inserting sample documents...")
await rag.ainsert(sample_docs)
query = "How does reranking improve retrieval quality?"
print(f"\n🔍 Testing query: '{query}'")
print("=" * 80)
# Test different top_k values to show parameter priority
top_k_values = [2, 5, 10]
for top_k in top_k_values:
print(f"\n📊 Testing with QueryParam(top_k={top_k}):")
# Test naive mode with specific top_k
result = await rag.aquery(query, param=QueryParam(mode="naive", top_k=top_k))
print(f" Result length: {len(result)} characters")
print(f" Preview: {result[:100]}...")
async def test_direct_rerank():
"""Test rerank function directly"""
print("\n🔧 Direct Rerank API Test")
print("=" * 40)
documents = [
{"content": "Reranking significantly improves retrieval quality"},
{"content": "LightRAG supports advanced reranking capabilities"},
{"content": "Vector search finds semantically similar documents"},
{"content": "Natural language processing with modern transformers"},
{"content": "The quick brown fox jumps over the lazy dog"},
]
query = "rerank improve quality"
print(f"Query: '{query}'")
print(f"Documents: {len(documents)}")
try:
reranked_docs = await custom_rerank(
query=query,
documents=documents,
model="BAAI/bge-reranker-v2-m3",
base_url="https://api.your-rerank-provider.com/v1/rerank",
api_key="your_rerank_api_key_here",
top_k=3,
)
print("\n✅ Rerank Results:")
for i, doc in enumerate(reranked_docs):
score = doc.get("rerank_score", "N/A")
content = doc.get("content", "")[:60]
print(f" {i+1}. Score: {score:.4f} | {content}...")
except Exception as e:
print(f"❌ Rerank failed: {e}")
async def main():
"""Main example function"""
print("🎯 LightRAG Rerank Integration Example")
print("=" * 60)
try:
# Test rerank with different top_k values
await test_rerank_with_different_topk()
# Test direct rerank
await test_direct_rerank()
print("\n✅ Example completed successfully!")
print("\n💡 Key Points:")
print(" ✓ All rerank configurations are contained within rerank_model_func")
print(" ✓ Rerank improves document relevance ordering")
print(" ✓ Configure API keys within your rerank function")
print(" ✓ Monitor API usage and costs when using rerank services")
except Exception as e:
print(f"\n❌ Example failed: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
asyncio.run(main())

View file

@ -1,5 +1,5 @@
from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam
__version__ = "1.3.10" __version__ = "1.4.0"
__author__ = "Zirui Guo" __author__ = "Zirui Guo"
__url__ = "https://github.com/HKUDS/LightRAG" __url__ = "https://github.com/HKUDS/LightRAG"

View file

@ -165,6 +165,24 @@ def parse_args() -> argparse.Namespace:
default=get_env_value("TOP_K", 60, int), default=get_env_value("TOP_K", 60, int),
help="Number of most similar results to return (default: from env or 60)", help="Number of most similar results to return (default: from env or 60)",
) )
parser.add_argument(
"--chunk-top-k",
type=int,
default=get_env_value("CHUNK_TOP_K", 15, int),
help="Number of text chunks to retrieve initially from vector search (default: from env or 15)",
)
parser.add_argument(
"--chunk-rerank-top-k",
type=int,
default=get_env_value("CHUNK_RERANK_TOP_K", 5, int),
help="Number of text chunks to keep after reranking (default: from env or 5)",
)
parser.add_argument(
"--enable-rerank",
action="store_true",
default=get_env_value("ENABLE_RERANK", False, bool),
help="Enable rerank functionality (default: from env or False)",
)
parser.add_argument( parser.add_argument(
"--cosine-threshold", "--cosine-threshold",
type=float, type=float,
@ -295,6 +313,11 @@ def parse_args() -> argparse.Namespace:
args.guest_token_expire_hours = get_env_value("GUEST_TOKEN_EXPIRE_HOURS", 24, int) args.guest_token_expire_hours = get_env_value("GUEST_TOKEN_EXPIRE_HOURS", 24, int)
args.jwt_algorithm = get_env_value("JWT_ALGORITHM", "HS256") args.jwt_algorithm = get_env_value("JWT_ALGORITHM", "HS256")
# Rerank model configuration
args.rerank_model = get_env_value("RERANK_MODEL", "BAAI/bge-reranker-v2-m3")
args.rerank_binding_host = get_env_value("RERANK_BINDING_HOST", None)
args.rerank_binding_api_key = get_env_value("RERANK_BINDING_API_KEY", None)
ollama_server_infos.LIGHTRAG_MODEL = args.simulated_model_name ollama_server_infos.LIGHTRAG_MODEL = args.simulated_model_name
return args return args

View file

@ -291,6 +291,32 @@ def create_app(args):
), ),
) )
# Configure rerank function if enabled
rerank_model_func = None
if args.enable_rerank and args.rerank_binding_api_key and args.rerank_binding_host:
from lightrag.rerank import custom_rerank
async def server_rerank_func(
query: str, documents: list, top_k: int = None, **kwargs
):
"""Server rerank function with configuration from environment variables"""
return await custom_rerank(
query=query,
documents=documents,
model=args.rerank_model,
base_url=args.rerank_binding_host,
api_key=args.rerank_binding_api_key,
top_k=top_k,
**kwargs,
)
rerank_model_func = server_rerank_func
logger.info(f"Rerank enabled with model: {args.rerank_model}")
elif args.enable_rerank:
logger.warning(
"Rerank enabled but RERANK_BINDING_API_KEY or RERANK_BINDING_HOST not configured. Rerank will be disabled."
)
# Initialize RAG # Initialize RAG
if args.llm_binding in ["lollms", "ollama", "openai"]: if args.llm_binding in ["lollms", "ollama", "openai"]:
rag = LightRAG( rag = LightRAG(
@ -324,6 +350,8 @@ def create_app(args):
}, },
enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract, enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract,
enable_llm_cache=args.enable_llm_cache, enable_llm_cache=args.enable_llm_cache,
enable_rerank=args.enable_rerank,
rerank_model_func=rerank_model_func,
auto_manage_storages_states=False, auto_manage_storages_states=False,
max_parallel_insert=args.max_parallel_insert, max_parallel_insert=args.max_parallel_insert,
max_graph_nodes=args.max_graph_nodes, max_graph_nodes=args.max_graph_nodes,
@ -352,6 +380,8 @@ def create_app(args):
}, },
enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract, enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract,
enable_llm_cache=args.enable_llm_cache, enable_llm_cache=args.enable_llm_cache,
enable_rerank=args.enable_rerank,
rerank_model_func=rerank_model_func,
auto_manage_storages_states=False, auto_manage_storages_states=False,
max_parallel_insert=args.max_parallel_insert, max_parallel_insert=args.max_parallel_insert,
max_graph_nodes=args.max_graph_nodes, max_graph_nodes=args.max_graph_nodes,
@ -478,6 +508,12 @@ def create_app(args):
"enable_llm_cache": args.enable_llm_cache, "enable_llm_cache": args.enable_llm_cache,
"workspace": args.workspace, "workspace": args.workspace,
"max_graph_nodes": args.max_graph_nodes, "max_graph_nodes": args.max_graph_nodes,
# Rerank configuration
"enable_rerank": args.enable_rerank,
"rerank_model": args.rerank_model if args.enable_rerank else None,
"rerank_binding_host": args.rerank_binding_host
if args.enable_rerank
else None,
}, },
"auth_mode": auth_mode, "auth_mode": auth_mode,
"pipeline_busy": pipeline_status.get("busy", False), "pipeline_busy": pipeline_status.get("busy", False),

View file

@ -49,6 +49,18 @@ class QueryRequest(BaseModel):
description="Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode.", description="Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode.",
) )
chunk_top_k: Optional[int] = Field(
ge=1,
default=None,
description="Number of text chunks to retrieve initially from vector search.",
)
chunk_rerank_top_k: Optional[int] = Field(
ge=1,
default=None,
description="Number of text chunks to keep after reranking.",
)
max_token_for_text_unit: Optional[int] = Field( max_token_for_text_unit: Optional[int] = Field(
gt=1, gt=1,
default=None, default=None,

View file

@ -60,7 +60,17 @@ class QueryParam:
top_k: int = int(os.getenv("TOP_K", "60")) top_k: int = int(os.getenv("TOP_K", "60"))
"""Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode.""" """Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode."""
max_token_for_text_unit: int = int(os.getenv("MAX_TOKEN_TEXT_CHUNK", "4000")) chunk_top_k: int = int(os.getenv("CHUNK_TOP_K", "5"))
"""Number of text chunks to retrieve initially from vector search.
If None, defaults to top_k value.
"""
chunk_rerank_top_k: int = int(os.getenv("CHUNK_RERANK_TOP_K", "5"))
"""Number of text chunks to keep after reranking.
If None, keeps all chunks returned from initial retrieval.
"""
max_token_for_text_unit: int = int(os.getenv("MAX_TOKEN_TEXT_CHUNK", "6000"))
"""Maximum number of tokens allowed for each retrieved text chunk.""" """Maximum number of tokens allowed for each retrieved text chunk."""
max_token_for_global_context: int = int( max_token_for_global_context: int = int(
@ -280,21 +290,6 @@ class BaseKVStorage(StorageNameSpace, ABC):
False: if the cache drop failed, or the cache mode is not supported False: if the cache drop failed, or the cache mode is not supported
""" """
# async def drop_cache_by_chunk_ids(self, chunk_ids: list[str] | None = None) -> bool:
# """Delete specific cache records from storage by chunk IDs
# Importance notes for in-memory storage:
# 1. Changes will be persisted to disk during the next index_done_callback
# 2. update flags to notify other processes that data persistence is needed
# Args:
# chunk_ids (list[str]): List of chunk IDs to be dropped from storage
# Returns:
# True: if the cache drop successfully
# False: if the cache drop failed, or the operation is not supported
# """
@dataclass @dataclass
class BaseGraphStorage(StorageNameSpace, ABC): class BaseGraphStorage(StorageNameSpace, ABC):

View file

@ -240,6 +240,17 @@ class LightRAG:
llm_model_kwargs: dict[str, Any] = field(default_factory=dict) llm_model_kwargs: dict[str, Any] = field(default_factory=dict)
"""Additional keyword arguments passed to the LLM model function.""" """Additional keyword arguments passed to the LLM model function."""
# Rerank Configuration
# ---
enable_rerank: bool = field(
default=bool(os.getenv("ENABLE_RERANK", "False").lower() == "true")
)
"""Enable reranking for improved retrieval quality. Defaults to False."""
rerank_model_func: Callable[..., object] | None = field(default=None)
"""Function for reranking retrieved documents. All rerank configurations (model name, API keys, top_k, etc.) should be included in this function. Optional."""
# Storage # Storage
# --- # ---
@ -447,6 +458,14 @@ class LightRAG:
) )
) )
# Init Rerank
if self.enable_rerank and self.rerank_model_func:
logger.info("Rerank model initialized for improved retrieval quality")
elif self.enable_rerank and not self.rerank_model_func:
logger.warning(
"Rerank is enabled but no rerank_model_func provided. Reranking will be skipped."
)
self._storages_status = StoragesStatus.CREATED self._storages_status = StoragesStatus.CREATED
if self.auto_manage_storages_states: if self.auto_manage_storages_states:

View file

@ -1527,6 +1527,7 @@ async def kg_query(
# Build context # Build context
context = await _build_query_context( context = await _build_query_context(
query,
ll_keywords_str, ll_keywords_str,
hl_keywords_str, hl_keywords_str,
knowledge_graph_inst, knowledge_graph_inst,
@ -1746,84 +1747,52 @@ async def _get_vector_context(
query: str, query: str,
chunks_vdb: BaseVectorStorage, chunks_vdb: BaseVectorStorage,
query_param: QueryParam, query_param: QueryParam,
tokenizer: Tokenizer, ) -> list[dict]:
) -> tuple[list, list, list] | None:
""" """
Retrieve vector context from the vector database. Retrieve text chunks from the vector database without reranking or truncation.
This function performs vector search to find relevant text chunks for a query, This function performs vector search to find relevant text chunks for a query.
formats them with file path and creation time information. Reranking and truncation will be handled later in the unified processing.
Args: Args:
query: The query string to search for query: The query string to search for
chunks_vdb: Vector database containing document chunks chunks_vdb: Vector database containing document chunks
query_param: Query parameters including top_k and ids query_param: Query parameters including chunk_top_k and ids
tokenizer: Tokenizer for counting tokens
Returns: Returns:
Tuple (empty_entities, empty_relations, text_units) for combine_contexts, List of text chunks with metadata
compatible with _get_edge_data and _get_node_data format
""" """
try: try:
results = await chunks_vdb.query( # Use chunk_top_k if specified, otherwise fall back to top_k
query, top_k=query_param.top_k, ids=query_param.ids search_top_k = query_param.chunk_top_k or query_param.top_k
)
results = await chunks_vdb.query(query, top_k=search_top_k, ids=query_param.ids)
if not results: if not results:
return [], [], [] return []
valid_chunks = [] valid_chunks = []
for result in results: for result in results:
if "content" in result: if "content" in result:
# Directly use content from chunks_vdb.query result chunk_with_metadata = {
chunk_with_time = {
"content": result["content"], "content": result["content"],
"created_at": result.get("created_at", None), "created_at": result.get("created_at", None),
"file_path": result.get("file_path", "unknown_source"), "file_path": result.get("file_path", "unknown_source"),
"source_type": "vector", # Mark the source type
} }
valid_chunks.append(chunk_with_time) valid_chunks.append(chunk_with_metadata)
if not valid_chunks:
return [], [], []
maybe_trun_chunks = truncate_list_by_token_size(
valid_chunks,
key=lambda x: x["content"],
max_token_size=query_param.max_token_for_text_unit,
tokenizer=tokenizer,
)
logger.debug(
f"Truncate chunks from {len(valid_chunks)} to {len(maybe_trun_chunks)} (max tokens:{query_param.max_token_for_text_unit})"
)
logger.info( logger.info(
f"Query chunks: {len(maybe_trun_chunks)} chunks, top_k: {query_param.top_k}" f"Naive query: {len(valid_chunks)} chunks (chunk_top_k: {search_top_k})"
) )
return valid_chunks
if not maybe_trun_chunks:
return [], [], []
# Create empty entities and relations contexts
entities_context = []
relations_context = []
# Create text_units_context directly as a list of dictionaries
text_units_context = []
for i, chunk in enumerate(maybe_trun_chunks):
text_units_context.append(
{
"id": i + 1,
"content": chunk["content"],
"file_path": chunk["file_path"],
}
)
return entities_context, relations_context, text_units_context
except Exception as e: except Exception as e:
logger.error(f"Error in _get_vector_context: {e}") logger.error(f"Error in _get_vector_context: {e}")
return [], [], [] return []
async def _build_query_context( async def _build_query_context(
query: str,
ll_keywords: str, ll_keywords: str,
hl_keywords: str, hl_keywords: str,
knowledge_graph_inst: BaseGraphStorage, knowledge_graph_inst: BaseGraphStorage,
@ -1831,27 +1800,36 @@ async def _build_query_context(
relationships_vdb: BaseVectorStorage, relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage, text_chunks_db: BaseKVStorage,
query_param: QueryParam, query_param: QueryParam,
chunks_vdb: BaseVectorStorage = None, # Add chunks_vdb parameter for mix mode chunks_vdb: BaseVectorStorage = None,
): ):
logger.info(f"Process {os.getpid()} building query context...") logger.info(f"Process {os.getpid()} building query context...")
# Handle local and global modes as before # Collect all chunks from different sources
all_chunks = []
entities_context = []
relations_context = []
# Handle local and global modes
if query_param.mode == "local": if query_param.mode == "local":
entities_context, relations_context, text_units_context = await _get_node_data( entities_context, relations_context, entity_chunks = await _get_node_data(
ll_keywords, ll_keywords,
knowledge_graph_inst, knowledge_graph_inst,
entities_vdb, entities_vdb,
text_chunks_db, text_chunks_db,
query_param, query_param,
) )
all_chunks.extend(entity_chunks)
elif query_param.mode == "global": elif query_param.mode == "global":
entities_context, relations_context, text_units_context = await _get_edge_data( entities_context, relations_context, relationship_chunks = await _get_edge_data(
hl_keywords, hl_keywords,
knowledge_graph_inst, knowledge_graph_inst,
relationships_vdb, relationships_vdb,
text_chunks_db, text_chunks_db,
query_param, query_param,
) )
all_chunks.extend(relationship_chunks)
else: # hybrid or mix mode else: # hybrid or mix mode
ll_data = await _get_node_data( ll_data = await _get_node_data(
ll_keywords, ll_keywords,
@ -1868,61 +1846,58 @@ async def _build_query_context(
query_param, query_param,
) )
( (ll_entities_context, ll_relations_context, ll_chunks) = ll_data
ll_entities_context, (hl_entities_context, hl_relations_context, hl_chunks) = hl_data
ll_relations_context,
ll_text_units_context,
) = ll_data
( # Collect chunks from entity and relationship sources
hl_entities_context, all_chunks.extend(ll_chunks)
hl_relations_context, all_chunks.extend(hl_chunks)
hl_text_units_context,
) = hl_data
# Initialize vector data with empty lists # Get vector chunks if in mix mode
vector_entities_context, vector_relations_context, vector_text_units_context = ( if query_param.mode == "mix" and chunks_vdb:
[], vector_chunks = await _get_vector_context(
[], query,
[],
)
# Only get vector data if in mix mode
if query_param.mode == "mix" and hasattr(query_param, "original_query"):
# Get tokenizer from text_chunks_db
tokenizer = text_chunks_db.global_config.get("tokenizer")
# Get vector context in triple format
vector_data = await _get_vector_context(
query_param.original_query, # We need to pass the original query
chunks_vdb, chunks_vdb,
query_param, query_param,
tokenizer,
) )
all_chunks.extend(vector_chunks)
# If vector_data is not None, unpack it # Combine entities and relations contexts
if vector_data is not None:
(
vector_entities_context,
vector_relations_context,
vector_text_units_context,
) = vector_data
# Combine and deduplicate the entities, relationships, and sources
entities_context = process_combine_contexts( entities_context = process_combine_contexts(
hl_entities_context, ll_entities_context, vector_entities_context hl_entities_context, ll_entities_context
) )
relations_context = process_combine_contexts( relations_context = process_combine_contexts(
hl_relations_context, ll_relations_context, vector_relations_context hl_relations_context, ll_relations_context
) )
text_units_context = process_combine_contexts(
hl_text_units_context, ll_text_units_context, vector_text_units_context # Process all chunks uniformly: deduplication, reranking, and token truncation
processed_chunks = await process_chunks_unified(
query=query,
chunks=all_chunks,
query_param=query_param,
global_config=text_chunks_db.global_config,
source_type="mixed",
)
# Build final text_units_context from processed chunks
text_units_context = []
for i, chunk in enumerate(processed_chunks):
text_units_context.append(
{
"id": i + 1,
"content": chunk["content"],
"file_path": chunk.get("file_path", "unknown_source"),
}
) )
logger.info(
f"Final context: {len(entities_context)} entities, {len(relations_context)} relations, {len(text_units_context)} chunks"
)
# not necessary to use LLM to generate a response # not necessary to use LLM to generate a response
if not entities_context and not relations_context: if not entities_context and not relations_context:
return None return None
# 转换为 JSON 字符串
entities_str = json.dumps(entities_context, ensure_ascii=False) entities_str = json.dumps(entities_context, ensure_ascii=False)
relations_str = json.dumps(relations_context, ensure_ascii=False) relations_str = json.dumps(relations_context, ensure_ascii=False)
text_units_str = json.dumps(text_units_context, ensure_ascii=False) text_units_str = json.dumps(text_units_context, ensure_ascii=False)
@ -2069,16 +2044,7 @@ async def _get_node_data(
} }
) )
text_units_context = [] return entities_context, relations_context, use_text_units
for i, t in enumerate(use_text_units):
text_units_context.append(
{
"id": i + 1,
"content": t["content"],
"file_path": t.get("file_path", "unknown_source"),
}
)
return entities_context, relations_context, text_units_context
async def _find_most_related_text_unit_from_entities( async def _find_most_related_text_unit_from_entities(
@ -2167,23 +2133,21 @@ async def _find_most_related_text_unit_from_entities(
logger.warning("No valid text units found") logger.warning("No valid text units found")
return [] return []
tokenizer: Tokenizer = text_chunks_db.global_config.get("tokenizer") # Sort by relation counts and order, but don't truncate
all_text_units = sorted( all_text_units = sorted(
all_text_units, key=lambda x: (x["order"], -x["relation_counts"]) all_text_units, key=lambda x: (x["order"], -x["relation_counts"])
) )
all_text_units = truncate_list_by_token_size(
all_text_units,
key=lambda x: x["data"]["content"],
max_token_size=query_param.max_token_for_text_unit,
tokenizer=tokenizer,
)
logger.debug( logger.debug(f"Found {len(all_text_units)} entity-related chunks")
f"Truncate chunks from {len(all_text_units_lookup)} to {len(all_text_units)} (max tokens:{query_param.max_token_for_text_unit})"
)
all_text_units = [t["data"] for t in all_text_units] # Add source type marking and return chunk data
return all_text_units result_chunks = []
for t in all_text_units:
chunk_data = t["data"].copy()
chunk_data["source_type"] = "entity"
result_chunks.append(chunk_data)
return result_chunks
async def _find_most_related_edges_from_entities( async def _find_most_related_edges_from_entities(
@ -2485,21 +2449,16 @@ async def _find_related_text_unit_from_relationships(
logger.warning("No valid text chunks after filtering") logger.warning("No valid text chunks after filtering")
return [] return []
tokenizer: Tokenizer = text_chunks_db.global_config.get("tokenizer") logger.debug(f"Found {len(valid_text_units)} relationship-related chunks")
truncated_text_units = truncate_list_by_token_size(
valid_text_units,
key=lambda x: x["data"]["content"],
max_token_size=query_param.max_token_for_text_unit,
tokenizer=tokenizer,
)
logger.debug( # Add source type marking and return chunk data
f"Truncate chunks from {len(valid_text_units)} to {len(truncated_text_units)} (max tokens:{query_param.max_token_for_text_unit})" result_chunks = []
) for t in valid_text_units:
chunk_data = t["data"].copy()
chunk_data["source_type"] = "relationship"
result_chunks.append(chunk_data)
all_text_units: list[TextChunkSchema] = [t["data"] for t in truncated_text_units] return result_chunks
return all_text_units
async def naive_query( async def naive_query(
@ -2527,12 +2486,32 @@ async def naive_query(
tokenizer: Tokenizer = global_config["tokenizer"] tokenizer: Tokenizer = global_config["tokenizer"]
_, _, text_units_context = await _get_vector_context( chunks = await _get_vector_context(query, chunks_vdb, query_param)
query, chunks_vdb, query_param, tokenizer
if chunks is None or len(chunks) == 0:
return PROMPTS["fail_response"]
# Process chunks using unified processing
processed_chunks = await process_chunks_unified(
query=query,
chunks=chunks,
query_param=query_param,
global_config=global_config,
source_type="vector",
) )
if text_units_context is None or len(text_units_context) == 0: logger.info(f"Final context: {len(processed_chunks)} chunks")
return PROMPTS["fail_response"]
# Build text_units_context from processed chunks
text_units_context = []
for i, chunk in enumerate(processed_chunks):
text_units_context.append(
{
"id": i + 1,
"content": chunk["content"],
"file_path": chunk.get("file_path", "unknown_source"),
}
)
text_units_str = json.dumps(text_units_context, ensure_ascii=False) text_units_str = json.dumps(text_units_context, ensure_ascii=False)
if query_param.only_need_context: if query_param.only_need_context:
@ -2658,6 +2637,7 @@ async def kg_query_with_keywords(
hl_keywords_str = ", ".join(hl_keywords) if hl_keywords else "" hl_keywords_str = ", ".join(hl_keywords) if hl_keywords else ""
context = await _build_query_context( context = await _build_query_context(
query,
ll_keywords_str, ll_keywords_str,
hl_keywords_str, hl_keywords_str,
knowledge_graph_inst, knowledge_graph_inst,
@ -2780,8 +2760,6 @@ async def query_with_keywords(
f"{prompt}\n\n### Keywords\n\n{keywords_str}\n\n### Query\n\n{query}" f"{prompt}\n\n### Keywords\n\n{keywords_str}\n\n### Query\n\n{query}"
) )
param.original_query = query
# Use appropriate query method based on mode # Use appropriate query method based on mode
if param.mode in ["local", "global", "hybrid", "mix"]: if param.mode in ["local", "global", "hybrid", "mix"]:
return await kg_query_with_keywords( return await kg_query_with_keywords(
@ -2808,3 +2786,131 @@ async def query_with_keywords(
) )
else: else:
raise ValueError(f"Unknown mode {param.mode}") raise ValueError(f"Unknown mode {param.mode}")
async def apply_rerank_if_enabled(
query: str,
retrieved_docs: list[dict],
global_config: dict,
top_k: int = None,
) -> list[dict]:
"""
Apply reranking to retrieved documents if rerank is enabled.
Args:
query: The search query
retrieved_docs: List of retrieved documents
global_config: Global configuration containing rerank settings
top_k: Number of top documents to return after reranking
Returns:
Reranked documents if rerank is enabled, otherwise original documents
"""
if not global_config.get("enable_rerank", False) or not retrieved_docs:
return retrieved_docs
rerank_func = global_config.get("rerank_model_func")
if not rerank_func:
logger.debug(
"Rerank is enabled but no rerank function provided, skipping rerank"
)
return retrieved_docs
try:
logger.debug(
f"Applying rerank to {len(retrieved_docs)} documents, returning top {top_k}"
)
# Apply reranking - let rerank_model_func handle top_k internally
reranked_docs = await rerank_func(
query=query,
documents=retrieved_docs,
top_k=top_k,
)
if reranked_docs and len(reranked_docs) > 0:
if len(reranked_docs) > top_k:
reranked_docs = reranked_docs[:top_k]
logger.info(
f"Successfully reranked {len(retrieved_docs)} documents to {len(reranked_docs)}"
)
return reranked_docs
else:
logger.warning("Rerank returned empty results, using original documents")
return retrieved_docs
except Exception as e:
logger.error(f"Error during reranking: {e}, using original documents")
return retrieved_docs
async def process_chunks_unified(
query: str,
chunks: list[dict],
query_param: QueryParam,
global_config: dict,
source_type: str = "mixed",
) -> list[dict]:
"""
Unified processing for text chunks: deduplication, chunk_top_k limiting, reranking, and token truncation.
Args:
query: Search query for reranking
chunks: List of text chunks to process
query_param: Query parameters containing configuration
global_config: Global configuration dictionary
source_type: Source type for logging ("vector", "entity", "relationship", "mixed")
Returns:
Processed and filtered list of text chunks
"""
if not chunks:
return []
# 1. Deduplication based on content
seen_content = set()
unique_chunks = []
for chunk in chunks:
content = chunk.get("content", "")
if content and content not in seen_content:
seen_content.add(content)
unique_chunks.append(chunk)
logger.debug(
f"Deduplication: {len(unique_chunks)} chunks (original: {len(chunks)})"
)
# 2. Apply reranking if enabled and query is provided
if global_config.get("enable_rerank", False) and query and unique_chunks:
rerank_top_k = query_param.chunk_rerank_top_k or len(unique_chunks)
unique_chunks = await apply_rerank_if_enabled(
query=query,
retrieved_docs=unique_chunks,
global_config=global_config,
top_k=rerank_top_k,
)
logger.debug(f"Rerank: {len(unique_chunks)} chunks (source: {source_type})")
# 3. Apply chunk_top_k limiting if specified
if query_param.chunk_top_k is not None and query_param.chunk_top_k > 0:
if len(unique_chunks) > query_param.chunk_top_k:
unique_chunks = unique_chunks[: query_param.chunk_top_k]
logger.debug(
f"Chunk top-k limiting: kept {len(unique_chunks)} chunks (chunk_top_k={query_param.chunk_top_k})"
)
# 4. Token-based final truncation
tokenizer = global_config.get("tokenizer")
if tokenizer and unique_chunks:
original_count = len(unique_chunks)
unique_chunks = truncate_list_by_token_size(
unique_chunks,
key=lambda x: x.get("content", ""),
max_token_size=query_param.max_token_for_text_unit,
tokenizer=tokenizer,
)
logger.debug(
f"Token truncation: {len(unique_chunks)} chunks from {original_count} "
f"(max tokens: {query_param.max_token_for_text_unit}, source: {source_type})"
)
return unique_chunks

321
lightrag/rerank.py Normal file
View file

@ -0,0 +1,321 @@
from __future__ import annotations
import os
import aiohttp
from typing import Callable, Any, List, Dict, Optional
from pydantic import BaseModel, Field
from .utils import logger
class RerankModel(BaseModel):
"""
Pydantic model class for defining a custom rerank model.
This class provides a convenient wrapper for rerank functions, allowing you to
encapsulate all rerank configurations (API keys, model settings, etc.) in one place.
Attributes:
rerank_func (Callable[[Any], List[Dict]]): A callable function that reranks documents.
The function should take query and documents as input and return reranked results.
kwargs (Dict[str, Any]): A dictionary that contains the arguments to pass to the callable function.
This should include all necessary configurations such as model name, API key, base_url, etc.
Example usage:
Rerank model example with Jina:
```python
rerank_model = RerankModel(
rerank_func=jina_rerank,
kwargs={
"model": "BAAI/bge-reranker-v2-m3",
"api_key": "your_api_key_here",
"base_url": "https://api.jina.ai/v1/rerank"
}
)
# Use in LightRAG
rag = LightRAG(
enable_rerank=True,
rerank_model_func=rerank_model.rerank,
# ... other configurations
)
```
Or define a custom function directly:
```python
async def my_rerank_func(query: str, documents: list, top_k: int = None, **kwargs):
return await jina_rerank(
query=query,
documents=documents,
model="BAAI/bge-reranker-v2-m3",
api_key="your_api_key_here",
top_k=top_k or 10,
**kwargs
)
rag = LightRAG(
enable_rerank=True,
rerank_model_func=my_rerank_func,
# ... other configurations
)
```
"""
rerank_func: Callable[[Any], List[Dict]]
kwargs: Dict[str, Any] = Field(default_factory=dict)
async def rerank(
self,
query: str,
documents: List[Dict[str, Any]],
top_k: Optional[int] = None,
**extra_kwargs,
) -> List[Dict[str, Any]]:
"""Rerank documents using the configured model function."""
# Merge extra kwargs with model kwargs
kwargs = {**self.kwargs, **extra_kwargs}
return await self.rerank_func(
query=query, documents=documents, top_k=top_k, **kwargs
)
class MultiRerankModel(BaseModel):
"""Multiple rerank models for different modes/scenarios."""
# Primary rerank model (used if mode-specific models are not defined)
rerank_model: Optional[RerankModel] = None
# Mode-specific rerank models
entity_rerank_model: Optional[RerankModel] = None
relation_rerank_model: Optional[RerankModel] = None
chunk_rerank_model: Optional[RerankModel] = None
async def rerank(
self,
query: str,
documents: List[Dict[str, Any]],
mode: str = "default",
top_k: Optional[int] = None,
**kwargs,
) -> List[Dict[str, Any]]:
"""Rerank using the appropriate model based on mode."""
# Select model based on mode
if mode == "entity" and self.entity_rerank_model:
model = self.entity_rerank_model
elif mode == "relation" and self.relation_rerank_model:
model = self.relation_rerank_model
elif mode == "chunk" and self.chunk_rerank_model:
model = self.chunk_rerank_model
elif self.rerank_model:
model = self.rerank_model
else:
logger.warning(f"No rerank model available for mode: {mode}")
return documents
return await model.rerank(query, documents, top_k, **kwargs)
async def generic_rerank_api(
query: str,
documents: List[Dict[str, Any]],
model: str,
base_url: str,
api_key: str,
top_k: Optional[int] = None,
**kwargs,
) -> List[Dict[str, Any]]:
"""
Generic rerank function that works with Jina/Cohere compatible APIs.
Args:
query: The search query
documents: List of documents to rerank
model: Model identifier
base_url: API endpoint URL
api_key: API authentication key
top_k: Number of top results to return
**kwargs: Additional API-specific parameters
Returns:
List of reranked documents with relevance scores
"""
if not api_key:
logger.warning("No API key provided for rerank service")
return documents
if not documents:
return documents
# Prepare documents for reranking - handle both text and dict formats
prepared_docs = []
for doc in documents:
if isinstance(doc, dict):
# Use 'content' field if available, otherwise use 'text' or convert to string
text = doc.get("content") or doc.get("text") or str(doc)
else:
text = str(doc)
prepared_docs.append(text)
# Prepare request
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
data = {"model": model, "query": query, "documents": prepared_docs, **kwargs}
if top_k is not None:
data["top_k"] = min(top_k, len(prepared_docs))
try:
async with aiohttp.ClientSession() as session:
async with session.post(base_url, headers=headers, json=data) as response:
if response.status != 200:
error_text = await response.text()
logger.error(f"Rerank API error {response.status}: {error_text}")
return documents
result = await response.json()
# Extract reranked results
if "results" in result:
# Standard format: results contain index and relevance_score
reranked_docs = []
for item in result["results"]:
if "index" in item:
doc_idx = item["index"]
if 0 <= doc_idx < len(documents):
reranked_doc = documents[doc_idx].copy()
if "relevance_score" in item:
reranked_doc["rerank_score"] = item[
"relevance_score"
]
reranked_docs.append(reranked_doc)
return reranked_docs
else:
logger.warning("Unexpected rerank API response format")
return documents
except Exception as e:
logger.error(f"Error during reranking: {e}")
return documents
async def jina_rerank(
query: str,
documents: List[Dict[str, Any]],
model: str = "BAAI/bge-reranker-v2-m3",
top_k: Optional[int] = None,
base_url: str = "https://api.jina.ai/v1/rerank",
api_key: Optional[str] = None,
**kwargs,
) -> List[Dict[str, Any]]:
"""
Rerank documents using Jina AI API.
Args:
query: The search query
documents: List of documents to rerank
model: Jina rerank model name
top_k: Number of top results to return
base_url: Jina API endpoint
api_key: Jina API key
**kwargs: Additional parameters
Returns:
List of reranked documents with relevance scores
"""
if api_key is None:
api_key = os.getenv("JINA_API_KEY") or os.getenv("RERANK_API_KEY")
return await generic_rerank_api(
query=query,
documents=documents,
model=model,
base_url=base_url,
api_key=api_key,
top_k=top_k,
**kwargs,
)
async def cohere_rerank(
query: str,
documents: List[Dict[str, Any]],
model: str = "rerank-english-v2.0",
top_k: Optional[int] = None,
base_url: str = "https://api.cohere.ai/v1/rerank",
api_key: Optional[str] = None,
**kwargs,
) -> List[Dict[str, Any]]:
"""
Rerank documents using Cohere API.
Args:
query: The search query
documents: List of documents to rerank
model: Cohere rerank model name
top_k: Number of top results to return
base_url: Cohere API endpoint
api_key: Cohere API key
**kwargs: Additional parameters
Returns:
List of reranked documents with relevance scores
"""
if api_key is None:
api_key = os.getenv("COHERE_API_KEY") or os.getenv("RERANK_API_KEY")
return await generic_rerank_api(
query=query,
documents=documents,
model=model,
base_url=base_url,
api_key=api_key,
top_k=top_k,
**kwargs,
)
# Convenience function for custom API endpoints
async def custom_rerank(
query: str,
documents: List[Dict[str, Any]],
model: str,
base_url: str,
api_key: str,
top_k: Optional[int] = None,
**kwargs,
) -> List[Dict[str, Any]]:
"""
Rerank documents using a custom API endpoint.
This is useful for self-hosted or custom rerank services.
"""
return await generic_rerank_api(
query=query,
documents=documents,
model=model,
base_url=base_url,
api_key=api_key,
top_k=top_k,
**kwargs,
)
if __name__ == "__main__":
import asyncio
async def main():
# Example usage
docs = [
{"content": "The capital of France is Paris."},
{"content": "Tokyo is the capital of Japan."},
{"content": "London is the capital of England."},
]
query = "What is the capital of France?"
result = await jina_rerank(
query=query, documents=docs, top_k=2, api_key="your-api-key-here"
)
print(result)
asyncio.run(main())

View file

@ -111,7 +111,7 @@ const useSettingsStoreBase = create<SettingsState>()(
mode: 'global', mode: 'global',
response_type: 'Multiple Paragraphs', response_type: 'Multiple Paragraphs',
top_k: 10, top_k: 10,
max_token_for_text_unit: 4000, max_token_for_text_unit: 6000,
max_token_for_global_context: 4000, max_token_for_global_context: 4000,
max_token_for_local_context: 4000, max_token_for_local_context: 4000,
only_need_context: false, only_need_context: false,