Merge branch 'HKUDS:main' into main

This commit is contained in:
minh nhan nguyen 2025-07-10 11:25:51 +07:00 committed by GitHub
commit dd830e9b9d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 2206 additions and 165 deletions

View file

@ -294,6 +294,16 @@ class QueryParam:
top_k: int = int(os.getenv("TOP_K", "60")) top_k: int = int(os.getenv("TOP_K", "60"))
"""Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode.""" """Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode."""
chunk_top_k: int = int(os.getenv("CHUNK_TOP_K", "5"))
"""Number of text chunks to retrieve initially from vector search.
If None, defaults to top_k value.
"""
chunk_rerank_top_k: int = int(os.getenv("CHUNK_RERANK_TOP_K", "5"))
"""Number of text chunks to keep after reranking.
If None, keeps all chunks returned from initial retrieval.
"""
max_token_for_text_unit: int = int(os.getenv("MAX_TOKEN_TEXT_CHUNK", "4000")) max_token_for_text_unit: int = int(os.getenv("MAX_TOKEN_TEXT_CHUNK", "4000"))
"""Maximum number of tokens allowed for each retrieved text chunk.""" """Maximum number of tokens allowed for each retrieved text chunk."""
@ -849,6 +859,18 @@ rag = LightRAG(
</details> </details>
### LightRAG实例间的数据隔离
通过 workspace 参数可以不同实现不同LightRAG实例之间的存储数据隔离。LightRAG在初始化后workspace就已经确定之后修改workspace是无效的。下面是不同类型的存储实现工作空间的方式
- **对于本地基于文件的数据库,数据隔离通过工作空间子目录实现:** JsonKVStorage, JsonDocStatusStorage, NetworkXStorage, NanoVectorDBStorage, FaissVectorDBStorage。
- **对于将数据存储在集合collection中的数据库通过在集合名称前添加工作空间前缀来实现** RedisKVStorage, RedisDocStatusStorage, MilvusVectorDBStorage, QdrantVectorDBStorage, MongoKVStorage, MongoDocStatusStorage, MongoVectorDBStorage, MongoGraphStorage, PGGraphStorage。
- **对于关系型数据库,数据隔离通过向表中添加 `workspace` 字段进行数据的逻辑隔离:** PGKVStorage, PGVectorStorage, PGDocStatusStorage。
* **对于Neo4j图数据库通过label来实现数据的逻辑隔离**Neo4JStorage
为了保持对遗留数据的兼容在未配置工作空间时PostgreSQL的默认工作空间为`default`Neo4j的默认工作空间为`base`。对于所有的外部存储,系统都提供了专用的工作空间环境变量,用于覆盖公共的 `WORKSPACE`环境变量配置。这些适用于指定存储类型的工作空间环境变量为:`REDIS_WORKSPACE`, `MILVUS_WORKSPACE`, `QDRANT_WORKSPACE`, `MONGODB_WORKSPACE`, `POSTGRES_WORKSPACE`, `NEO4J_WORKSPACE`
## 编辑实体和关系 ## 编辑实体和关系
LightRAG现在支持全面的知识图谱管理功能允许您在知识图谱中创建、编辑和删除实体和关系。 LightRAG现在支持全面的知识图谱管理功能允许您在知识图谱中创建、编辑和删除实体和关系。

View file

@ -153,7 +153,7 @@ curl https://raw.githubusercontent.com/gusye1234/nano-graphrag/main/tests/mock_d
python examples/lightrag_openai_demo.py python examples/lightrag_openai_demo.py
``` ```
For a streaming response implementation example, please see `examples/lightrag_openai_compatible_demo.py`. Prior to execution, ensure you modify the sample codes LLM and embedding configurations accordingly. For a streaming response implementation example, please see `examples/lightrag_openai_compatible_demo.py`. Prior to execution, ensure you modify the sample code's LLM and embedding configurations accordingly.
**Note 1**: When running the demo program, please be aware that different test scripts may use different embedding models. If you switch to a different embedding model, you must clear the data directory (`./dickens`); otherwise, the program may encounter errors. If you wish to retain the LLM cache, you can preserve the `kv_store_llm_response_cache.json` file while clearing the data directory. **Note 1**: When running the demo program, please be aware that different test scripts may use different embedding models. If you switch to a different embedding model, you must clear the data directory (`./dickens`); otherwise, the program may encounter errors. If you wish to retain the LLM cache, you can preserve the `kv_store_llm_response_cache.json` file while clearing the data directory.
@ -239,6 +239,7 @@ A full list of LightRAG init parameters:
| **Parameter** | **Type** | **Explanation** | **Default** | | **Parameter** | **Type** | **Explanation** | **Default** |
|--------------|----------|-----------------|-------------| |--------------|----------|-----------------|-------------|
| **working_dir** | `str` | Directory where the cache will be stored | `lightrag_cache+timestamp` | | **working_dir** | `str` | Directory where the cache will be stored | `lightrag_cache+timestamp` |
| **workspace** | str | Workspace name for data isolation between different LightRAG Instances | |
| **kv_storage** | `str` | Storage type for documents and text chunks. Supported types: `JsonKVStorage`,`PGKVStorage`,`RedisKVStorage`,`MongoKVStorage` | `JsonKVStorage` | | **kv_storage** | `str` | Storage type for documents and text chunks. Supported types: `JsonKVStorage`,`PGKVStorage`,`RedisKVStorage`,`MongoKVStorage` | `JsonKVStorage` |
| **vector_storage** | `str` | Storage type for embedding vectors. Supported types: `NanoVectorDBStorage`,`PGVectorStorage`,`MilvusVectorDBStorage`,`ChromaVectorDBStorage`,`FaissVectorDBStorage`,`MongoVectorDBStorage`,`QdrantVectorDBStorage` | `NanoVectorDBStorage` | | **vector_storage** | `str` | Storage type for embedding vectors. Supported types: `NanoVectorDBStorage`,`PGVectorStorage`,`MilvusVectorDBStorage`,`ChromaVectorDBStorage`,`FaissVectorDBStorage`,`MongoVectorDBStorage`,`QdrantVectorDBStorage` | `NanoVectorDBStorage` |
| **graph_storage** | `str` | Storage type for graph edges and nodes. Supported types: `NetworkXStorage`,`Neo4JStorage`,`PGGraphStorage`,`AGEStorage` | `NetworkXStorage` | | **graph_storage** | `str` | Storage type for graph edges and nodes. Supported types: `NetworkXStorage`,`Neo4JStorage`,`PGGraphStorage`,`AGEStorage` | `NetworkXStorage` |
@ -300,6 +301,16 @@ class QueryParam:
top_k: int = int(os.getenv("TOP_K", "60")) top_k: int = int(os.getenv("TOP_K", "60"))
"""Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode.""" """Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode."""
chunk_top_k: int = int(os.getenv("CHUNK_TOP_K", "5"))
"""Number of text chunks to retrieve initially from vector search.
If None, defaults to top_k value.
"""
chunk_rerank_top_k: int = int(os.getenv("CHUNK_RERANK_TOP_K", "5"))
"""Number of text chunks to keep after reranking.
If None, keeps all chunks returned from initial retrieval.
"""
max_token_for_text_unit: int = int(os.getenv("MAX_TOKEN_TEXT_CHUNK", "4000")) max_token_for_text_unit: int = int(os.getenv("MAX_TOKEN_TEXT_CHUNK", "4000"))
"""Maximum number of tokens allowed for each retrieved text chunk.""" """Maximum number of tokens allowed for each retrieved text chunk."""
@ -860,6 +871,52 @@ rag = LightRAG(
</details> </details>
<details>
<summary> <b>Using Memgraph for Storage</b> </summary>
* Memgraph is a high-performance, in-memory graph database compatible with the Neo4j Bolt protocol.
* You can run Memgraph locally using Docker for easy testing:
* See: https://memgraph.com/download
```python
export MEMGRAPH_URI="bolt://localhost:7687"
# Setup logger for LightRAG
setup_logger("lightrag", level="INFO")
# When you launch the project, override the default KG: NetworkX
# by specifying kg="MemgraphStorage".
# Note: Default settings use NetworkX
# Initialize LightRAG with Memgraph implementation.
async def initialize_rag():
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=gpt_4o_mini_complete, # Use gpt_4o_mini_complete LLM model
graph_storage="MemgraphStorage", #<-----------override KG default
)
# Initialize database connections
await rag.initialize_storages()
# Initialize pipeline status for document processing
await initialize_pipeline_status()
return rag
```
</details>
### Data Isolation Between LightRAG Instances
The `workspace` parameter ensures data isolation between different LightRAG instances. Once initialized, the `workspace` is immutable and cannot be changed.Here is how workspaces are implemented for different types of storage:
- **For local file-based databases, data isolation is achieved through workspace subdirectories:** `JsonKVStorage`, `JsonDocStatusStorage`, `NetworkXStorage`, `NanoVectorDBStorage`, `FaissVectorDBStorage`.
- **For databases that store data in collections, it's done by adding a workspace prefix to the collection name:** `RedisKVStorage`, `RedisDocStatusStorage`, `MilvusVectorDBStorage`, `QdrantVectorDBStorage`, `MongoKVStorage`, `MongoDocStatusStorage`, `MongoVectorDBStorage`, `MongoGraphStorage`, `PGGraphStorage`.
- **For relational databases, data isolation is achieved by adding a `workspace` field to the tables for logical data separation:** `PGKVStorage`, `PGVectorStorage`, `PGDocStatusStorage`.
- **For the Neo4j graph database, logical data isolation is achieved through labels:** `Neo4JStorage`
To maintain compatibility with legacy data, the default workspace for PostgreSQL is `default` and for Neo4j is `base` when no workspace is configured. For all external storages, the system provides dedicated workspace environment variables to override the common `WORKSPACE` environment variable configuration. These storage-specific workspace environment variables are: `REDIS_WORKSPACE`, `MILVUS_WORKSPACE`, `QDRANT_WORKSPACE`, `MONGODB_WORKSPACE`, `POSTGRES_WORKSPACE`, `NEO4J_WORKSPACE`.
## Edit Entities and Relations ## Edit Entities and Relations
LightRAG now supports comprehensive knowledge graph management capabilities, allowing you to create, edit, and delete entities and relationships within your knowledge graph. LightRAG now supports comprehensive knowledge graph management capabilities, allowing you to create, edit, and delete entities and relationships within your knowledge graph.

View file

@ -21,3 +21,6 @@ password = your_password
database = your_database database = your_database
workspace = default # 可选,默认为default workspace = default # 可选,默认为default
max_connections = 12 max_connections = 12
[memgraph]
uri = bolt://localhost:7687

275
docs/rerank_integration.md Normal file
View file

@ -0,0 +1,275 @@
# Rerank Integration in LightRAG
This document explains how to configure and use the rerank functionality in LightRAG to improve retrieval quality.
## Overview
Reranking is an optional feature that improves the quality of retrieved documents by re-ordering them based on their relevance to the query. This is particularly useful when you want higher precision in document retrieval across all query modes (naive, local, global, hybrid, mix).
## Architecture
The rerank integration follows a simplified design pattern:
- **Single Function Configuration**: All rerank settings (model, API keys, top_k, etc.) are contained within the rerank function
- **Async Processing**: Non-blocking rerank operations
- **Error Handling**: Graceful fallback to original results
- **Optional Feature**: Can be enabled/disabled via configuration
- **Code Reuse**: Single generic implementation for Jina/Cohere compatible APIs
## Configuration
### Environment Variables
Set this variable in your `.env` file or environment:
```bash
# Enable/disable reranking
ENABLE_RERANK=True
```
### Programmatic Configuration
```python
from lightrag import LightRAG
from lightrag.rerank import custom_rerank, RerankModel
# Method 1: Using a custom rerank function with all settings included
async def my_rerank_func(query: str, documents: list, top_k: int = None, **kwargs):
return await custom_rerank(
query=query,
documents=documents,
model="BAAI/bge-reranker-v2-m3",
base_url="https://api.your-provider.com/v1/rerank",
api_key="your_api_key_here",
top_k=top_k or 10, # Handle top_k within the function
**kwargs
)
rag = LightRAG(
working_dir="./rag_storage",
llm_model_func=your_llm_func,
embedding_func=your_embedding_func,
enable_rerank=True,
rerank_model_func=my_rerank_func,
)
# Method 2: Using RerankModel wrapper
rerank_model = RerankModel(
rerank_func=custom_rerank,
kwargs={
"model": "BAAI/bge-reranker-v2-m3",
"base_url": "https://api.your-provider.com/v1/rerank",
"api_key": "your_api_key_here",
}
)
rag = LightRAG(
working_dir="./rag_storage",
llm_model_func=your_llm_func,
embedding_func=your_embedding_func,
enable_rerank=True,
rerank_model_func=rerank_model.rerank,
)
```
## Supported Providers
### 1. Custom/Generic API (Recommended)
For Jina/Cohere compatible APIs:
```python
from lightrag.rerank import custom_rerank
# Your custom API endpoint
result = await custom_rerank(
query="your query",
documents=documents,
model="BAAI/bge-reranker-v2-m3",
base_url="https://api.your-provider.com/v1/rerank",
api_key="your_api_key_here",
top_k=10
)
```
### 2. Jina AI
```python
from lightrag.rerank import jina_rerank
result = await jina_rerank(
query="your query",
documents=documents,
model="BAAI/bge-reranker-v2-m3",
api_key="your_jina_api_key",
top_k=10
)
```
### 3. Cohere
```python
from lightrag.rerank import cohere_rerank
result = await cohere_rerank(
query="your query",
documents=documents,
model="rerank-english-v2.0",
api_key="your_cohere_api_key",
top_k=10
)
```
## Integration Points
Reranking is automatically applied at these key retrieval stages:
1. **Naive Mode**: After vector similarity search in `_get_vector_context`
2. **Local Mode**: After entity retrieval in `_get_node_data`
3. **Global Mode**: After relationship retrieval in `_get_edge_data`
4. **Hybrid/Mix Modes**: Applied to all relevant components
## Configuration Parameters
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `enable_rerank` | bool | False | Enable/disable reranking |
| `rerank_model_func` | callable | None | Custom rerank function containing all configurations (model, API keys, top_k, etc.) |
## Example Usage
### Basic Usage
```python
import asyncio
from lightrag import LightRAG, QueryParam
from lightrag.llm.openai import gpt_4o_mini_complete, openai_embedding
from lightrag.kg.shared_storage import initialize_pipeline_status
from lightrag.rerank import jina_rerank
async def my_rerank_func(query: str, documents: list, top_k: int = None, **kwargs):
"""Custom rerank function with all settings included"""
return await jina_rerank(
query=query,
documents=documents,
model="BAAI/bge-reranker-v2-m3",
api_key="your_jina_api_key_here",
top_k=top_k or 10, # Default top_k if not provided
**kwargs
)
async def main():
# Initialize with rerank enabled
rag = LightRAG(
working_dir="./rag_storage",
llm_model_func=gpt_4o_mini_complete,
embedding_func=openai_embedding,
enable_rerank=True,
rerank_model_func=my_rerank_func,
)
await rag.initialize_storages()
await initialize_pipeline_status()
# Insert documents
await rag.ainsert([
"Document 1 content...",
"Document 2 content...",
])
# Query with rerank (automatically applied)
result = await rag.aquery(
"Your question here",
param=QueryParam(mode="hybrid", top_k=5) # This top_k is passed to rerank function
)
print(result)
asyncio.run(main())
```
### Direct Rerank Usage
```python
from lightrag.rerank import custom_rerank
async def test_rerank():
documents = [
{"content": "Text about topic A"},
{"content": "Text about topic B"},
{"content": "Text about topic C"},
]
reranked = await custom_rerank(
query="Tell me about topic A",
documents=documents,
model="BAAI/bge-reranker-v2-m3",
base_url="https://api.your-provider.com/v1/rerank",
api_key="your_api_key_here",
top_k=2
)
for doc in reranked:
print(f"Score: {doc.get('rerank_score')}, Content: {doc.get('content')}")
```
## Best Practices
1. **Self-Contained Functions**: Include all necessary configurations (API keys, models, top_k handling) within your rerank function
2. **Performance**: Use reranking selectively for better performance vs. quality tradeoff
3. **API Limits**: Monitor API usage and implement rate limiting within your rerank function
4. **Fallback**: Always handle rerank failures gracefully (returns original results)
5. **Top-k Handling**: Handle top_k parameter appropriately within your rerank function
6. **Cost Management**: Consider rerank API costs in your budget planning
## Troubleshooting
### Common Issues
1. **API Key Missing**: Ensure API keys are properly configured within your rerank function
2. **Network Issues**: Check API endpoints and network connectivity
3. **Model Errors**: Verify the rerank model name is supported by your API
4. **Document Format**: Ensure documents have `content` or `text` fields
### Debug Mode
Enable debug logging to see rerank operations:
```python
import logging
logging.getLogger("lightrag.rerank").setLevel(logging.DEBUG)
```
### Error Handling
The rerank integration includes automatic fallback:
```python
# If rerank fails, original documents are returned
# No exceptions are raised to the user
# Errors are logged for debugging
```
## API Compatibility
The generic rerank API expects this response format:
```json
{
"results": [
{
"index": 0,
"relevance_score": 0.95
},
{
"index": 2,
"relevance_score": 0.87
}
]
}
```
This is compatible with:
- Jina AI Rerank API
- Cohere Rerank API
- Custom APIs following the same format

View file

@ -42,13 +42,31 @@ OLLAMA_EMULATING_MODEL_TAG=latest
### Logfile location (defaults to current working directory) ### Logfile location (defaults to current working directory)
# LOG_DIR=/path/to/log/directory # LOG_DIR=/path/to/log/directory
### Settings for RAG query ### RAG Configuration
### Chunk size for document splitting, 500~1500 is recommended
# CHUNK_SIZE=1200
# CHUNK_OVERLAP_SIZE=100
# MAX_TOKEN_SUMMARY=500
### RAG Query Configuration
# HISTORY_TURNS=3 # HISTORY_TURNS=3
# COSINE_THRESHOLD=0.2 # MAX_TOKEN_TEXT_CHUNK=6000
# TOP_K=60
# MAX_TOKEN_TEXT_CHUNK=4000
# MAX_TOKEN_RELATION_DESC=4000 # MAX_TOKEN_RELATION_DESC=4000
# MAX_TOKEN_ENTITY_DESC=4000 # MAX_TOKEN_ENTITY_DESC=4000
# COSINE_THRESHOLD=0.2
### Number of entities or relations to retrieve from KG
# TOP_K=60
### Number of text chunks to retrieve initially from vector search
# CHUNK_TOP_K=5
### Rerank Configuration
# ENABLE_RERANK=False
### Number of text chunks to keep after reranking (should be <= CHUNK_TOP_K)
# CHUNK_RERANK_TOP_K=5
### Rerank model configuration (required when ENABLE_RERANK=True)
# RERANK_MODEL=BAAI/bge-reranker-v2-m3
# RERANK_BINDING_HOST=https://api.your-rerank-provider.com/v1/rerank
# RERANK_BINDING_API_KEY=your_rerank_api_key_here
### Entity and relation summarization configuration ### Entity and relation summarization configuration
### Language: English, Chinese, French, German ... ### Language: English, Chinese, French, German ...
@ -62,9 +80,6 @@ SUMMARY_LANGUAGE=English
### Number of parallel processing documents(Less than MAX_ASYNC/2 is recommended) ### Number of parallel processing documents(Less than MAX_ASYNC/2 is recommended)
# MAX_PARALLEL_INSERT=2 # MAX_PARALLEL_INSERT=2
### Chunk size for document splitting, 500~1500 is recommended
# CHUNK_SIZE=1200
# CHUNK_OVERLAP_SIZE=100
### LLM Configuration ### LLM Configuration
ENABLE_LLM_CACHE=true ENABLE_LLM_CACHE=true
@ -134,13 +149,14 @@ EMBEDDING_BINDING_HOST=http://localhost:11434
# LIGHTRAG_VECTOR_STORAGE=QdrantVectorDBStorage # LIGHTRAG_VECTOR_STORAGE=QdrantVectorDBStorage
### Graph Storage (Recommended for production deployment) ### Graph Storage (Recommended for production deployment)
# LIGHTRAG_GRAPH_STORAGE=Neo4JStorage # LIGHTRAG_GRAPH_STORAGE=Neo4JStorage
# LIGHTRAG_GRAPH_STORAGE=MemgraphStorage
#################################################################### ####################################################################
### Default workspace for all storage types ### Default workspace for all storage types
### For the purpose of isolation of data for each LightRAG instance ### For the purpose of isolation of data for each LightRAG instance
### Valid characters: a-z, A-Z, 0-9, and _ ### Valid characters: a-z, A-Z, 0-9, and _
#################################################################### ####################################################################
# WORKSPACE=doc— # WORKSPACE=space1
### PostgreSQL Configuration ### PostgreSQL Configuration
POSTGRES_HOST=localhost POSTGRES_HOST=localhost
@ -179,3 +195,10 @@ QDRANT_URL=http://localhost:6333
### Redis ### Redis
REDIS_URI=redis://localhost:6379 REDIS_URI=redis://localhost:6379
# REDIS_WORKSPACE=forced_workspace_name # REDIS_WORKSPACE=forced_workspace_name
### Memgraph Configuration
MEMGRAPH_URI=bolt://localhost:7687
MEMGRAPH_USERNAME=
MEMGRAPH_PASSWORD=
MEMGRAPH_DATABASE=memgraph
# MEMGRAPH_WORKSPACE=forced_workspace_name

233
examples/rerank_example.py Normal file
View file

@ -0,0 +1,233 @@
"""
LightRAG Rerank Integration Example
This example demonstrates how to use rerank functionality with LightRAG
to improve retrieval quality across different query modes.
Configuration Required:
1. Set your LLM API key and base URL in llm_model_func()
2. Set your embedding API key and base URL in embedding_func()
3. Set your rerank API key and base URL in the rerank configuration
4. Or use environment variables (.env file):
- ENABLE_RERANK=True
"""
import asyncio
import os
import numpy as np
from lightrag import LightRAG, QueryParam
from lightrag.rerank import custom_rerank, RerankModel
from lightrag.llm.openai import openai_complete_if_cache, openai_embed
from lightrag.utils import EmbeddingFunc, setup_logger
from lightrag.kg.shared_storage import initialize_pipeline_status
# Set up your working directory
WORKING_DIR = "./test_rerank"
setup_logger("test_rerank")
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
async def llm_model_func(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
return await openai_complete_if_cache(
"gpt-4o-mini",
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
api_key="your_llm_api_key_here",
base_url="https://api.your-llm-provider.com/v1",
**kwargs,
)
async def embedding_func(texts: list[str]) -> np.ndarray:
return await openai_embed(
texts,
model="text-embedding-3-large",
api_key="your_embedding_api_key_here",
base_url="https://api.your-embedding-provider.com/v1",
)
async def my_rerank_func(query: str, documents: list, top_k: int = None, **kwargs):
"""Custom rerank function with all settings included"""
return await custom_rerank(
query=query,
documents=documents,
model="BAAI/bge-reranker-v2-m3",
base_url="https://api.your-rerank-provider.com/v1/rerank",
api_key="your_rerank_api_key_here",
top_k=top_k or 10, # Default top_k if not provided
**kwargs,
)
async def create_rag_with_rerank():
"""Create LightRAG instance with rerank configuration"""
# Get embedding dimension
test_embedding = await embedding_func(["test"])
embedding_dim = test_embedding.shape[1]
print(f"Detected embedding dimension: {embedding_dim}")
# Method 1: Using custom rerank function
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=embedding_dim,
max_token_size=8192,
func=embedding_func,
),
# Simplified Rerank Configuration
enable_rerank=True,
rerank_model_func=my_rerank_func,
)
await rag.initialize_storages()
await initialize_pipeline_status()
return rag
async def create_rag_with_rerank_model():
"""Alternative: Create LightRAG instance using RerankModel wrapper"""
# Get embedding dimension
test_embedding = await embedding_func(["test"])
embedding_dim = test_embedding.shape[1]
print(f"Detected embedding dimension: {embedding_dim}")
# Method 2: Using RerankModel wrapper
rerank_model = RerankModel(
rerank_func=custom_rerank,
kwargs={
"model": "BAAI/bge-reranker-v2-m3",
"base_url": "https://api.your-rerank-provider.com/v1/rerank",
"api_key": "your_rerank_api_key_here",
},
)
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=embedding_dim,
max_token_size=8192,
func=embedding_func,
),
enable_rerank=True,
rerank_model_func=rerank_model.rerank,
)
await rag.initialize_storages()
await initialize_pipeline_status()
return rag
async def test_rerank_with_different_topk():
"""
Test rerank functionality with different top_k settings
"""
print("🚀 Setting up LightRAG with Rerank functionality...")
rag = await create_rag_with_rerank()
# Insert sample documents
sample_docs = [
"Reranking improves retrieval quality by re-ordering documents based on relevance.",
"LightRAG is a powerful retrieval-augmented generation system with multiple query modes.",
"Vector databases enable efficient similarity search in high-dimensional embedding spaces.",
"Natural language processing has evolved with large language models and transformers.",
"Machine learning algorithms can learn patterns from data without explicit programming.",
]
print("📄 Inserting sample documents...")
await rag.ainsert(sample_docs)
query = "How does reranking improve retrieval quality?"
print(f"\n🔍 Testing query: '{query}'")
print("=" * 80)
# Test different top_k values to show parameter priority
top_k_values = [2, 5, 10]
for top_k in top_k_values:
print(f"\n📊 Testing with QueryParam(top_k={top_k}):")
# Test naive mode with specific top_k
result = await rag.aquery(query, param=QueryParam(mode="naive", top_k=top_k))
print(f" Result length: {len(result)} characters")
print(f" Preview: {result[:100]}...")
async def test_direct_rerank():
"""Test rerank function directly"""
print("\n🔧 Direct Rerank API Test")
print("=" * 40)
documents = [
{"content": "Reranking significantly improves retrieval quality"},
{"content": "LightRAG supports advanced reranking capabilities"},
{"content": "Vector search finds semantically similar documents"},
{"content": "Natural language processing with modern transformers"},
{"content": "The quick brown fox jumps over the lazy dog"},
]
query = "rerank improve quality"
print(f"Query: '{query}'")
print(f"Documents: {len(documents)}")
try:
reranked_docs = await custom_rerank(
query=query,
documents=documents,
model="BAAI/bge-reranker-v2-m3",
base_url="https://api.your-rerank-provider.com/v1/rerank",
api_key="your_rerank_api_key_here",
top_k=3,
)
print("\n✅ Rerank Results:")
for i, doc in enumerate(reranked_docs):
score = doc.get("rerank_score", "N/A")
content = doc.get("content", "")[:60]
print(f" {i+1}. Score: {score:.4f} | {content}...")
except Exception as e:
print(f"❌ Rerank failed: {e}")
async def main():
"""Main example function"""
print("🎯 LightRAG Rerank Integration Example")
print("=" * 60)
try:
# Test rerank with different top_k values
await test_rerank_with_different_topk()
# Test direct rerank
await test_direct_rerank()
print("\n✅ Example completed successfully!")
print("\n💡 Key Points:")
print(" ✓ All rerank configurations are contained within rerank_model_func")
print(" ✓ Rerank improves document relevance ordering")
print(" ✓ Configure API keys within your rerank function")
print(" ✓ Monitor API usage and costs when using rerank services")
except Exception as e:
print(f"\n❌ Example failed: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
asyncio.run(main())

View file

@ -1,5 +1,5 @@
from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam
__version__ = "1.3.10" __version__ = "1.4.0"
__author__ = "Zirui Guo" __author__ = "Zirui Guo"
__url__ = "https://github.com/HKUDS/LightRAG" __url__ = "https://github.com/HKUDS/LightRAG"

View file

@ -165,6 +165,24 @@ def parse_args() -> argparse.Namespace:
default=get_env_value("TOP_K", 60, int), default=get_env_value("TOP_K", 60, int),
help="Number of most similar results to return (default: from env or 60)", help="Number of most similar results to return (default: from env or 60)",
) )
parser.add_argument(
"--chunk-top-k",
type=int,
default=get_env_value("CHUNK_TOP_K", 15, int),
help="Number of text chunks to retrieve initially from vector search (default: from env or 15)",
)
parser.add_argument(
"--chunk-rerank-top-k",
type=int,
default=get_env_value("CHUNK_RERANK_TOP_K", 5, int),
help="Number of text chunks to keep after reranking (default: from env or 5)",
)
parser.add_argument(
"--enable-rerank",
action="store_true",
default=get_env_value("ENABLE_RERANK", False, bool),
help="Enable rerank functionality (default: from env or False)",
)
parser.add_argument( parser.add_argument(
"--cosine-threshold", "--cosine-threshold",
type=float, type=float,
@ -295,6 +313,11 @@ def parse_args() -> argparse.Namespace:
args.guest_token_expire_hours = get_env_value("GUEST_TOKEN_EXPIRE_HOURS", 24, int) args.guest_token_expire_hours = get_env_value("GUEST_TOKEN_EXPIRE_HOURS", 24, int)
args.jwt_algorithm = get_env_value("JWT_ALGORITHM", "HS256") args.jwt_algorithm = get_env_value("JWT_ALGORITHM", "HS256")
# Rerank model configuration
args.rerank_model = get_env_value("RERANK_MODEL", "BAAI/bge-reranker-v2-m3")
args.rerank_binding_host = get_env_value("RERANK_BINDING_HOST", None)
args.rerank_binding_api_key = get_env_value("RERANK_BINDING_API_KEY", None)
ollama_server_infos.LIGHTRAG_MODEL = args.simulated_model_name ollama_server_infos.LIGHTRAG_MODEL = args.simulated_model_name
return args return args

View file

@ -291,6 +291,32 @@ def create_app(args):
), ),
) )
# Configure rerank function if enabled
rerank_model_func = None
if args.enable_rerank and args.rerank_binding_api_key and args.rerank_binding_host:
from lightrag.rerank import custom_rerank
async def server_rerank_func(
query: str, documents: list, top_k: int = None, **kwargs
):
"""Server rerank function with configuration from environment variables"""
return await custom_rerank(
query=query,
documents=documents,
model=args.rerank_model,
base_url=args.rerank_binding_host,
api_key=args.rerank_binding_api_key,
top_k=top_k,
**kwargs,
)
rerank_model_func = server_rerank_func
logger.info(f"Rerank enabled with model: {args.rerank_model}")
elif args.enable_rerank:
logger.warning(
"Rerank enabled but RERANK_BINDING_API_KEY or RERANK_BINDING_HOST not configured. Rerank will be disabled."
)
# Initialize RAG # Initialize RAG
if args.llm_binding in ["lollms", "ollama", "openai"]: if args.llm_binding in ["lollms", "ollama", "openai"]:
rag = LightRAG( rag = LightRAG(
@ -324,6 +350,8 @@ def create_app(args):
}, },
enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract, enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract,
enable_llm_cache=args.enable_llm_cache, enable_llm_cache=args.enable_llm_cache,
enable_rerank=args.enable_rerank,
rerank_model_func=rerank_model_func,
auto_manage_storages_states=False, auto_manage_storages_states=False,
max_parallel_insert=args.max_parallel_insert, max_parallel_insert=args.max_parallel_insert,
max_graph_nodes=args.max_graph_nodes, max_graph_nodes=args.max_graph_nodes,
@ -352,6 +380,8 @@ def create_app(args):
}, },
enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract, enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract,
enable_llm_cache=args.enable_llm_cache, enable_llm_cache=args.enable_llm_cache,
enable_rerank=args.enable_rerank,
rerank_model_func=rerank_model_func,
auto_manage_storages_states=False, auto_manage_storages_states=False,
max_parallel_insert=args.max_parallel_insert, max_parallel_insert=args.max_parallel_insert,
max_graph_nodes=args.max_graph_nodes, max_graph_nodes=args.max_graph_nodes,
@ -478,6 +508,12 @@ def create_app(args):
"enable_llm_cache": args.enable_llm_cache, "enable_llm_cache": args.enable_llm_cache,
"workspace": args.workspace, "workspace": args.workspace,
"max_graph_nodes": args.max_graph_nodes, "max_graph_nodes": args.max_graph_nodes,
# Rerank configuration
"enable_rerank": args.enable_rerank,
"rerank_model": args.rerank_model if args.enable_rerank else None,
"rerank_binding_host": args.rerank_binding_host
if args.enable_rerank
else None,
}, },
"auth_mode": auth_mode, "auth_mode": auth_mode,
"pipeline_busy": pipeline_status.get("busy", False), "pipeline_busy": pipeline_status.get("busy", False),

View file

@ -49,6 +49,18 @@ class QueryRequest(BaseModel):
description="Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode.", description="Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode.",
) )
chunk_top_k: Optional[int] = Field(
ge=1,
default=None,
description="Number of text chunks to retrieve initially from vector search.",
)
chunk_rerank_top_k: Optional[int] = Field(
ge=1,
default=None,
description="Number of text chunks to keep after reranking.",
)
max_token_for_text_unit: Optional[int] = Field( max_token_for_text_unit: Optional[int] = Field(
gt=1, gt=1,
default=None, default=None,

View file

@ -60,7 +60,17 @@ class QueryParam:
top_k: int = int(os.getenv("TOP_K", "60")) top_k: int = int(os.getenv("TOP_K", "60"))
"""Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode.""" """Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode."""
max_token_for_text_unit: int = int(os.getenv("MAX_TOKEN_TEXT_CHUNK", "4000")) chunk_top_k: int = int(os.getenv("CHUNK_TOP_K", "5"))
"""Number of text chunks to retrieve initially from vector search.
If None, defaults to top_k value.
"""
chunk_rerank_top_k: int = int(os.getenv("CHUNK_RERANK_TOP_K", "5"))
"""Number of text chunks to keep after reranking.
If None, keeps all chunks returned from initial retrieval.
"""
max_token_for_text_unit: int = int(os.getenv("MAX_TOKEN_TEXT_CHUNK", "6000"))
"""Maximum number of tokens allowed for each retrieved text chunk.""" """Maximum number of tokens allowed for each retrieved text chunk."""
max_token_for_global_context: int = int( max_token_for_global_context: int = int(
@ -280,21 +290,6 @@ class BaseKVStorage(StorageNameSpace, ABC):
False: if the cache drop failed, or the cache mode is not supported False: if the cache drop failed, or the cache mode is not supported
""" """
# async def drop_cache_by_chunk_ids(self, chunk_ids: list[str] | None = None) -> bool:
# """Delete specific cache records from storage by chunk IDs
# Importance notes for in-memory storage:
# 1. Changes will be persisted to disk during the next index_done_callback
# 2. update flags to notify other processes that data persistence is needed
# Args:
# chunk_ids (list[str]): List of chunk IDs to be dropped from storage
# Returns:
# True: if the cache drop successfully
# False: if the cache drop failed, or the operation is not supported
# """
@dataclass @dataclass
class BaseGraphStorage(StorageNameSpace, ABC): class BaseGraphStorage(StorageNameSpace, ABC):

View file

@ -15,6 +15,7 @@ STORAGE_IMPLEMENTATIONS = {
"Neo4JStorage", "Neo4JStorage",
"PGGraphStorage", "PGGraphStorage",
"MongoGraphStorage", "MongoGraphStorage",
"MemgraphStorage",
# "AGEStorage", # "AGEStorage",
# "TiDBGraphStorage", # "TiDBGraphStorage",
# "GremlinStorage", # "GremlinStorage",
@ -57,6 +58,7 @@ STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = {
"NetworkXStorage": [], "NetworkXStorage": [],
"Neo4JStorage": ["NEO4J_URI", "NEO4J_USERNAME", "NEO4J_PASSWORD"], "Neo4JStorage": ["NEO4J_URI", "NEO4J_USERNAME", "NEO4J_PASSWORD"],
"MongoGraphStorage": [], "MongoGraphStorage": [],
"MemgraphStorage": ["MEMGRAPH_URI"],
# "TiDBGraphStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"], # "TiDBGraphStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"],
"AGEStorage": [ "AGEStorage": [
"AGE_POSTGRES_DB", "AGE_POSTGRES_DB",
@ -111,6 +113,7 @@ STORAGES = {
"PGDocStatusStorage": ".kg.postgres_impl", "PGDocStatusStorage": ".kg.postgres_impl",
"FaissVectorDBStorage": ".kg.faiss_impl", "FaissVectorDBStorage": ".kg.faiss_impl",
"QdrantVectorDBStorage": ".kg.qdrant_impl", "QdrantVectorDBStorage": ".kg.qdrant_impl",
"MemgraphStorage": ".kg.memgraph_impl",
} }

View file

@ -0,0 +1,906 @@
import os
from dataclasses import dataclass
from typing import final
import configparser
from ..utils import logger
from ..base import BaseGraphStorage
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
from ..constants import GRAPH_FIELD_SEP
import pipmaster as pm
if not pm.is_installed("neo4j"):
pm.install("neo4j")
from neo4j import (
AsyncGraphDatabase,
AsyncManagedTransaction,
)
from dotenv import load_dotenv
# use the .env that is inside the current folder
load_dotenv(dotenv_path=".env", override=False)
MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000))
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
@final
@dataclass
class MemgraphStorage(BaseGraphStorage):
def __init__(self, namespace, global_config, embedding_func, workspace=None):
memgraph_workspace = os.environ.get("MEMGRAPH_WORKSPACE")
if memgraph_workspace and memgraph_workspace.strip():
workspace = memgraph_workspace
super().__init__(
namespace=namespace,
workspace=workspace or "",
global_config=global_config,
embedding_func=embedding_func,
)
self._driver = None
def _get_workspace_label(self) -> str:
"""Get workspace label, return 'base' for compatibility when workspace is empty"""
workspace = getattr(self, "workspace", None)
return workspace if workspace else "base"
async def initialize(self):
URI = os.environ.get(
"MEMGRAPH_URI",
config.get("memgraph", "uri", fallback="bolt://localhost:7687"),
)
USERNAME = os.environ.get(
"MEMGRAPH_USERNAME", config.get("memgraph", "username", fallback="")
)
PASSWORD = os.environ.get(
"MEMGRAPH_PASSWORD", config.get("memgraph", "password", fallback="")
)
DATABASE = os.environ.get(
"MEMGRAPH_DATABASE", config.get("memgraph", "database", fallback="memgraph")
)
self._driver = AsyncGraphDatabase.driver(
URI,
auth=(USERNAME, PASSWORD),
)
self._DATABASE = DATABASE
try:
async with self._driver.session(database=DATABASE) as session:
# Create index for base nodes on entity_id if it doesn't exist
try:
workspace_label = self._get_workspace_label()
await session.run(
f"""CREATE INDEX ON :{workspace_label}(entity_id)"""
)
logger.info(
f"Created index on :{workspace_label}(entity_id) in Memgraph."
)
except Exception as e:
# Index may already exist, which is not an error
logger.warning(
f"Index creation on :{workspace_label}(entity_id) may have failed or already exists: {e}"
)
await session.run("RETURN 1")
logger.info(f"Connected to Memgraph at {URI}")
except Exception as e:
logger.error(f"Failed to connect to Memgraph at {URI}: {e}")
raise
async def finalize(self):
if self._driver is not None:
await self._driver.close()
self._driver = None
async def __aexit__(self, exc_type, exc, tb):
await self.finalize()
async def index_done_callback(self):
# Memgraph handles persistence automatically
pass
async def has_node(self, node_id: str) -> bool:
"""
Check if a node exists in the graph.
Args:
node_id: The ID of the node to check.
Returns:
bool: True if the node exists, False otherwise.
Raises:
Exception: If there is an error checking the node existence.
"""
if self._driver is None:
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
try:
workspace_label = self._get_workspace_label()
query = f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN count(n) > 0 AS node_exists"
result = await session.run(query, entity_id=node_id)
single_result = await result.single()
await result.consume() # Ensure result is fully consumed
return (
single_result["node_exists"] if single_result is not None else False
)
except Exception as e:
logger.error(f"Error checking node existence for {node_id}: {str(e)}")
await result.consume() # Ensure the result is consumed even on error
raise
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
"""
Check if an edge exists between two nodes in the graph.
Args:
source_node_id: The ID of the source node.
target_node_id: The ID of the target node.
Returns:
bool: True if the edge exists, False otherwise.
Raises:
Exception: If there is an error checking the edge existence.
"""
if self._driver is None:
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
try:
workspace_label = self._get_workspace_label()
query = (
f"MATCH (a:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(b:`{workspace_label}` {{entity_id: $target_entity_id}}) "
"RETURN COUNT(r) > 0 AS edgeExists"
)
result = await session.run(
query,
source_entity_id=source_node_id,
target_entity_id=target_node_id,
) # type: ignore
single_result = await result.single()
await result.consume() # Ensure result is fully consumed
return (
single_result["edgeExists"] if single_result is not None else False
)
except Exception as e:
logger.error(
f"Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}"
)
await result.consume() # Ensure the result is consumed even on error
raise
async def get_node(self, node_id: str) -> dict[str, str] | None:
"""Get node by its label identifier, return only node properties
Args:
node_id: The node label to look up
Returns:
dict: Node properties if found
None: If node not found
Raises:
Exception: If there is an error executing the query
"""
if self._driver is None:
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
try:
workspace_label = self._get_workspace_label()
query = (
f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN n"
)
result = await session.run(query, entity_id=node_id)
try:
records = await result.fetch(
2
) # Get 2 records for duplication check
if len(records) > 1:
logger.warning(
f"Multiple nodes found with label '{node_id}'. Using first node."
)
if records:
node = records[0]["n"]
node_dict = dict(node)
# Remove workspace label from labels list if it exists
if "labels" in node_dict:
node_dict["labels"] = [
label
for label in node_dict["labels"]
if label != workspace_label
]
return node_dict
return None
finally:
await result.consume() # Ensure result is fully consumed
except Exception as e:
logger.error(f"Error getting node for {node_id}: {str(e)}")
raise
async def node_degree(self, node_id: str) -> int:
"""Get the degree (number of relationships) of a node with the given label.
If multiple nodes have the same label, returns the degree of the first node.
If no node is found, returns 0.
Args:
node_id: The label of the node
Returns:
int: The number of relationships the node has, or 0 if no node found
Raises:
Exception: If there is an error executing the query
"""
if self._driver is None:
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
try:
workspace_label = self._get_workspace_label()
query = f"""
MATCH (n:`{workspace_label}` {{entity_id: $entity_id}})
OPTIONAL MATCH (n)-[r]-()
RETURN COUNT(r) AS degree
"""
result = await session.run(query, entity_id=node_id)
try:
record = await result.single()
if not record:
logger.warning(f"No node found with label '{node_id}'")
return 0
degree = record["degree"]
return degree
finally:
await result.consume() # Ensure result is fully consumed
except Exception as e:
logger.error(f"Error getting node degree for {node_id}: {str(e)}")
raise
async def get_all_labels(self) -> list[str]:
"""
Get all existing node labels in the database
Returns:
["Person", "Company", ...] # Alphabetically sorted label list
Raises:
Exception: If there is an error executing the query
"""
if self._driver is None:
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
try:
workspace_label = self._get_workspace_label()
query = f"""
MATCH (n:`{workspace_label}`)
WHERE n.entity_id IS NOT NULL
RETURN DISTINCT n.entity_id AS label
ORDER BY label
"""
result = await session.run(query)
labels = []
async for record in result:
labels.append(record["label"])
await result.consume()
return labels
except Exception as e:
logger.error(f"Error getting all labels: {str(e)}")
await result.consume() # Ensure the result is consumed even on error
raise
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
"""Retrieves all edges (relationships) for a particular node identified by its label.
Args:
source_node_id: Label of the node to get edges for
Returns:
list[tuple[str, str]]: List of (source_label, target_label) tuples representing edges
None: If no edges found
Raises:
Exception: If there is an error executing the query
"""
if self._driver is None:
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
try:
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
try:
workspace_label = self._get_workspace_label()
query = f"""MATCH (n:`{workspace_label}` {{entity_id: $entity_id}})
OPTIONAL MATCH (n)-[r]-(connected:`{workspace_label}`)
WHERE connected.entity_id IS NOT NULL
RETURN n, r, connected"""
results = await session.run(query, entity_id=source_node_id)
edges = []
async for record in results:
source_node = record["n"]
connected_node = record["connected"]
# Skip if either node is None
if not source_node or not connected_node:
continue
source_label = (
source_node.get("entity_id")
if source_node.get("entity_id")
else None
)
target_label = (
connected_node.get("entity_id")
if connected_node.get("entity_id")
else None
)
if source_label and target_label:
edges.append((source_label, target_label))
await results.consume() # Ensure results are consumed
return edges
except Exception as e:
logger.error(
f"Error getting edges for node {source_node_id}: {str(e)}"
)
await results.consume() # Ensure results are consumed even on error
raise
except Exception as e:
logger.error(f"Error in get_node_edges for {source_node_id}: {str(e)}")
raise
async def get_edge(
self, source_node_id: str, target_node_id: str
) -> dict[str, str] | None:
"""Get edge properties between two nodes.
Args:
source_node_id: Label of the source node
target_node_id: Label of the target node
Returns:
dict: Edge properties if found, default properties if not found or on error
Raises:
Exception: If there is an error executing the query
"""
if self._driver is None:
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
try:
workspace_label = self._get_workspace_label()
query = f"""
MATCH (start:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(end:`{workspace_label}` {{entity_id: $target_entity_id}})
RETURN properties(r) as edge_properties
"""
result = await session.run(
query,
source_entity_id=source_node_id,
target_entity_id=target_node_id,
)
records = await result.fetch(2)
await result.consume()
if records:
edge_result = dict(records[0]["edge_properties"])
for key, default_value in {
"weight": 0.0,
"source_id": None,
"description": None,
"keywords": None,
}.items():
if key not in edge_result:
edge_result[key] = default_value
logger.warning(
f"Edge between {source_node_id} and {target_node_id} is missing property: {key}. Using default value: {default_value}"
)
return edge_result
return None
except Exception as e:
logger.error(
f"Error getting edge between {source_node_id} and {target_node_id}: {str(e)}"
)
await result.consume() # Ensure the result is consumed even on error
raise
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
"""
Upsert a node in the Neo4j database.
Args:
node_id: The unique identifier for the node (used as label)
node_data: Dictionary of node properties
"""
if self._driver is None:
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
properties = node_data
entity_type = properties["entity_type"]
if "entity_id" not in properties:
raise ValueError("Neo4j: node properties must contain an 'entity_id' field")
try:
async with self._driver.session(database=self._DATABASE) as session:
workspace_label = self._get_workspace_label()
async def execute_upsert(tx: AsyncManagedTransaction):
query = f"""
MERGE (n:`{workspace_label}` {{entity_id: $entity_id}})
SET n += $properties
SET n:`{entity_type}`
"""
result = await tx.run(
query, entity_id=node_id, properties=properties
)
await result.consume() # Ensure result is fully consumed
await session.execute_write(execute_upsert)
except Exception as e:
logger.error(f"Error during upsert: {str(e)}")
raise
async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
) -> None:
"""
Upsert an edge and its properties between two nodes identified by their labels.
Ensures both source and target nodes exist and are unique before creating the edge.
Uses entity_id property to uniquely identify nodes.
Args:
source_node_id (str): Label of the source node (used as identifier)
target_node_id (str): Label of the target node (used as identifier)
edge_data (dict): Dictionary of properties to set on the edge
Raises:
Exception: If there is an error executing the query
"""
if self._driver is None:
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
try:
edge_properties = edge_data
async with self._driver.session(database=self._DATABASE) as session:
async def execute_upsert(tx: AsyncManagedTransaction):
workspace_label = self._get_workspace_label()
query = f"""
MATCH (source:`{workspace_label}` {{entity_id: $source_entity_id}})
WITH source
MATCH (target:`{workspace_label}` {{entity_id: $target_entity_id}})
MERGE (source)-[r:DIRECTED]-(target)
SET r += $properties
RETURN r, source, target
"""
result = await tx.run(
query,
source_entity_id=source_node_id,
target_entity_id=target_node_id,
properties=edge_properties,
)
try:
await result.fetch(2)
finally:
await result.consume() # Ensure result is consumed
await session.execute_write(execute_upsert)
except Exception as e:
logger.error(f"Error during edge upsert: {str(e)}")
raise
async def delete_node(self, node_id: str) -> None:
"""Delete a node with the specified label
Args:
node_id: The label of the node to delete
Raises:
Exception: If there is an error executing the query
"""
if self._driver is None:
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
async def _do_delete(tx: AsyncManagedTransaction):
workspace_label = self._get_workspace_label()
query = f"""
MATCH (n:`{workspace_label}` {{entity_id: $entity_id}})
DETACH DELETE n
"""
result = await tx.run(query, entity_id=node_id)
logger.debug(f"Deleted node with label {node_id}")
await result.consume()
try:
async with self._driver.session(database=self._DATABASE) as session:
await session.execute_write(_do_delete)
except Exception as e:
logger.error(f"Error during node deletion: {str(e)}")
raise
async def remove_nodes(self, nodes: list[str]):
"""Delete multiple nodes
Args:
nodes: List of node labels to be deleted
"""
if self._driver is None:
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
for node in nodes:
await self.delete_node(node)
async def remove_edges(self, edges: list[tuple[str, str]]):
"""Delete multiple edges
Args:
edges: List of edges to be deleted, each edge is a (source, target) tuple
Raises:
Exception: If there is an error executing the query
"""
if self._driver is None:
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
for source, target in edges:
async def _do_delete_edge(tx: AsyncManagedTransaction):
workspace_label = self._get_workspace_label()
query = f"""
MATCH (source:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(target:`{workspace_label}` {{entity_id: $target_entity_id}})
DELETE r
"""
result = await tx.run(
query, source_entity_id=source, target_entity_id=target
)
logger.debug(f"Deleted edge from '{source}' to '{target}'")
await result.consume() # Ensure result is fully consumed
try:
async with self._driver.session(database=self._DATABASE) as session:
await session.execute_write(_do_delete_edge)
except Exception as e:
logger.error(f"Error during edge deletion: {str(e)}")
raise
async def drop(self) -> dict[str, str]:
"""Drop all data from the current workspace and clean up resources
This method will delete all nodes and relationships in the Memgraph database.
Returns:
dict[str, str]: Operation status and message
- On success: {"status": "success", "message": "data dropped"}
- On failure: {"status": "error", "message": "<error details>"}
Raises:
Exception: If there is an error executing the query
"""
if self._driver is None:
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
try:
async with self._driver.session(database=self._DATABASE) as session:
workspace_label = self._get_workspace_label()
query = f"MATCH (n:`{workspace_label}`) DETACH DELETE n"
result = await session.run(query)
await result.consume()
logger.info(
f"Dropped workspace {workspace_label} from Memgraph database {self._DATABASE}"
)
return {"status": "success", "message": "workspace data dropped"}
except Exception as e:
logger.error(
f"Error dropping workspace {workspace_label} from Memgraph database {self._DATABASE}: {e}"
)
return {"status": "error", "message": str(e)}
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
"""Get the total degree (sum of relationships) of two nodes.
Args:
src_id: Label of the source node
tgt_id: Label of the target node
Returns:
int: Sum of the degrees of both nodes
"""
if self._driver is None:
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
src_degree = await self.node_degree(src_id)
trg_degree = await self.node_degree(tgt_id)
# Convert None to 0 for addition
src_degree = 0 if src_degree is None else src_degree
trg_degree = 0 if trg_degree is None else trg_degree
degrees = int(src_degree) + int(trg_degree)
return degrees
async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
"""Get all nodes that are associated with the given chunk_ids.
Args:
chunk_ids: List of chunk IDs to find associated nodes for
Returns:
list[dict]: A list of nodes, where each node is a dictionary of its properties.
An empty list if no matching nodes are found.
"""
if self._driver is None:
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
workspace_label = self._get_workspace_label()
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
query = f"""
UNWIND $chunk_ids AS chunk_id
MATCH (n:`{workspace_label}`)
WHERE n.source_id IS NOT NULL AND chunk_id IN split(n.source_id, $sep)
RETURN DISTINCT n
"""
result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP)
nodes = []
async for record in result:
node = record["n"]
node_dict = dict(node)
node_dict["id"] = node_dict.get("entity_id")
nodes.append(node_dict)
await result.consume()
return nodes
async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
"""Get all edges that are associated with the given chunk_ids.
Args:
chunk_ids: List of chunk IDs to find associated edges for
Returns:
list[dict]: A list of edges, where each edge is a dictionary of its properties.
An empty list if no matching edges are found.
"""
if self._driver is None:
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
workspace_label = self._get_workspace_label()
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
query = f"""
UNWIND $chunk_ids AS chunk_id
MATCH (a:`{workspace_label}`)-[r]-(b:`{workspace_label}`)
WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, $sep)
WITH a, b, r, a.entity_id AS source_id, b.entity_id AS target_id
// Ensure we only return each unique edge once by ordering the source and target
WITH a, b, r,
CASE WHEN source_id <= target_id THEN source_id ELSE target_id END AS ordered_source,
CASE WHEN source_id <= target_id THEN target_id ELSE source_id END AS ordered_target
RETURN DISTINCT ordered_source AS source, ordered_target AS target, properties(r) AS properties
"""
result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP)
edges = []
async for record in result:
edge_properties = record["properties"]
edge_properties["source"] = record["source"]
edge_properties["target"] = record["target"]
edges.append(edge_properties)
await result.consume()
return edges
async def get_knowledge_graph(
self,
node_label: str,
max_depth: int = 3,
max_nodes: int = MAX_GRAPH_NODES,
) -> KnowledgeGraph:
"""
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
Args:
node_label: Label of the starting node, * means all nodes
max_depth: Maximum depth of the subgraph, Defaults to 3
max_nodes: Maxiumu nodes to return by BFS, Defaults to 1000
Returns:
KnowledgeGraph object containing nodes and edges, with an is_truncated flag
indicating whether the graph was truncated due to max_nodes limit
Raises:
Exception: If there is an error executing the query
"""
if self._driver is None:
raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first."
)
result = KnowledgeGraph()
seen_nodes = set()
seen_edges = set()
workspace_label = self._get_workspace_label()
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
try:
if node_label == "*":
# First check if database has any nodes
count_query = "MATCH (n) RETURN count(n) as total"
count_result = None
total_count = 0
try:
count_result = await session.run(count_query)
count_record = await count_result.single()
if count_record:
total_count = count_record["total"]
if total_count == 0:
logger.debug("No nodes found in database")
return result
if total_count > max_nodes:
result.is_truncated = True
logger.info(
f"Graph truncated: {total_count} nodes found, limited to {max_nodes}"
)
finally:
if count_result:
await count_result.consume()
# Run the main query to get nodes with highest degree
main_query = f"""
MATCH (n:`{workspace_label}`)
OPTIONAL MATCH (n)-[r]-()
WITH n, COALESCE(count(r), 0) AS degree
ORDER BY degree DESC
LIMIT $max_nodes
WITH collect(n) AS kept_nodes
MATCH (a)-[r]-(b)
WHERE a IN kept_nodes AND b IN kept_nodes
RETURN [node IN kept_nodes | {{node: node}}] AS node_info,
collect(DISTINCT r) AS relationships
"""
result_set = None
try:
result_set = await session.run(
main_query, {"max_nodes": max_nodes}
)
record = await result_set.single()
if not record:
logger.debug("No record returned from main query")
return result
finally:
if result_set:
await result_set.consume()
else:
bfs_query = f"""
MATCH (start:`{workspace_label}`)
WHERE start.entity_id = $entity_id
WITH start
CALL {{
WITH start
MATCH path = (start)-[*0..{max_depth}]-(node)
WITH nodes(path) AS path_nodes, relationships(path) AS path_rels
UNWIND path_nodes AS n
WITH collect(DISTINCT n) AS all_nodes, collect(DISTINCT path_rels) AS all_rel_lists
WITH all_nodes, reduce(r = [], x IN all_rel_lists | r + x) AS all_rels
RETURN all_nodes, all_rels
}}
WITH all_nodes AS nodes, all_rels AS relationships, size(all_nodes) AS total_nodes
WITH
CASE
WHEN total_nodes <= {max_nodes} THEN nodes
ELSE nodes[0..{max_nodes}]
END AS limited_nodes,
relationships,
total_nodes,
total_nodes > {max_nodes} AS is_truncated
RETURN
[node IN limited_nodes | {{node: node}}] AS node_info,
relationships,
total_nodes,
is_truncated
"""
result_set = None
try:
result_set = await session.run(
bfs_query,
{
"entity_id": node_label,
},
)
record = await result_set.single()
if not record:
logger.debug(f"No nodes found for entity_id: {node_label}")
return result
# Check if the query indicates truncation
if "is_truncated" in record and record["is_truncated"]:
result.is_truncated = True
logger.info(
f"Graph truncated: breadth-first search limited to {max_nodes} nodes"
)
finally:
if result_set:
await result_set.consume()
# Process the record if it exists
if record and record["node_info"]:
for node_info in record["node_info"]:
node = node_info["node"]
node_id = node.id
if node_id not in seen_nodes:
seen_nodes.add(node_id)
result.nodes.append(
KnowledgeGraphNode(
id=f"{node_id}",
labels=[node.get("entity_id")],
properties=dict(node),
)
)
for rel in record["relationships"]:
edge_id = rel.id
if edge_id not in seen_edges:
seen_edges.add(edge_id)
start = rel.start_node
end = rel.end_node
result.edges.append(
KnowledgeGraphEdge(
id=f"{edge_id}",
type=rel.type,
source=f"{start.id}",
target=f"{end.id}",
properties=dict(rel),
)
)
logger.info(
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
)
except Exception as e:
logger.error(f"Error getting knowledge graph: {str(e)}")
# Return empty but properly initialized KnowledgeGraph on error
return KnowledgeGraph()
return result

View file

@ -240,6 +240,17 @@ class LightRAG:
llm_model_kwargs: dict[str, Any] = field(default_factory=dict) llm_model_kwargs: dict[str, Any] = field(default_factory=dict)
"""Additional keyword arguments passed to the LLM model function.""" """Additional keyword arguments passed to the LLM model function."""
# Rerank Configuration
# ---
enable_rerank: bool = field(
default=bool(os.getenv("ENABLE_RERANK", "False").lower() == "true")
)
"""Enable reranking for improved retrieval quality. Defaults to False."""
rerank_model_func: Callable[..., object] | None = field(default=None)
"""Function for reranking retrieved documents. All rerank configurations (model name, API keys, top_k, etc.) should be included in this function. Optional."""
# Storage # Storage
# --- # ---
@ -447,6 +458,14 @@ class LightRAG:
) )
) )
# Init Rerank
if self.enable_rerank and self.rerank_model_func:
logger.info("Rerank model initialized for improved retrieval quality")
elif self.enable_rerank and not self.rerank_model_func:
logger.warning(
"Rerank is enabled but no rerank_model_func provided. Reranking will be skipped."
)
self._storages_status = StoragesStatus.CREATED self._storages_status = StoragesStatus.CREATED
if self.auto_manage_storages_states: if self.auto_manage_storages_states:
@ -900,9 +919,15 @@ class LightRAG:
# Get first document's file path and total count for job name # Get first document's file path and total count for job name
first_doc_id, first_doc = next(iter(to_process_docs.items())) first_doc_id, first_doc = next(iter(to_process_docs.items()))
first_doc_path = first_doc.file_path first_doc_path = first_doc.file_path
path_prefix = first_doc_path[:20] + (
"..." if len(first_doc_path) > 20 else "" # Handle cases where first_doc_path is None
) if first_doc_path:
path_prefix = first_doc_path[:20] + (
"..." if len(first_doc_path) > 20 else ""
)
else:
path_prefix = "unknown_source"
total_files = len(to_process_docs) total_files = len(to_process_docs)
job_name = f"{path_prefix}[{total_files} files]" job_name = f"{path_prefix}[{total_files} files]"
pipeline_status["job_name"] = job_name pipeline_status["job_name"] = job_name

View file

@ -1527,6 +1527,7 @@ async def kg_query(
# Build context # Build context
context = await _build_query_context( context = await _build_query_context(
query,
ll_keywords_str, ll_keywords_str,
hl_keywords_str, hl_keywords_str,
knowledge_graph_inst, knowledge_graph_inst,
@ -1746,84 +1747,52 @@ async def _get_vector_context(
query: str, query: str,
chunks_vdb: BaseVectorStorage, chunks_vdb: BaseVectorStorage,
query_param: QueryParam, query_param: QueryParam,
tokenizer: Tokenizer, ) -> list[dict]:
) -> tuple[list, list, list] | None:
""" """
Retrieve vector context from the vector database. Retrieve text chunks from the vector database without reranking or truncation.
This function performs vector search to find relevant text chunks for a query, This function performs vector search to find relevant text chunks for a query.
formats them with file path and creation time information. Reranking and truncation will be handled later in the unified processing.
Args: Args:
query: The query string to search for query: The query string to search for
chunks_vdb: Vector database containing document chunks chunks_vdb: Vector database containing document chunks
query_param: Query parameters including top_k and ids query_param: Query parameters including chunk_top_k and ids
tokenizer: Tokenizer for counting tokens
Returns: Returns:
Tuple (empty_entities, empty_relations, text_units) for combine_contexts, List of text chunks with metadata
compatible with _get_edge_data and _get_node_data format
""" """
try: try:
results = await chunks_vdb.query( # Use chunk_top_k if specified, otherwise fall back to top_k
query, top_k=query_param.top_k, ids=query_param.ids search_top_k = query_param.chunk_top_k or query_param.top_k
)
results = await chunks_vdb.query(query, top_k=search_top_k, ids=query_param.ids)
if not results: if not results:
return [], [], [] return []
valid_chunks = [] valid_chunks = []
for result in results: for result in results:
if "content" in result: if "content" in result:
# Directly use content from chunks_vdb.query result chunk_with_metadata = {
chunk_with_time = {
"content": result["content"], "content": result["content"],
"created_at": result.get("created_at", None), "created_at": result.get("created_at", None),
"file_path": result.get("file_path", "unknown_source"), "file_path": result.get("file_path", "unknown_source"),
"source_type": "vector", # Mark the source type
} }
valid_chunks.append(chunk_with_time) valid_chunks.append(chunk_with_metadata)
if not valid_chunks:
return [], [], []
maybe_trun_chunks = truncate_list_by_token_size(
valid_chunks,
key=lambda x: x["content"],
max_token_size=query_param.max_token_for_text_unit,
tokenizer=tokenizer,
)
logger.debug(
f"Truncate chunks from {len(valid_chunks)} to {len(maybe_trun_chunks)} (max tokens:{query_param.max_token_for_text_unit})"
)
logger.info( logger.info(
f"Query chunks: {len(maybe_trun_chunks)} chunks, top_k: {query_param.top_k}" f"Naive query: {len(valid_chunks)} chunks (chunk_top_k: {search_top_k})"
) )
return valid_chunks
if not maybe_trun_chunks:
return [], [], []
# Create empty entities and relations contexts
entities_context = []
relations_context = []
# Create text_units_context directly as a list of dictionaries
text_units_context = []
for i, chunk in enumerate(maybe_trun_chunks):
text_units_context.append(
{
"id": i + 1,
"content": chunk["content"],
"file_path": chunk["file_path"],
}
)
return entities_context, relations_context, text_units_context
except Exception as e: except Exception as e:
logger.error(f"Error in _get_vector_context: {e}") logger.error(f"Error in _get_vector_context: {e}")
return [], [], [] return []
async def _build_query_context( async def _build_query_context(
query: str,
ll_keywords: str, ll_keywords: str,
hl_keywords: str, hl_keywords: str,
knowledge_graph_inst: BaseGraphStorage, knowledge_graph_inst: BaseGraphStorage,
@ -1831,27 +1800,36 @@ async def _build_query_context(
relationships_vdb: BaseVectorStorage, relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage, text_chunks_db: BaseKVStorage,
query_param: QueryParam, query_param: QueryParam,
chunks_vdb: BaseVectorStorage = None, # Add chunks_vdb parameter for mix mode chunks_vdb: BaseVectorStorage = None,
): ):
logger.info(f"Process {os.getpid()} building query context...") logger.info(f"Process {os.getpid()} building query context...")
# Handle local and global modes as before # Collect all chunks from different sources
all_chunks = []
entities_context = []
relations_context = []
# Handle local and global modes
if query_param.mode == "local": if query_param.mode == "local":
entities_context, relations_context, text_units_context = await _get_node_data( entities_context, relations_context, entity_chunks = await _get_node_data(
ll_keywords, ll_keywords,
knowledge_graph_inst, knowledge_graph_inst,
entities_vdb, entities_vdb,
text_chunks_db, text_chunks_db,
query_param, query_param,
) )
all_chunks.extend(entity_chunks)
elif query_param.mode == "global": elif query_param.mode == "global":
entities_context, relations_context, text_units_context = await _get_edge_data( entities_context, relations_context, relationship_chunks = await _get_edge_data(
hl_keywords, hl_keywords,
knowledge_graph_inst, knowledge_graph_inst,
relationships_vdb, relationships_vdb,
text_chunks_db, text_chunks_db,
query_param, query_param,
) )
all_chunks.extend(relationship_chunks)
else: # hybrid or mix mode else: # hybrid or mix mode
ll_data = await _get_node_data( ll_data = await _get_node_data(
ll_keywords, ll_keywords,
@ -1868,61 +1846,58 @@ async def _build_query_context(
query_param, query_param,
) )
( (ll_entities_context, ll_relations_context, ll_chunks) = ll_data
ll_entities_context, (hl_entities_context, hl_relations_context, hl_chunks) = hl_data
ll_relations_context,
ll_text_units_context,
) = ll_data
( # Collect chunks from entity and relationship sources
hl_entities_context, all_chunks.extend(ll_chunks)
hl_relations_context, all_chunks.extend(hl_chunks)
hl_text_units_context,
) = hl_data
# Initialize vector data with empty lists # Get vector chunks if in mix mode
vector_entities_context, vector_relations_context, vector_text_units_context = ( if query_param.mode == "mix" and chunks_vdb:
[], vector_chunks = await _get_vector_context(
[], query,
[],
)
# Only get vector data if in mix mode
if query_param.mode == "mix" and hasattr(query_param, "original_query"):
# Get tokenizer from text_chunks_db
tokenizer = text_chunks_db.global_config.get("tokenizer")
# Get vector context in triple format
vector_data = await _get_vector_context(
query_param.original_query, # We need to pass the original query
chunks_vdb, chunks_vdb,
query_param, query_param,
tokenizer,
) )
all_chunks.extend(vector_chunks)
# If vector_data is not None, unpack it # Combine entities and relations contexts
if vector_data is not None:
(
vector_entities_context,
vector_relations_context,
vector_text_units_context,
) = vector_data
# Combine and deduplicate the entities, relationships, and sources
entities_context = process_combine_contexts( entities_context = process_combine_contexts(
hl_entities_context, ll_entities_context, vector_entities_context hl_entities_context, ll_entities_context
) )
relations_context = process_combine_contexts( relations_context = process_combine_contexts(
hl_relations_context, ll_relations_context, vector_relations_context hl_relations_context, ll_relations_context
) )
text_units_context = process_combine_contexts(
hl_text_units_context, ll_text_units_context, vector_text_units_context # Process all chunks uniformly: deduplication, reranking, and token truncation
processed_chunks = await process_chunks_unified(
query=query,
chunks=all_chunks,
query_param=query_param,
global_config=text_chunks_db.global_config,
source_type="mixed",
)
# Build final text_units_context from processed chunks
text_units_context = []
for i, chunk in enumerate(processed_chunks):
text_units_context.append(
{
"id": i + 1,
"content": chunk["content"],
"file_path": chunk.get("file_path", "unknown_source"),
}
) )
logger.info(
f"Final context: {len(entities_context)} entities, {len(relations_context)} relations, {len(text_units_context)} chunks"
)
# not necessary to use LLM to generate a response # not necessary to use LLM to generate a response
if not entities_context and not relations_context: if not entities_context and not relations_context:
return None return None
# 转换为 JSON 字符串
entities_str = json.dumps(entities_context, ensure_ascii=False) entities_str = json.dumps(entities_context, ensure_ascii=False)
relations_str = json.dumps(relations_context, ensure_ascii=False) relations_str = json.dumps(relations_context, ensure_ascii=False)
text_units_str = json.dumps(text_units_context, ensure_ascii=False) text_units_str = json.dumps(text_units_context, ensure_ascii=False)
@ -2069,16 +2044,7 @@ async def _get_node_data(
} }
) )
text_units_context = [] return entities_context, relations_context, use_text_units
for i, t in enumerate(use_text_units):
text_units_context.append(
{
"id": i + 1,
"content": t["content"],
"file_path": t.get("file_path", "unknown_source"),
}
)
return entities_context, relations_context, text_units_context
async def _find_most_related_text_unit_from_entities( async def _find_most_related_text_unit_from_entities(
@ -2167,23 +2133,21 @@ async def _find_most_related_text_unit_from_entities(
logger.warning("No valid text units found") logger.warning("No valid text units found")
return [] return []
tokenizer: Tokenizer = text_chunks_db.global_config.get("tokenizer") # Sort by relation counts and order, but don't truncate
all_text_units = sorted( all_text_units = sorted(
all_text_units, key=lambda x: (x["order"], -x["relation_counts"]) all_text_units, key=lambda x: (x["order"], -x["relation_counts"])
) )
all_text_units = truncate_list_by_token_size(
all_text_units,
key=lambda x: x["data"]["content"],
max_token_size=query_param.max_token_for_text_unit,
tokenizer=tokenizer,
)
logger.debug( logger.debug(f"Found {len(all_text_units)} entity-related chunks")
f"Truncate chunks from {len(all_text_units_lookup)} to {len(all_text_units)} (max tokens:{query_param.max_token_for_text_unit})"
)
all_text_units = [t["data"] for t in all_text_units] # Add source type marking and return chunk data
return all_text_units result_chunks = []
for t in all_text_units:
chunk_data = t["data"].copy()
chunk_data["source_type"] = "entity"
result_chunks.append(chunk_data)
return result_chunks
async def _find_most_related_edges_from_entities( async def _find_most_related_edges_from_entities(
@ -2485,21 +2449,16 @@ async def _find_related_text_unit_from_relationships(
logger.warning("No valid text chunks after filtering") logger.warning("No valid text chunks after filtering")
return [] return []
tokenizer: Tokenizer = text_chunks_db.global_config.get("tokenizer") logger.debug(f"Found {len(valid_text_units)} relationship-related chunks")
truncated_text_units = truncate_list_by_token_size(
valid_text_units,
key=lambda x: x["data"]["content"],
max_token_size=query_param.max_token_for_text_unit,
tokenizer=tokenizer,
)
logger.debug( # Add source type marking and return chunk data
f"Truncate chunks from {len(valid_text_units)} to {len(truncated_text_units)} (max tokens:{query_param.max_token_for_text_unit})" result_chunks = []
) for t in valid_text_units:
chunk_data = t["data"].copy()
chunk_data["source_type"] = "relationship"
result_chunks.append(chunk_data)
all_text_units: list[TextChunkSchema] = [t["data"] for t in truncated_text_units] return result_chunks
return all_text_units
async def naive_query( async def naive_query(
@ -2527,12 +2486,32 @@ async def naive_query(
tokenizer: Tokenizer = global_config["tokenizer"] tokenizer: Tokenizer = global_config["tokenizer"]
_, _, text_units_context = await _get_vector_context( chunks = await _get_vector_context(query, chunks_vdb, query_param)
query, chunks_vdb, query_param, tokenizer
if chunks is None or len(chunks) == 0:
return PROMPTS["fail_response"]
# Process chunks using unified processing
processed_chunks = await process_chunks_unified(
query=query,
chunks=chunks,
query_param=query_param,
global_config=global_config,
source_type="vector",
) )
if text_units_context is None or len(text_units_context) == 0: logger.info(f"Final context: {len(processed_chunks)} chunks")
return PROMPTS["fail_response"]
# Build text_units_context from processed chunks
text_units_context = []
for i, chunk in enumerate(processed_chunks):
text_units_context.append(
{
"id": i + 1,
"content": chunk["content"],
"file_path": chunk.get("file_path", "unknown_source"),
}
)
text_units_str = json.dumps(text_units_context, ensure_ascii=False) text_units_str = json.dumps(text_units_context, ensure_ascii=False)
if query_param.only_need_context: if query_param.only_need_context:
@ -2658,6 +2637,7 @@ async def kg_query_with_keywords(
hl_keywords_str = ", ".join(hl_keywords) if hl_keywords else "" hl_keywords_str = ", ".join(hl_keywords) if hl_keywords else ""
context = await _build_query_context( context = await _build_query_context(
query,
ll_keywords_str, ll_keywords_str,
hl_keywords_str, hl_keywords_str,
knowledge_graph_inst, knowledge_graph_inst,
@ -2780,8 +2760,6 @@ async def query_with_keywords(
f"{prompt}\n\n### Keywords\n\n{keywords_str}\n\n### Query\n\n{query}" f"{prompt}\n\n### Keywords\n\n{keywords_str}\n\n### Query\n\n{query}"
) )
param.original_query = query
# Use appropriate query method based on mode # Use appropriate query method based on mode
if param.mode in ["local", "global", "hybrid", "mix"]: if param.mode in ["local", "global", "hybrid", "mix"]:
return await kg_query_with_keywords( return await kg_query_with_keywords(
@ -2808,3 +2786,131 @@ async def query_with_keywords(
) )
else: else:
raise ValueError(f"Unknown mode {param.mode}") raise ValueError(f"Unknown mode {param.mode}")
async def apply_rerank_if_enabled(
query: str,
retrieved_docs: list[dict],
global_config: dict,
top_k: int = None,
) -> list[dict]:
"""
Apply reranking to retrieved documents if rerank is enabled.
Args:
query: The search query
retrieved_docs: List of retrieved documents
global_config: Global configuration containing rerank settings
top_k: Number of top documents to return after reranking
Returns:
Reranked documents if rerank is enabled, otherwise original documents
"""
if not global_config.get("enable_rerank", False) or not retrieved_docs:
return retrieved_docs
rerank_func = global_config.get("rerank_model_func")
if not rerank_func:
logger.debug(
"Rerank is enabled but no rerank function provided, skipping rerank"
)
return retrieved_docs
try:
logger.debug(
f"Applying rerank to {len(retrieved_docs)} documents, returning top {top_k}"
)
# Apply reranking - let rerank_model_func handle top_k internally
reranked_docs = await rerank_func(
query=query,
documents=retrieved_docs,
top_k=top_k,
)
if reranked_docs and len(reranked_docs) > 0:
if len(reranked_docs) > top_k:
reranked_docs = reranked_docs[:top_k]
logger.info(
f"Successfully reranked {len(retrieved_docs)} documents to {len(reranked_docs)}"
)
return reranked_docs
else:
logger.warning("Rerank returned empty results, using original documents")
return retrieved_docs
except Exception as e:
logger.error(f"Error during reranking: {e}, using original documents")
return retrieved_docs
async def process_chunks_unified(
query: str,
chunks: list[dict],
query_param: QueryParam,
global_config: dict,
source_type: str = "mixed",
) -> list[dict]:
"""
Unified processing for text chunks: deduplication, chunk_top_k limiting, reranking, and token truncation.
Args:
query: Search query for reranking
chunks: List of text chunks to process
query_param: Query parameters containing configuration
global_config: Global configuration dictionary
source_type: Source type for logging ("vector", "entity", "relationship", "mixed")
Returns:
Processed and filtered list of text chunks
"""
if not chunks:
return []
# 1. Deduplication based on content
seen_content = set()
unique_chunks = []
for chunk in chunks:
content = chunk.get("content", "")
if content and content not in seen_content:
seen_content.add(content)
unique_chunks.append(chunk)
logger.debug(
f"Deduplication: {len(unique_chunks)} chunks (original: {len(chunks)})"
)
# 2. Apply reranking if enabled and query is provided
if global_config.get("enable_rerank", False) and query and unique_chunks:
rerank_top_k = query_param.chunk_rerank_top_k or len(unique_chunks)
unique_chunks = await apply_rerank_if_enabled(
query=query,
retrieved_docs=unique_chunks,
global_config=global_config,
top_k=rerank_top_k,
)
logger.debug(f"Rerank: {len(unique_chunks)} chunks (source: {source_type})")
# 3. Apply chunk_top_k limiting if specified
if query_param.chunk_top_k is not None and query_param.chunk_top_k > 0:
if len(unique_chunks) > query_param.chunk_top_k:
unique_chunks = unique_chunks[: query_param.chunk_top_k]
logger.debug(
f"Chunk top-k limiting: kept {len(unique_chunks)} chunks (chunk_top_k={query_param.chunk_top_k})"
)
# 4. Token-based final truncation
tokenizer = global_config.get("tokenizer")
if tokenizer and unique_chunks:
original_count = len(unique_chunks)
unique_chunks = truncate_list_by_token_size(
unique_chunks,
key=lambda x: x.get("content", ""),
max_token_size=query_param.max_token_for_text_unit,
tokenizer=tokenizer,
)
logger.debug(
f"Token truncation: {len(unique_chunks)} chunks from {original_count} "
f"(max tokens: {query_param.max_token_for_text_unit}, source: {source_type})"
)
return unique_chunks

321
lightrag/rerank.py Normal file
View file

@ -0,0 +1,321 @@
from __future__ import annotations
import os
import aiohttp
from typing import Callable, Any, List, Dict, Optional
from pydantic import BaseModel, Field
from .utils import logger
class RerankModel(BaseModel):
"""
Pydantic model class for defining a custom rerank model.
This class provides a convenient wrapper for rerank functions, allowing you to
encapsulate all rerank configurations (API keys, model settings, etc.) in one place.
Attributes:
rerank_func (Callable[[Any], List[Dict]]): A callable function that reranks documents.
The function should take query and documents as input and return reranked results.
kwargs (Dict[str, Any]): A dictionary that contains the arguments to pass to the callable function.
This should include all necessary configurations such as model name, API key, base_url, etc.
Example usage:
Rerank model example with Jina:
```python
rerank_model = RerankModel(
rerank_func=jina_rerank,
kwargs={
"model": "BAAI/bge-reranker-v2-m3",
"api_key": "your_api_key_here",
"base_url": "https://api.jina.ai/v1/rerank"
}
)
# Use in LightRAG
rag = LightRAG(
enable_rerank=True,
rerank_model_func=rerank_model.rerank,
# ... other configurations
)
```
Or define a custom function directly:
```python
async def my_rerank_func(query: str, documents: list, top_k: int = None, **kwargs):
return await jina_rerank(
query=query,
documents=documents,
model="BAAI/bge-reranker-v2-m3",
api_key="your_api_key_here",
top_k=top_k or 10,
**kwargs
)
rag = LightRAG(
enable_rerank=True,
rerank_model_func=my_rerank_func,
# ... other configurations
)
```
"""
rerank_func: Callable[[Any], List[Dict]]
kwargs: Dict[str, Any] = Field(default_factory=dict)
async def rerank(
self,
query: str,
documents: List[Dict[str, Any]],
top_k: Optional[int] = None,
**extra_kwargs,
) -> List[Dict[str, Any]]:
"""Rerank documents using the configured model function."""
# Merge extra kwargs with model kwargs
kwargs = {**self.kwargs, **extra_kwargs}
return await self.rerank_func(
query=query, documents=documents, top_k=top_k, **kwargs
)
class MultiRerankModel(BaseModel):
"""Multiple rerank models for different modes/scenarios."""
# Primary rerank model (used if mode-specific models are not defined)
rerank_model: Optional[RerankModel] = None
# Mode-specific rerank models
entity_rerank_model: Optional[RerankModel] = None
relation_rerank_model: Optional[RerankModel] = None
chunk_rerank_model: Optional[RerankModel] = None
async def rerank(
self,
query: str,
documents: List[Dict[str, Any]],
mode: str = "default",
top_k: Optional[int] = None,
**kwargs,
) -> List[Dict[str, Any]]:
"""Rerank using the appropriate model based on mode."""
# Select model based on mode
if mode == "entity" and self.entity_rerank_model:
model = self.entity_rerank_model
elif mode == "relation" and self.relation_rerank_model:
model = self.relation_rerank_model
elif mode == "chunk" and self.chunk_rerank_model:
model = self.chunk_rerank_model
elif self.rerank_model:
model = self.rerank_model
else:
logger.warning(f"No rerank model available for mode: {mode}")
return documents
return await model.rerank(query, documents, top_k, **kwargs)
async def generic_rerank_api(
query: str,
documents: List[Dict[str, Any]],
model: str,
base_url: str,
api_key: str,
top_k: Optional[int] = None,
**kwargs,
) -> List[Dict[str, Any]]:
"""
Generic rerank function that works with Jina/Cohere compatible APIs.
Args:
query: The search query
documents: List of documents to rerank
model: Model identifier
base_url: API endpoint URL
api_key: API authentication key
top_k: Number of top results to return
**kwargs: Additional API-specific parameters
Returns:
List of reranked documents with relevance scores
"""
if not api_key:
logger.warning("No API key provided for rerank service")
return documents
if not documents:
return documents
# Prepare documents for reranking - handle both text and dict formats
prepared_docs = []
for doc in documents:
if isinstance(doc, dict):
# Use 'content' field if available, otherwise use 'text' or convert to string
text = doc.get("content") or doc.get("text") or str(doc)
else:
text = str(doc)
prepared_docs.append(text)
# Prepare request
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
data = {"model": model, "query": query, "documents": prepared_docs, **kwargs}
if top_k is not None:
data["top_k"] = min(top_k, len(prepared_docs))
try:
async with aiohttp.ClientSession() as session:
async with session.post(base_url, headers=headers, json=data) as response:
if response.status != 200:
error_text = await response.text()
logger.error(f"Rerank API error {response.status}: {error_text}")
return documents
result = await response.json()
# Extract reranked results
if "results" in result:
# Standard format: results contain index and relevance_score
reranked_docs = []
for item in result["results"]:
if "index" in item:
doc_idx = item["index"]
if 0 <= doc_idx < len(documents):
reranked_doc = documents[doc_idx].copy()
if "relevance_score" in item:
reranked_doc["rerank_score"] = item[
"relevance_score"
]
reranked_docs.append(reranked_doc)
return reranked_docs
else:
logger.warning("Unexpected rerank API response format")
return documents
except Exception as e:
logger.error(f"Error during reranking: {e}")
return documents
async def jina_rerank(
query: str,
documents: List[Dict[str, Any]],
model: str = "BAAI/bge-reranker-v2-m3",
top_k: Optional[int] = None,
base_url: str = "https://api.jina.ai/v1/rerank",
api_key: Optional[str] = None,
**kwargs,
) -> List[Dict[str, Any]]:
"""
Rerank documents using Jina AI API.
Args:
query: The search query
documents: List of documents to rerank
model: Jina rerank model name
top_k: Number of top results to return
base_url: Jina API endpoint
api_key: Jina API key
**kwargs: Additional parameters
Returns:
List of reranked documents with relevance scores
"""
if api_key is None:
api_key = os.getenv("JINA_API_KEY") or os.getenv("RERANK_API_KEY")
return await generic_rerank_api(
query=query,
documents=documents,
model=model,
base_url=base_url,
api_key=api_key,
top_k=top_k,
**kwargs,
)
async def cohere_rerank(
query: str,
documents: List[Dict[str, Any]],
model: str = "rerank-english-v2.0",
top_k: Optional[int] = None,
base_url: str = "https://api.cohere.ai/v1/rerank",
api_key: Optional[str] = None,
**kwargs,
) -> List[Dict[str, Any]]:
"""
Rerank documents using Cohere API.
Args:
query: The search query
documents: List of documents to rerank
model: Cohere rerank model name
top_k: Number of top results to return
base_url: Cohere API endpoint
api_key: Cohere API key
**kwargs: Additional parameters
Returns:
List of reranked documents with relevance scores
"""
if api_key is None:
api_key = os.getenv("COHERE_API_KEY") or os.getenv("RERANK_API_KEY")
return await generic_rerank_api(
query=query,
documents=documents,
model=model,
base_url=base_url,
api_key=api_key,
top_k=top_k,
**kwargs,
)
# Convenience function for custom API endpoints
async def custom_rerank(
query: str,
documents: List[Dict[str, Any]],
model: str,
base_url: str,
api_key: str,
top_k: Optional[int] = None,
**kwargs,
) -> List[Dict[str, Any]]:
"""
Rerank documents using a custom API endpoint.
This is useful for self-hosted or custom rerank services.
"""
return await generic_rerank_api(
query=query,
documents=documents,
model=model,
base_url=base_url,
api_key=api_key,
top_k=top_k,
**kwargs,
)
if __name__ == "__main__":
import asyncio
async def main():
# Example usage
docs = [
{"content": "The capital of France is Paris."},
{"content": "Tokyo is the capital of Japan."},
{"content": "London is the capital of England."},
]
query = "What is the capital of France?"
result = await jina_rerank(
query=query, documents=docs, top_k=2, api_key="your-api-key-here"
)
print(result)
asyncio.run(main())

View file

@ -111,7 +111,7 @@ const useSettingsStoreBase = create<SettingsState>()(
mode: 'global', mode: 'global',
response_type: 'Multiple Paragraphs', response_type: 'Multiple Paragraphs',
top_k: 10, top_k: 10,
max_token_for_text_unit: 4000, max_token_for_text_unit: 6000,
max_token_for_global_context: 4000, max_token_for_global_context: 4000,
max_token_for_local_context: 4000, max_token_for_local_context: 4000,
only_need_context: false, only_need_context: false,

View file

@ -10,6 +10,7 @@
- Neo4JStorage - Neo4JStorage
- MongoDBStorage - MongoDBStorage
- PGGraphStorage - PGGraphStorage
- MemgraphStorage
""" """
import asyncio import asyncio