diff --git a/README-zh.md b/README-zh.md index e9599099..d6aef2c8 100644 --- a/README-zh.md +++ b/README-zh.md @@ -294,6 +294,16 @@ class QueryParam: top_k: int = int(os.getenv("TOP_K", "60")) """Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode.""" + chunk_top_k: int = int(os.getenv("CHUNK_TOP_K", "5")) + """Number of text chunks to retrieve initially from vector search. + If None, defaults to top_k value. + """ + + chunk_rerank_top_k: int = int(os.getenv("CHUNK_RERANK_TOP_K", "5")) + """Number of text chunks to keep after reranking. + If None, keeps all chunks returned from initial retrieval. + """ + max_token_for_text_unit: int = int(os.getenv("MAX_TOKEN_TEXT_CHUNK", "4000")) """Maximum number of tokens allowed for each retrieved text chunk.""" @@ -849,6 +859,18 @@ rag = LightRAG( +### LightRAG实例间的数据隔离 + +通过 workspace 参数可以不同实现不同LightRAG实例之间的存储数据隔离。LightRAG在初始化后workspace就已经确定,之后修改workspace是无效的。下面是不同类型的存储实现工作空间的方式: + +- **对于本地基于文件的数据库,数据隔离通过工作空间子目录实现:** JsonKVStorage, JsonDocStatusStorage, NetworkXStorage, NanoVectorDBStorage, FaissVectorDBStorage。 +- **对于将数据存储在集合(collection)中的数据库,通过在集合名称前添加工作空间前缀来实现:** RedisKVStorage, RedisDocStatusStorage, MilvusVectorDBStorage, QdrantVectorDBStorage, MongoKVStorage, MongoDocStatusStorage, MongoVectorDBStorage, MongoGraphStorage, PGGraphStorage。 +- **对于关系型数据库,数据隔离通过向表中添加 `workspace` 字段进行数据的逻辑隔离:** PGKVStorage, PGVectorStorage, PGDocStatusStorage。 + +* **对于Neo4j图数据库,通过label来实现数据的逻辑隔离**:Neo4JStorage + +为了保持对遗留数据的兼容,在未配置工作空间时PostgreSQL的默认工作空间为`default`,Neo4j的默认工作空间为`base`。对于所有的外部存储,系统都提供了专用的工作空间环境变量,用于覆盖公共的 `WORKSPACE`环境变量配置。这些适用于指定存储类型的工作空间环境变量为:`REDIS_WORKSPACE`, `MILVUS_WORKSPACE`, `QDRANT_WORKSPACE`, `MONGODB_WORKSPACE`, `POSTGRES_WORKSPACE`, `NEO4J_WORKSPACE`。 + ## 编辑实体和关系 LightRAG现在支持全面的知识图谱管理功能,允许您在知识图谱中创建、编辑和删除实体和关系。 diff --git a/README.md b/README.md index e812e8df..5fb6149b 100644 --- a/README.md +++ b/README.md @@ -153,7 +153,7 @@ curl https://raw.githubusercontent.com/gusye1234/nano-graphrag/main/tests/mock_d python examples/lightrag_openai_demo.py ``` -For a streaming response implementation example, please see `examples/lightrag_openai_compatible_demo.py`. Prior to execution, ensure you modify the sample code’s LLM and embedding configurations accordingly. +For a streaming response implementation example, please see `examples/lightrag_openai_compatible_demo.py`. Prior to execution, ensure you modify the sample code's LLM and embedding configurations accordingly. **Note 1**: When running the demo program, please be aware that different test scripts may use different embedding models. If you switch to a different embedding model, you must clear the data directory (`./dickens`); otherwise, the program may encounter errors. If you wish to retain the LLM cache, you can preserve the `kv_store_llm_response_cache.json` file while clearing the data directory. @@ -239,6 +239,7 @@ A full list of LightRAG init parameters: | **Parameter** | **Type** | **Explanation** | **Default** | |--------------|----------|-----------------|-------------| | **working_dir** | `str` | Directory where the cache will be stored | `lightrag_cache+timestamp` | +| **workspace** | str | Workspace name for data isolation between different LightRAG Instances | | | **kv_storage** | `str` | Storage type for documents and text chunks. Supported types: `JsonKVStorage`,`PGKVStorage`,`RedisKVStorage`,`MongoKVStorage` | `JsonKVStorage` | | **vector_storage** | `str` | Storage type for embedding vectors. Supported types: `NanoVectorDBStorage`,`PGVectorStorage`,`MilvusVectorDBStorage`,`ChromaVectorDBStorage`,`FaissVectorDBStorage`,`MongoVectorDBStorage`,`QdrantVectorDBStorage` | `NanoVectorDBStorage` | | **graph_storage** | `str` | Storage type for graph edges and nodes. Supported types: `NetworkXStorage`,`Neo4JStorage`,`PGGraphStorage`,`AGEStorage` | `NetworkXStorage` | @@ -300,6 +301,16 @@ class QueryParam: top_k: int = int(os.getenv("TOP_K", "60")) """Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode.""" + chunk_top_k: int = int(os.getenv("CHUNK_TOP_K", "5")) + """Number of text chunks to retrieve initially from vector search. + If None, defaults to top_k value. + """ + + chunk_rerank_top_k: int = int(os.getenv("CHUNK_RERANK_TOP_K", "5")) + """Number of text chunks to keep after reranking. + If None, keeps all chunks returned from initial retrieval. + """ + max_token_for_text_unit: int = int(os.getenv("MAX_TOKEN_TEXT_CHUNK", "4000")) """Maximum number of tokens allowed for each retrieved text chunk.""" @@ -860,6 +871,52 @@ rag = LightRAG( +
+ Using Memgraph for Storage + +* Memgraph is a high-performance, in-memory graph database compatible with the Neo4j Bolt protocol. +* You can run Memgraph locally using Docker for easy testing: +* See: https://memgraph.com/download + +```python +export MEMGRAPH_URI="bolt://localhost:7687" + +# Setup logger for LightRAG +setup_logger("lightrag", level="INFO") + +# When you launch the project, override the default KG: NetworkX +# by specifying kg="MemgraphStorage". + +# Note: Default settings use NetworkX +# Initialize LightRAG with Memgraph implementation. +async def initialize_rag(): + rag = LightRAG( + working_dir=WORKING_DIR, + llm_model_func=gpt_4o_mini_complete, # Use gpt_4o_mini_complete LLM model + graph_storage="MemgraphStorage", #<-----------override KG default + ) + + # Initialize database connections + await rag.initialize_storages() + # Initialize pipeline status for document processing + await initialize_pipeline_status() + + return rag +``` + +
+ +### Data Isolation Between LightRAG Instances + +The `workspace` parameter ensures data isolation between different LightRAG instances. Once initialized, the `workspace` is immutable and cannot be changed.Here is how workspaces are implemented for different types of storage: + +- **For local file-based databases, data isolation is achieved through workspace subdirectories:** `JsonKVStorage`, `JsonDocStatusStorage`, `NetworkXStorage`, `NanoVectorDBStorage`, `FaissVectorDBStorage`. +- **For databases that store data in collections, it's done by adding a workspace prefix to the collection name:** `RedisKVStorage`, `RedisDocStatusStorage`, `MilvusVectorDBStorage`, `QdrantVectorDBStorage`, `MongoKVStorage`, `MongoDocStatusStorage`, `MongoVectorDBStorage`, `MongoGraphStorage`, `PGGraphStorage`. +- **For relational databases, data isolation is achieved by adding a `workspace` field to the tables for logical data separation:** `PGKVStorage`, `PGVectorStorage`, `PGDocStatusStorage`. +- **For the Neo4j graph database, logical data isolation is achieved through labels:** `Neo4JStorage` + +To maintain compatibility with legacy data, the default workspace for PostgreSQL is `default` and for Neo4j is `base` when no workspace is configured. For all external storages, the system provides dedicated workspace environment variables to override the common `WORKSPACE` environment variable configuration. These storage-specific workspace environment variables are: `REDIS_WORKSPACE`, `MILVUS_WORKSPACE`, `QDRANT_WORKSPACE`, `MONGODB_WORKSPACE`, `POSTGRES_WORKSPACE`, `NEO4J_WORKSPACE`. + ## Edit Entities and Relations LightRAG now supports comprehensive knowledge graph management capabilities, allowing you to create, edit, and delete entities and relationships within your knowledge graph. diff --git a/config.ini.example b/config.ini.example index 63d9c2c0..94d300a1 100644 --- a/config.ini.example +++ b/config.ini.example @@ -21,3 +21,6 @@ password = your_password database = your_database workspace = default # 可选,默认为default max_connections = 12 + +[memgraph] +uri = bolt://localhost:7687 diff --git a/docs/rerank_integration.md b/docs/rerank_integration.md new file mode 100644 index 00000000..fdaebfa5 --- /dev/null +++ b/docs/rerank_integration.md @@ -0,0 +1,275 @@ +# Rerank Integration in LightRAG + +This document explains how to configure and use the rerank functionality in LightRAG to improve retrieval quality. + +## Overview + +Reranking is an optional feature that improves the quality of retrieved documents by re-ordering them based on their relevance to the query. This is particularly useful when you want higher precision in document retrieval across all query modes (naive, local, global, hybrid, mix). + +## Architecture + +The rerank integration follows a simplified design pattern: + +- **Single Function Configuration**: All rerank settings (model, API keys, top_k, etc.) are contained within the rerank function +- **Async Processing**: Non-blocking rerank operations +- **Error Handling**: Graceful fallback to original results +- **Optional Feature**: Can be enabled/disabled via configuration +- **Code Reuse**: Single generic implementation for Jina/Cohere compatible APIs + +## Configuration + +### Environment Variables + +Set this variable in your `.env` file or environment: + +```bash +# Enable/disable reranking +ENABLE_RERANK=True +``` + +### Programmatic Configuration + +```python +from lightrag import LightRAG +from lightrag.rerank import custom_rerank, RerankModel + +# Method 1: Using a custom rerank function with all settings included +async def my_rerank_func(query: str, documents: list, top_k: int = None, **kwargs): + return await custom_rerank( + query=query, + documents=documents, + model="BAAI/bge-reranker-v2-m3", + base_url="https://api.your-provider.com/v1/rerank", + api_key="your_api_key_here", + top_k=top_k or 10, # Handle top_k within the function + **kwargs + ) + +rag = LightRAG( + working_dir="./rag_storage", + llm_model_func=your_llm_func, + embedding_func=your_embedding_func, + enable_rerank=True, + rerank_model_func=my_rerank_func, +) + +# Method 2: Using RerankModel wrapper +rerank_model = RerankModel( + rerank_func=custom_rerank, + kwargs={ + "model": "BAAI/bge-reranker-v2-m3", + "base_url": "https://api.your-provider.com/v1/rerank", + "api_key": "your_api_key_here", + } +) + +rag = LightRAG( + working_dir="./rag_storage", + llm_model_func=your_llm_func, + embedding_func=your_embedding_func, + enable_rerank=True, + rerank_model_func=rerank_model.rerank, +) +``` + +## Supported Providers + +### 1. Custom/Generic API (Recommended) + +For Jina/Cohere compatible APIs: + +```python +from lightrag.rerank import custom_rerank + +# Your custom API endpoint +result = await custom_rerank( + query="your query", + documents=documents, + model="BAAI/bge-reranker-v2-m3", + base_url="https://api.your-provider.com/v1/rerank", + api_key="your_api_key_here", + top_k=10 +) +``` + +### 2. Jina AI + +```python +from lightrag.rerank import jina_rerank + +result = await jina_rerank( + query="your query", + documents=documents, + model="BAAI/bge-reranker-v2-m3", + api_key="your_jina_api_key", + top_k=10 +) +``` + +### 3. Cohere + +```python +from lightrag.rerank import cohere_rerank + +result = await cohere_rerank( + query="your query", + documents=documents, + model="rerank-english-v2.0", + api_key="your_cohere_api_key", + top_k=10 +) +``` + +## Integration Points + +Reranking is automatically applied at these key retrieval stages: + +1. **Naive Mode**: After vector similarity search in `_get_vector_context` +2. **Local Mode**: After entity retrieval in `_get_node_data` +3. **Global Mode**: After relationship retrieval in `_get_edge_data` +4. **Hybrid/Mix Modes**: Applied to all relevant components + +## Configuration Parameters + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `enable_rerank` | bool | False | Enable/disable reranking | +| `rerank_model_func` | callable | None | Custom rerank function containing all configurations (model, API keys, top_k, etc.) | + +## Example Usage + +### Basic Usage + +```python +import asyncio +from lightrag import LightRAG, QueryParam +from lightrag.llm.openai import gpt_4o_mini_complete, openai_embedding +from lightrag.kg.shared_storage import initialize_pipeline_status +from lightrag.rerank import jina_rerank + +async def my_rerank_func(query: str, documents: list, top_k: int = None, **kwargs): + """Custom rerank function with all settings included""" + return await jina_rerank( + query=query, + documents=documents, + model="BAAI/bge-reranker-v2-m3", + api_key="your_jina_api_key_here", + top_k=top_k or 10, # Default top_k if not provided + **kwargs + ) + +async def main(): + # Initialize with rerank enabled + rag = LightRAG( + working_dir="./rag_storage", + llm_model_func=gpt_4o_mini_complete, + embedding_func=openai_embedding, + enable_rerank=True, + rerank_model_func=my_rerank_func, + ) + + await rag.initialize_storages() + await initialize_pipeline_status() + + # Insert documents + await rag.ainsert([ + "Document 1 content...", + "Document 2 content...", + ]) + + # Query with rerank (automatically applied) + result = await rag.aquery( + "Your question here", + param=QueryParam(mode="hybrid", top_k=5) # This top_k is passed to rerank function + ) + + print(result) + +asyncio.run(main()) +``` + +### Direct Rerank Usage + +```python +from lightrag.rerank import custom_rerank + +async def test_rerank(): + documents = [ + {"content": "Text about topic A"}, + {"content": "Text about topic B"}, + {"content": "Text about topic C"}, + ] + + reranked = await custom_rerank( + query="Tell me about topic A", + documents=documents, + model="BAAI/bge-reranker-v2-m3", + base_url="https://api.your-provider.com/v1/rerank", + api_key="your_api_key_here", + top_k=2 + ) + + for doc in reranked: + print(f"Score: {doc.get('rerank_score')}, Content: {doc.get('content')}") +``` + +## Best Practices + +1. **Self-Contained Functions**: Include all necessary configurations (API keys, models, top_k handling) within your rerank function +2. **Performance**: Use reranking selectively for better performance vs. quality tradeoff +3. **API Limits**: Monitor API usage and implement rate limiting within your rerank function +4. **Fallback**: Always handle rerank failures gracefully (returns original results) +5. **Top-k Handling**: Handle top_k parameter appropriately within your rerank function +6. **Cost Management**: Consider rerank API costs in your budget planning + +## Troubleshooting + +### Common Issues + +1. **API Key Missing**: Ensure API keys are properly configured within your rerank function +2. **Network Issues**: Check API endpoints and network connectivity +3. **Model Errors**: Verify the rerank model name is supported by your API +4. **Document Format**: Ensure documents have `content` or `text` fields + +### Debug Mode + +Enable debug logging to see rerank operations: + +```python +import logging +logging.getLogger("lightrag.rerank").setLevel(logging.DEBUG) +``` + +### Error Handling + +The rerank integration includes automatic fallback: + +```python +# If rerank fails, original documents are returned +# No exceptions are raised to the user +# Errors are logged for debugging +``` + +## API Compatibility + +The generic rerank API expects this response format: + +```json +{ + "results": [ + { + "index": 0, + "relevance_score": 0.95 + }, + { + "index": 2, + "relevance_score": 0.87 + } + ] +} +``` + +This is compatible with: +- Jina AI Rerank API +- Cohere Rerank API +- Custom APIs following the same format diff --git a/env.example b/env.example index f759ea92..4515fe34 100644 --- a/env.example +++ b/env.example @@ -42,13 +42,31 @@ OLLAMA_EMULATING_MODEL_TAG=latest ### Logfile location (defaults to current working directory) # LOG_DIR=/path/to/log/directory -### Settings for RAG query +### RAG Configuration +### Chunk size for document splitting, 500~1500 is recommended +# CHUNK_SIZE=1200 +# CHUNK_OVERLAP_SIZE=100 +# MAX_TOKEN_SUMMARY=500 + +### RAG Query Configuration # HISTORY_TURNS=3 -# COSINE_THRESHOLD=0.2 -# TOP_K=60 -# MAX_TOKEN_TEXT_CHUNK=4000 +# MAX_TOKEN_TEXT_CHUNK=6000 # MAX_TOKEN_RELATION_DESC=4000 # MAX_TOKEN_ENTITY_DESC=4000 +# COSINE_THRESHOLD=0.2 +### Number of entities or relations to retrieve from KG +# TOP_K=60 +### Number of text chunks to retrieve initially from vector search +# CHUNK_TOP_K=5 + +### Rerank Configuration +# ENABLE_RERANK=False +### Number of text chunks to keep after reranking (should be <= CHUNK_TOP_K) +# CHUNK_RERANK_TOP_K=5 +### Rerank model configuration (required when ENABLE_RERANK=True) +# RERANK_MODEL=BAAI/bge-reranker-v2-m3 +# RERANK_BINDING_HOST=https://api.your-rerank-provider.com/v1/rerank +# RERANK_BINDING_API_KEY=your_rerank_api_key_here ### Entity and relation summarization configuration ### Language: English, Chinese, French, German ... @@ -62,9 +80,6 @@ SUMMARY_LANGUAGE=English ### Number of parallel processing documents(Less than MAX_ASYNC/2 is recommended) # MAX_PARALLEL_INSERT=2 -### Chunk size for document splitting, 500~1500 is recommended -# CHUNK_SIZE=1200 -# CHUNK_OVERLAP_SIZE=100 ### LLM Configuration ENABLE_LLM_CACHE=true @@ -134,13 +149,14 @@ EMBEDDING_BINDING_HOST=http://localhost:11434 # LIGHTRAG_VECTOR_STORAGE=QdrantVectorDBStorage ### Graph Storage (Recommended for production deployment) # LIGHTRAG_GRAPH_STORAGE=Neo4JStorage +# LIGHTRAG_GRAPH_STORAGE=MemgraphStorage #################################################################### ### Default workspace for all storage types ### For the purpose of isolation of data for each LightRAG instance ### Valid characters: a-z, A-Z, 0-9, and _ #################################################################### -# WORKSPACE=doc— +# WORKSPACE=space1 ### PostgreSQL Configuration POSTGRES_HOST=localhost @@ -179,3 +195,10 @@ QDRANT_URL=http://localhost:6333 ### Redis REDIS_URI=redis://localhost:6379 # REDIS_WORKSPACE=forced_workspace_name + +### Memgraph Configuration +MEMGRAPH_URI=bolt://localhost:7687 +MEMGRAPH_USERNAME= +MEMGRAPH_PASSWORD= +MEMGRAPH_DATABASE=memgraph +# MEMGRAPH_WORKSPACE=forced_workspace_name diff --git a/examples/rerank_example.py b/examples/rerank_example.py new file mode 100644 index 00000000..e0e361a5 --- /dev/null +++ b/examples/rerank_example.py @@ -0,0 +1,233 @@ +""" +LightRAG Rerank Integration Example + +This example demonstrates how to use rerank functionality with LightRAG +to improve retrieval quality across different query modes. + +Configuration Required: +1. Set your LLM API key and base URL in llm_model_func() +2. Set your embedding API key and base URL in embedding_func() +3. Set your rerank API key and base URL in the rerank configuration +4. Or use environment variables (.env file): + - ENABLE_RERANK=True +""" + +import asyncio +import os +import numpy as np + +from lightrag import LightRAG, QueryParam +from lightrag.rerank import custom_rerank, RerankModel +from lightrag.llm.openai import openai_complete_if_cache, openai_embed +from lightrag.utils import EmbeddingFunc, setup_logger +from lightrag.kg.shared_storage import initialize_pipeline_status + +# Set up your working directory +WORKING_DIR = "./test_rerank" +setup_logger("test_rerank") + +if not os.path.exists(WORKING_DIR): + os.mkdir(WORKING_DIR) + + +async def llm_model_func( + prompt, system_prompt=None, history_messages=[], **kwargs +) -> str: + return await openai_complete_if_cache( + "gpt-4o-mini", + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + api_key="your_llm_api_key_here", + base_url="https://api.your-llm-provider.com/v1", + **kwargs, + ) + + +async def embedding_func(texts: list[str]) -> np.ndarray: + return await openai_embed( + texts, + model="text-embedding-3-large", + api_key="your_embedding_api_key_here", + base_url="https://api.your-embedding-provider.com/v1", + ) + + +async def my_rerank_func(query: str, documents: list, top_k: int = None, **kwargs): + """Custom rerank function with all settings included""" + return await custom_rerank( + query=query, + documents=documents, + model="BAAI/bge-reranker-v2-m3", + base_url="https://api.your-rerank-provider.com/v1/rerank", + api_key="your_rerank_api_key_here", + top_k=top_k or 10, # Default top_k if not provided + **kwargs, + ) + + +async def create_rag_with_rerank(): + """Create LightRAG instance with rerank configuration""" + + # Get embedding dimension + test_embedding = await embedding_func(["test"]) + embedding_dim = test_embedding.shape[1] + print(f"Detected embedding dimension: {embedding_dim}") + + # Method 1: Using custom rerank function + rag = LightRAG( + working_dir=WORKING_DIR, + llm_model_func=llm_model_func, + embedding_func=EmbeddingFunc( + embedding_dim=embedding_dim, + max_token_size=8192, + func=embedding_func, + ), + # Simplified Rerank Configuration + enable_rerank=True, + rerank_model_func=my_rerank_func, + ) + + await rag.initialize_storages() + await initialize_pipeline_status() + + return rag + + +async def create_rag_with_rerank_model(): + """Alternative: Create LightRAG instance using RerankModel wrapper""" + + # Get embedding dimension + test_embedding = await embedding_func(["test"]) + embedding_dim = test_embedding.shape[1] + print(f"Detected embedding dimension: {embedding_dim}") + + # Method 2: Using RerankModel wrapper + rerank_model = RerankModel( + rerank_func=custom_rerank, + kwargs={ + "model": "BAAI/bge-reranker-v2-m3", + "base_url": "https://api.your-rerank-provider.com/v1/rerank", + "api_key": "your_rerank_api_key_here", + }, + ) + + rag = LightRAG( + working_dir=WORKING_DIR, + llm_model_func=llm_model_func, + embedding_func=EmbeddingFunc( + embedding_dim=embedding_dim, + max_token_size=8192, + func=embedding_func, + ), + enable_rerank=True, + rerank_model_func=rerank_model.rerank, + ) + + await rag.initialize_storages() + await initialize_pipeline_status() + + return rag + + +async def test_rerank_with_different_topk(): + """ + Test rerank functionality with different top_k settings + """ + print("🚀 Setting up LightRAG with Rerank functionality...") + + rag = await create_rag_with_rerank() + + # Insert sample documents + sample_docs = [ + "Reranking improves retrieval quality by re-ordering documents based on relevance.", + "LightRAG is a powerful retrieval-augmented generation system with multiple query modes.", + "Vector databases enable efficient similarity search in high-dimensional embedding spaces.", + "Natural language processing has evolved with large language models and transformers.", + "Machine learning algorithms can learn patterns from data without explicit programming.", + ] + + print("📄 Inserting sample documents...") + await rag.ainsert(sample_docs) + + query = "How does reranking improve retrieval quality?" + print(f"\n🔍 Testing query: '{query}'") + print("=" * 80) + + # Test different top_k values to show parameter priority + top_k_values = [2, 5, 10] + + for top_k in top_k_values: + print(f"\n📊 Testing with QueryParam(top_k={top_k}):") + + # Test naive mode with specific top_k + result = await rag.aquery(query, param=QueryParam(mode="naive", top_k=top_k)) + print(f" Result length: {len(result)} characters") + print(f" Preview: {result[:100]}...") + + +async def test_direct_rerank(): + """Test rerank function directly""" + print("\n🔧 Direct Rerank API Test") + print("=" * 40) + + documents = [ + {"content": "Reranking significantly improves retrieval quality"}, + {"content": "LightRAG supports advanced reranking capabilities"}, + {"content": "Vector search finds semantically similar documents"}, + {"content": "Natural language processing with modern transformers"}, + {"content": "The quick brown fox jumps over the lazy dog"}, + ] + + query = "rerank improve quality" + print(f"Query: '{query}'") + print(f"Documents: {len(documents)}") + + try: + reranked_docs = await custom_rerank( + query=query, + documents=documents, + model="BAAI/bge-reranker-v2-m3", + base_url="https://api.your-rerank-provider.com/v1/rerank", + api_key="your_rerank_api_key_here", + top_k=3, + ) + + print("\n✅ Rerank Results:") + for i, doc in enumerate(reranked_docs): + score = doc.get("rerank_score", "N/A") + content = doc.get("content", "")[:60] + print(f" {i+1}. Score: {score:.4f} | {content}...") + + except Exception as e: + print(f"❌ Rerank failed: {e}") + + +async def main(): + """Main example function""" + print("🎯 LightRAG Rerank Integration Example") + print("=" * 60) + + try: + # Test rerank with different top_k values + await test_rerank_with_different_topk() + + # Test direct rerank + await test_direct_rerank() + + print("\n✅ Example completed successfully!") + print("\n💡 Key Points:") + print(" ✓ All rerank configurations are contained within rerank_model_func") + print(" ✓ Rerank improves document relevance ordering") + print(" ✓ Configure API keys within your rerank function") + print(" ✓ Monitor API usage and costs when using rerank services") + + except Exception as e: + print(f"\n❌ Example failed: {e}") + import traceback + + traceback.print_exc() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/lightrag/__init__.py b/lightrag/__init__.py index 392b3f60..e72f906a 100644 --- a/lightrag/__init__.py +++ b/lightrag/__init__.py @@ -1,5 +1,5 @@ from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam -__version__ = "1.3.10" +__version__ = "1.4.0" __author__ = "Zirui Guo" __url__ = "https://github.com/HKUDS/LightRAG" diff --git a/lightrag/api/config.py b/lightrag/api/config.py index ad0e670b..70147bde 100644 --- a/lightrag/api/config.py +++ b/lightrag/api/config.py @@ -165,6 +165,24 @@ def parse_args() -> argparse.Namespace: default=get_env_value("TOP_K", 60, int), help="Number of most similar results to return (default: from env or 60)", ) + parser.add_argument( + "--chunk-top-k", + type=int, + default=get_env_value("CHUNK_TOP_K", 15, int), + help="Number of text chunks to retrieve initially from vector search (default: from env or 15)", + ) + parser.add_argument( + "--chunk-rerank-top-k", + type=int, + default=get_env_value("CHUNK_RERANK_TOP_K", 5, int), + help="Number of text chunks to keep after reranking (default: from env or 5)", + ) + parser.add_argument( + "--enable-rerank", + action="store_true", + default=get_env_value("ENABLE_RERANK", False, bool), + help="Enable rerank functionality (default: from env or False)", + ) parser.add_argument( "--cosine-threshold", type=float, @@ -295,6 +313,11 @@ def parse_args() -> argparse.Namespace: args.guest_token_expire_hours = get_env_value("GUEST_TOKEN_EXPIRE_HOURS", 24, int) args.jwt_algorithm = get_env_value("JWT_ALGORITHM", "HS256") + # Rerank model configuration + args.rerank_model = get_env_value("RERANK_MODEL", "BAAI/bge-reranker-v2-m3") + args.rerank_binding_host = get_env_value("RERANK_BINDING_HOST", None) + args.rerank_binding_api_key = get_env_value("RERANK_BINDING_API_KEY", None) + ollama_server_infos.LIGHTRAG_MODEL = args.simulated_model_name return args diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index cd87af22..b43c66d9 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -291,6 +291,32 @@ def create_app(args): ), ) + # Configure rerank function if enabled + rerank_model_func = None + if args.enable_rerank and args.rerank_binding_api_key and args.rerank_binding_host: + from lightrag.rerank import custom_rerank + + async def server_rerank_func( + query: str, documents: list, top_k: int = None, **kwargs + ): + """Server rerank function with configuration from environment variables""" + return await custom_rerank( + query=query, + documents=documents, + model=args.rerank_model, + base_url=args.rerank_binding_host, + api_key=args.rerank_binding_api_key, + top_k=top_k, + **kwargs, + ) + + rerank_model_func = server_rerank_func + logger.info(f"Rerank enabled with model: {args.rerank_model}") + elif args.enable_rerank: + logger.warning( + "Rerank enabled but RERANK_BINDING_API_KEY or RERANK_BINDING_HOST not configured. Rerank will be disabled." + ) + # Initialize RAG if args.llm_binding in ["lollms", "ollama", "openai"]: rag = LightRAG( @@ -324,6 +350,8 @@ def create_app(args): }, enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract, enable_llm_cache=args.enable_llm_cache, + enable_rerank=args.enable_rerank, + rerank_model_func=rerank_model_func, auto_manage_storages_states=False, max_parallel_insert=args.max_parallel_insert, max_graph_nodes=args.max_graph_nodes, @@ -352,6 +380,8 @@ def create_app(args): }, enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract, enable_llm_cache=args.enable_llm_cache, + enable_rerank=args.enable_rerank, + rerank_model_func=rerank_model_func, auto_manage_storages_states=False, max_parallel_insert=args.max_parallel_insert, max_graph_nodes=args.max_graph_nodes, @@ -478,6 +508,12 @@ def create_app(args): "enable_llm_cache": args.enable_llm_cache, "workspace": args.workspace, "max_graph_nodes": args.max_graph_nodes, + # Rerank configuration + "enable_rerank": args.enable_rerank, + "rerank_model": args.rerank_model if args.enable_rerank else None, + "rerank_binding_host": args.rerank_binding_host + if args.enable_rerank + else None, }, "auth_mode": auth_mode, "pipeline_busy": pipeline_status.get("busy", False), diff --git a/lightrag/api/routers/query_routes.py b/lightrag/api/routers/query_routes.py index 69aa32d8..0a0c6227 100644 --- a/lightrag/api/routers/query_routes.py +++ b/lightrag/api/routers/query_routes.py @@ -49,6 +49,18 @@ class QueryRequest(BaseModel): description="Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode.", ) + chunk_top_k: Optional[int] = Field( + ge=1, + default=None, + description="Number of text chunks to retrieve initially from vector search.", + ) + + chunk_rerank_top_k: Optional[int] = Field( + ge=1, + default=None, + description="Number of text chunks to keep after reranking.", + ) + max_token_for_text_unit: Optional[int] = Field( gt=1, default=None, diff --git a/lightrag/base.py b/lightrag/base.py index 57cb2ac6..97564ac2 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -60,7 +60,17 @@ class QueryParam: top_k: int = int(os.getenv("TOP_K", "60")) """Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode.""" - max_token_for_text_unit: int = int(os.getenv("MAX_TOKEN_TEXT_CHUNK", "4000")) + chunk_top_k: int = int(os.getenv("CHUNK_TOP_K", "5")) + """Number of text chunks to retrieve initially from vector search. + If None, defaults to top_k value. + """ + + chunk_rerank_top_k: int = int(os.getenv("CHUNK_RERANK_TOP_K", "5")) + """Number of text chunks to keep after reranking. + If None, keeps all chunks returned from initial retrieval. + """ + + max_token_for_text_unit: int = int(os.getenv("MAX_TOKEN_TEXT_CHUNK", "6000")) """Maximum number of tokens allowed for each retrieved text chunk.""" max_token_for_global_context: int = int( @@ -280,21 +290,6 @@ class BaseKVStorage(StorageNameSpace, ABC): False: if the cache drop failed, or the cache mode is not supported """ - # async def drop_cache_by_chunk_ids(self, chunk_ids: list[str] | None = None) -> bool: - # """Delete specific cache records from storage by chunk IDs - - # Importance notes for in-memory storage: - # 1. Changes will be persisted to disk during the next index_done_callback - # 2. update flags to notify other processes that data persistence is needed - - # Args: - # chunk_ids (list[str]): List of chunk IDs to be dropped from storage - - # Returns: - # True: if the cache drop successfully - # False: if the cache drop failed, or the operation is not supported - # """ - @dataclass class BaseGraphStorage(StorageNameSpace, ABC): diff --git a/lightrag/kg/__init__.py b/lightrag/kg/__init__.py index 1f5fd56f..b2a93e82 100644 --- a/lightrag/kg/__init__.py +++ b/lightrag/kg/__init__.py @@ -15,6 +15,7 @@ STORAGE_IMPLEMENTATIONS = { "Neo4JStorage", "PGGraphStorage", "MongoGraphStorage", + "MemgraphStorage", # "AGEStorage", # "TiDBGraphStorage", # "GremlinStorage", @@ -57,6 +58,7 @@ STORAGE_ENV_REQUIREMENTS: dict[str, list[str]] = { "NetworkXStorage": [], "Neo4JStorage": ["NEO4J_URI", "NEO4J_USERNAME", "NEO4J_PASSWORD"], "MongoGraphStorage": [], + "MemgraphStorage": ["MEMGRAPH_URI"], # "TiDBGraphStorage": ["TIDB_USER", "TIDB_PASSWORD", "TIDB_DATABASE"], "AGEStorage": [ "AGE_POSTGRES_DB", @@ -111,6 +113,7 @@ STORAGES = { "PGDocStatusStorage": ".kg.postgres_impl", "FaissVectorDBStorage": ".kg.faiss_impl", "QdrantVectorDBStorage": ".kg.qdrant_impl", + "MemgraphStorage": ".kg.memgraph_impl", } diff --git a/lightrag/kg/memgraph_impl.py b/lightrag/kg/memgraph_impl.py new file mode 100644 index 00000000..8c6d6574 --- /dev/null +++ b/lightrag/kg/memgraph_impl.py @@ -0,0 +1,906 @@ +import os +from dataclasses import dataclass +from typing import final +import configparser + +from ..utils import logger +from ..base import BaseGraphStorage +from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge +from ..constants import GRAPH_FIELD_SEP +import pipmaster as pm + +if not pm.is_installed("neo4j"): + pm.install("neo4j") + +from neo4j import ( + AsyncGraphDatabase, + AsyncManagedTransaction, +) + +from dotenv import load_dotenv + +# use the .env that is inside the current folder +load_dotenv(dotenv_path=".env", override=False) + +MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000)) + +config = configparser.ConfigParser() +config.read("config.ini", "utf-8") + + +@final +@dataclass +class MemgraphStorage(BaseGraphStorage): + def __init__(self, namespace, global_config, embedding_func, workspace=None): + memgraph_workspace = os.environ.get("MEMGRAPH_WORKSPACE") + if memgraph_workspace and memgraph_workspace.strip(): + workspace = memgraph_workspace + super().__init__( + namespace=namespace, + workspace=workspace or "", + global_config=global_config, + embedding_func=embedding_func, + ) + self._driver = None + + def _get_workspace_label(self) -> str: + """Get workspace label, return 'base' for compatibility when workspace is empty""" + workspace = getattr(self, "workspace", None) + return workspace if workspace else "base" + + async def initialize(self): + URI = os.environ.get( + "MEMGRAPH_URI", + config.get("memgraph", "uri", fallback="bolt://localhost:7687"), + ) + USERNAME = os.environ.get( + "MEMGRAPH_USERNAME", config.get("memgraph", "username", fallback="") + ) + PASSWORD = os.environ.get( + "MEMGRAPH_PASSWORD", config.get("memgraph", "password", fallback="") + ) + DATABASE = os.environ.get( + "MEMGRAPH_DATABASE", config.get("memgraph", "database", fallback="memgraph") + ) + + self._driver = AsyncGraphDatabase.driver( + URI, + auth=(USERNAME, PASSWORD), + ) + self._DATABASE = DATABASE + try: + async with self._driver.session(database=DATABASE) as session: + # Create index for base nodes on entity_id if it doesn't exist + try: + workspace_label = self._get_workspace_label() + await session.run( + f"""CREATE INDEX ON :{workspace_label}(entity_id)""" + ) + logger.info( + f"Created index on :{workspace_label}(entity_id) in Memgraph." + ) + except Exception as e: + # Index may already exist, which is not an error + logger.warning( + f"Index creation on :{workspace_label}(entity_id) may have failed or already exists: {e}" + ) + await session.run("RETURN 1") + logger.info(f"Connected to Memgraph at {URI}") + except Exception as e: + logger.error(f"Failed to connect to Memgraph at {URI}: {e}") + raise + + async def finalize(self): + if self._driver is not None: + await self._driver.close() + self._driver = None + + async def __aexit__(self, exc_type, exc, tb): + await self.finalize() + + async def index_done_callback(self): + # Memgraph handles persistence automatically + pass + + async def has_node(self, node_id: str) -> bool: + """ + Check if a node exists in the graph. + + Args: + node_id: The ID of the node to check. + + Returns: + bool: True if the node exists, False otherwise. + + Raises: + Exception: If there is an error checking the node existence. + """ + if self._driver is None: + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: + try: + workspace_label = self._get_workspace_label() + query = f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN count(n) > 0 AS node_exists" + result = await session.run(query, entity_id=node_id) + single_result = await result.single() + await result.consume() # Ensure result is fully consumed + return ( + single_result["node_exists"] if single_result is not None else False + ) + except Exception as e: + logger.error(f"Error checking node existence for {node_id}: {str(e)}") + await result.consume() # Ensure the result is consumed even on error + raise + + async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: + """ + Check if an edge exists between two nodes in the graph. + + Args: + source_node_id: The ID of the source node. + target_node_id: The ID of the target node. + + Returns: + bool: True if the edge exists, False otherwise. + + Raises: + Exception: If there is an error checking the edge existence. + """ + if self._driver is None: + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: + try: + workspace_label = self._get_workspace_label() + query = ( + f"MATCH (a:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(b:`{workspace_label}` {{entity_id: $target_entity_id}}) " + "RETURN COUNT(r) > 0 AS edgeExists" + ) + result = await session.run( + query, + source_entity_id=source_node_id, + target_entity_id=target_node_id, + ) # type: ignore + single_result = await result.single() + await result.consume() # Ensure result is fully consumed + return ( + single_result["edgeExists"] if single_result is not None else False + ) + except Exception as e: + logger.error( + f"Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}" + ) + await result.consume() # Ensure the result is consumed even on error + raise + + async def get_node(self, node_id: str) -> dict[str, str] | None: + """Get node by its label identifier, return only node properties + + Args: + node_id: The node label to look up + + Returns: + dict: Node properties if found + None: If node not found + + Raises: + Exception: If there is an error executing the query + """ + if self._driver is None: + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: + try: + workspace_label = self._get_workspace_label() + query = ( + f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN n" + ) + result = await session.run(query, entity_id=node_id) + try: + records = await result.fetch( + 2 + ) # Get 2 records for duplication check + + if len(records) > 1: + logger.warning( + f"Multiple nodes found with label '{node_id}'. Using first node." + ) + if records: + node = records[0]["n"] + node_dict = dict(node) + # Remove workspace label from labels list if it exists + if "labels" in node_dict: + node_dict["labels"] = [ + label + for label in node_dict["labels"] + if label != workspace_label + ] + return node_dict + return None + finally: + await result.consume() # Ensure result is fully consumed + except Exception as e: + logger.error(f"Error getting node for {node_id}: {str(e)}") + raise + + async def node_degree(self, node_id: str) -> int: + """Get the degree (number of relationships) of a node with the given label. + If multiple nodes have the same label, returns the degree of the first node. + If no node is found, returns 0. + + Args: + node_id: The label of the node + + Returns: + int: The number of relationships the node has, or 0 if no node found + + Raises: + Exception: If there is an error executing the query + """ + if self._driver is None: + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: + try: + workspace_label = self._get_workspace_label() + query = f""" + MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) + OPTIONAL MATCH (n)-[r]-() + RETURN COUNT(r) AS degree + """ + result = await session.run(query, entity_id=node_id) + try: + record = await result.single() + + if not record: + logger.warning(f"No node found with label '{node_id}'") + return 0 + + degree = record["degree"] + return degree + finally: + await result.consume() # Ensure result is fully consumed + except Exception as e: + logger.error(f"Error getting node degree for {node_id}: {str(e)}") + raise + + async def get_all_labels(self) -> list[str]: + """ + Get all existing node labels in the database + Returns: + ["Person", "Company", ...] # Alphabetically sorted label list + + Raises: + Exception: If there is an error executing the query + """ + if self._driver is None: + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: + try: + workspace_label = self._get_workspace_label() + query = f""" + MATCH (n:`{workspace_label}`) + WHERE n.entity_id IS NOT NULL + RETURN DISTINCT n.entity_id AS label + ORDER BY label + """ + result = await session.run(query) + labels = [] + async for record in result: + labels.append(record["label"]) + await result.consume() + return labels + except Exception as e: + logger.error(f"Error getting all labels: {str(e)}") + await result.consume() # Ensure the result is consumed even on error + raise + + async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: + """Retrieves all edges (relationships) for a particular node identified by its label. + + Args: + source_node_id: Label of the node to get edges for + + Returns: + list[tuple[str, str]]: List of (source_label, target_label) tuples representing edges + None: If no edges found + + Raises: + Exception: If there is an error executing the query + """ + if self._driver is None: + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) + try: + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: + try: + workspace_label = self._get_workspace_label() + query = f"""MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) + OPTIONAL MATCH (n)-[r]-(connected:`{workspace_label}`) + WHERE connected.entity_id IS NOT NULL + RETURN n, r, connected""" + results = await session.run(query, entity_id=source_node_id) + + edges = [] + async for record in results: + source_node = record["n"] + connected_node = record["connected"] + + # Skip if either node is None + if not source_node or not connected_node: + continue + + source_label = ( + source_node.get("entity_id") + if source_node.get("entity_id") + else None + ) + target_label = ( + connected_node.get("entity_id") + if connected_node.get("entity_id") + else None + ) + + if source_label and target_label: + edges.append((source_label, target_label)) + + await results.consume() # Ensure results are consumed + return edges + except Exception as e: + logger.error( + f"Error getting edges for node {source_node_id}: {str(e)}" + ) + await results.consume() # Ensure results are consumed even on error + raise + except Exception as e: + logger.error(f"Error in get_node_edges for {source_node_id}: {str(e)}") + raise + + async def get_edge( + self, source_node_id: str, target_node_id: str + ) -> dict[str, str] | None: + """Get edge properties between two nodes. + + Args: + source_node_id: Label of the source node + target_node_id: Label of the target node + + Returns: + dict: Edge properties if found, default properties if not found or on error + + Raises: + Exception: If there is an error executing the query + """ + if self._driver is None: + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: + try: + workspace_label = self._get_workspace_label() + query = f""" + MATCH (start:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(end:`{workspace_label}` {{entity_id: $target_entity_id}}) + RETURN properties(r) as edge_properties + """ + result = await session.run( + query, + source_entity_id=source_node_id, + target_entity_id=target_node_id, + ) + records = await result.fetch(2) + await result.consume() + if records: + edge_result = dict(records[0]["edge_properties"]) + for key, default_value in { + "weight": 0.0, + "source_id": None, + "description": None, + "keywords": None, + }.items(): + if key not in edge_result: + edge_result[key] = default_value + logger.warning( + f"Edge between {source_node_id} and {target_node_id} is missing property: {key}. Using default value: {default_value}" + ) + return edge_result + return None + except Exception as e: + logger.error( + f"Error getting edge between {source_node_id} and {target_node_id}: {str(e)}" + ) + await result.consume() # Ensure the result is consumed even on error + raise + + async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: + """ + Upsert a node in the Neo4j database. + + Args: + node_id: The unique identifier for the node (used as label) + node_data: Dictionary of node properties + """ + if self._driver is None: + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) + properties = node_data + entity_type = properties["entity_type"] + if "entity_id" not in properties: + raise ValueError("Neo4j: node properties must contain an 'entity_id' field") + + try: + async with self._driver.session(database=self._DATABASE) as session: + workspace_label = self._get_workspace_label() + + async def execute_upsert(tx: AsyncManagedTransaction): + query = f""" + MERGE (n:`{workspace_label}` {{entity_id: $entity_id}}) + SET n += $properties + SET n:`{entity_type}` + """ + result = await tx.run( + query, entity_id=node_id, properties=properties + ) + await result.consume() # Ensure result is fully consumed + + await session.execute_write(execute_upsert) + except Exception as e: + logger.error(f"Error during upsert: {str(e)}") + raise + + async def upsert_edge( + self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] + ) -> None: + """ + Upsert an edge and its properties between two nodes identified by their labels. + Ensures both source and target nodes exist and are unique before creating the edge. + Uses entity_id property to uniquely identify nodes. + + Args: + source_node_id (str): Label of the source node (used as identifier) + target_node_id (str): Label of the target node (used as identifier) + edge_data (dict): Dictionary of properties to set on the edge + + Raises: + Exception: If there is an error executing the query + """ + if self._driver is None: + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) + try: + edge_properties = edge_data + async with self._driver.session(database=self._DATABASE) as session: + + async def execute_upsert(tx: AsyncManagedTransaction): + workspace_label = self._get_workspace_label() + query = f""" + MATCH (source:`{workspace_label}` {{entity_id: $source_entity_id}}) + WITH source + MATCH (target:`{workspace_label}` {{entity_id: $target_entity_id}}) + MERGE (source)-[r:DIRECTED]-(target) + SET r += $properties + RETURN r, source, target + """ + result = await tx.run( + query, + source_entity_id=source_node_id, + target_entity_id=target_node_id, + properties=edge_properties, + ) + try: + await result.fetch(2) + finally: + await result.consume() # Ensure result is consumed + + await session.execute_write(execute_upsert) + except Exception as e: + logger.error(f"Error during edge upsert: {str(e)}") + raise + + async def delete_node(self, node_id: str) -> None: + """Delete a node with the specified label + + Args: + node_id: The label of the node to delete + + Raises: + Exception: If there is an error executing the query + """ + if self._driver is None: + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) + + async def _do_delete(tx: AsyncManagedTransaction): + workspace_label = self._get_workspace_label() + query = f""" + MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) + DETACH DELETE n + """ + result = await tx.run(query, entity_id=node_id) + logger.debug(f"Deleted node with label {node_id}") + await result.consume() + + try: + async with self._driver.session(database=self._DATABASE) as session: + await session.execute_write(_do_delete) + except Exception as e: + logger.error(f"Error during node deletion: {str(e)}") + raise + + async def remove_nodes(self, nodes: list[str]): + """Delete multiple nodes + + Args: + nodes: List of node labels to be deleted + """ + if self._driver is None: + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) + for node in nodes: + await self.delete_node(node) + + async def remove_edges(self, edges: list[tuple[str, str]]): + """Delete multiple edges + + Args: + edges: List of edges to be deleted, each edge is a (source, target) tuple + + Raises: + Exception: If there is an error executing the query + """ + if self._driver is None: + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) + for source, target in edges: + + async def _do_delete_edge(tx: AsyncManagedTransaction): + workspace_label = self._get_workspace_label() + query = f""" + MATCH (source:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(target:`{workspace_label}` {{entity_id: $target_entity_id}}) + DELETE r + """ + result = await tx.run( + query, source_entity_id=source, target_entity_id=target + ) + logger.debug(f"Deleted edge from '{source}' to '{target}'") + await result.consume() # Ensure result is fully consumed + + try: + async with self._driver.session(database=self._DATABASE) as session: + await session.execute_write(_do_delete_edge) + except Exception as e: + logger.error(f"Error during edge deletion: {str(e)}") + raise + + async def drop(self) -> dict[str, str]: + """Drop all data from the current workspace and clean up resources + + This method will delete all nodes and relationships in the Memgraph database. + + Returns: + dict[str, str]: Operation status and message + - On success: {"status": "success", "message": "data dropped"} + - On failure: {"status": "error", "message": ""} + + Raises: + Exception: If there is an error executing the query + """ + if self._driver is None: + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) + try: + async with self._driver.session(database=self._DATABASE) as session: + workspace_label = self._get_workspace_label() + query = f"MATCH (n:`{workspace_label}`) DETACH DELETE n" + result = await session.run(query) + await result.consume() + logger.info( + f"Dropped workspace {workspace_label} from Memgraph database {self._DATABASE}" + ) + return {"status": "success", "message": "workspace data dropped"} + except Exception as e: + logger.error( + f"Error dropping workspace {workspace_label} from Memgraph database {self._DATABASE}: {e}" + ) + return {"status": "error", "message": str(e)} + + async def edge_degree(self, src_id: str, tgt_id: str) -> int: + """Get the total degree (sum of relationships) of two nodes. + + Args: + src_id: Label of the source node + tgt_id: Label of the target node + + Returns: + int: Sum of the degrees of both nodes + """ + if self._driver is None: + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) + src_degree = await self.node_degree(src_id) + trg_degree = await self.node_degree(tgt_id) + + # Convert None to 0 for addition + src_degree = 0 if src_degree is None else src_degree + trg_degree = 0 if trg_degree is None else trg_degree + + degrees = int(src_degree) + int(trg_degree) + return degrees + + async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: + """Get all nodes that are associated with the given chunk_ids. + + Args: + chunk_ids: List of chunk IDs to find associated nodes for + + Returns: + list[dict]: A list of nodes, where each node is a dictionary of its properties. + An empty list if no matching nodes are found. + """ + if self._driver is None: + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) + workspace_label = self._get_workspace_label() + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: + query = f""" + UNWIND $chunk_ids AS chunk_id + MATCH (n:`{workspace_label}`) + WHERE n.source_id IS NOT NULL AND chunk_id IN split(n.source_id, $sep) + RETURN DISTINCT n + """ + result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP) + nodes = [] + async for record in result: + node = record["n"] + node_dict = dict(node) + node_dict["id"] = node_dict.get("entity_id") + nodes.append(node_dict) + await result.consume() + return nodes + + async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: + """Get all edges that are associated with the given chunk_ids. + + Args: + chunk_ids: List of chunk IDs to find associated edges for + + Returns: + list[dict]: A list of edges, where each edge is a dictionary of its properties. + An empty list if no matching edges are found. + """ + if self._driver is None: + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) + workspace_label = self._get_workspace_label() + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: + query = f""" + UNWIND $chunk_ids AS chunk_id + MATCH (a:`{workspace_label}`)-[r]-(b:`{workspace_label}`) + WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, $sep) + WITH a, b, r, a.entity_id AS source_id, b.entity_id AS target_id + // Ensure we only return each unique edge once by ordering the source and target + WITH a, b, r, + CASE WHEN source_id <= target_id THEN source_id ELSE target_id END AS ordered_source, + CASE WHEN source_id <= target_id THEN target_id ELSE source_id END AS ordered_target + RETURN DISTINCT ordered_source AS source, ordered_target AS target, properties(r) AS properties + """ + result = await session.run(query, chunk_ids=chunk_ids, sep=GRAPH_FIELD_SEP) + edges = [] + async for record in result: + edge_properties = record["properties"] + edge_properties["source"] = record["source"] + edge_properties["target"] = record["target"] + edges.append(edge_properties) + await result.consume() + return edges + + async def get_knowledge_graph( + self, + node_label: str, + max_depth: int = 3, + max_nodes: int = MAX_GRAPH_NODES, + ) -> KnowledgeGraph: + """ + Retrieve a connected subgraph of nodes where the label includes the specified `node_label`. + + Args: + node_label: Label of the starting node, * means all nodes + max_depth: Maximum depth of the subgraph, Defaults to 3 + max_nodes: Maxiumu nodes to return by BFS, Defaults to 1000 + + Returns: + KnowledgeGraph object containing nodes and edges, with an is_truncated flag + indicating whether the graph was truncated due to max_nodes limit + + Raises: + Exception: If there is an error executing the query + """ + if self._driver is None: + raise RuntimeError( + "Memgraph driver is not initialized. Call 'await initialize()' first." + ) + + result = KnowledgeGraph() + seen_nodes = set() + seen_edges = set() + workspace_label = self._get_workspace_label() + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: + try: + if node_label == "*": + # First check if database has any nodes + count_query = "MATCH (n) RETURN count(n) as total" + count_result = None + total_count = 0 + try: + count_result = await session.run(count_query) + count_record = await count_result.single() + if count_record: + total_count = count_record["total"] + if total_count == 0: + logger.debug("No nodes found in database") + return result + if total_count > max_nodes: + result.is_truncated = True + logger.info( + f"Graph truncated: {total_count} nodes found, limited to {max_nodes}" + ) + finally: + if count_result: + await count_result.consume() + + # Run the main query to get nodes with highest degree + main_query = f""" + MATCH (n:`{workspace_label}`) + OPTIONAL MATCH (n)-[r]-() + WITH n, COALESCE(count(r), 0) AS degree + ORDER BY degree DESC + LIMIT $max_nodes + WITH collect(n) AS kept_nodes + MATCH (a)-[r]-(b) + WHERE a IN kept_nodes AND b IN kept_nodes + RETURN [node IN kept_nodes | {{node: node}}] AS node_info, + collect(DISTINCT r) AS relationships + """ + result_set = None + try: + result_set = await session.run( + main_query, {"max_nodes": max_nodes} + ) + record = await result_set.single() + if not record: + logger.debug("No record returned from main query") + return result + finally: + if result_set: + await result_set.consume() + + else: + bfs_query = f""" + MATCH (start:`{workspace_label}`) + WHERE start.entity_id = $entity_id + WITH start + CALL {{ + WITH start + MATCH path = (start)-[*0..{max_depth}]-(node) + WITH nodes(path) AS path_nodes, relationships(path) AS path_rels + UNWIND path_nodes AS n + WITH collect(DISTINCT n) AS all_nodes, collect(DISTINCT path_rels) AS all_rel_lists + WITH all_nodes, reduce(r = [], x IN all_rel_lists | r + x) AS all_rels + RETURN all_nodes, all_rels + }} + WITH all_nodes AS nodes, all_rels AS relationships, size(all_nodes) AS total_nodes + WITH + CASE + WHEN total_nodes <= {max_nodes} THEN nodes + ELSE nodes[0..{max_nodes}] + END AS limited_nodes, + relationships, + total_nodes, + total_nodes > {max_nodes} AS is_truncated + RETURN + [node IN limited_nodes | {{node: node}}] AS node_info, + relationships, + total_nodes, + is_truncated + """ + result_set = None + try: + result_set = await session.run( + bfs_query, + { + "entity_id": node_label, + }, + ) + record = await result_set.single() + if not record: + logger.debug(f"No nodes found for entity_id: {node_label}") + return result + + # Check if the query indicates truncation + if "is_truncated" in record and record["is_truncated"]: + result.is_truncated = True + logger.info( + f"Graph truncated: breadth-first search limited to {max_nodes} nodes" + ) + + finally: + if result_set: + await result_set.consume() + + # Process the record if it exists + if record and record["node_info"]: + for node_info in record["node_info"]: + node = node_info["node"] + node_id = node.id + if node_id not in seen_nodes: + seen_nodes.add(node_id) + result.nodes.append( + KnowledgeGraphNode( + id=f"{node_id}", + labels=[node.get("entity_id")], + properties=dict(node), + ) + ) + + for rel in record["relationships"]: + edge_id = rel.id + if edge_id not in seen_edges: + seen_edges.add(edge_id) + start = rel.start_node + end = rel.end_node + result.edges.append( + KnowledgeGraphEdge( + id=f"{edge_id}", + type=rel.type, + source=f"{start.id}", + target=f"{end.id}", + properties=dict(rel), + ) + ) + + logger.info( + f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" + ) + + except Exception as e: + logger.error(f"Error getting knowledge graph: {str(e)}") + # Return empty but properly initialized KnowledgeGraph on error + return KnowledgeGraph() + + return result diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 1f61a42e..bc3c289a 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -240,6 +240,17 @@ class LightRAG: llm_model_kwargs: dict[str, Any] = field(default_factory=dict) """Additional keyword arguments passed to the LLM model function.""" + # Rerank Configuration + # --- + + enable_rerank: bool = field( + default=bool(os.getenv("ENABLE_RERANK", "False").lower() == "true") + ) + """Enable reranking for improved retrieval quality. Defaults to False.""" + + rerank_model_func: Callable[..., object] | None = field(default=None) + """Function for reranking retrieved documents. All rerank configurations (model name, API keys, top_k, etc.) should be included in this function. Optional.""" + # Storage # --- @@ -447,6 +458,14 @@ class LightRAG: ) ) + # Init Rerank + if self.enable_rerank and self.rerank_model_func: + logger.info("Rerank model initialized for improved retrieval quality") + elif self.enable_rerank and not self.rerank_model_func: + logger.warning( + "Rerank is enabled but no rerank_model_func provided. Reranking will be skipped." + ) + self._storages_status = StoragesStatus.CREATED if self.auto_manage_storages_states: @@ -900,9 +919,15 @@ class LightRAG: # Get first document's file path and total count for job name first_doc_id, first_doc = next(iter(to_process_docs.items())) first_doc_path = first_doc.file_path - path_prefix = first_doc_path[:20] + ( - "..." if len(first_doc_path) > 20 else "" - ) + + # Handle cases where first_doc_path is None + if first_doc_path: + path_prefix = first_doc_path[:20] + ( + "..." if len(first_doc_path) > 20 else "" + ) + else: + path_prefix = "unknown_source" + total_files = len(to_process_docs) job_name = f"{path_prefix}[{total_files} files]" pipeline_status["job_name"] = job_name diff --git a/lightrag/operate.py b/lightrag/operate.py index 4e219cf8..be4499ab 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -1527,6 +1527,7 @@ async def kg_query( # Build context context = await _build_query_context( + query, ll_keywords_str, hl_keywords_str, knowledge_graph_inst, @@ -1746,84 +1747,52 @@ async def _get_vector_context( query: str, chunks_vdb: BaseVectorStorage, query_param: QueryParam, - tokenizer: Tokenizer, -) -> tuple[list, list, list] | None: +) -> list[dict]: """ - Retrieve vector context from the vector database. + Retrieve text chunks from the vector database without reranking or truncation. - This function performs vector search to find relevant text chunks for a query, - formats them with file path and creation time information. + This function performs vector search to find relevant text chunks for a query. + Reranking and truncation will be handled later in the unified processing. Args: query: The query string to search for chunks_vdb: Vector database containing document chunks - query_param: Query parameters including top_k and ids - tokenizer: Tokenizer for counting tokens + query_param: Query parameters including chunk_top_k and ids Returns: - Tuple (empty_entities, empty_relations, text_units) for combine_contexts, - compatible with _get_edge_data and _get_node_data format + List of text chunks with metadata """ try: - results = await chunks_vdb.query( - query, top_k=query_param.top_k, ids=query_param.ids - ) + # Use chunk_top_k if specified, otherwise fall back to top_k + search_top_k = query_param.chunk_top_k or query_param.top_k + + results = await chunks_vdb.query(query, top_k=search_top_k, ids=query_param.ids) if not results: - return [], [], [] + return [] valid_chunks = [] for result in results: if "content" in result: - # Directly use content from chunks_vdb.query result - chunk_with_time = { + chunk_with_metadata = { "content": result["content"], "created_at": result.get("created_at", None), "file_path": result.get("file_path", "unknown_source"), + "source_type": "vector", # Mark the source type } - valid_chunks.append(chunk_with_time) + valid_chunks.append(chunk_with_metadata) - if not valid_chunks: - return [], [], [] - - maybe_trun_chunks = truncate_list_by_token_size( - valid_chunks, - key=lambda x: x["content"], - max_token_size=query_param.max_token_for_text_unit, - tokenizer=tokenizer, - ) - - logger.debug( - f"Truncate chunks from {len(valid_chunks)} to {len(maybe_trun_chunks)} (max tokens:{query_param.max_token_for_text_unit})" - ) logger.info( - f"Query chunks: {len(maybe_trun_chunks)} chunks, top_k: {query_param.top_k}" + f"Naive query: {len(valid_chunks)} chunks (chunk_top_k: {search_top_k})" ) + return valid_chunks - if not maybe_trun_chunks: - return [], [], [] - - # Create empty entities and relations contexts - entities_context = [] - relations_context = [] - - # Create text_units_context directly as a list of dictionaries - text_units_context = [] - for i, chunk in enumerate(maybe_trun_chunks): - text_units_context.append( - { - "id": i + 1, - "content": chunk["content"], - "file_path": chunk["file_path"], - } - ) - - return entities_context, relations_context, text_units_context except Exception as e: logger.error(f"Error in _get_vector_context: {e}") - return [], [], [] + return [] async def _build_query_context( + query: str, ll_keywords: str, hl_keywords: str, knowledge_graph_inst: BaseGraphStorage, @@ -1831,27 +1800,36 @@ async def _build_query_context( relationships_vdb: BaseVectorStorage, text_chunks_db: BaseKVStorage, query_param: QueryParam, - chunks_vdb: BaseVectorStorage = None, # Add chunks_vdb parameter for mix mode + chunks_vdb: BaseVectorStorage = None, ): logger.info(f"Process {os.getpid()} building query context...") - # Handle local and global modes as before + # Collect all chunks from different sources + all_chunks = [] + entities_context = [] + relations_context = [] + + # Handle local and global modes if query_param.mode == "local": - entities_context, relations_context, text_units_context = await _get_node_data( + entities_context, relations_context, entity_chunks = await _get_node_data( ll_keywords, knowledge_graph_inst, entities_vdb, text_chunks_db, query_param, ) + all_chunks.extend(entity_chunks) + elif query_param.mode == "global": - entities_context, relations_context, text_units_context = await _get_edge_data( + entities_context, relations_context, relationship_chunks = await _get_edge_data( hl_keywords, knowledge_graph_inst, relationships_vdb, text_chunks_db, query_param, ) + all_chunks.extend(relationship_chunks) + else: # hybrid or mix mode ll_data = await _get_node_data( ll_keywords, @@ -1868,61 +1846,58 @@ async def _build_query_context( query_param, ) - ( - ll_entities_context, - ll_relations_context, - ll_text_units_context, - ) = ll_data + (ll_entities_context, ll_relations_context, ll_chunks) = ll_data + (hl_entities_context, hl_relations_context, hl_chunks) = hl_data - ( - hl_entities_context, - hl_relations_context, - hl_text_units_context, - ) = hl_data + # Collect chunks from entity and relationship sources + all_chunks.extend(ll_chunks) + all_chunks.extend(hl_chunks) - # Initialize vector data with empty lists - vector_entities_context, vector_relations_context, vector_text_units_context = ( - [], - [], - [], - ) - - # Only get vector data if in mix mode - if query_param.mode == "mix" and hasattr(query_param, "original_query"): - # Get tokenizer from text_chunks_db - tokenizer = text_chunks_db.global_config.get("tokenizer") - - # Get vector context in triple format - vector_data = await _get_vector_context( - query_param.original_query, # We need to pass the original query + # Get vector chunks if in mix mode + if query_param.mode == "mix" and chunks_vdb: + vector_chunks = await _get_vector_context( + query, chunks_vdb, query_param, - tokenizer, ) + all_chunks.extend(vector_chunks) - # If vector_data is not None, unpack it - if vector_data is not None: - ( - vector_entities_context, - vector_relations_context, - vector_text_units_context, - ) = vector_data - - # Combine and deduplicate the entities, relationships, and sources + # Combine entities and relations contexts entities_context = process_combine_contexts( - hl_entities_context, ll_entities_context, vector_entities_context + hl_entities_context, ll_entities_context ) relations_context = process_combine_contexts( - hl_relations_context, ll_relations_context, vector_relations_context + hl_relations_context, ll_relations_context ) - text_units_context = process_combine_contexts( - hl_text_units_context, ll_text_units_context, vector_text_units_context + + # Process all chunks uniformly: deduplication, reranking, and token truncation + processed_chunks = await process_chunks_unified( + query=query, + chunks=all_chunks, + query_param=query_param, + global_config=text_chunks_db.global_config, + source_type="mixed", + ) + + # Build final text_units_context from processed chunks + text_units_context = [] + for i, chunk in enumerate(processed_chunks): + text_units_context.append( + { + "id": i + 1, + "content": chunk["content"], + "file_path": chunk.get("file_path", "unknown_source"), + } ) + + logger.info( + f"Final context: {len(entities_context)} entities, {len(relations_context)} relations, {len(text_units_context)} chunks" + ) + # not necessary to use LLM to generate a response if not entities_context and not relations_context: return None - # 转换为 JSON 字符串 entities_str = json.dumps(entities_context, ensure_ascii=False) relations_str = json.dumps(relations_context, ensure_ascii=False) text_units_str = json.dumps(text_units_context, ensure_ascii=False) @@ -2069,16 +2044,7 @@ async def _get_node_data( } ) - text_units_context = [] - for i, t in enumerate(use_text_units): - text_units_context.append( - { - "id": i + 1, - "content": t["content"], - "file_path": t.get("file_path", "unknown_source"), - } - ) - return entities_context, relations_context, text_units_context + return entities_context, relations_context, use_text_units async def _find_most_related_text_unit_from_entities( @@ -2167,23 +2133,21 @@ async def _find_most_related_text_unit_from_entities( logger.warning("No valid text units found") return [] - tokenizer: Tokenizer = text_chunks_db.global_config.get("tokenizer") + # Sort by relation counts and order, but don't truncate all_text_units = sorted( all_text_units, key=lambda x: (x["order"], -x["relation_counts"]) ) - all_text_units = truncate_list_by_token_size( - all_text_units, - key=lambda x: x["data"]["content"], - max_token_size=query_param.max_token_for_text_unit, - tokenizer=tokenizer, - ) - logger.debug( - f"Truncate chunks from {len(all_text_units_lookup)} to {len(all_text_units)} (max tokens:{query_param.max_token_for_text_unit})" - ) + logger.debug(f"Found {len(all_text_units)} entity-related chunks") - all_text_units = [t["data"] for t in all_text_units] - return all_text_units + # Add source type marking and return chunk data + result_chunks = [] + for t in all_text_units: + chunk_data = t["data"].copy() + chunk_data["source_type"] = "entity" + result_chunks.append(chunk_data) + + return result_chunks async def _find_most_related_edges_from_entities( @@ -2485,21 +2449,16 @@ async def _find_related_text_unit_from_relationships( logger.warning("No valid text chunks after filtering") return [] - tokenizer: Tokenizer = text_chunks_db.global_config.get("tokenizer") - truncated_text_units = truncate_list_by_token_size( - valid_text_units, - key=lambda x: x["data"]["content"], - max_token_size=query_param.max_token_for_text_unit, - tokenizer=tokenizer, - ) + logger.debug(f"Found {len(valid_text_units)} relationship-related chunks") - logger.debug( - f"Truncate chunks from {len(valid_text_units)} to {len(truncated_text_units)} (max tokens:{query_param.max_token_for_text_unit})" - ) + # Add source type marking and return chunk data + result_chunks = [] + for t in valid_text_units: + chunk_data = t["data"].copy() + chunk_data["source_type"] = "relationship" + result_chunks.append(chunk_data) - all_text_units: list[TextChunkSchema] = [t["data"] for t in truncated_text_units] - - return all_text_units + return result_chunks async def naive_query( @@ -2527,12 +2486,32 @@ async def naive_query( tokenizer: Tokenizer = global_config["tokenizer"] - _, _, text_units_context = await _get_vector_context( - query, chunks_vdb, query_param, tokenizer + chunks = await _get_vector_context(query, chunks_vdb, query_param) + + if chunks is None or len(chunks) == 0: + return PROMPTS["fail_response"] + + # Process chunks using unified processing + processed_chunks = await process_chunks_unified( + query=query, + chunks=chunks, + query_param=query_param, + global_config=global_config, + source_type="vector", ) - if text_units_context is None or len(text_units_context) == 0: - return PROMPTS["fail_response"] + logger.info(f"Final context: {len(processed_chunks)} chunks") + + # Build text_units_context from processed chunks + text_units_context = [] + for i, chunk in enumerate(processed_chunks): + text_units_context.append( + { + "id": i + 1, + "content": chunk["content"], + "file_path": chunk.get("file_path", "unknown_source"), + } + ) text_units_str = json.dumps(text_units_context, ensure_ascii=False) if query_param.only_need_context: @@ -2658,6 +2637,7 @@ async def kg_query_with_keywords( hl_keywords_str = ", ".join(hl_keywords) if hl_keywords else "" context = await _build_query_context( + query, ll_keywords_str, hl_keywords_str, knowledge_graph_inst, @@ -2780,8 +2760,6 @@ async def query_with_keywords( f"{prompt}\n\n### Keywords\n\n{keywords_str}\n\n### Query\n\n{query}" ) - param.original_query = query - # Use appropriate query method based on mode if param.mode in ["local", "global", "hybrid", "mix"]: return await kg_query_with_keywords( @@ -2808,3 +2786,131 @@ async def query_with_keywords( ) else: raise ValueError(f"Unknown mode {param.mode}") + + +async def apply_rerank_if_enabled( + query: str, + retrieved_docs: list[dict], + global_config: dict, + top_k: int = None, +) -> list[dict]: + """ + Apply reranking to retrieved documents if rerank is enabled. + + Args: + query: The search query + retrieved_docs: List of retrieved documents + global_config: Global configuration containing rerank settings + top_k: Number of top documents to return after reranking + + Returns: + Reranked documents if rerank is enabled, otherwise original documents + """ + if not global_config.get("enable_rerank", False) or not retrieved_docs: + return retrieved_docs + + rerank_func = global_config.get("rerank_model_func") + if not rerank_func: + logger.debug( + "Rerank is enabled but no rerank function provided, skipping rerank" + ) + return retrieved_docs + + try: + logger.debug( + f"Applying rerank to {len(retrieved_docs)} documents, returning top {top_k}" + ) + + # Apply reranking - let rerank_model_func handle top_k internally + reranked_docs = await rerank_func( + query=query, + documents=retrieved_docs, + top_k=top_k, + ) + if reranked_docs and len(reranked_docs) > 0: + if len(reranked_docs) > top_k: + reranked_docs = reranked_docs[:top_k] + logger.info( + f"Successfully reranked {len(retrieved_docs)} documents to {len(reranked_docs)}" + ) + return reranked_docs + else: + logger.warning("Rerank returned empty results, using original documents") + return retrieved_docs + + except Exception as e: + logger.error(f"Error during reranking: {e}, using original documents") + return retrieved_docs + + +async def process_chunks_unified( + query: str, + chunks: list[dict], + query_param: QueryParam, + global_config: dict, + source_type: str = "mixed", +) -> list[dict]: + """ + Unified processing for text chunks: deduplication, chunk_top_k limiting, reranking, and token truncation. + + Args: + query: Search query for reranking + chunks: List of text chunks to process + query_param: Query parameters containing configuration + global_config: Global configuration dictionary + source_type: Source type for logging ("vector", "entity", "relationship", "mixed") + + Returns: + Processed and filtered list of text chunks + """ + if not chunks: + return [] + + # 1. Deduplication based on content + seen_content = set() + unique_chunks = [] + for chunk in chunks: + content = chunk.get("content", "") + if content and content not in seen_content: + seen_content.add(content) + unique_chunks.append(chunk) + + logger.debug( + f"Deduplication: {len(unique_chunks)} chunks (original: {len(chunks)})" + ) + + # 2. Apply reranking if enabled and query is provided + if global_config.get("enable_rerank", False) and query and unique_chunks: + rerank_top_k = query_param.chunk_rerank_top_k or len(unique_chunks) + unique_chunks = await apply_rerank_if_enabled( + query=query, + retrieved_docs=unique_chunks, + global_config=global_config, + top_k=rerank_top_k, + ) + logger.debug(f"Rerank: {len(unique_chunks)} chunks (source: {source_type})") + + # 3. Apply chunk_top_k limiting if specified + if query_param.chunk_top_k is not None and query_param.chunk_top_k > 0: + if len(unique_chunks) > query_param.chunk_top_k: + unique_chunks = unique_chunks[: query_param.chunk_top_k] + logger.debug( + f"Chunk top-k limiting: kept {len(unique_chunks)} chunks (chunk_top_k={query_param.chunk_top_k})" + ) + + # 4. Token-based final truncation + tokenizer = global_config.get("tokenizer") + if tokenizer and unique_chunks: + original_count = len(unique_chunks) + unique_chunks = truncate_list_by_token_size( + unique_chunks, + key=lambda x: x.get("content", ""), + max_token_size=query_param.max_token_for_text_unit, + tokenizer=tokenizer, + ) + logger.debug( + f"Token truncation: {len(unique_chunks)} chunks from {original_count} " + f"(max tokens: {query_param.max_token_for_text_unit}, source: {source_type})" + ) + + return unique_chunks diff --git a/lightrag/rerank.py b/lightrag/rerank.py new file mode 100644 index 00000000..59719bc9 --- /dev/null +++ b/lightrag/rerank.py @@ -0,0 +1,321 @@ +from __future__ import annotations + +import os +import aiohttp +from typing import Callable, Any, List, Dict, Optional +from pydantic import BaseModel, Field + +from .utils import logger + + +class RerankModel(BaseModel): + """ + Pydantic model class for defining a custom rerank model. + + This class provides a convenient wrapper for rerank functions, allowing you to + encapsulate all rerank configurations (API keys, model settings, etc.) in one place. + + Attributes: + rerank_func (Callable[[Any], List[Dict]]): A callable function that reranks documents. + The function should take query and documents as input and return reranked results. + kwargs (Dict[str, Any]): A dictionary that contains the arguments to pass to the callable function. + This should include all necessary configurations such as model name, API key, base_url, etc. + + Example usage: + Rerank model example with Jina: + ```python + rerank_model = RerankModel( + rerank_func=jina_rerank, + kwargs={ + "model": "BAAI/bge-reranker-v2-m3", + "api_key": "your_api_key_here", + "base_url": "https://api.jina.ai/v1/rerank" + } + ) + + # Use in LightRAG + rag = LightRAG( + enable_rerank=True, + rerank_model_func=rerank_model.rerank, + # ... other configurations + ) + ``` + + Or define a custom function directly: + ```python + async def my_rerank_func(query: str, documents: list, top_k: int = None, **kwargs): + return await jina_rerank( + query=query, + documents=documents, + model="BAAI/bge-reranker-v2-m3", + api_key="your_api_key_here", + top_k=top_k or 10, + **kwargs + ) + + rag = LightRAG( + enable_rerank=True, + rerank_model_func=my_rerank_func, + # ... other configurations + ) + ``` + """ + + rerank_func: Callable[[Any], List[Dict]] + kwargs: Dict[str, Any] = Field(default_factory=dict) + + async def rerank( + self, + query: str, + documents: List[Dict[str, Any]], + top_k: Optional[int] = None, + **extra_kwargs, + ) -> List[Dict[str, Any]]: + """Rerank documents using the configured model function.""" + # Merge extra kwargs with model kwargs + kwargs = {**self.kwargs, **extra_kwargs} + return await self.rerank_func( + query=query, documents=documents, top_k=top_k, **kwargs + ) + + +class MultiRerankModel(BaseModel): + """Multiple rerank models for different modes/scenarios.""" + + # Primary rerank model (used if mode-specific models are not defined) + rerank_model: Optional[RerankModel] = None + + # Mode-specific rerank models + entity_rerank_model: Optional[RerankModel] = None + relation_rerank_model: Optional[RerankModel] = None + chunk_rerank_model: Optional[RerankModel] = None + + async def rerank( + self, + query: str, + documents: List[Dict[str, Any]], + mode: str = "default", + top_k: Optional[int] = None, + **kwargs, + ) -> List[Dict[str, Any]]: + """Rerank using the appropriate model based on mode.""" + + # Select model based on mode + if mode == "entity" and self.entity_rerank_model: + model = self.entity_rerank_model + elif mode == "relation" and self.relation_rerank_model: + model = self.relation_rerank_model + elif mode == "chunk" and self.chunk_rerank_model: + model = self.chunk_rerank_model + elif self.rerank_model: + model = self.rerank_model + else: + logger.warning(f"No rerank model available for mode: {mode}") + return documents + + return await model.rerank(query, documents, top_k, **kwargs) + + +async def generic_rerank_api( + query: str, + documents: List[Dict[str, Any]], + model: str, + base_url: str, + api_key: str, + top_k: Optional[int] = None, + **kwargs, +) -> List[Dict[str, Any]]: + """ + Generic rerank function that works with Jina/Cohere compatible APIs. + + Args: + query: The search query + documents: List of documents to rerank + model: Model identifier + base_url: API endpoint URL + api_key: API authentication key + top_k: Number of top results to return + **kwargs: Additional API-specific parameters + + Returns: + List of reranked documents with relevance scores + """ + if not api_key: + logger.warning("No API key provided for rerank service") + return documents + + if not documents: + return documents + + # Prepare documents for reranking - handle both text and dict formats + prepared_docs = [] + for doc in documents: + if isinstance(doc, dict): + # Use 'content' field if available, otherwise use 'text' or convert to string + text = doc.get("content") or doc.get("text") or str(doc) + else: + text = str(doc) + prepared_docs.append(text) + + # Prepare request + headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"} + + data = {"model": model, "query": query, "documents": prepared_docs, **kwargs} + + if top_k is not None: + data["top_k"] = min(top_k, len(prepared_docs)) + + try: + async with aiohttp.ClientSession() as session: + async with session.post(base_url, headers=headers, json=data) as response: + if response.status != 200: + error_text = await response.text() + logger.error(f"Rerank API error {response.status}: {error_text}") + return documents + + result = await response.json() + + # Extract reranked results + if "results" in result: + # Standard format: results contain index and relevance_score + reranked_docs = [] + for item in result["results"]: + if "index" in item: + doc_idx = item["index"] + if 0 <= doc_idx < len(documents): + reranked_doc = documents[doc_idx].copy() + if "relevance_score" in item: + reranked_doc["rerank_score"] = item[ + "relevance_score" + ] + reranked_docs.append(reranked_doc) + return reranked_docs + else: + logger.warning("Unexpected rerank API response format") + return documents + + except Exception as e: + logger.error(f"Error during reranking: {e}") + return documents + + +async def jina_rerank( + query: str, + documents: List[Dict[str, Any]], + model: str = "BAAI/bge-reranker-v2-m3", + top_k: Optional[int] = None, + base_url: str = "https://api.jina.ai/v1/rerank", + api_key: Optional[str] = None, + **kwargs, +) -> List[Dict[str, Any]]: + """ + Rerank documents using Jina AI API. + + Args: + query: The search query + documents: List of documents to rerank + model: Jina rerank model name + top_k: Number of top results to return + base_url: Jina API endpoint + api_key: Jina API key + **kwargs: Additional parameters + + Returns: + List of reranked documents with relevance scores + """ + if api_key is None: + api_key = os.getenv("JINA_API_KEY") or os.getenv("RERANK_API_KEY") + + return await generic_rerank_api( + query=query, + documents=documents, + model=model, + base_url=base_url, + api_key=api_key, + top_k=top_k, + **kwargs, + ) + + +async def cohere_rerank( + query: str, + documents: List[Dict[str, Any]], + model: str = "rerank-english-v2.0", + top_k: Optional[int] = None, + base_url: str = "https://api.cohere.ai/v1/rerank", + api_key: Optional[str] = None, + **kwargs, +) -> List[Dict[str, Any]]: + """ + Rerank documents using Cohere API. + + Args: + query: The search query + documents: List of documents to rerank + model: Cohere rerank model name + top_k: Number of top results to return + base_url: Cohere API endpoint + api_key: Cohere API key + **kwargs: Additional parameters + + Returns: + List of reranked documents with relevance scores + """ + if api_key is None: + api_key = os.getenv("COHERE_API_KEY") or os.getenv("RERANK_API_KEY") + + return await generic_rerank_api( + query=query, + documents=documents, + model=model, + base_url=base_url, + api_key=api_key, + top_k=top_k, + **kwargs, + ) + + +# Convenience function for custom API endpoints +async def custom_rerank( + query: str, + documents: List[Dict[str, Any]], + model: str, + base_url: str, + api_key: str, + top_k: Optional[int] = None, + **kwargs, +) -> List[Dict[str, Any]]: + """ + Rerank documents using a custom API endpoint. + This is useful for self-hosted or custom rerank services. + """ + return await generic_rerank_api( + query=query, + documents=documents, + model=model, + base_url=base_url, + api_key=api_key, + top_k=top_k, + **kwargs, + ) + + +if __name__ == "__main__": + import asyncio + + async def main(): + # Example usage + docs = [ + {"content": "The capital of France is Paris."}, + {"content": "Tokyo is the capital of Japan."}, + {"content": "London is the capital of England."}, + ] + + query = "What is the capital of France?" + + result = await jina_rerank( + query=query, documents=docs, top_k=2, api_key="your-api-key-here" + ) + print(result) + + asyncio.run(main()) diff --git a/lightrag_webui/src/stores/settings.ts b/lightrag_webui/src/stores/settings.ts index 203502dc..5942ddca 100644 --- a/lightrag_webui/src/stores/settings.ts +++ b/lightrag_webui/src/stores/settings.ts @@ -111,7 +111,7 @@ const useSettingsStoreBase = create()( mode: 'global', response_type: 'Multiple Paragraphs', top_k: 10, - max_token_for_text_unit: 4000, + max_token_for_text_unit: 6000, max_token_for_global_context: 4000, max_token_for_local_context: 4000, only_need_context: false, diff --git a/tests/test_graph_storage.py b/tests/test_graph_storage.py index 258c8795..62f658ff 100644 --- a/tests/test_graph_storage.py +++ b/tests/test_graph_storage.py @@ -10,6 +10,7 @@ - Neo4JStorage - MongoDBStorage - PGGraphStorage +- MemgraphStorage """ import asyncio