diff --git a/.clinerules/01-basic.md b/.clinerules/01-basic.md index 955afa83..15997330 100644 --- a/.clinerules/01-basic.md +++ b/.clinerules/01-basic.md @@ -127,6 +127,31 @@ for key, value in matching_items: 4. **Implement caching strategically** - Cache expensive operations 5. **Monitor memory usage** - Prevent memory leaks +### 5. Testing Workflow (CRITICAL) +**Pattern**: All tests must use pytest markers for proper CI/CD execution +**Test Categories**: +- **Offline Tests**: Use `@pytest.mark.offline` - No external dependencies (runs in CI) +- **Integration Tests**: Use `@pytest.mark.integration` - Requires databases/APIs (skipped by default) + +**Commands**: +- `pytest tests/ -m offline -v` - CI default (~3 seconds for 21 tests) +- `pytest tests/ --run-integration -v` - Full test suite (all 46 tests) + +**Best Practices**: +1. **Prefer offline tests** - Use mocks for LLM, embeddings, databases +2. **Mock external dependencies** - AsyncMock for async functions +3. **Test isolation** - Each test should be independent +4. **Documentation** - Add docstrings explaining purpose and scope + +**Configuration**: +- `tests/pytest.ini` - Marker definitions and test discovery +- `tests/conftest.py` - Fixtures and custom options +- `.github/workflows/tests.yml` - CI/CD workflow (Python 3.10/3.11/3.12) + +**Documentation**: See `memory-bank/testing-guidelines.md` for complete testing guidelines + +**Impact**: Ensures all tests run reliably in CI without external services while maintaining comprehensive integration test coverage for local development + ## Technology Stack Intelligence ### 1. LLM Integration diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 00000000..e7d00f4a --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,54 @@ +name: Offline Unit Tests + +on: + push: + branches: [ main, dev ] + pull_request: + branches: [ main, dev ] + +jobs: + offline-tests: + name: Offline Tests + runs-on: ubuntu-latest + + strategy: + matrix: + python-version: ['3.10', '3.11', '3.12'] + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Cache pip packages + uses: actions/cache@v4 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements*.txt', '**/pyproject.toml') }} + restore-keys: | + ${{ runner.os }}-pip- + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e ".[api]" + pip install pytest pytest-asyncio + + - name: Run offline tests + run: | + # Run only tests marked as 'offline' (no external dependencies) + # Integration tests requiring databases/APIs are skipped by default + pytest tests/ -m offline -v --tb=short + + - name: Upload test results + if: always() + uses: actions/upload-artifact@v4 + with: + name: test-results-py${{ matrix.python-version }} + path: | + .pytest_cache/ + test-results.xml + retention-days: 7 diff --git a/.gitignore b/.gitignore index 8a5059c8..3c676aaf 100644 --- a/.gitignore +++ b/.gitignore @@ -72,3 +72,5 @@ download_models_hf.py # Cline files memory-bank +.claude/CLAUDE.md +.claude/ diff --git a/README-zh.md b/README-zh.md index 8dcbc0e5..57eb9e4a 100644 --- a/README-zh.md +++ b/README-zh.md @@ -222,6 +222,10 @@ python examples/lightrag_openai_demo.py > ⚠️ **如果您希望将LightRAG集成到您的项目中,建议您使用LightRAG Server提供的REST API**。LightRAG Core通常用于嵌入式应用,或供希望进行研究与评估的学者使用。 +### ⚠️ 重要:初始化要求 + +LightRAG 在使用前需要显式初始化。 创建 LightRAG 实例后,您必须调用 await rag.initialize_storages(),否则将出现错误。 + ### 一个简单程序 以下Python代码片段演示了如何初始化LightRAG、插入文本并进行查询: @@ -231,7 +235,6 @@ import os import asyncio from lightrag import LightRAG, QueryParam from lightrag.llm.openai import gpt_4o_mini_complete, gpt_4o_complete, openai_embed -from lightrag.kg.shared_storage import initialize_pipeline_status from lightrag.utils import setup_logger setup_logger("lightrag", level="INFO") @@ -246,9 +249,7 @@ async def initialize_rag(): embedding_func=openai_embed, llm_model_func=gpt_4o_mini_complete, ) - await rag.initialize_storages() - await initialize_pipeline_status() - return rag + await rag.initialize_storages() return rag async def main(): try: @@ -442,8 +443,6 @@ async def initialize_rag(): ) await rag.initialize_storages() - await initialize_pipeline_status() - return rag ``` @@ -572,7 +571,6 @@ from lightrag import LightRAG from lightrag.llm.llama_index_impl import llama_index_complete_if_cache, llama_index_embed from llama_index.embeddings.openai import OpenAIEmbedding from llama_index.llms.openai import OpenAI -from lightrag.kg.shared_storage import initialize_pipeline_status from lightrag.utils import setup_logger # 为LightRAG设置日志处理程序 @@ -589,8 +587,6 @@ async def initialize_rag(): ) await rag.initialize_storages() - await initialize_pipeline_status() - return rag def main(): @@ -840,8 +836,6 @@ async def initialize_rag(): # 初始化数据库连接 await rag.initialize_storages() # 初始化文档处理的管道状态 - await initialize_pipeline_status() - return rag ``` diff --git a/README.md b/README.md index 376d1154..9b3e3c70 100644 --- a/README.md +++ b/README.md @@ -224,10 +224,7 @@ For a streaming response implementation example, please see `examples/lightrag_o ### ⚠️ Important: Initialization Requirements -**LightRAG requires explicit initialization before use.** You must call both `await rag.initialize_storages()` and `await initialize_pipeline_status()` after creating a LightRAG instance, otherwise you will encounter errors like: - -- `AttributeError: __aenter__` - if storages are not initialized -- `KeyError: 'history_messages'` - if pipeline status is not initialized +**LightRAG requires explicit initialization before use.** You must call `await rag.initialize_storages()` after creating a LightRAG instance, otherwise you will encounter errors. ### A Simple Program @@ -238,7 +235,6 @@ import os import asyncio from lightrag import LightRAG, QueryParam from lightrag.llm.openai import gpt_4o_mini_complete, gpt_4o_complete, openai_embed -from lightrag.kg.shared_storage import initialize_pipeline_status from lightrag.utils import setup_logger setup_logger("lightrag", level="INFO") @@ -254,9 +250,7 @@ async def initialize_rag(): llm_model_func=gpt_4o_mini_complete, ) # IMPORTANT: Both initialization calls are required! - await rag.initialize_storages() # Initialize storage backends - await initialize_pipeline_status() # Initialize processing pipeline - return rag + await rag.initialize_storages() # Initialize storage backends return rag async def main(): try: @@ -445,8 +439,6 @@ async def initialize_rag(): ) await rag.initialize_storages() - await initialize_pipeline_status() - return rag ``` @@ -577,7 +569,6 @@ from lightrag import LightRAG from lightrag.llm.llama_index_impl import llama_index_complete_if_cache, llama_index_embed from llama_index.embeddings.openai import OpenAIEmbedding from llama_index.llms.openai import OpenAI -from lightrag.kg.shared_storage import initialize_pipeline_status from lightrag.utils import setup_logger # Setup log handler for LightRAG @@ -594,8 +585,6 @@ async def initialize_rag(): ) await rag.initialize_storages() - await initialize_pipeline_status() - return rag def main(): @@ -847,8 +836,6 @@ async def initialize_rag(): # Initialize database connections await rag.initialize_storages() # Initialize pipeline status for document processing - await initialize_pipeline_status() - return rag ``` @@ -933,8 +920,6 @@ async def initialize_rag(): # Initialize database connections await rag.initialize_storages() # Initialize pipeline status for document processing - await initialize_pipeline_status() - return rag ``` @@ -1542,16 +1527,13 @@ If you encounter these errors when using LightRAG: 2. **`KeyError: 'history_messages'`** - **Cause**: Pipeline status not initialized - - **Solution**: Call `await initialize_pipeline_status()` after initializing storages - + - **Solution**: Call ` 3. **Both errors in sequence** - **Cause**: Neither initialization method was called - **Solution**: Always follow this pattern: ```python rag = LightRAG(...) - await rag.initialize_storages() - await initialize_pipeline_status() - ``` + await rag.initialize_storages() ``` ### Model Switching Issues diff --git a/env.example b/env.example index 60aaf0ed..042a30b9 100644 --- a/env.example +++ b/env.example @@ -349,7 +349,8 @@ POSTGRES_USER=your_username POSTGRES_PASSWORD='your_password' POSTGRES_DATABASE=your_database POSTGRES_MAX_CONNECTIONS=12 -# POSTGRES_WORKSPACE=forced_workspace_name +### DB specific workspace should not be set, keep for compatible only +### POSTGRES_WORKSPACE=forced_workspace_name ### PostgreSQL Vector Storage Configuration ### Vector storage type: HNSW, IVFFlat @@ -395,7 +396,8 @@ NEO4J_MAX_TRANSACTION_RETRY_TIME=30 NEO4J_MAX_CONNECTION_LIFETIME=300 NEO4J_LIVENESS_CHECK_TIMEOUT=30 NEO4J_KEEP_ALIVE=true -# NEO4J_WORKSPACE=forced_workspace_name +### DB specific workspace should not be set, keep for compatible only +### NEO4J_WORKSPACE=forced_workspace_name ### MongoDB Configuration MONGO_URI=mongodb://root:root@localhost:27017/ @@ -409,12 +411,14 @@ MILVUS_DB_NAME=lightrag # MILVUS_USER=root # MILVUS_PASSWORD=your_password # MILVUS_TOKEN=your_token -# MILVUS_WORKSPACE=forced_workspace_name +### DB specific workspace should not be set, keep for compatible only +### MILVUS_WORKSPACE=forced_workspace_name ### Qdrant QDRANT_URL=http://localhost:6333 # QDRANT_API_KEY=your-api-key -# QDRANT_WORKSPACE=forced_workspace_name +### DB specific workspace should not be set, keep for compatible only +### QDRANT_WORKSPACE=forced_workspace_name ### Redis REDIS_URI=redis://localhost:6379 @@ -422,14 +426,16 @@ REDIS_SOCKET_TIMEOUT=30 REDIS_CONNECT_TIMEOUT=10 REDIS_MAX_CONNECTIONS=100 REDIS_RETRY_ATTEMPTS=3 -# REDIS_WORKSPACE=forced_workspace_name +### DB specific workspace should not be set, keep for compatible only +### REDIS_WORKSPACE=forced_workspace_name ### Memgraph Configuration MEMGRAPH_URI=bolt://localhost:7687 MEMGRAPH_USERNAME= MEMGRAPH_PASSWORD= MEMGRAPH_DATABASE=memgraph -# MEMGRAPH_WORKSPACE=forced_workspace_name +### DB specific workspace should not be set, keep for compatible only +### MEMGRAPH_WORKSPACE=forced_workspace_name ############################ ### Evaluation Configuration diff --git a/examples/lightrag_azure_openai_demo.py b/examples/lightrag_azure_openai_demo.py index c101383d..99c54e35 100644 --- a/examples/lightrag_azure_openai_demo.py +++ b/examples/lightrag_azure_openai_demo.py @@ -6,7 +6,6 @@ import numpy as np from dotenv import load_dotenv import logging from openai import AzureOpenAI -from lightrag.kg.shared_storage import initialize_pipeline_status logging.basicConfig(level=logging.INFO) @@ -93,9 +92,7 @@ async def initialize_rag(): ), ) - await rag.initialize_storages() - await initialize_pipeline_status() - + await rag.initialize_storages() # Auto-initializes pipeline_status return rag diff --git a/examples/lightrag_ollama_demo.py b/examples/lightrag_ollama_demo.py index 18fcc790..cb51e433 100644 --- a/examples/lightrag_ollama_demo.py +++ b/examples/lightrag_ollama_demo.py @@ -6,7 +6,6 @@ import logging.config from lightrag import LightRAG, QueryParam from lightrag.llm.ollama import ollama_model_complete, ollama_embed from lightrag.utils import EmbeddingFunc, logger, set_verbose_debug -from lightrag.kg.shared_storage import initialize_pipeline_status from dotenv import load_dotenv @@ -104,9 +103,7 @@ async def initialize_rag(): ), ) - await rag.initialize_storages() - await initialize_pipeline_status() - + await rag.initialize_storages() # Auto-initializes pipeline_status return rag diff --git a/examples/lightrag_openai_compatible_demo.py b/examples/lightrag_openai_compatible_demo.py index 15187d25..abeb6347 100644 --- a/examples/lightrag_openai_compatible_demo.py +++ b/examples/lightrag_openai_compatible_demo.py @@ -7,7 +7,6 @@ from lightrag import LightRAG, QueryParam from lightrag.llm.openai import openai_complete_if_cache from lightrag.llm.ollama import ollama_embed from lightrag.utils import EmbeddingFunc, logger, set_verbose_debug -from lightrag.kg.shared_storage import initialize_pipeline_status from dotenv import load_dotenv @@ -120,9 +119,7 @@ async def initialize_rag(): ), ) - await rag.initialize_storages() - await initialize_pipeline_status() - + await rag.initialize_storages() # Auto-initializes pipeline_status return rag diff --git a/examples/lightrag_openai_demo.py b/examples/lightrag_openai_demo.py index fa0b37f1..f79d5feb 100644 --- a/examples/lightrag_openai_demo.py +++ b/examples/lightrag_openai_demo.py @@ -4,7 +4,6 @@ import logging import logging.config from lightrag import LightRAG, QueryParam from lightrag.llm.openai import gpt_4o_mini_complete, openai_embed -from lightrag.kg.shared_storage import initialize_pipeline_status from lightrag.utils import logger, set_verbose_debug WORKING_DIR = "./dickens" @@ -84,8 +83,7 @@ async def initialize_rag(): llm_model_func=gpt_4o_mini_complete, ) - await rag.initialize_storages() - await initialize_pipeline_status() + await rag.initialize_storages() # Auto-initializes pipeline_status return rag diff --git a/examples/lightrag_openai_mongodb_graph_demo.py b/examples/lightrag_openai_mongodb_graph_demo.py index 67c51892..df8a455d 100644 --- a/examples/lightrag_openai_mongodb_graph_demo.py +++ b/examples/lightrag_openai_mongodb_graph_demo.py @@ -4,7 +4,6 @@ from lightrag import LightRAG, QueryParam from lightrag.llm.openai import gpt_4o_mini_complete, openai_embed from lightrag.utils import EmbeddingFunc import numpy as np -from lightrag.kg.shared_storage import initialize_pipeline_status ######### # Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert() @@ -61,9 +60,7 @@ async def initialize_rag(): log_level="DEBUG", ) - await rag.initialize_storages() - await initialize_pipeline_status() - + await rag.initialize_storages() # Auto-initializes pipeline_status return rag diff --git a/examples/modalprocessors_example.py b/examples/modalprocessors_example.py index b25c12c2..31eaa672 100644 --- a/examples/modalprocessors_example.py +++ b/examples/modalprocessors_example.py @@ -7,7 +7,6 @@ This example demonstrates how to use LightRAG's modal processors directly withou import asyncio import argparse from lightrag.llm.openai import openai_complete_if_cache, openai_embed -from lightrag.kg.shared_storage import initialize_pipeline_status from lightrag import LightRAG from lightrag.utils import EmbeddingFunc from raganything.modalprocessors import ( @@ -190,9 +189,7 @@ async def initialize_rag(api_key: str, base_url: str = None): ), ) - await rag.initialize_storages() - await initialize_pipeline_status() - + await rag.initialize_storages() # Auto-initializes pipeline_status return rag diff --git a/examples/rerank_example.py b/examples/rerank_example.py index c7db6656..da3d0efe 100644 --- a/examples/rerank_example.py +++ b/examples/rerank_example.py @@ -29,7 +29,6 @@ import numpy as np from lightrag import LightRAG, QueryParam from lightrag.llm.openai import openai_complete_if_cache, openai_embed from lightrag.utils import EmbeddingFunc, setup_logger -from lightrag.kg.shared_storage import initialize_pipeline_status from functools import partial from lightrag.rerank import cohere_rerank @@ -94,9 +93,7 @@ async def create_rag_with_rerank(): rerank_model_func=rerank_model_func, ) - await rag.initialize_storages() - await initialize_pipeline_status() - + await rag.initialize_storages() # Auto-initializes pipeline_status return rag diff --git a/examples/unofficial-sample/lightrag_bedrock_demo.py b/examples/unofficial-sample/lightrag_bedrock_demo.py index c7f41677..88c46538 100644 --- a/examples/unofficial-sample/lightrag_bedrock_demo.py +++ b/examples/unofficial-sample/lightrag_bedrock_demo.py @@ -8,7 +8,6 @@ import logging from lightrag import LightRAG, QueryParam from lightrag.llm.bedrock import bedrock_complete, bedrock_embed from lightrag.utils import EmbeddingFunc -from lightrag.kg.shared_storage import initialize_pipeline_status import asyncio import nest_asyncio @@ -32,9 +31,7 @@ async def initialize_rag(): ), ) - await rag.initialize_storages() - await initialize_pipeline_status() - + await rag.initialize_storages() # Auto-initializes pipeline_status return rag diff --git a/examples/unofficial-sample/lightrag_cloudflare_demo.py b/examples/unofficial-sample/lightrag_cloudflare_demo.py index b53e6714..55be6d28 100644 --- a/examples/unofficial-sample/lightrag_cloudflare_demo.py +++ b/examples/unofficial-sample/lightrag_cloudflare_demo.py @@ -5,7 +5,6 @@ import logging import logging.config from lightrag import LightRAG, QueryParam from lightrag.utils import EmbeddingFunc, logger, set_verbose_debug -from lightrag.kg.shared_storage import initialize_pipeline_status import requests import numpy as np @@ -221,9 +220,7 @@ async def initialize_rag(): ), ) - await rag.initialize_storages() - await initialize_pipeline_status() - + await rag.initialize_storages() # Auto-initializes pipeline_status return rag diff --git a/examples/unofficial-sample/lightrag_hf_demo.py b/examples/unofficial-sample/lightrag_hf_demo.py index f2abbb2f..68216b2a 100644 --- a/examples/unofficial-sample/lightrag_hf_demo.py +++ b/examples/unofficial-sample/lightrag_hf_demo.py @@ -4,7 +4,6 @@ from lightrag import LightRAG, QueryParam from lightrag.llm.hf import hf_model_complete, hf_embed from lightrag.utils import EmbeddingFunc from transformers import AutoModel, AutoTokenizer -from lightrag.kg.shared_storage import initialize_pipeline_status import asyncio import nest_asyncio @@ -37,9 +36,7 @@ async def initialize_rag(): ), ) - await rag.initialize_storages() - await initialize_pipeline_status() - + await rag.initialize_storages() # Auto-initializes pipeline_status return rag diff --git a/examples/unofficial-sample/lightrag_llamaindex_direct_demo.py b/examples/unofficial-sample/lightrag_llamaindex_direct_demo.py index d5e3f617..1226f1c4 100644 --- a/examples/unofficial-sample/lightrag_llamaindex_direct_demo.py +++ b/examples/unofficial-sample/lightrag_llamaindex_direct_demo.py @@ -12,7 +12,6 @@ import nest_asyncio nest_asyncio.apply() -from lightrag.kg.shared_storage import initialize_pipeline_status # Configure working directory WORKING_DIR = "./index_default" @@ -94,9 +93,7 @@ async def initialize_rag(): ), ) - await rag.initialize_storages() - await initialize_pipeline_status() - + await rag.initialize_storages() # Auto-initializes pipeline_status return rag diff --git a/examples/unofficial-sample/lightrag_llamaindex_litellm_demo.py b/examples/unofficial-sample/lightrag_llamaindex_litellm_demo.py index 3d0c69db..b8ce2957 100644 --- a/examples/unofficial-sample/lightrag_llamaindex_litellm_demo.py +++ b/examples/unofficial-sample/lightrag_llamaindex_litellm_demo.py @@ -12,7 +12,6 @@ import nest_asyncio nest_asyncio.apply() -from lightrag.kg.shared_storage import initialize_pipeline_status # Configure working directory WORKING_DIR = "./index_default" @@ -96,9 +95,7 @@ async def initialize_rag(): ), ) - await rag.initialize_storages() - await initialize_pipeline_status() - + await rag.initialize_storages() # Auto-initializes pipeline_status return rag diff --git a/examples/unofficial-sample/lightrag_llamaindex_litellm_opik_demo.py b/examples/unofficial-sample/lightrag_llamaindex_litellm_opik_demo.py index 700f6209..97537b37 100644 --- a/examples/unofficial-sample/lightrag_llamaindex_litellm_opik_demo.py +++ b/examples/unofficial-sample/lightrag_llamaindex_litellm_opik_demo.py @@ -12,7 +12,6 @@ import nest_asyncio nest_asyncio.apply() -from lightrag.kg.shared_storage import initialize_pipeline_status # Configure working directory WORKING_DIR = "./index_default" @@ -107,9 +106,7 @@ async def initialize_rag(): ), ) - await rag.initialize_storages() - await initialize_pipeline_status() - + await rag.initialize_storages() # Auto-initializes pipeline_status return rag diff --git a/examples/unofficial-sample/lightrag_lmdeploy_demo.py b/examples/unofficial-sample/lightrag_lmdeploy_demo.py index ba118fc9..3f2062aa 100644 --- a/examples/unofficial-sample/lightrag_lmdeploy_demo.py +++ b/examples/unofficial-sample/lightrag_lmdeploy_demo.py @@ -5,7 +5,6 @@ from lightrag.llm.lmdeploy import lmdeploy_model_if_cache from lightrag.llm.hf import hf_embed from lightrag.utils import EmbeddingFunc from transformers import AutoModel, AutoTokenizer -from lightrag.kg.shared_storage import initialize_pipeline_status import asyncio import nest_asyncio @@ -62,9 +61,7 @@ async def initialize_rag(): ), ) - await rag.initialize_storages() - await initialize_pipeline_status() - + await rag.initialize_storages() # Auto-initializes pipeline_status return rag diff --git a/examples/unofficial-sample/lightrag_nvidia_demo.py b/examples/unofficial-sample/lightrag_nvidia_demo.py index 97cfc38a..ca63c8ac 100644 --- a/examples/unofficial-sample/lightrag_nvidia_demo.py +++ b/examples/unofficial-sample/lightrag_nvidia_demo.py @@ -9,7 +9,6 @@ from lightrag.llm import ( ) from lightrag.utils import EmbeddingFunc import numpy as np -from lightrag.kg.shared_storage import initialize_pipeline_status # for custom llm_model_func from lightrag.utils import locate_json_string_body_from_string @@ -115,9 +114,7 @@ async def initialize_rag(): ), ) - await rag.initialize_storages() - await initialize_pipeline_status() - + await rag.initialize_storages() # Auto-initializes pipeline_status return rag diff --git a/examples/unofficial-sample/lightrag_openai_neo4j_milvus_redis_demo.py b/examples/unofficial-sample/lightrag_openai_neo4j_milvus_redis_demo.py index 00845796..509c7059 100644 --- a/examples/unofficial-sample/lightrag_openai_neo4j_milvus_redis_demo.py +++ b/examples/unofficial-sample/lightrag_openai_neo4j_milvus_redis_demo.py @@ -3,7 +3,6 @@ import asyncio from lightrag import LightRAG, QueryParam from lightrag.llm.ollama import ollama_embed, openai_complete_if_cache from lightrag.utils import EmbeddingFunc -from lightrag.kg.shared_storage import initialize_pipeline_status # WorkingDir ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) @@ -66,9 +65,7 @@ async def initialize_rag(): doc_status_storage="RedisKVStorage", ) - await rag.initialize_storages() - await initialize_pipeline_status() - + await rag.initialize_storages() # Auto-initializes pipeline_status return rag diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 41a07f7f..b29e39b2 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -56,7 +56,8 @@ from lightrag.api.routers.ollama_api import OllamaAPI from lightrag.utils import logger, set_verbose_debug from lightrag.kg.shared_storage import ( get_namespace_data, - initialize_pipeline_status, + get_default_workspace, + # set_default_workspace, cleanup_keyed_lock, finalize_share_data, ) @@ -350,8 +351,8 @@ def create_app(args): try: # Initialize database connections + # Note: initialize_storages() now auto-initializes pipeline_status for rag.workspace await rag.initialize_storages() - await initialize_pipeline_status() # Data migration regardless of storage implementation await rag.check_and_migrate_data() @@ -452,6 +453,28 @@ def create_app(args): # Create combined auth dependency for all endpoints combined_auth = get_combined_auth_dependency(api_key) + def get_workspace_from_request(request: Request) -> str | None: + """ + Extract workspace from HTTP request header or use default. + + This enables multi-workspace API support by checking the custom + 'LIGHTRAG-WORKSPACE' header. If not present, falls back to the + server's default workspace configuration. + + Args: + request: FastAPI Request object + + Returns: + Workspace identifier (may be empty string for global namespace) + """ + # Check custom header first + workspace = request.headers.get("LIGHTRAG-WORKSPACE", "").strip() + + if not workspace: + workspace = None + + return workspace + # Create working directory if it doesn't exist Path(args.working_dir).mkdir(parents=True, exist_ok=True) @@ -1113,10 +1136,16 @@ def create_app(args): } @app.get("/health", dependencies=[Depends(combined_auth)]) - async def get_status(): + async def get_status(request: Request): """Get current system status""" try: - pipeline_status = await get_namespace_data("pipeline_status") + workspace = get_workspace_from_request(request) + default_workspace = get_default_workspace() + if workspace is None: + workspace = default_workspace + pipeline_status = await get_namespace_data( + "pipeline_status", workspace=workspace + ) if not auth_configured: auth_mode = "disabled" @@ -1147,7 +1176,7 @@ def create_app(args): "vector_storage": args.vector_storage, "enable_llm_cache_for_extract": args.enable_llm_cache_for_extract, "enable_llm_cache": args.enable_llm_cache, - "workspace": args.workspace, + "workspace": default_workspace, "max_graph_nodes": args.max_graph_nodes, # Rerank configuration "enable_rerank": rerank_model_func is not None, diff --git a/lightrag/api/routers/document_routes.py b/lightrag/api/routers/document_routes.py index a59f10c7..a0c2f0dd 100644 --- a/lightrag/api/routers/document_routes.py +++ b/lightrag/api/routers/document_routes.py @@ -1641,11 +1641,15 @@ async def background_delete_documents( """Background task to delete multiple documents""" from lightrag.kg.shared_storage import ( get_namespace_data, - get_pipeline_status_lock, + get_namespace_lock, ) - pipeline_status = await get_namespace_data("pipeline_status") - pipeline_status_lock = get_pipeline_status_lock() + pipeline_status = await get_namespace_data( + "pipeline_status", workspace=rag.workspace + ) + pipeline_status_lock = get_namespace_lock( + "pipeline_status", workspace=rag.workspace + ) total_docs = len(doc_ids) successful_deletions = [] @@ -1661,6 +1665,7 @@ async def background_delete_documents( pipeline_status.update( { "busy": True, + # Job name can not be changed, it's verified in adelete_by_doc_id() "job_name": f"Deleting {total_docs} Documents", "job_start": datetime.now().isoformat(), "docs": total_docs, @@ -2134,12 +2139,16 @@ def create_document_routes( """ from lightrag.kg.shared_storage import ( get_namespace_data, - get_pipeline_status_lock, + get_namespace_lock, ) # Get pipeline status and lock - pipeline_status = await get_namespace_data("pipeline_status") - pipeline_status_lock = get_pipeline_status_lock() + pipeline_status = await get_namespace_data( + "pipeline_status", workspace=rag.workspace + ) + pipeline_status_lock = get_namespace_lock( + "pipeline_status", workspace=rag.workspace + ) # Check and set status with lock async with pipeline_status_lock: @@ -2330,13 +2339,19 @@ def create_document_routes( try: from lightrag.kg.shared_storage import ( get_namespace_data, + get_namespace_lock, get_all_update_flags_status, ) - pipeline_status = await get_namespace_data("pipeline_status") + pipeline_status = await get_namespace_data( + "pipeline_status", workspace=rag.workspace + ) + pipeline_status_lock = get_namespace_lock( + "pipeline_status", workspace=rag.workspace + ) # Get update flags status for all namespaces - update_status = await get_all_update_flags_status() + update_status = await get_all_update_flags_status(workspace=rag.workspace) # Convert MutableBoolean objects to regular boolean values processed_update_status = {} @@ -2350,8 +2365,9 @@ def create_document_routes( processed_flags.append(bool(flag)) processed_update_status[namespace] = processed_flags - # Convert to regular dict if it's a Manager.dict - status_dict = dict(pipeline_status) + async with pipeline_status_lock: + # Convert to regular dict if it's a Manager.dict + status_dict = dict(pipeline_status) # Add processed update_status to the status dictionary status_dict["update_status"] = processed_update_status @@ -2538,17 +2554,26 @@ def create_document_routes( doc_ids = delete_request.doc_ids try: - from lightrag.kg.shared_storage import get_namespace_data + from lightrag.kg.shared_storage import ( + get_namespace_data, + get_namespace_lock, + ) - pipeline_status = await get_namespace_data("pipeline_status") + pipeline_status = await get_namespace_data( + "pipeline_status", workspace=rag.workspace + ) + pipeline_status_lock = get_namespace_lock( + "pipeline_status", workspace=rag.workspace + ) - # Check if pipeline is busy - if pipeline_status.get("busy", False): - return DeleteDocByIdResponse( - status="busy", - message="Cannot delete documents while pipeline is busy", - doc_id=", ".join(doc_ids), - ) + # Check if pipeline is busy with proper lock + async with pipeline_status_lock: + if pipeline_status.get("busy", False): + return DeleteDocByIdResponse( + status="busy", + message="Cannot delete documents while pipeline is busy", + doc_id=", ".join(doc_ids), + ) # Add deletion task to background tasks background_tasks.add_task( @@ -2944,11 +2969,15 @@ def create_document_routes( try: from lightrag.kg.shared_storage import ( get_namespace_data, - get_pipeline_status_lock, + get_namespace_lock, ) - pipeline_status = await get_namespace_data("pipeline_status") - pipeline_status_lock = get_pipeline_status_lock() + pipeline_status = await get_namespace_data( + "pipeline_status", workspace=rag.workspace + ) + pipeline_status_lock = get_namespace_lock( + "pipeline_status", workspace=rag.workspace + ) async with pipeline_status_lock: if not pipeline_status.get("busy", False): diff --git a/lightrag/exceptions.py b/lightrag/exceptions.py index 303391c2..e6a616cd 100644 --- a/lightrag/exceptions.py +++ b/lightrag/exceptions.py @@ -68,10 +68,7 @@ class StorageNotInitializedError(RuntimeError): f"{storage_type} not initialized. Please ensure proper initialization:\n" f"\n" f" rag = LightRAG(...)\n" - f" await rag.initialize_storages() # Required\n" - f" \n" - f" from lightrag.kg.shared_storage import initialize_pipeline_status\n" - f" await initialize_pipeline_status() # Required for pipeline operations\n" + f" await rag.initialize_storages() # Required - auto-initializes pipeline_status\n" f"\n" f"See: https://github.com/HKUDS/LightRAG#important-initialization-requirements" ) @@ -82,18 +79,21 @@ class PipelineNotInitializedError(KeyError): def __init__(self, namespace: str = ""): msg = ( - f"Pipeline namespace '{namespace}' not found. " - f"This usually means pipeline status was not initialized.\n" + f"Pipeline namespace '{namespace}' not found.\n" f"\n" - f"Please call 'await initialize_pipeline_status()' after initializing storages:\n" + f"Pipeline status should be auto-initialized by initialize_storages().\n" + f"If you see this error, please ensure:\n" f"\n" + f" 1. You called await rag.initialize_storages()\n" + f" 2. For multi-workspace setups, each LightRAG instance was properly initialized\n" + f"\n" + f"Standard initialization:\n" + f" rag = LightRAG(workspace='your_workspace')\n" + f" await rag.initialize_storages() # Auto-initializes pipeline_status\n" + f"\n" + f"If you need manual control (advanced):\n" f" from lightrag.kg.shared_storage import initialize_pipeline_status\n" - f" await initialize_pipeline_status()\n" - f"\n" - f"Full initialization sequence:\n" - f" rag = LightRAG(...)\n" - f" await rag.initialize_storages()\n" - f" await initialize_pipeline_status()" + f" await initialize_pipeline_status(workspace='your_workspace')" ) super().__init__(msg) diff --git a/lightrag/kg/faiss_impl.py b/lightrag/kg/faiss_impl.py index 2f10ab1a..adb0058b 100644 --- a/lightrag/kg/faiss_impl.py +++ b/lightrag/kg/faiss_impl.py @@ -10,7 +10,7 @@ from lightrag.utils import logger, compute_mdhash_id from lightrag.base import BaseVectorStorage from .shared_storage import ( - get_storage_lock, + get_namespace_lock, get_update_flag, set_all_update_flags, ) @@ -42,13 +42,11 @@ class FaissVectorDBStorage(BaseVectorStorage): if self.workspace: # Include workspace in the file path for data isolation workspace_dir = os.path.join(working_dir, self.workspace) - self.final_namespace = f"{self.workspace}_{self.namespace}" else: # Default behavior when workspace is empty - self.final_namespace = self.namespace - self.workspace = "_" workspace_dir = working_dir + self.workspace = "" os.makedirs(workspace_dir, exist_ok=True) self._faiss_index_file = os.path.join( @@ -73,9 +71,13 @@ class FaissVectorDBStorage(BaseVectorStorage): async def initialize(self): """Initialize storage data""" # Get the update flag for cross-process update notification - self.storage_updated = await get_update_flag(self.final_namespace) + self.storage_updated = await get_update_flag( + self.namespace, workspace=self.workspace + ) # Get the storage lock for use in other methods - self._storage_lock = get_storage_lock() + self._storage_lock = get_namespace_lock( + self.namespace, workspace=self.workspace + ) async def _get_index(self): """Check if the shtorage should be reloaded""" @@ -400,7 +402,7 @@ class FaissVectorDBStorage(BaseVectorStorage): # Save data to disk self._save_faiss_index() # Notify other processes that data has been updated - await set_all_update_flags(self.final_namespace) + await set_all_update_flags(self.namespace, workspace=self.workspace) # Reset own update flag to avoid self-reloading self.storage_updated.value = False except Exception as e: @@ -527,7 +529,7 @@ class FaissVectorDBStorage(BaseVectorStorage): self._load_faiss_index() # Notify other processes - await set_all_update_flags(self.final_namespace) + await set_all_update_flags(self.namespace, workspace=self.workspace) self.storage_updated.value = False logger.info( diff --git a/lightrag/kg/json_doc_status_impl.py b/lightrag/kg/json_doc_status_impl.py index bf6e7b17..df6502ee 100644 --- a/lightrag/kg/json_doc_status_impl.py +++ b/lightrag/kg/json_doc_status_impl.py @@ -16,7 +16,7 @@ from lightrag.utils import ( from lightrag.exceptions import StorageNotInitializedError from .shared_storage import ( get_namespace_data, - get_storage_lock, + get_namespace_lock, get_data_init_lock, get_update_flag, set_all_update_flags, @@ -35,12 +35,10 @@ class JsonDocStatusStorage(DocStatusStorage): if self.workspace: # Include workspace in the file path for data isolation workspace_dir = os.path.join(working_dir, self.workspace) - self.final_namespace = f"{self.workspace}_{self.namespace}" else: # Default behavior when workspace is empty - self.final_namespace = self.namespace - self.workspace = "_" workspace_dir = working_dir + self.workspace = "" os.makedirs(workspace_dir, exist_ok=True) self._file_name = os.path.join(workspace_dir, f"kv_store_{self.namespace}.json") @@ -50,12 +48,20 @@ class JsonDocStatusStorage(DocStatusStorage): async def initialize(self): """Initialize storage data""" - self._storage_lock = get_storage_lock() - self.storage_updated = await get_update_flag(self.final_namespace) + self._storage_lock = get_namespace_lock( + self.namespace, workspace=self.workspace + ) + self.storage_updated = await get_update_flag( + self.namespace, workspace=self.workspace + ) async with get_data_init_lock(): # check need_init must before get_namespace_data - need_init = await try_initialize_namespace(self.final_namespace) - self._data = await get_namespace_data(self.final_namespace) + need_init = await try_initialize_namespace( + self.namespace, workspace=self.workspace + ) + self._data = await get_namespace_data( + self.namespace, workspace=self.workspace + ) if need_init: loaded_data = load_json(self._file_name) or {} async with self._storage_lock: @@ -175,7 +181,7 @@ class JsonDocStatusStorage(DocStatusStorage): self._data.clear() self._data.update(cleaned_data) - await clear_all_update_flags(self.final_namespace) + await clear_all_update_flags(self.namespace, workspace=self.workspace) async def upsert(self, data: dict[str, dict[str, Any]]) -> None: """ @@ -196,7 +202,7 @@ class JsonDocStatusStorage(DocStatusStorage): if "chunks_list" not in doc_data: doc_data["chunks_list"] = [] self._data.update(data) - await set_all_update_flags(self.final_namespace) + await set_all_update_flags(self.namespace, workspace=self.workspace) await self.index_done_callback() @@ -350,7 +356,7 @@ class JsonDocStatusStorage(DocStatusStorage): any_deleted = True if any_deleted: - await set_all_update_flags(self.final_namespace) + await set_all_update_flags(self.namespace, workspace=self.workspace) async def get_doc_by_file_path(self, file_path: str) -> Union[dict[str, Any], None]: """Get document by file path @@ -389,7 +395,7 @@ class JsonDocStatusStorage(DocStatusStorage): try: async with self._storage_lock: self._data.clear() - await set_all_update_flags(self.final_namespace) + await set_all_update_flags(self.namespace, workspace=self.workspace) await self.index_done_callback() logger.info( diff --git a/lightrag/kg/json_kv_impl.py b/lightrag/kg/json_kv_impl.py index f9adb20f..8435c989 100644 --- a/lightrag/kg/json_kv_impl.py +++ b/lightrag/kg/json_kv_impl.py @@ -13,7 +13,7 @@ from lightrag.utils import ( from lightrag.exceptions import StorageNotInitializedError from .shared_storage import ( get_namespace_data, - get_storage_lock, + get_namespace_lock, get_data_init_lock, get_update_flag, set_all_update_flags, @@ -30,12 +30,10 @@ class JsonKVStorage(BaseKVStorage): if self.workspace: # Include workspace in the file path for data isolation workspace_dir = os.path.join(working_dir, self.workspace) - self.final_namespace = f"{self.workspace}_{self.namespace}" else: # Default behavior when workspace is empty workspace_dir = working_dir - self.final_namespace = self.namespace - self.workspace = "_" + self.workspace = "" os.makedirs(workspace_dir, exist_ok=True) self._file_name = os.path.join(workspace_dir, f"kv_store_{self.namespace}.json") @@ -46,12 +44,20 @@ class JsonKVStorage(BaseKVStorage): async def initialize(self): """Initialize storage data""" - self._storage_lock = get_storage_lock() - self.storage_updated = await get_update_flag(self.final_namespace) + self._storage_lock = get_namespace_lock( + self.namespace, workspace=self.workspace + ) + self.storage_updated = await get_update_flag( + self.namespace, workspace=self.workspace + ) async with get_data_init_lock(): # check need_init must before get_namespace_data - need_init = await try_initialize_namespace(self.final_namespace) - self._data = await get_namespace_data(self.final_namespace) + need_init = await try_initialize_namespace( + self.namespace, workspace=self.workspace + ) + self._data = await get_namespace_data( + self.namespace, workspace=self.workspace + ) if need_init: loaded_data = load_json(self._file_name) or {} async with self._storage_lock: @@ -95,7 +101,7 @@ class JsonKVStorage(BaseKVStorage): self._data.clear() self._data.update(cleaned_data) - await clear_all_update_flags(self.final_namespace) + await clear_all_update_flags(self.namespace, workspace=self.workspace) async def get_by_id(self, id: str) -> dict[str, Any] | None: async with self._storage_lock: @@ -168,7 +174,7 @@ class JsonKVStorage(BaseKVStorage): v["_id"] = k self._data.update(data) - await set_all_update_flags(self.final_namespace) + await set_all_update_flags(self.namespace, workspace=self.workspace) async def delete(self, ids: list[str]) -> None: """Delete specific records from storage by their IDs @@ -191,7 +197,7 @@ class JsonKVStorage(BaseKVStorage): any_deleted = True if any_deleted: - await set_all_update_flags(self.final_namespace) + await set_all_update_flags(self.namespace, workspace=self.workspace) async def is_empty(self) -> bool: """Check if the storage is empty @@ -219,7 +225,7 @@ class JsonKVStorage(BaseKVStorage): try: async with self._storage_lock: self._data.clear() - await set_all_update_flags(self.final_namespace) + await set_all_update_flags(self.namespace, workspace=self.workspace) await self.index_done_callback() logger.info( diff --git a/lightrag/kg/memgraph_impl.py b/lightrag/kg/memgraph_impl.py index e82aceec..6fd6841c 100644 --- a/lightrag/kg/memgraph_impl.py +++ b/lightrag/kg/memgraph_impl.py @@ -8,7 +8,7 @@ import configparser from ..utils import logger from ..base import BaseGraphStorage from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge -from ..kg.shared_storage import get_data_init_lock, get_graph_db_lock +from ..kg.shared_storage import get_data_init_lock import pipmaster as pm if not pm.is_installed("neo4j"): @@ -101,10 +101,9 @@ class MemgraphStorage(BaseGraphStorage): raise async def finalize(self): - async with get_graph_db_lock(): - if self._driver is not None: - await self._driver.close() - self._driver = None + if self._driver is not None: + await self._driver.close() + self._driver = None async def __aexit__(self, exc_type, exc, tb): await self.finalize() @@ -762,22 +761,21 @@ class MemgraphStorage(BaseGraphStorage): raise RuntimeError( "Memgraph driver is not initialized. Call 'await initialize()' first." ) - async with get_graph_db_lock(): - try: - async with self._driver.session(database=self._DATABASE) as session: - workspace_label = self._get_workspace_label() - query = f"MATCH (n:`{workspace_label}`) DETACH DELETE n" - result = await session.run(query) - await result.consume() - logger.info( - f"[{self.workspace}] Dropped workspace {workspace_label} from Memgraph database {self._DATABASE}" - ) - return {"status": "success", "message": "workspace data dropped"} - except Exception as e: - logger.error( - f"[{self.workspace}] Error dropping workspace {workspace_label} from Memgraph database {self._DATABASE}: {e}" + try: + async with self._driver.session(database=self._DATABASE) as session: + workspace_label = self._get_workspace_label() + query = f"MATCH (n:`{workspace_label}`) DETACH DELETE n" + result = await session.run(query) + await result.consume() + logger.info( + f"[{self.workspace}] Dropped workspace {workspace_label} from Memgraph database {self._DATABASE}" ) - return {"status": "error", "message": str(e)} + return {"status": "success", "message": "workspace data dropped"} + except Exception as e: + logger.error( + f"[{self.workspace}] Error dropping workspace {workspace_label} from Memgraph database {self._DATABASE}: {e}" + ) + return {"status": "error", "message": str(e)} async def edge_degree(self, src_id: str, tgt_id: str) -> int: """Get the total degree (sum of relationships) of two nodes. diff --git a/lightrag/kg/milvus_impl.py b/lightrag/kg/milvus_impl.py index 3c621c06..d42c91a7 100644 --- a/lightrag/kg/milvus_impl.py +++ b/lightrag/kg/milvus_impl.py @@ -6,7 +6,7 @@ import numpy as np from lightrag.utils import logger, compute_mdhash_id from ..base import BaseVectorStorage from ..constants import DEFAULT_MAX_FILE_PATH_LENGTH -from ..kg.shared_storage import get_data_init_lock, get_storage_lock +from ..kg.shared_storage import get_data_init_lock import pipmaster as pm if not pm.is_installed("pymilvus"): @@ -961,8 +961,8 @@ class MilvusVectorDBStorage(BaseVectorStorage): else: # When workspace is empty, final_namespace equals original namespace self.final_namespace = self.namespace + self.workspace = "" logger.debug(f"Final namespace (no workspace): '{self.final_namespace}'") - self.workspace = "_" kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {}) cosine_threshold = kwargs.get("cosine_better_than_threshold") @@ -1351,21 +1351,20 @@ class MilvusVectorDBStorage(BaseVectorStorage): - On success: {"status": "success", "message": "data dropped"} - On failure: {"status": "error", "message": ""} """ - async with get_storage_lock(): - try: - # Drop the collection and recreate it - if self._client.has_collection(self.final_namespace): - self._client.drop_collection(self.final_namespace) + try: + # Drop the collection and recreate it + if self._client.has_collection(self.final_namespace): + self._client.drop_collection(self.final_namespace) - # Recreate the collection - self._create_collection_if_not_exist() + # Recreate the collection + self._create_collection_if_not_exist() - logger.info( - f"[{self.workspace}] Process {os.getpid()} drop Milvus collection {self.namespace}" - ) - return {"status": "success", "message": "data dropped"} - except Exception as e: - logger.error( - f"[{self.workspace}] Error dropping Milvus collection {self.namespace}: {e}" - ) - return {"status": "error", "message": str(e)} + logger.info( + f"[{self.workspace}] Process {os.getpid()} drop Milvus collection {self.namespace}" + ) + return {"status": "success", "message": "data dropped"} + except Exception as e: + logger.error( + f"[{self.workspace}] Error dropping Milvus collection {self.namespace}: {e}" + ) + return {"status": "error", "message": str(e)} diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index 30452c74..e11e6411 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -19,7 +19,7 @@ from ..base import ( from ..utils import logger, compute_mdhash_id from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from ..constants import GRAPH_FIELD_SEP -from ..kg.shared_storage import get_data_init_lock, get_storage_lock, get_graph_db_lock +from ..kg.shared_storage import get_data_init_lock import pipmaster as pm @@ -120,7 +120,7 @@ class MongoKVStorage(BaseKVStorage): else: # When workspace is empty, final_namespace equals original namespace self.final_namespace = self.namespace - self.workspace = "_" + self.workspace = "" logger.debug( f"[{self.workspace}] Final namespace (no workspace): '{self.namespace}'" ) @@ -138,11 +138,10 @@ class MongoKVStorage(BaseKVStorage): ) async def finalize(self): - async with get_storage_lock(): - if self.db is not None: - await ClientManager.release_client(self.db) - self.db = None - self._data = None + if self.db is not None: + await ClientManager.release_client(self.db) + self.db = None + self._data = None async def get_by_id(self, id: str) -> dict[str, Any] | None: # Unified handling for flattened keys @@ -263,23 +262,22 @@ class MongoKVStorage(BaseKVStorage): Returns: dict[str, str]: Status of the operation with keys 'status' and 'message' """ - async with get_storage_lock(): - try: - result = await self._data.delete_many({}) - deleted_count = result.deleted_count + try: + result = await self._data.delete_many({}) + deleted_count = result.deleted_count - logger.info( - f"[{self.workspace}] Dropped {deleted_count} documents from doc status {self._collection_name}" - ) - return { - "status": "success", - "message": f"{deleted_count} documents dropped", - } - except PyMongoError as e: - logger.error( - f"[{self.workspace}] Error dropping doc status {self._collection_name}: {e}" - ) - return {"status": "error", "message": str(e)} + logger.info( + f"[{self.workspace}] Dropped {deleted_count} documents from doc status {self._collection_name}" + ) + return { + "status": "success", + "message": f"{deleted_count} documents dropped", + } + except PyMongoError as e: + logger.error( + f"[{self.workspace}] Error dropping doc status {self._collection_name}: {e}" + ) + return {"status": "error", "message": str(e)} @final @@ -350,7 +348,7 @@ class MongoDocStatusStorage(DocStatusStorage): else: # When workspace is empty, final_namespace equals original namespace self.final_namespace = self.namespace - self.workspace = "_" + self.workspace = "" logger.debug(f"Final namespace (no workspace): '{self.final_namespace}'") self._collection_name = self.final_namespace @@ -370,11 +368,10 @@ class MongoDocStatusStorage(DocStatusStorage): ) async def finalize(self): - async with get_storage_lock(): - if self.db is not None: - await ClientManager.release_client(self.db) - self.db = None - self._data = None + if self.db is not None: + await ClientManager.release_client(self.db) + self.db = None + self._data = None async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: return await self._data.find_one({"_id": id}) @@ -484,23 +481,22 @@ class MongoDocStatusStorage(DocStatusStorage): Returns: dict[str, str]: Status of the operation with keys 'status' and 'message' """ - async with get_storage_lock(): - try: - result = await self._data.delete_many({}) - deleted_count = result.deleted_count + try: + result = await self._data.delete_many({}) + deleted_count = result.deleted_count - logger.info( - f"[{self.workspace}] Dropped {deleted_count} documents from doc status {self._collection_name}" - ) - return { - "status": "success", - "message": f"{deleted_count} documents dropped", - } - except PyMongoError as e: - logger.error( - f"[{self.workspace}] Error dropping doc status {self._collection_name}: {e}" - ) - return {"status": "error", "message": str(e)} + logger.info( + f"[{self.workspace}] Dropped {deleted_count} documents from doc status {self._collection_name}" + ) + return { + "status": "success", + "message": f"{deleted_count} documents dropped", + } + except PyMongoError as e: + logger.error( + f"[{self.workspace}] Error dropping doc status {self._collection_name}: {e}" + ) + return {"status": "error", "message": str(e)} async def delete(self, ids: list[str]) -> None: await self._data.delete_many({"_id": {"$in": ids}}) @@ -517,7 +513,7 @@ class MongoDocStatusStorage(DocStatusStorage): collation_config = {"locale": "zh", "numericOrdering": True} # Use workspace-specific index names to avoid cross-workspace conflicts - workspace_prefix = f"{self.workspace}_" if self.workspace != "_" else "" + workspace_prefix = f"{self.workspace}_" if self.workspace != "" else "" # 1. Define all indexes needed with workspace-specific names all_indexes = [ @@ -775,7 +771,7 @@ class MongoGraphStorage(BaseGraphStorage): else: # When workspace is empty, final_namespace equals original namespace self.final_namespace = self.namespace - self.workspace = "_" + self.workspace = "" logger.debug(f"Final namespace (no workspace): '{self.final_namespace}'") self._collection_name = self.final_namespace @@ -801,12 +797,11 @@ class MongoGraphStorage(BaseGraphStorage): ) async def finalize(self): - async with get_graph_db_lock(): - if self.db is not None: - await ClientManager.release_client(self.db) - self.db = None - self.collection = None - self.edge_collection = None + if self.db is not None: + await ClientManager.release_client(self.db) + self.db = None + self.collection = None + self.edge_collection = None # Sample entity document # "source_ids" is Array representation of "source_id" split by GRAPH_FIELD_SEP @@ -2015,30 +2010,29 @@ class MongoGraphStorage(BaseGraphStorage): Returns: dict[str, str]: Status of the operation with keys 'status' and 'message' """ - async with get_graph_db_lock(): - try: - result = await self.collection.delete_many({}) - deleted_count = result.deleted_count + try: + result = await self.collection.delete_many({}) + deleted_count = result.deleted_count - logger.info( - f"[{self.workspace}] Dropped {deleted_count} documents from graph {self._collection_name}" - ) + logger.info( + f"[{self.workspace}] Dropped {deleted_count} documents from graph {self._collection_name}" + ) - result = await self.edge_collection.delete_many({}) - edge_count = result.deleted_count - logger.info( - f"[{self.workspace}] Dropped {edge_count} edges from graph {self._edge_collection_name}" - ) + result = await self.edge_collection.delete_many({}) + edge_count = result.deleted_count + logger.info( + f"[{self.workspace}] Dropped {edge_count} edges from graph {self._edge_collection_name}" + ) - return { - "status": "success", - "message": f"{deleted_count} documents and {edge_count} edges dropped", - } - except PyMongoError as e: - logger.error( - f"[{self.workspace}] Error dropping graph {self._collection_name}: {e}" - ) - return {"status": "error", "message": str(e)} + return { + "status": "success", + "message": f"{deleted_count} documents and {edge_count} edges dropped", + } + except PyMongoError as e: + logger.error( + f"[{self.workspace}] Error dropping graph {self._collection_name}: {e}" + ) + return {"status": "error", "message": str(e)} @final @@ -2089,7 +2083,7 @@ class MongoVectorDBStorage(BaseVectorStorage): else: # When workspace is empty, final_namespace equals original namespace self.final_namespace = self.namespace - self.workspace = "_" + self.workspace = "" logger.debug(f"Final namespace (no workspace): '{self.final_namespace}'") # Set index name based on workspace for backward compatibility @@ -2125,11 +2119,10 @@ class MongoVectorDBStorage(BaseVectorStorage): ) async def finalize(self): - async with get_storage_lock(): - if self.db is not None: - await ClientManager.release_client(self.db) - self.db = None - self._data = None + if self.db is not None: + await ClientManager.release_client(self.db) + self.db = None + self._data = None async def create_vector_index_if_not_exists(self): """Creates an Atlas Vector Search index.""" @@ -2452,27 +2445,26 @@ class MongoVectorDBStorage(BaseVectorStorage): Returns: dict[str, str]: Status of the operation with keys 'status' and 'message' """ - async with get_storage_lock(): - try: - # Delete all documents - result = await self._data.delete_many({}) - deleted_count = result.deleted_count + try: + # Delete all documents + result = await self._data.delete_many({}) + deleted_count = result.deleted_count - # Recreate vector index - await self.create_vector_index_if_not_exists() + # Recreate vector index + await self.create_vector_index_if_not_exists() - logger.info( - f"[{self.workspace}] Dropped {deleted_count} documents from vector storage {self._collection_name} and recreated vector index" - ) - return { - "status": "success", - "message": f"{deleted_count} documents dropped and vector index recreated", - } - except PyMongoError as e: - logger.error( - f"[{self.workspace}] Error dropping vector storage {self._collection_name}: {e}" - ) - return {"status": "error", "message": str(e)} + logger.info( + f"[{self.workspace}] Dropped {deleted_count} documents from vector storage {self._collection_name} and recreated vector index" + ) + return { + "status": "success", + "message": f"{deleted_count} documents dropped and vector index recreated", + } + except PyMongoError as e: + logger.error( + f"[{self.workspace}] Error dropping vector storage {self._collection_name}: {e}" + ) + return {"status": "error", "message": str(e)} async def get_or_create_collection(db: AsyncDatabase, collection_name: str): diff --git a/lightrag/kg/nano_vector_db_impl.py b/lightrag/kg/nano_vector_db_impl.py index 1185241c..d390c37b 100644 --- a/lightrag/kg/nano_vector_db_impl.py +++ b/lightrag/kg/nano_vector_db_impl.py @@ -15,7 +15,7 @@ from lightrag.utils import ( from lightrag.base import BaseVectorStorage from nano_vectordb import NanoVectorDB from .shared_storage import ( - get_storage_lock, + get_namespace_lock, get_update_flag, set_all_update_flags, ) @@ -47,7 +47,7 @@ class NanoVectorDBStorage(BaseVectorStorage): else: # Default behavior when workspace is empty self.final_namespace = self.namespace - self.workspace = "_" + self.workspace = "" workspace_dir = working_dir os.makedirs(workspace_dir, exist_ok=True) @@ -65,9 +65,13 @@ class NanoVectorDBStorage(BaseVectorStorage): async def initialize(self): """Initialize storage data""" # Get the update flag for cross-process update notification - self.storage_updated = await get_update_flag(self.final_namespace) + self.storage_updated = await get_update_flag( + self.namespace, workspace=self.workspace + ) # Get the storage lock for use in other methods - self._storage_lock = get_storage_lock(enable_logging=False) + self._storage_lock = get_namespace_lock( + self.namespace, workspace=self.workspace + ) async def _get_client(self): """Check if the storage should be reloaded""" @@ -288,7 +292,7 @@ class NanoVectorDBStorage(BaseVectorStorage): # Save data to disk self._client.save() # Notify other processes that data has been updated - await set_all_update_flags(self.final_namespace) + await set_all_update_flags(self.namespace, workspace=self.workspace) # Reset own update flag to avoid self-reloading self.storage_updated.value = False return True # Return success @@ -410,7 +414,7 @@ class NanoVectorDBStorage(BaseVectorStorage): ) # Notify other processes that data has been updated - await set_all_update_flags(self.final_namespace) + await set_all_update_flags(self.namespace, workspace=self.workspace) # Reset own update flag to avoid self-reloading self.storage_updated.value = False diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index 31df4623..256656d8 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -16,7 +16,7 @@ import logging from ..utils import logger from ..base import BaseGraphStorage from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge -from ..kg.shared_storage import get_data_init_lock, get_graph_db_lock +from ..kg.shared_storage import get_data_init_lock import pipmaster as pm if not pm.is_installed("neo4j"): @@ -340,10 +340,9 @@ class Neo4JStorage(BaseGraphStorage): async def finalize(self): """Close the Neo4j driver and release all resources""" - async with get_graph_db_lock(): - if self._driver: - await self._driver.close() - self._driver = None + if self._driver: + await self._driver.close() + self._driver = None async def __aexit__(self, exc_type, exc, tb): """Ensure driver is closed when context manager exits""" @@ -1773,24 +1772,23 @@ class Neo4JStorage(BaseGraphStorage): - On success: {"status": "success", "message": "workspace data dropped"} - On failure: {"status": "error", "message": ""} """ - async with get_graph_db_lock(): - workspace_label = self._get_workspace_label() - try: - async with self._driver.session(database=self._DATABASE) as session: - # Delete all nodes and relationships in current workspace only - query = f"MATCH (n:`{workspace_label}`) DETACH DELETE n" - result = await session.run(query) - await result.consume() # Ensure result is fully consumed + workspace_label = self._get_workspace_label() + try: + async with self._driver.session(database=self._DATABASE) as session: + # Delete all nodes and relationships in current workspace only + query = f"MATCH (n:`{workspace_label}`) DETACH DELETE n" + result = await session.run(query) + await result.consume() # Ensure result is fully consumed - # logger.debug( - # f"[{self.workspace}] Process {os.getpid()} drop Neo4j workspace '{workspace_label}' in database {self._DATABASE}" - # ) - return { - "status": "success", - "message": f"workspace '{workspace_label}' data dropped", - } - except Exception as e: - logger.error( - f"[{self.workspace}] Error dropping Neo4j workspace '{workspace_label}' in database {self._DATABASE}: {e}" - ) - return {"status": "error", "message": str(e)} + # logger.debug( + # f"[{self.workspace}] Process {os.getpid()} drop Neo4j workspace '{workspace_label}' in database {self._DATABASE}" + # ) + return { + "status": "success", + "message": f"workspace '{workspace_label}' data dropped", + } + except Exception as e: + logger.error( + f"[{self.workspace}] Error dropping Neo4j workspace '{workspace_label}' in database {self._DATABASE}: {e}" + ) + return {"status": "error", "message": str(e)} diff --git a/lightrag/kg/networkx_impl.py b/lightrag/kg/networkx_impl.py index 48a2d2af..145b9c01 100644 --- a/lightrag/kg/networkx_impl.py +++ b/lightrag/kg/networkx_impl.py @@ -7,7 +7,7 @@ from lightrag.utils import logger from lightrag.base import BaseGraphStorage import networkx as nx from .shared_storage import ( - get_storage_lock, + get_namespace_lock, get_update_flag, set_all_update_flags, ) @@ -41,12 +41,10 @@ class NetworkXStorage(BaseGraphStorage): if self.workspace: # Include workspace in the file path for data isolation workspace_dir = os.path.join(working_dir, self.workspace) - self.final_namespace = f"{self.workspace}_{self.namespace}" else: # Default behavior when workspace is empty - self.final_namespace = self.namespace workspace_dir = working_dir - self.workspace = "_" + self.workspace = "" os.makedirs(workspace_dir, exist_ok=True) self._graphml_xml_file = os.path.join( @@ -71,9 +69,13 @@ class NetworkXStorage(BaseGraphStorage): async def initialize(self): """Initialize storage data""" # Get the update flag for cross-process update notification - self.storage_updated = await get_update_flag(self.final_namespace) + self.storage_updated = await get_update_flag( + self.namespace, workspace=self.workspace + ) # Get the storage lock for use in other methods - self._storage_lock = get_storage_lock() + self._storage_lock = get_namespace_lock( + self.namespace, workspace=self.workspace + ) async def _get_graph(self): """Check if the storage should be reloaded""" @@ -522,7 +524,7 @@ class NetworkXStorage(BaseGraphStorage): self._graph, self._graphml_xml_file, self.workspace ) # Notify other processes that data has been updated - await set_all_update_flags(self.final_namespace) + await set_all_update_flags(self.namespace, workspace=self.workspace) # Reset own update flag to avoid self-reloading self.storage_updated.value = False return True # Return success @@ -553,7 +555,7 @@ class NetworkXStorage(BaseGraphStorage): os.remove(self._graphml_xml_file) self._graph = nx.Graph() # Notify other processes that data has been updated - await set_all_update_flags(self.final_namespace) + await set_all_update_flags(self.namespace, workspace=self.workspace) # Reset own update flag to avoid self-reloading self.storage_updated.value = False logger.info( diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index d043176e..62078459 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -33,7 +33,7 @@ from ..base import ( ) from ..namespace import NameSpace, is_namespace from ..utils import logger -from ..kg.shared_storage import get_data_init_lock, get_graph_db_lock, get_storage_lock +from ..kg.shared_storage import get_data_init_lock import pipmaster as pm @@ -1702,10 +1702,9 @@ class PGKVStorage(BaseKVStorage): self.workspace = "default" async def finalize(self): - async with get_storage_lock(): - if self.db is not None: - await ClientManager.release_client(self.db) - self.db = None + if self.db is not None: + await ClientManager.release_client(self.db) + self.db = None ################ QUERY METHODS ################ async def get_by_id(self, id: str) -> dict[str, Any] | None: @@ -2147,22 +2146,21 @@ class PGKVStorage(BaseKVStorage): async def drop(self) -> dict[str, str]: """Drop the storage""" - async with get_storage_lock(): - try: - table_name = namespace_to_table_name(self.namespace) - if not table_name: - return { - "status": "error", - "message": f"Unknown namespace: {self.namespace}", - } + try: + table_name = namespace_to_table_name(self.namespace) + if not table_name: + return { + "status": "error", + "message": f"Unknown namespace: {self.namespace}", + } - drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format( - table_name=table_name - ) - await self.db.execute(drop_sql, {"workspace": self.workspace}) - return {"status": "success", "message": "data dropped"} - except Exception as e: - return {"status": "error", "message": str(e)} + drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format( + table_name=table_name + ) + await self.db.execute(drop_sql, {"workspace": self.workspace}) + return {"status": "success", "message": "data dropped"} + except Exception as e: + return {"status": "error", "message": str(e)} @final @@ -2197,10 +2195,9 @@ class PGVectorStorage(BaseVectorStorage): self.workspace = "default" async def finalize(self): - async with get_storage_lock(): - if self.db is not None: - await ClientManager.release_client(self.db) - self.db = None + if self.db is not None: + await ClientManager.release_client(self.db) + self.db = None def _upsert_chunks( self, item: dict[str, Any], current_time: datetime.datetime @@ -2536,22 +2533,21 @@ class PGVectorStorage(BaseVectorStorage): async def drop(self) -> dict[str, str]: """Drop the storage""" - async with get_storage_lock(): - try: - table_name = namespace_to_table_name(self.namespace) - if not table_name: - return { - "status": "error", - "message": f"Unknown namespace: {self.namespace}", - } + try: + table_name = namespace_to_table_name(self.namespace) + if not table_name: + return { + "status": "error", + "message": f"Unknown namespace: {self.namespace}", + } - drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format( - table_name=table_name - ) - await self.db.execute(drop_sql, {"workspace": self.workspace}) - return {"status": "success", "message": "data dropped"} - except Exception as e: - return {"status": "error", "message": str(e)} + drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format( + table_name=table_name + ) + await self.db.execute(drop_sql, {"workspace": self.workspace}) + return {"status": "success", "message": "data dropped"} + except Exception as e: + return {"status": "error", "message": str(e)} @final @@ -2586,10 +2582,9 @@ class PGDocStatusStorage(DocStatusStorage): self.workspace = "default" async def finalize(self): - async with get_storage_lock(): - if self.db is not None: - await ClientManager.release_client(self.db) - self.db = None + if self.db is not None: + await ClientManager.release_client(self.db) + self.db = None async def filter_keys(self, keys: set[str]) -> set[str]: """Filter out duplicated content""" @@ -3164,22 +3159,21 @@ class PGDocStatusStorage(DocStatusStorage): async def drop(self) -> dict[str, str]: """Drop the storage""" - async with get_storage_lock(): - try: - table_name = namespace_to_table_name(self.namespace) - if not table_name: - return { - "status": "error", - "message": f"Unknown namespace: {self.namespace}", - } + try: + table_name = namespace_to_table_name(self.namespace) + if not table_name: + return { + "status": "error", + "message": f"Unknown namespace: {self.namespace}", + } - drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format( - table_name=table_name - ) - await self.db.execute(drop_sql, {"workspace": self.workspace}) - return {"status": "success", "message": "data dropped"} - except Exception as e: - return {"status": "error", "message": str(e)} + drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format( + table_name=table_name + ) + await self.db.execute(drop_sql, {"workspace": self.workspace}) + return {"status": "success", "message": "data dropped"} + except Exception as e: + return {"status": "error", "message": str(e)} class PGGraphQueryException(Exception): @@ -3311,10 +3305,9 @@ class PGGraphStorage(BaseGraphStorage): ) async def finalize(self): - async with get_graph_db_lock(): - if self.db is not None: - await ClientManager.release_client(self.db) - self.db = None + if self.db is not None: + await ClientManager.release_client(self.db) + self.db = None async def index_done_callback(self) -> None: # PG handles persistence automatically @@ -4714,21 +4707,20 @@ class PGGraphStorage(BaseGraphStorage): async def drop(self) -> dict[str, str]: """Drop the storage""" - async with get_graph_db_lock(): - try: - drop_query = f"""SELECT * FROM cypher('{self.graph_name}', $$ - MATCH (n) - DETACH DELETE n - $$) AS (result agtype)""" + try: + drop_query = f"""SELECT * FROM cypher('{self.graph_name}', $$ + MATCH (n) + DETACH DELETE n + $$) AS (result agtype)""" - await self._query(drop_query, readonly=False) - return { - "status": "success", - "message": f"workspace '{self.workspace}' graph data dropped", - } - except Exception as e: - logger.error(f"[{self.workspace}] Error dropping graph: {e}") - return {"status": "error", "message": str(e)} + await self._query(drop_query, readonly=False) + return { + "status": "success", + "message": f"workspace '{self.workspace}' graph data dropped", + } + except Exception as e: + logger.error(f"[{self.workspace}] Error dropping graph: {e}") + return {"status": "error", "message": str(e)} # Note: Order matters! More specific namespaces (e.g., "full_entities") must come before diff --git a/lightrag/kg/qdrant_impl.py b/lightrag/kg/qdrant_impl.py index d51d8898..75de2613 100644 --- a/lightrag/kg/qdrant_impl.py +++ b/lightrag/kg/qdrant_impl.py @@ -11,7 +11,7 @@ import pipmaster as pm from ..base import BaseVectorStorage from ..exceptions import QdrantMigrationError -from ..kg.shared_storage import get_data_init_lock, get_storage_lock +from ..kg.shared_storage import get_data_init_lock from ..utils import compute_mdhash_id, logger if not pm.is_installed("qdrant-client"): @@ -698,25 +698,25 @@ class QdrantVectorDBStorage(BaseVectorStorage): - On success: {"status": "success", "message": "data dropped"} - On failure: {"status": "error", "message": ""} """ - async with get_storage_lock(): - try: - # Delete all points for the current workspace - self._client.delete( - collection_name=self.final_namespace, - points_selector=models.FilterSelector( - filter=models.Filter( - must=[workspace_filter_condition(self.effective_workspace)] - ) - ), - wait=True, - ) + # No need to lock: data integrity is ensured by allowing only one process to hold pipeline at a time + try: + # Delete all points for the current workspace + self._client.delete( + collection_name=self.final_namespace, + points_selector=models.FilterSelector( + filter=models.Filter( + must=[workspace_filter_condition(self.effective_workspace)] + ) + ), + wait=True, + ) - logger.info( - f"[{self.workspace}] Process {os.getpid()} dropped workspace data from Qdrant collection {self.namespace}" - ) - return {"status": "success", "message": "data dropped"} - except Exception as e: - logger.error( - f"[{self.workspace}] Error dropping workspace data from Qdrant collection {self.namespace}: {e}" - ) - return {"status": "error", "message": str(e)} + logger.info( + f"[{self.workspace}] Process {os.getpid()} dropped workspace data from Qdrant collection {self.namespace}" + ) + return {"status": "success", "message": "data dropped"} + except Exception as e: + logger.error( + f"[{self.workspace}] Error dropping workspace data from Qdrant collection {self.namespace}: {e}" + ) + return {"status": "error", "message": str(e)} diff --git a/lightrag/kg/redis_impl.py b/lightrag/kg/redis_impl.py index 2e9a7d43..a254d4ee 100644 --- a/lightrag/kg/redis_impl.py +++ b/lightrag/kg/redis_impl.py @@ -21,7 +21,7 @@ from lightrag.base import ( DocStatus, DocProcessingStatus, ) -from ..kg.shared_storage import get_data_init_lock, get_storage_lock +from ..kg.shared_storage import get_data_init_lock import json # Import tenacity for retry logic @@ -153,7 +153,7 @@ class RedisKVStorage(BaseKVStorage): else: # When workspace is empty, final_namespace equals original namespace self.final_namespace = self.namespace - self.workspace = "_" + self.workspace = "" logger.debug(f"Final namespace (no workspace): '{self.final_namespace}'") self._redis_url = os.environ.get( @@ -401,42 +401,39 @@ class RedisKVStorage(BaseKVStorage): Returns: dict[str, str]: Status of the operation with keys 'status' and 'message' """ - async with get_storage_lock(): - async with self._get_redis_connection() as redis: - try: - # Use SCAN to find all keys with the namespace prefix - pattern = f"{self.final_namespace}:*" - cursor = 0 - deleted_count = 0 + async with self._get_redis_connection() as redis: + try: + # Use SCAN to find all keys with the namespace prefix + pattern = f"{self.final_namespace}:*" + cursor = 0 + deleted_count = 0 - while True: - cursor, keys = await redis.scan( - cursor, match=pattern, count=1000 - ) - if keys: - # Delete keys in batches - pipe = redis.pipeline() - for key in keys: - pipe.delete(key) - results = await pipe.execute() - deleted_count += sum(results) + while True: + cursor, keys = await redis.scan(cursor, match=pattern, count=1000) + if keys: + # Delete keys in batches + pipe = redis.pipeline() + for key in keys: + pipe.delete(key) + results = await pipe.execute() + deleted_count += sum(results) - if cursor == 0: - break + if cursor == 0: + break - logger.info( - f"[{self.workspace}] Dropped {deleted_count} keys from {self.namespace}" - ) - return { - "status": "success", - "message": f"{deleted_count} keys dropped", - } + logger.info( + f"[{self.workspace}] Dropped {deleted_count} keys from {self.namespace}" + ) + return { + "status": "success", + "message": f"{deleted_count} keys dropped", + } - except Exception as e: - logger.error( - f"[{self.workspace}] Error dropping keys from {self.namespace}: {e}" - ) - return {"status": "error", "message": str(e)} + except Exception as e: + logger.error( + f"[{self.workspace}] Error dropping keys from {self.namespace}: {e}" + ) + return {"status": "error", "message": str(e)} async def _migrate_legacy_cache_structure(self): """Migrate legacy nested cache structure to flattened structure for Redis @@ -1091,35 +1088,32 @@ class RedisDocStatusStorage(DocStatusStorage): async def drop(self) -> dict[str, str]: """Drop all document status data from storage and clean up resources""" - async with get_storage_lock(): - try: - async with self._get_redis_connection() as redis: - # Use SCAN to find all keys with the namespace prefix - pattern = f"{self.final_namespace}:*" - cursor = 0 - deleted_count = 0 + try: + async with self._get_redis_connection() as redis: + # Use SCAN to find all keys with the namespace prefix + pattern = f"{self.final_namespace}:*" + cursor = 0 + deleted_count = 0 - while True: - cursor, keys = await redis.scan( - cursor, match=pattern, count=1000 - ) - if keys: - # Delete keys in batches - pipe = redis.pipeline() - for key in keys: - pipe.delete(key) - results = await pipe.execute() - deleted_count += sum(results) + while True: + cursor, keys = await redis.scan(cursor, match=pattern, count=1000) + if keys: + # Delete keys in batches + pipe = redis.pipeline() + for key in keys: + pipe.delete(key) + results = await pipe.execute() + deleted_count += sum(results) - if cursor == 0: - break + if cursor == 0: + break - logger.info( - f"[{self.workspace}] Dropped {deleted_count} doc status keys from {self.namespace}" - ) - return {"status": "success", "message": "data dropped"} - except Exception as e: - logger.error( - f"[{self.workspace}] Error dropping doc status {self.namespace}: {e}" + logger.info( + f"[{self.workspace}] Dropped {deleted_count} doc status keys from {self.namespace}" ) - return {"status": "error", "message": str(e)} + return {"status": "success", "message": "data dropped"} + except Exception as e: + logger.error( + f"[{self.workspace}] Error dropping doc status {self.namespace}: {e}" + ) + return {"status": "error", "message": str(e)} diff --git a/lightrag/kg/shared_storage.py b/lightrag/kg/shared_storage.py index 0abcf719..834cdc8f 100644 --- a/lightrag/kg/shared_storage.py +++ b/lightrag/kg/shared_storage.py @@ -6,6 +6,7 @@ from multiprocessing.synchronize import Lock as ProcessLock from multiprocessing import Manager import time import logging +from contextvars import ContextVar from typing import Any, Dict, List, Optional, Union, TypeVar, Generic from lightrag.exceptions import PipelineNotInitializedError @@ -75,16 +76,16 @@ _last_mp_cleanup_time: Optional[float] = None _initialized = None +# Default workspace for backward compatibility +_default_workspace: Optional[str] = None + # shared data for storage across processes _shared_dicts: Optional[Dict[str, Any]] = None _init_flags: Optional[Dict[str, bool]] = None # namespace -> initialized _update_flags: Optional[Dict[str, bool]] = None # namespace -> updated # locks for mutex access -_storage_lock: Optional[LockType] = None _internal_lock: Optional[LockType] = None -_pipeline_status_lock: Optional[LockType] = None -_graph_db_lock: Optional[LockType] = None _data_init_lock: Optional[LockType] = None # Manager for all keyed locks _storage_keyed_lock: Optional["KeyedUnifiedLock"] = None @@ -95,6 +96,22 @@ _async_locks: Optional[Dict[str, asyncio.Lock]] = None _debug_n_locks_acquired: int = 0 +def get_final_namespace(namespace: str, workspace: str | None = None): + global _default_workspace + if workspace is None: + workspace = _default_workspace + + if workspace is None: + direct_log( + f"Error: Invoke namespace operation without workspace, pid={os.getpid()}", + level="ERROR", + ) + raise ValueError("Invoke namespace operation without workspace") + + final_namespace = f"{workspace}:{namespace}" if workspace else f"{namespace}" + return final_namespace + + def inc_debug_n_locks_acquired(): global _debug_n_locks_acquired if DEBUG_LOCKS: @@ -1053,40 +1070,10 @@ def get_internal_lock(enable_logging: bool = False) -> UnifiedLock: ) -def get_storage_lock(enable_logging: bool = False) -> UnifiedLock: - """return unified storage lock for data consistency""" - async_lock = _async_locks.get("storage_lock") if _is_multiprocess else None - return UnifiedLock( - lock=_storage_lock, - is_async=not _is_multiprocess, - name="storage_lock", - enable_logging=enable_logging, - async_lock=async_lock, - ) - - -def get_pipeline_status_lock(enable_logging: bool = False) -> UnifiedLock: - """return unified storage lock for data consistency""" - async_lock = _async_locks.get("pipeline_status_lock") if _is_multiprocess else None - return UnifiedLock( - lock=_pipeline_status_lock, - is_async=not _is_multiprocess, - name="pipeline_status_lock", - enable_logging=enable_logging, - async_lock=async_lock, - ) - - -def get_graph_db_lock(enable_logging: bool = False) -> UnifiedLock: - """return unified graph database lock for ensuring atomic operations""" - async_lock = _async_locks.get("graph_db_lock") if _is_multiprocess else None - return UnifiedLock( - lock=_graph_db_lock, - is_async=not _is_multiprocess, - name="graph_db_lock", - enable_logging=enable_logging, - async_lock=async_lock, - ) +# Workspace based storage_lock is implemented by get_storage_keyed_lock instead. +# Workspace based pipeline_status_lock is implemented by get_storage_keyed_lock instead. +# No need to implement graph_db_lock: +# data integrity is ensured by entity level keyed-lock and allowing only one process to hold pipeline at a time. def get_storage_keyed_lock( @@ -1190,14 +1177,11 @@ def initialize_share_data(workers: int = 1): _manager, \ _workers, \ _is_multiprocess, \ - _storage_lock, \ _lock_registry, \ _lock_registry_count, \ _lock_cleanup_data, \ _registry_guard, \ _internal_lock, \ - _pipeline_status_lock, \ - _graph_db_lock, \ _data_init_lock, \ _shared_dicts, \ _init_flags, \ @@ -1225,9 +1209,6 @@ def initialize_share_data(workers: int = 1): _lock_cleanup_data = _manager.dict() _registry_guard = _manager.RLock() _internal_lock = _manager.Lock() - _storage_lock = _manager.Lock() - _pipeline_status_lock = _manager.Lock() - _graph_db_lock = _manager.Lock() _data_init_lock = _manager.Lock() _shared_dicts = _manager.dict() _init_flags = _manager.dict() @@ -1238,8 +1219,6 @@ def initialize_share_data(workers: int = 1): # Initialize async locks for multiprocess mode _async_locks = { "internal_lock": asyncio.Lock(), - "storage_lock": asyncio.Lock(), - "pipeline_status_lock": asyncio.Lock(), "graph_db_lock": asyncio.Lock(), "data_init_lock": asyncio.Lock(), } @@ -1250,9 +1229,6 @@ def initialize_share_data(workers: int = 1): else: _is_multiprocess = False _internal_lock = asyncio.Lock() - _storage_lock = asyncio.Lock() - _pipeline_status_lock = asyncio.Lock() - _graph_db_lock = asyncio.Lock() _data_init_lock = asyncio.Lock() _shared_dicts = {} _init_flags = {} @@ -1270,12 +1246,19 @@ def initialize_share_data(workers: int = 1): _initialized = True -async def initialize_pipeline_status(): +async def initialize_pipeline_status(workspace: str | None = None): """ - Initialize pipeline namespace with default values. - This function is called during FASTAPI lifespan for each worker. + Initialize pipeline_status share data with default values. + This function could be called before during FASTAPI lifespan for each worker. + + Args: + workspace: Optional workspace identifier for pipeline_status of specific workspace. + If None or empty string, uses the default workspace set by + set_default_workspace(). """ - pipeline_namespace = await get_namespace_data("pipeline_status", first_init=True) + pipeline_namespace = await get_namespace_data( + "pipeline_status", first_init=True, workspace=workspace + ) async with get_internal_lock(): # Check if already initialized by checking for required fields @@ -1298,10 +1281,14 @@ async def initialize_pipeline_status(): "history_messages": history_messages, # 使用共享列表对象 } ) - direct_log(f"Process {os.getpid()} Pipeline namespace initialized") + + final_namespace = get_final_namespace("pipeline_status", workspace) + direct_log( + f"Process {os.getpid()} Pipeline namespace '{final_namespace}' initialized" + ) -async def get_update_flag(namespace: str): +async def get_update_flag(namespace: str, workspace: str | None = None): """ Create a namespace's update flag for a workers. Returen the update flag to caller for referencing or reset. @@ -1310,14 +1297,16 @@ async def get_update_flag(namespace: str): if _update_flags is None: raise ValueError("Try to create namespace before Shared-Data is initialized") + final_namespace = get_final_namespace(namespace, workspace) + async with get_internal_lock(): - if namespace not in _update_flags: + if final_namespace not in _update_flags: if _is_multiprocess and _manager is not None: - _update_flags[namespace] = _manager.list() + _update_flags[final_namespace] = _manager.list() else: - _update_flags[namespace] = [] + _update_flags[final_namespace] = [] direct_log( - f"Process {os.getpid()} initialized updated flags for namespace: [{namespace}]" + f"Process {os.getpid()} initialized updated flags for namespace: [{final_namespace}]" ) if _is_multiprocess and _manager is not None: @@ -1330,39 +1319,43 @@ async def get_update_flag(namespace: str): new_update_flag = MutableBoolean(False) - _update_flags[namespace].append(new_update_flag) + _update_flags[final_namespace].append(new_update_flag) return new_update_flag -async def set_all_update_flags(namespace: str): +async def set_all_update_flags(namespace: str, workspace: str | None = None): """Set all update flag of namespace indicating all workers need to reload data from files""" global _update_flags if _update_flags is None: raise ValueError("Try to create namespace before Shared-Data is initialized") + final_namespace = get_final_namespace(namespace, workspace) + async with get_internal_lock(): - if namespace not in _update_flags: - raise ValueError(f"Namespace {namespace} not found in update flags") + if final_namespace not in _update_flags: + raise ValueError(f"Namespace {final_namespace} not found in update flags") # Update flags for both modes - for i in range(len(_update_flags[namespace])): - _update_flags[namespace][i].value = True + for i in range(len(_update_flags[final_namespace])): + _update_flags[final_namespace][i].value = True -async def clear_all_update_flags(namespace: str): +async def clear_all_update_flags(namespace: str, workspace: str | None = None): """Clear all update flag of namespace indicating all workers need to reload data from files""" global _update_flags if _update_flags is None: raise ValueError("Try to create namespace before Shared-Data is initialized") + final_namespace = get_final_namespace(namespace, workspace) + async with get_internal_lock(): - if namespace not in _update_flags: - raise ValueError(f"Namespace {namespace} not found in update flags") + if final_namespace not in _update_flags: + raise ValueError(f"Namespace {final_namespace} not found in update flags") # Update flags for both modes - for i in range(len(_update_flags[namespace])): - _update_flags[namespace][i].value = False + for i in range(len(_update_flags[final_namespace])): + _update_flags[final_namespace][i].value = False -async def get_all_update_flags_status() -> Dict[str, list]: +async def get_all_update_flags_status(workspace: str | None = None) -> Dict[str, list]: """ Get update flags status for all namespaces. @@ -1372,9 +1365,26 @@ async def get_all_update_flags_status() -> Dict[str, list]: if _update_flags is None: return {} + if workspace is None: + workspace = get_default_workspace() + result = {} async with get_internal_lock(): for namespace, flags in _update_flags.items(): + # Check if namespace has a workspace prefix (contains ':') + if ":" in namespace: + # Namespace has workspace prefix like "space1:pipeline_status" + # Only include if workspace matches the prefix + # Use rsplit to split from the right since workspace can contain colons + namespace_split = namespace.rsplit(":", 1) + if not workspace or namespace_split[0] != workspace: + continue + else: + # Namespace has no workspace prefix like "pipeline_status" + # Only include if we're querying the default (empty) workspace + if workspace: + continue + worker_statuses = [] for flag in flags: if _is_multiprocess: @@ -1386,7 +1396,9 @@ async def get_all_update_flags_status() -> Dict[str, list]: return result -async def try_initialize_namespace(namespace: str) -> bool: +async def try_initialize_namespace( + namespace: str, workspace: str | None = None +) -> bool: """ Returns True if the current worker(process) gets initialization permission for loading data later. The worker does not get the permission is prohibited to load data from files. @@ -1396,52 +1408,161 @@ async def try_initialize_namespace(namespace: str) -> bool: if _init_flags is None: raise ValueError("Try to create nanmespace before Shared-Data is initialized") + final_namespace = get_final_namespace(namespace, workspace) + async with get_internal_lock(): - if namespace not in _init_flags: - _init_flags[namespace] = True + if final_namespace not in _init_flags: + _init_flags[final_namespace] = True direct_log( - f"Process {os.getpid()} ready to initialize storage namespace: [{namespace}]" + f"Process {os.getpid()} ready to initialize storage namespace: [{final_namespace}]" ) return True direct_log( - f"Process {os.getpid()} storage namespace already initialized: [{namespace}]" + f"Process {os.getpid()} storage namespace already initialized: [{final_namespace}]" ) return False async def get_namespace_data( - namespace: str, first_init: bool = False + namespace: str, first_init: bool = False, workspace: str | None = None ) -> Dict[str, Any]: """get the shared data reference for specific namespace Args: namespace: The namespace to retrieve - allow_create: If True, allows creation of the namespace if it doesn't exist. - Used internally by initialize_pipeline_status(). + first_init: If True, allows pipeline_status namespace to create namespace if it doesn't exist. + Prevent getting pipeline_status namespace without initialize_pipeline_status(). + This parameter is used internally by initialize_pipeline_status(). + workspace: Workspace identifier (may be empty string for global namespace) """ if _shared_dicts is None: direct_log( - f"Error: try to getnanmespace before it is initialized, pid={os.getpid()}", + f"Error: Try to getnanmespace before it is initialized, pid={os.getpid()}", level="ERROR", ) raise ValueError("Shared dictionaries not initialized") + final_namespace = get_final_namespace(namespace, workspace) + async with get_internal_lock(): - if namespace not in _shared_dicts: + if final_namespace not in _shared_dicts: # Special handling for pipeline_status namespace - if namespace == "pipeline_status" and not first_init: + if ( + final_namespace.endswith(":pipeline_status") + or final_namespace == "pipeline_status" + ) and not first_init: # Check if pipeline_status should have been initialized but wasn't - # This helps users understand they need to call initialize_pipeline_status() - raise PipelineNotInitializedError(namespace) + # This helps users to call initialize_pipeline_status() before get_namespace_data() + raise PipelineNotInitializedError(final_namespace) # For other namespaces or when allow_create=True, create them dynamically if _is_multiprocess and _manager is not None: - _shared_dicts[namespace] = _manager.dict() + _shared_dicts[final_namespace] = _manager.dict() else: - _shared_dicts[namespace] = {} + _shared_dicts[final_namespace] = {} - return _shared_dicts[namespace] + return _shared_dicts[final_namespace] + + +class NamespaceLock: + """ + Reusable namespace lock wrapper that creates a fresh context on each use. + + This class solves the lock re-entrance and concurrent coroutine issues by using + contextvars.ContextVar to provide per-coroutine storage. Each coroutine gets its + own independent lock context, preventing state interference between concurrent + coroutines using the same NamespaceLock instance. + + Example: + lock = NamespaceLock("my_namespace", "workspace1") + + # Can be used multiple times safely + async with lock: + await do_something() + + # Can even be used concurrently without deadlock + await asyncio.gather( + coroutine_1(lock), # Each gets its own context + coroutine_2(lock) # No state interference + ) + """ + + def __init__( + self, namespace: str, workspace: str | None = None, enable_logging: bool = False + ): + self._namespace = namespace + self._workspace = workspace + self._enable_logging = enable_logging + # Use ContextVar to provide per-coroutine storage for lock context + # This ensures each coroutine has its own independent context + self._ctx_var: ContextVar[Optional[_KeyedLockContext]] = ContextVar( + "lock_ctx", default=None + ) + + async def __aenter__(self): + """Create a fresh context each time we enter""" + # Check if this coroutine already has an active lock context + if self._ctx_var.get() is not None: + raise RuntimeError( + "NamespaceLock already acquired in current coroutine context" + ) + + final_namespace = get_final_namespace(self._namespace, self._workspace) + ctx = get_storage_keyed_lock( + ["default_key"], + namespace=final_namespace, + enable_logging=self._enable_logging, + ) + + # Acquire the lock first, then store context only after successful acquisition + # This prevents the ContextVar from being set if acquisition fails (e.g., due to cancellation), + # which would permanently brick the lock + result = await ctx.__aenter__() + self._ctx_var.set(ctx) + return result + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Exit the current context and clean up""" + # Retrieve this coroutine's context + ctx = self._ctx_var.get() + if ctx is None: + raise RuntimeError("NamespaceLock exited without being entered") + + result = await ctx.__aexit__(exc_type, exc_val, exc_tb) + # Clear this coroutine's context + self._ctx_var.set(None) + return result + + +def get_namespace_lock( + namespace: str, workspace: str | None = None, enable_logging: bool = False +) -> NamespaceLock: + """Get a reusable namespace lock wrapper. + + This function returns a NamespaceLock instance that can be used multiple times + safely, even in concurrent scenarios. Each use creates a fresh lock context + internally, preventing lock re-entrance errors. + + Args: + namespace: The namespace to get the lock for. + workspace: Workspace identifier (may be empty string for global namespace) + enable_logging: Whether to enable lock operation logging + + Returns: + NamespaceLock: A reusable lock wrapper that can be used with 'async with' + + Example: + lock = get_namespace_lock("pipeline_status", workspace="space1") + + # Can be used multiple times + async with lock: + await do_something() + + async with lock: + await do_something_else() + """ + return NamespaceLock(namespace, workspace, enable_logging) def finalize_share_data(): @@ -1457,16 +1578,14 @@ def finalize_share_data(): global \ _manager, \ _is_multiprocess, \ - _storage_lock, \ _internal_lock, \ - _pipeline_status_lock, \ - _graph_db_lock, \ _data_init_lock, \ _shared_dicts, \ _init_flags, \ _initialized, \ _update_flags, \ - _async_locks + _async_locks, \ + _default_workspace # Check if already initialized if not _initialized: @@ -1525,12 +1644,42 @@ def finalize_share_data(): _is_multiprocess = None _shared_dicts = None _init_flags = None - _storage_lock = None _internal_lock = None - _pipeline_status_lock = None - _graph_db_lock = None _data_init_lock = None _update_flags = None _async_locks = None + _default_workspace = None direct_log(f"Process {os.getpid()} storage data finalization complete") + + +def set_default_workspace(workspace: str | None = None): + """ + Set default workspace for namespace operations for backward compatibility. + + This allows get_namespace_data(),get_namespace_lock() or initialize_pipeline_status() to + automatically use the correct workspace when called without workspace parameters, + maintaining compatibility with legacy code that doesn't pass workspace explicitly. + + Args: + workspace: Workspace identifier (may be empty string for global namespace) + """ + global _default_workspace + if workspace is None: + workspace = "" + _default_workspace = workspace + direct_log( + f"Default workspace set to: '{_default_workspace}' (empty means global)", + level="DEBUG", + ) + + +def get_default_workspace() -> str: + """ + Get default workspace for backward compatibility. + + Returns: + The default workspace string. Empty string means global namespace. None means not set. + """ + global _default_workspace + return _default_workspace diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index c53c98ac..c0fa8627 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -64,9 +64,10 @@ from lightrag.kg import ( from lightrag.kg.shared_storage import ( get_namespace_data, - get_pipeline_status_lock, - get_graph_db_lock, get_data_init_lock, + get_default_workspace, + set_default_workspace, + get_namespace_lock, ) from lightrag.base import ( @@ -658,6 +659,22 @@ class LightRAG: async def initialize_storages(self): """Storage initialization must be called one by one to prevent deadlock""" if self._storages_status == StoragesStatus.CREATED: + # Set the first initialized workspace will set the default workspace + # Allows namespace operation without specifying workspace for backward compatibility + default_workspace = get_default_workspace() + if default_workspace is None: + set_default_workspace(self.workspace) + elif default_workspace != self.workspace: + logger.info( + f"Creating LightRAG instance with workspace='{self.workspace}' " + f"while default workspace is set to '{default_workspace}'" + ) + + # Auto-initialize pipeline_status for this workspace + from lightrag.kg.shared_storage import initialize_pipeline_status + + await initialize_pipeline_status(workspace=self.workspace) + for storage in ( self.full_docs, self.text_chunks, @@ -1592,8 +1609,12 @@ class LightRAG: """ # Get pipeline status shared data and lock - pipeline_status = await get_namespace_data("pipeline_status") - pipeline_status_lock = get_pipeline_status_lock() + pipeline_status = await get_namespace_data( + "pipeline_status", workspace=self.workspace + ) + pipeline_status_lock = get_namespace_lock( + "pipeline_status", workspace=self.workspace + ) # Check if another process is already processing the queue async with pipeline_status_lock: @@ -2927,6 +2948,26 @@ class LightRAG: data across different storage layers are removed or rebuiled. If entities or relationships are partially affected, they will be rebuilded using LLM cached from remaining documents. + **Concurrency Control Design:** + + This function implements a pipeline-based concurrency control to prevent data corruption: + + 1. **Single Document Deletion** (when WE acquire pipeline): + - Sets job_name to "Single document deletion" (NOT starting with "deleting") + - Prevents other adelete_by_doc_id calls from running concurrently + - Ensures exclusive access to graph operations for this deletion + + 2. **Batch Document Deletion** (when background_delete_documents acquires pipeline): + - Sets job_name to "Deleting {N} Documents" (starts with "deleting") + - Allows multiple adelete_by_doc_id calls to join the deletion queue + - Each call validates the job name to ensure it's part of a deletion operation + + The validation logic `if not job_name.startswith("deleting") or "document" not in job_name` + ensures that: + - adelete_by_doc_id can only run when pipeline is idle OR during batch deletion + - Prevents concurrent single deletions that could cause race conditions + - Rejects operations when pipeline is busy with non-deletion tasks + Args: doc_id (str): The unique identifier of the document to be deleted. delete_llm_cache (bool): Whether to delete cached LLM extraction results @@ -2934,20 +2975,62 @@ class LightRAG: Returns: DeletionResult: An object containing the outcome of the deletion process. - - `status` (str): "success", "not_found", or "failure". + - `status` (str): "success", "not_found", "not_allowed", or "failure". - `doc_id` (str): The ID of the document attempted to be deleted. - `message` (str): A summary of the operation's result. - - `status_code` (int): HTTP status code (e.g., 200, 404, 500). + - `status_code` (int): HTTP status code (e.g., 200, 404, 403, 500). - `file_path` (str | None): The file path of the deleted document, if available. """ + # Get pipeline status shared data and lock for validation + pipeline_status = await get_namespace_data( + "pipeline_status", workspace=self.workspace + ) + pipeline_status_lock = get_namespace_lock( + "pipeline_status", workspace=self.workspace + ) + + # Track whether WE acquired the pipeline + we_acquired_pipeline = False + + # Check and acquire pipeline if needed + async with pipeline_status_lock: + if not pipeline_status.get("busy", False): + # Pipeline is idle - WE acquire it for this deletion + we_acquired_pipeline = True + pipeline_status.update( + { + "busy": True, + "job_name": "Single document deletion", + "job_start": datetime.now(timezone.utc).isoformat(), + "docs": 1, + "batchs": 1, + "cur_batch": 0, + "request_pending": False, + "cancellation_requested": False, + "latest_message": f"Starting deletion for document: {doc_id}", + } + ) + # Initialize history messages + pipeline_status["history_messages"][:] = [ + f"Starting deletion for document: {doc_id}" + ] + else: + # Pipeline already busy - verify it's a deletion job + job_name = pipeline_status.get("job_name", "").lower() + if not job_name.startswith("deleting") or "document" not in job_name: + return DeletionResult( + status="not_allowed", + doc_id=doc_id, + message=f"Deletion not allowed: current job '{pipeline_status.get('job_name')}' is not a document deletion job", + status_code=403, + file_path=None, + ) + # Pipeline is busy with deletion - proceed without acquiring + deletion_operations_started = False original_exception = None doc_llm_cache_ids: list[str] = [] - # Get pipeline status shared data and lock for status updates - pipeline_status = await get_namespace_data("pipeline_status") - pipeline_status_lock = get_pipeline_status_lock() - async with pipeline_status_lock: log_message = f"Starting deletion process for document {doc_id}" logger.info(log_message) @@ -3300,31 +3383,111 @@ class LightRAG: logger.error(f"Failed to process graph analysis results: {e}") raise Exception(f"Failed to process graph dependencies: {e}") from e - # Use graph database lock to prevent dirty read - graph_db_lock = get_graph_db_lock(enable_logging=False) - async with graph_db_lock: - # 5. Delete chunks from storage - if chunk_ids: - try: - await self.chunks_vdb.delete(chunk_ids) - await self.text_chunks.delete(chunk_ids) + # Data integrity is ensured by allowing only one process to hold pipeline at a time(no graph db lock is needed anymore) - async with pipeline_status_lock: - log_message = f"Successfully deleted {len(chunk_ids)} chunks from storage" - logger.info(log_message) - pipeline_status["latest_message"] = log_message - pipeline_status["history_messages"].append(log_message) + # 5. Delete chunks from storage + if chunk_ids: + try: + await self.chunks_vdb.delete(chunk_ids) + await self.text_chunks.delete(chunk_ids) - except Exception as e: - logger.error(f"Failed to delete chunks: {e}") - raise Exception(f"Failed to delete document chunks: {e}") from e + async with pipeline_status_lock: + log_message = ( + f"Successfully deleted {len(chunk_ids)} chunks from storage" + ) + logger.info(log_message) + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) - # 6. Delete relationships that have no remaining sources - if relationships_to_delete: - try: - # Delete from relation vdb + except Exception as e: + logger.error(f"Failed to delete chunks: {e}") + raise Exception(f"Failed to delete document chunks: {e}") from e + + # 6. Delete relationships that have no remaining sources + if relationships_to_delete: + try: + # Delete from relation vdb + rel_ids_to_delete = [] + for src, tgt in relationships_to_delete: + rel_ids_to_delete.extend( + [ + compute_mdhash_id(src + tgt, prefix="rel-"), + compute_mdhash_id(tgt + src, prefix="rel-"), + ] + ) + await self.relationships_vdb.delete(rel_ids_to_delete) + + # Delete from graph + await self.chunk_entity_relation_graph.remove_edges( + list(relationships_to_delete) + ) + + # Delete from relation_chunks storage + if self.relation_chunks: + relation_storage_keys = [ + make_relation_chunk_key(src, tgt) + for src, tgt in relationships_to_delete + ] + await self.relation_chunks.delete(relation_storage_keys) + + async with pipeline_status_lock: + log_message = f"Successfully deleted {len(relationships_to_delete)} relations" + logger.info(log_message) + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) + + except Exception as e: + logger.error(f"Failed to delete relationships: {e}") + raise Exception(f"Failed to delete relationships: {e}") from e + + # 7. Delete entities that have no remaining sources + if entities_to_delete: + try: + # Batch get all edges for entities to avoid N+1 query problem + nodes_edges_dict = ( + await self.chunk_entity_relation_graph.get_nodes_edges_batch( + list(entities_to_delete) + ) + ) + + # Debug: Check and log all edges before deleting nodes + edges_to_delete = set() + edges_still_exist = 0 + + for entity, edges in nodes_edges_dict.items(): + if edges: + for src, tgt in edges: + # Normalize edge representation (sorted for consistency) + edge_tuple = tuple(sorted((src, tgt))) + edges_to_delete.add(edge_tuple) + + if ( + src in entities_to_delete + and tgt in entities_to_delete + ): + logger.warning( + f"Edge still exists: {src} <-> {tgt}" + ) + elif src in entities_to_delete: + logger.warning( + f"Edge still exists: {src} --> {tgt}" + ) + else: + logger.warning( + f"Edge still exists: {src} <-- {tgt}" + ) + edges_still_exist += 1 + + if edges_still_exist: + logger.warning( + f"⚠️ {edges_still_exist} entities still has edges before deletion" + ) + + # Clean residual edges from VDB and storage before deleting nodes + if edges_to_delete: + # Delete from relationships_vdb rel_ids_to_delete = [] - for src, tgt in relationships_to_delete: + for src, tgt in edges_to_delete: rel_ids_to_delete.extend( [ compute_mdhash_id(src + tgt, prefix="rel-"), @@ -3333,123 +3496,48 @@ class LightRAG: ) await self.relationships_vdb.delete(rel_ids_to_delete) - # Delete from graph - await self.chunk_entity_relation_graph.remove_edges( - list(relationships_to_delete) - ) - # Delete from relation_chunks storage if self.relation_chunks: relation_storage_keys = [ make_relation_chunk_key(src, tgt) - for src, tgt in relationships_to_delete + for src, tgt in edges_to_delete ] await self.relation_chunks.delete(relation_storage_keys) - async with pipeline_status_lock: - log_message = f"Successfully deleted {len(relationships_to_delete)} relations" - logger.info(log_message) - pipeline_status["latest_message"] = log_message - pipeline_status["history_messages"].append(log_message) - - except Exception as e: - logger.error(f"Failed to delete relationships: {e}") - raise Exception(f"Failed to delete relationships: {e}") from e - - # 7. Delete entities that have no remaining sources - if entities_to_delete: - try: - # Batch get all edges for entities to avoid N+1 query problem - nodes_edges_dict = await self.chunk_entity_relation_graph.get_nodes_edges_batch( - list(entities_to_delete) + logger.info( + f"Cleaned {len(edges_to_delete)} residual edges from VDB and chunk-tracking storage" ) - # Debug: Check and log all edges before deleting nodes - edges_to_delete = set() - edges_still_exist = 0 + # Delete from graph (edges will be auto-deleted with nodes) + await self.chunk_entity_relation_graph.remove_nodes( + list(entities_to_delete) + ) - for entity, edges in nodes_edges_dict.items(): - if edges: - for src, tgt in edges: - # Normalize edge representation (sorted for consistency) - edge_tuple = tuple(sorted((src, tgt))) - edges_to_delete.add(edge_tuple) + # Delete from vector vdb + entity_vdb_ids = [ + compute_mdhash_id(entity, prefix="ent-") + for entity in entities_to_delete + ] + await self.entities_vdb.delete(entity_vdb_ids) - if ( - src in entities_to_delete - and tgt in entities_to_delete - ): - logger.warning( - f"Edge still exists: {src} <-> {tgt}" - ) - elif src in entities_to_delete: - logger.warning( - f"Edge still exists: {src} --> {tgt}" - ) - else: - logger.warning( - f"Edge still exists: {src} <-- {tgt}" - ) - edges_still_exist += 1 + # Delete from entity_chunks storage + if self.entity_chunks: + await self.entity_chunks.delete(list(entities_to_delete)) - if edges_still_exist: - logger.warning( - f"⚠️ {edges_still_exist} entities still has edges before deletion" - ) - - # Clean residual edges from VDB and storage before deleting nodes - if edges_to_delete: - # Delete from relationships_vdb - rel_ids_to_delete = [] - for src, tgt in edges_to_delete: - rel_ids_to_delete.extend( - [ - compute_mdhash_id(src + tgt, prefix="rel-"), - compute_mdhash_id(tgt + src, prefix="rel-"), - ] - ) - await self.relationships_vdb.delete(rel_ids_to_delete) - - # Delete from relation_chunks storage - if self.relation_chunks: - relation_storage_keys = [ - make_relation_chunk_key(src, tgt) - for src, tgt in edges_to_delete - ] - await self.relation_chunks.delete(relation_storage_keys) - - logger.info( - f"Cleaned {len(edges_to_delete)} residual edges from VDB and chunk-tracking storage" - ) - - # Delete from graph (edges will be auto-deleted with nodes) - await self.chunk_entity_relation_graph.remove_nodes( - list(entities_to_delete) + async with pipeline_status_lock: + log_message = ( + f"Successfully deleted {len(entities_to_delete)} entities" ) + logger.info(log_message) + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) - # Delete from vector vdb - entity_vdb_ids = [ - compute_mdhash_id(entity, prefix="ent-") - for entity in entities_to_delete - ] - await self.entities_vdb.delete(entity_vdb_ids) + except Exception as e: + logger.error(f"Failed to delete entities: {e}") + raise Exception(f"Failed to delete entities: {e}") from e - # Delete from entity_chunks storage - if self.entity_chunks: - await self.entity_chunks.delete(list(entities_to_delete)) - - async with pipeline_status_lock: - log_message = f"Successfully deleted {len(entities_to_delete)} entities" - logger.info(log_message) - pipeline_status["latest_message"] = log_message - pipeline_status["history_messages"].append(log_message) - - except Exception as e: - logger.error(f"Failed to delete entities: {e}") - raise Exception(f"Failed to delete entities: {e}") from e - - # Persist changes to graph database before releasing graph database lock - await self._insert_done() + # Persist changes to graph database before entity and relationship rebuild + await self._insert_done() # 8. Rebuild entities and relationships from remaining chunks if entities_to_rebuild or relationships_to_rebuild: @@ -3555,6 +3643,18 @@ class LightRAG: f"No deletion operations were started for document {doc_id}, skipping persistence" ) + # Release pipeline only if WE acquired it + if we_acquired_pipeline: + async with pipeline_status_lock: + pipeline_status["busy"] = False + pipeline_status["cancellation_requested"] = False + completion_msg = ( + f"Deletion process completed for document: {doc_id}" + ) + pipeline_status["latest_message"] = completion_msg + pipeline_status["history_messages"].append(completion_msg) + logger.info(completion_msg) + async def adelete_by_entity(self, entity_name: str) -> DeletionResult: """Asynchronously delete an entity and all its relationships. diff --git a/lightrag/tools/check_initialization.py b/lightrag/tools/check_initialization.py index 6bcb17e3..ee650824 100644 --- a/lightrag/tools/check_initialization.py +++ b/lightrag/tools/check_initialization.py @@ -3,10 +3,17 @@ Diagnostic tool to check LightRAG initialization status. This tool helps developers verify that their LightRAG instance is properly -initialized before use, preventing common initialization errors. +initialized and ready to use. It should be called AFTER initialize_storages() +to validate that all components are correctly set up. Usage: - python -m lightrag.tools.check_initialization + # Basic usage in your code: + rag = LightRAG(...) + await rag.initialize_storages() + await check_lightrag_setup(rag, verbose=True) + + # Run demo from command line: + python -m lightrag.tools.check_initialization --demo """ import asyncio @@ -82,11 +89,11 @@ async def check_lightrag_setup(rag_instance: LightRAG, verbose: bool = False) -> try: from lightrag.kg.shared_storage import get_namespace_data - get_namespace_data("pipeline_status") + get_namespace_data("pipeline_status", workspace=rag_instance.workspace) print("✅ Pipeline status: INITIALIZED") except KeyError: issues.append( - "Pipeline status not initialized - call initialize_pipeline_status()" + "Pipeline status not initialized - call rag.initialize_storages() first" ) except Exception as e: issues.append(f"Error checking pipeline status: {str(e)}") @@ -101,8 +108,6 @@ async def check_lightrag_setup(rag_instance: LightRAG, verbose: bool = False) -> print("\n📝 To fix, run this initialization sequence:\n") print(" await rag.initialize_storages()") - print(" from lightrag.kg.shared_storage import initialize_pipeline_status") - print(" await initialize_pipeline_status()") print( "\n📚 Documentation: https://github.com/HKUDS/LightRAG#important-initialization-requirements" ) @@ -127,7 +132,6 @@ async def check_lightrag_setup(rag_instance: LightRAG, verbose: bool = False) -> async def demo(): """Demonstrate the diagnostic tool with a test instance.""" from lightrag.llm.openai import openai_embed, gpt_4o_mini_complete - from lightrag.kg.shared_storage import initialize_pipeline_status print("=" * 50) print("LightRAG Initialization Diagnostic Tool") @@ -140,15 +144,10 @@ async def demo(): llm_model_func=gpt_4o_mini_complete, ) - print("\n🔴 BEFORE initialization:\n") - await check_lightrag_setup(rag, verbose=True) + print("\n🔄 Initializing storages...\n") + await rag.initialize_storages() # Auto-initializes pipeline_status - print("\n" + "=" * 50) - print("\n🔄 Initializing...\n") - await rag.initialize_storages() - await initialize_pipeline_status() - - print("\n🟢 AFTER initialization:\n") + print("\n🔍 Checking initialization status:\n") await check_lightrag_setup(rag, verbose=True) # Cleanup diff --git a/lightrag/tools/clean_llm_query_cache.py b/lightrag/tools/clean_llm_query_cache.py index eca658c7..dbe2e455 100644 --- a/lightrag/tools/clean_llm_query_cache.py +++ b/lightrag/tools/clean_llm_query_cache.py @@ -463,7 +463,9 @@ class CleanupTool: # CRITICAL: Set update flag so changes persist to disk # Without this, deletions remain in-memory only and are lost on exit - await set_all_update_flags(storage.final_namespace) + await set_all_update_flags( + storage.namespace, workspace=storage.workspace + ) # Success stats.successful_batches += 1 diff --git a/pyproject.toml b/pyproject.toml index 3c7450f4..e40452e0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,13 @@ dependencies = [ ] [project.optional-dependencies] +# Test framework dependencies (for CI/CD and testing) +pytest = [ + "pytest>=8.4.2", + "pytest-asyncio>=1.2.0", + "pre-commit", +] + api = [ # Core dependencies "aiohttp", @@ -125,12 +132,14 @@ offline = [ ] evaluation = [ + # Test framework dependencies (for evaluation) + "pytest>=8.4.2", + "pytest-asyncio>=1.2.0", + "pre-commit", # RAG evaluation dependencies (RAGAS framework) "ragas>=0.3.7", "datasets>=4.3.0", "httpx>=0.28.1", - "pytest>=8.4.2", - "pytest-asyncio>=1.2.0", ] observability = [ @@ -162,5 +171,13 @@ version = {attr = "lightrag.__version__"} [tool.setuptools.package-data] lightrag = ["api/webui/**/*", "api/static/**/*"] +[tool.pytest.ini_options] +asyncio_mode = "auto" +asyncio_default_fixture_loop_scope = "function" +testpaths = ["tests"] +python_files = ["test_*.py"] +python_classes = ["Test*"] +python_functions = ["test_*"] + [tool.ruff] target-version = "py310" diff --git a/reproduce/Step_1.py b/reproduce/Step_1.py index c94015ad..933bfffa 100644 --- a/reproduce/Step_1.py +++ b/reproduce/Step_1.py @@ -4,7 +4,6 @@ import time import asyncio from lightrag import LightRAG -from lightrag.kg.shared_storage import initialize_pipeline_status def insert_text(rag, file_path): @@ -35,9 +34,7 @@ if not os.path.exists(WORKING_DIR): async def initialize_rag(): rag = LightRAG(working_dir=WORKING_DIR) - await rag.initialize_storages() - await initialize_pipeline_status() - + await rag.initialize_storages() # Auto-initializes pipeline_status return rag diff --git a/reproduce/Step_1_openai_compatible.py b/reproduce/Step_1_openai_compatible.py index 8093a9ee..434ab594 100644 --- a/reproduce/Step_1_openai_compatible.py +++ b/reproduce/Step_1_openai_compatible.py @@ -7,7 +7,6 @@ import numpy as np from lightrag import LightRAG from lightrag.utils import EmbeddingFunc from lightrag.llm.openai import openai_complete_if_cache, openai_embed -from lightrag.kg.shared_storage import initialize_pipeline_status ## For Upstage API @@ -70,9 +69,7 @@ async def initialize_rag(): embedding_func=EmbeddingFunc(embedding_dim=4096, func=embedding_func), ) - await rag.initialize_storages() - await initialize_pipeline_status() - + await rag.initialize_storages() # Auto-initializes pipeline_status return rag diff --git a/tests/README_WORKSPACE_ISOLATION_TESTS.md b/tests/README_WORKSPACE_ISOLATION_TESTS.md new file mode 100644 index 00000000..bf11e4ac --- /dev/null +++ b/tests/README_WORKSPACE_ISOLATION_TESTS.md @@ -0,0 +1,265 @@ +# Workspace Isolation Test Suite + +## Overview +Comprehensive test coverage for LightRAG's workspace isolation feature, ensuring that different workspaces (projects) can coexist independently without data contamination or resource conflicts. + +## Test Architecture + +### Design Principles +1. **Concurrency-Based Assertions**: Instead of timing-based tests (which are flaky), we measure actual concurrent lock holders +2. **Timeline Validation**: Finite state machine validates proper sequential execution +3. **Performance Metrics**: Each test reports execution metrics for debugging and optimization +4. **Configurable Stress Testing**: Environment variables control test intensity + +## Test Categories + +### 1. Data Isolation Tests +**Tests:** 1, 4, 8, 9, 10 +**Purpose:** Verify that data in one workspace doesn't leak into another + +- **Test 1: Pipeline Status Isolation** - Core shared data structures remain separate +- **Test 4: Multi-Workspace Concurrency** - Concurrent operations don't interfere +- **Test 8: Update Flags Isolation** - Flag management respects workspace boundaries +- **Test 9: Empty Workspace Standardization** - Edge case handling for empty workspace strings +- **Test 10: JsonKVStorage Integration** - Storage layer properly isolates data + +### 2. Lock Mechanism Tests +**Tests:** 2, 5, 6 +**Purpose:** Validate that locking mechanisms allow parallelism across workspaces while enforcing serialization within workspaces + +- **Test 2: Lock Mechanism** - Different workspaces run in parallel, same workspace serializes +- **Test 5: Re-entrance Protection** - Prevent deadlocks from re-entrant lock acquisition +- **Test 6: Namespace Lock Isolation** - Different namespaces within same workspace are independent + +### 3. Backward Compatibility Tests +**Test:** 3 +**Purpose:** Ensure legacy code without workspace parameters still functions correctly + +- Default workspace fallback behavior +- Empty workspace handling +- None vs empty string normalization + +### 4. Error Handling Tests +**Test:** 7 +**Purpose:** Validate guardrails for invalid configurations + +- Missing workspace validation +- Workspace normalization +- Edge case handling + +### 5. End-to-End Integration Tests +**Test:** 11 +**Purpose:** Validate complete LightRAG workflows maintain isolation + +- Full document insertion pipeline +- File system separation +- Data content verification + +## Running Tests + +### Basic Usage +```bash +# Run all workspace isolation tests +pytest tests/test_workspace_isolation.py -v + +# Run specific test +pytest tests/test_workspace_isolation.py::test_lock_mechanism -v + +# Run with detailed output +pytest tests/test_workspace_isolation.py -v -s +``` + +### Environment Configuration + +#### Stress Testing +Enable stress testing with configurable number of workers: +```bash +# Enable stress mode with default 3 workers +LIGHTRAG_STRESS_TEST=true pytest tests/test_workspace_isolation.py -v + +# Custom number of workers (e.g., 10) +LIGHTRAG_STRESS_TEST=true LIGHTRAG_TEST_WORKERS=10 pytest tests/test_workspace_isolation.py -v +``` + +#### Keep Test Artifacts +Preserve temporary directories for manual inspection: +```bash +# Keep test artifacts (useful for debugging) +LIGHTRAG_KEEP_ARTIFACTS=true pytest tests/test_workspace_isolation.py -v +``` + +#### Combined Example +```bash +# Stress test with 20 workers and keep artifacts +LIGHTRAG_STRESS_TEST=true \ +LIGHTRAG_TEST_WORKERS=20 \ +LIGHTRAG_KEEP_ARTIFACTS=true \ +pytest tests/test_workspace_isolation.py::test_lock_mechanism -v -s +``` + +### CI/CD Integration +```bash +# Recommended CI/CD command (no artifacts, default workers) +pytest tests/test_workspace_isolation.py -v --tb=short +``` + +## Test Implementation Details + +### Helper Functions + +#### `_measure_lock_parallelism` +Measures actual concurrency rather than wall-clock time. + +**Returns:** +- `max_parallel`: Peak number of concurrent lock holders +- `timeline`: Ordered list of (task_name, event) tuples +- `metrics`: Dict with performance data (duration, concurrency, workers) + +**Example:** +```python +workload = [ + ("task1", "workspace1", "namespace"), + ("task2", "workspace2", "namespace"), +] +max_parallel, timeline, metrics = await _measure_lock_parallelism(workload) + +# Assert on actual behavior, not timing +assert max_parallel >= 2 # Two different workspaces should run concurrently +``` + +#### `_assert_no_timeline_overlap` +Validates sequential execution using finite state machine. + +**Validates:** +- No overlapping lock acquisitions +- Proper lock release ordering +- All locks properly released + +**Example:** +```python +timeline = [ + ("task1", "start"), + ("task1", "end"), + ("task2", "start"), + ("task2", "end"), +] +_assert_no_timeline_overlap(timeline) # Passes - no overlap + +timeline_bad = [ + ("task1", "start"), + ("task2", "start"), # ERROR: task2 started before task1 ended + ("task1", "end"), +] +_assert_no_timeline_overlap(timeline_bad) # Raises AssertionError +``` + +## Configuration Variables + +| Variable | Type | Default | Description | +|----------|------|---------|-------------| +| `LIGHTRAG_STRESS_TEST` | bool | `false` | Enable stress testing mode | +| `LIGHTRAG_TEST_WORKERS` | int | `3` | Number of parallel workers in stress mode | +| `LIGHTRAG_KEEP_ARTIFACTS` | bool | `false` | Keep temporary test directories | + +## Performance Benchmarks + +### Expected Performance (Reference System) +- **Test 1-9**: < 1s each +- **Test 10**: < 2s (includes file I/O) +- **Test 11**: < 5s (includes full RAG pipeline) +- **Total Suite**: < 15s + +### Stress Test Performance +With `LIGHTRAG_TEST_WORKERS=10`: +- **Test 2 (Parallel)**: ~0.05s (10 workers, all concurrent) +- **Test 2 (Serial)**: ~0.10s (2 workers, serialized) + +## Troubleshooting + +### Common Issues + +#### Flaky Test Failures +**Symptom:** Tests pass locally but fail in CI/CD +**Cause:** System under heavy load, timing-based assertions +**Solution:** Our tests use concurrency-based assertions, not timing. If failures persist, check the `timeline` output in error messages. + +#### Resource Cleanup Errors +**Symptom:** "Directory not empty" or "Cannot remove directory" +**Cause:** Concurrent test execution or OS file locking +**Solution:** Run tests serially (`pytest -n 1`) or use `LIGHTRAG_KEEP_ARTIFACTS=true` to inspect state + +#### Lock Timeout Errors +**Symptom:** "Lock acquisition timeout" +**Cause:** Deadlock or resource starvation +**Solution:** Check test output for deadlock patterns, review lock acquisition order + +### Debug Tips + +1. **Enable verbose output:** + ```bash + pytest tests/test_workspace_isolation.py -v -s + ``` + +2. **Run single test with artifacts:** + ```bash + LIGHTRAG_KEEP_ARTIFACTS=true pytest tests/test_workspace_isolation.py::test_json_kv_storage_workspace_isolation -v -s + ``` + +3. **Check performance metrics:** + Look for the "Performance:" lines in test output showing duration and concurrency. + +4. **Inspect timeline on failure:** + Timeline data is included in assertion error messages. + +## Contributing + +### Adding New Tests + +1. **Follow naming convention:** `test__` +2. **Add purpose/scope comments:** Explain what and why +3. **Use helper functions:** `_measure_lock_parallelism`, `_assert_no_timeline_overlap` +4. **Document assertions:** Explain expected behavior in assertions +5. **Update this README:** Add test to appropriate category + +### Test Template +```python +@pytest.mark.asyncio +async def test_new_feature(): + """ + Brief description of what this test validates. + """ + # Purpose: Why this test exists + # Scope: What functions/classes this tests + print("\n" + "=" * 60) + print("TEST N: Feature Name") + print("=" * 60) + + # Test implementation + # ... + + print("✅ PASSED: Feature Name") + print(f" Validation details") +``` + +## Related Documentation + +- [Workspace Isolation Design Doc](../docs/LightRAG_concurrent_explain.md) +- [Project Intelligence](.clinerules/01-basic.md) +- [Memory Bank](../.memory-bank/) + +## Test Coverage Matrix + +| Component | Data Isolation | Lock Mechanism | Backward Compat | Error Handling | E2E | +|-----------|:--------------:|:--------------:|:---------------:|:--------------:|:---:| +| shared_storage | ✅ T1, T4 | ✅ T2, T5, T6 | ✅ T3 | ✅ T7 | ✅ T11 | +| update_flags | ✅ T8 | - | - | - | - | +| JsonKVStorage | ✅ T10 | - | - | - | ✅ T11 | +| LightRAG Core | - | - | - | - | ✅ T11 | +| Namespace | ✅ T9 | - | ✅ T3 | ✅ T7 | - | + +**Legend:** T# = Test number + +## Version History + +- **v2.0** (2025-01-18): Added performance metrics, stress testing, configurable cleanup +- **v1.0** (Initial): Basic workspace isolation tests with timing-based assertions diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..09769fd6 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,143 @@ +""" +Pytest configuration for LightRAG tests. + +This file provides command-line options and fixtures for test configuration. +""" + +import pytest + + +def pytest_configure(config): + """Register custom markers for LightRAG tests.""" + config.addinivalue_line( + "markers", "offline: marks tests as offline (no external dependencies)" + ) + config.addinivalue_line( + "markers", + "integration: marks tests requiring external services (skipped by default)", + ) + config.addinivalue_line("markers", "requires_db: marks tests requiring database") + config.addinivalue_line( + "markers", "requires_api: marks tests requiring LightRAG API server" + ) + + +def pytest_addoption(parser): + """Add custom command-line options for LightRAG tests.""" + + parser.addoption( + "--keep-artifacts", + action="store_true", + default=False, + help="Keep test artifacts (temporary directories and files) after test completion for inspection", + ) + + parser.addoption( + "--stress-test", + action="store_true", + default=False, + help="Enable stress test mode with more intensive workloads", + ) + + parser.addoption( + "--test-workers", + action="store", + default=3, + type=int, + help="Number of parallel workers for stress tests (default: 3)", + ) + + parser.addoption( + "--run-integration", + action="store_true", + default=False, + help="Run integration tests that require external services (database, API server, etc.)", + ) + + +def pytest_collection_modifyitems(config, items): + """Modify test collection to skip integration tests by default. + + Integration tests are skipped unless --run-integration flag is provided. + This allows running offline tests quickly without needing external services. + """ + if config.getoption("--run-integration"): + # If --run-integration is specified, run all tests + return + + skip_integration = pytest.mark.skip( + reason="Requires external services(DB/API), use --run-integration to run" + ) + + for item in items: + if "integration" in item.keywords: + item.add_marker(skip_integration) + + +@pytest.fixture(scope="session") +def keep_test_artifacts(request): + """ + Fixture to determine whether to keep test artifacts. + + Priority: CLI option > Environment variable > Default (False) + """ + import os + + # Check CLI option first + if request.config.getoption("--keep-artifacts"): + return True + + # Fall back to environment variable + return os.getenv("LIGHTRAG_KEEP_ARTIFACTS", "false").lower() == "true" + + +@pytest.fixture(scope="session") +def stress_test_mode(request): + """ + Fixture to determine whether stress test mode is enabled. + + Priority: CLI option > Environment variable > Default (False) + """ + import os + + # Check CLI option first + if request.config.getoption("--stress-test"): + return True + + # Fall back to environment variable + return os.getenv("LIGHTRAG_STRESS_TEST", "false").lower() == "true" + + +@pytest.fixture(scope="session") +def parallel_workers(request): + """ + Fixture to determine the number of parallel workers for stress tests. + + Priority: CLI option > Environment variable > Default (3) + """ + import os + + # Check CLI option first + cli_workers = request.config.getoption("--test-workers") + if cli_workers != 3: # Non-default value provided + return cli_workers + + # Fall back to environment variable + return int(os.getenv("LIGHTRAG_TEST_WORKERS", "3")) + + +@pytest.fixture(scope="session") +def run_integration_tests(request): + """ + Fixture to determine whether to run integration tests. + + Priority: CLI option > Environment variable > Default (False) + """ + import os + + # Check CLI option first + if request.config.getoption("--run-integration"): + return True + + # Fall back to environment variable + return os.getenv("LIGHTRAG_RUN_INTEGRATION", "false").lower() == "true" diff --git a/tests/test_aquery_data_endpoint.py b/tests/test_aquery_data_endpoint.py index 8845cb79..4866c779 100644 --- a/tests/test_aquery_data_endpoint.py +++ b/tests/test_aquery_data_endpoint.py @@ -9,6 +9,7 @@ Updated to handle the new data format where: - Includes backward compatibility with legacy format """ +import pytest import requests import time import json @@ -84,6 +85,8 @@ def parse_streaming_response( return references, response_chunks, errors +@pytest.mark.integration +@pytest.mark.requires_api def test_query_endpoint_references(): """Test /query endpoint references functionality""" @@ -187,6 +190,8 @@ def test_query_endpoint_references(): return True +@pytest.mark.integration +@pytest.mark.requires_api def test_query_stream_endpoint_references(): """Test /query/stream endpoint references functionality""" @@ -322,6 +327,8 @@ def test_query_stream_endpoint_references(): return True +@pytest.mark.integration +@pytest.mark.requires_api def test_references_consistency(): """Test references consistency across all endpoints""" @@ -472,6 +479,8 @@ def test_references_consistency(): return consistency_passed +@pytest.mark.integration +@pytest.mark.requires_api def test_aquery_data_endpoint(): """Test the /query/data endpoint""" @@ -654,6 +663,8 @@ def print_query_results(data: Dict[str, Any]): print("=" * 60) +@pytest.mark.integration +@pytest.mark.requires_api def compare_with_regular_query(): """Compare results between regular query and data query""" @@ -690,6 +701,8 @@ def compare_with_regular_query(): print(f" Regular query error: {str(e)}") +@pytest.mark.integration +@pytest.mark.requires_api def run_all_reference_tests(): """Run all reference-related tests""" diff --git a/tests/test_graph_storage.py b/tests/test_graph_storage.py index c6932384..64ed5dd5 100644 --- a/tests/test_graph_storage.py +++ b/tests/test_graph_storage.py @@ -18,6 +18,7 @@ import os import sys import importlib import numpy as np +import pytest from dotenv import load_dotenv from ascii_colors import ASCIIColors @@ -111,7 +112,6 @@ async def initialize_graph_storage(): } # Initialize shared_storage for all storage types (required for locks) - # All graph storage implementations use locks like get_data_init_lock() and get_graph_db_lock() initialize_share_data() # Use single-process mode (workers=1) try: @@ -130,6 +130,8 @@ async def initialize_graph_storage(): return None +@pytest.mark.integration +@pytest.mark.requires_db async def test_graph_basic(storage): """ Test basic graph database operations: @@ -255,6 +257,8 @@ async def test_graph_basic(storage): return False +@pytest.mark.integration +@pytest.mark.requires_db async def test_graph_advanced(storage): """ Test advanced graph database operations: @@ -475,6 +479,8 @@ async def test_graph_advanced(storage): return False +@pytest.mark.integration +@pytest.mark.requires_db async def test_graph_batch_operations(storage): """ Test batch operations of the graph database: @@ -828,6 +834,8 @@ async def test_graph_batch_operations(storage): return False +@pytest.mark.integration +@pytest.mark.requires_db async def test_graph_special_characters(storage): """ Test the graph database's handling of special characters: @@ -982,6 +990,8 @@ async def test_graph_special_characters(storage): return False +@pytest.mark.integration +@pytest.mark.requires_db async def test_graph_undirected_property(storage): """ Specifically test the undirected graph property of the storage: diff --git a/tests/test_lightrag_ollama_chat.py b/tests/test_lightrag_ollama_chat.py index 80038928..02dc9550 100644 --- a/tests/test_lightrag_ollama_chat.py +++ b/tests/test_lightrag_ollama_chat.py @@ -9,6 +9,7 @@ This script tests the LightRAG's Ollama compatibility interface, including: All responses use the JSON Lines format, complying with the Ollama API specification. """ +import pytest import requests import json import argparse @@ -75,8 +76,8 @@ class OutputControl: @dataclass -class TestResult: - """Test result data class""" +class ExecutionResult: + """Test execution result data class""" name: str success: bool @@ -89,14 +90,14 @@ class TestResult: self.timestamp = datetime.now().isoformat() -class TestStats: - """Test statistics""" +class ExecutionStats: + """Test execution statistics""" def __init__(self): - self.results: List[TestResult] = [] + self.results: List[ExecutionResult] = [] self.start_time = datetime.now() - def add_result(self, result: TestResult): + def add_result(self, result: ExecutionResult): self.results.append(result) def export_results(self, path: str = "test_results.json"): @@ -273,7 +274,7 @@ def create_generate_request_data( # Global test statistics -STATS = TestStats() +STATS = ExecutionStats() def run_test(func: Callable, name: str) -> None: @@ -286,13 +287,15 @@ def run_test(func: Callable, name: str) -> None: try: func() duration = time.time() - start_time - STATS.add_result(TestResult(name, True, duration)) + STATS.add_result(ExecutionResult(name, True, duration)) except Exception as e: duration = time.time() - start_time - STATS.add_result(TestResult(name, False, duration, str(e))) + STATS.add_result(ExecutionResult(name, False, duration, str(e))) raise +@pytest.mark.integration +@pytest.mark.requires_api def test_non_stream_chat() -> None: """Test non-streaming call to /api/chat endpoint""" url = get_base_url() @@ -317,6 +320,8 @@ def test_non_stream_chat() -> None: ) +@pytest.mark.integration +@pytest.mark.requires_api def test_stream_chat() -> None: """Test streaming call to /api/chat endpoint @@ -377,6 +382,8 @@ def test_stream_chat() -> None: print() +@pytest.mark.integration +@pytest.mark.requires_api def test_query_modes() -> None: """Test different query mode prefixes @@ -436,6 +443,8 @@ def create_error_test_data(error_type: str) -> Dict[str, Any]: return error_data.get(error_type, error_data["empty_messages"]) +@pytest.mark.integration +@pytest.mark.requires_api def test_stream_error_handling() -> None: """Test error handling for streaming responses @@ -482,6 +491,8 @@ def test_stream_error_handling() -> None: response.close() +@pytest.mark.integration +@pytest.mark.requires_api def test_error_handling() -> None: """Test error handling for non-streaming responses @@ -529,6 +540,8 @@ def test_error_handling() -> None: print_json_response(response.json(), "Error message") +@pytest.mark.integration +@pytest.mark.requires_api def test_non_stream_generate() -> None: """Test non-streaming call to /api/generate endpoint""" url = get_base_url("generate") @@ -548,6 +561,8 @@ def test_non_stream_generate() -> None: print(json.dumps(response_json, ensure_ascii=False, indent=2)) +@pytest.mark.integration +@pytest.mark.requires_api def test_stream_generate() -> None: """Test streaming call to /api/generate endpoint""" url = get_base_url("generate") @@ -588,6 +603,8 @@ def test_stream_generate() -> None: print() +@pytest.mark.integration +@pytest.mark.requires_api def test_generate_with_system() -> None: """Test generate with system prompt""" url = get_base_url("generate") @@ -616,6 +633,8 @@ def test_generate_with_system() -> None: ) +@pytest.mark.integration +@pytest.mark.requires_api def test_generate_error_handling() -> None: """Test error handling for generate endpoint""" url = get_base_url("generate") @@ -641,6 +660,8 @@ def test_generate_error_handling() -> None: print_json_response(response.json(), "Error message") +@pytest.mark.integration +@pytest.mark.requires_api def test_generate_concurrent() -> None: """Test concurrent generate requests""" import asyncio diff --git a/tests/test_postgres_retry_integration.py b/tests/test_postgres_retry_integration.py index 515f3072..2c7b3499 100644 --- a/tests/test_postgres_retry_integration.py +++ b/tests/test_postgres_retry_integration.py @@ -24,6 +24,8 @@ asyncpg = pytest.importorskip("asyncpg") load_dotenv(dotenv_path=".env", override=False) +@pytest.mark.integration +@pytest.mark.requires_db class TestPostgresRetryIntegration: """Integration tests for PostgreSQL retry mechanism with real database.""" diff --git a/tests/test_workspace_isolation.py b/tests/test_workspace_isolation.py new file mode 100644 index 00000000..68f7f8ec --- /dev/null +++ b/tests/test_workspace_isolation.py @@ -0,0 +1,1193 @@ +#!/usr/bin/env python +""" +Test script for Workspace Isolation Feature + +Comprehensive test suite covering workspace isolation in LightRAG: +1. Pipeline Status Isolation - Data isolation between workspaces +2. Lock Mechanism - Parallel execution for different workspaces, serial for same workspace +3. Backward Compatibility - Legacy code without workspace parameters +4. Multi-Workspace Concurrency - Concurrent operations on different workspaces +5. NamespaceLock Re-entrance Protection - Prevents deadlocks +6. Different Namespace Lock Isolation - Locks isolated by namespace +7. Error Handling - Invalid workspace configurations +8. Update Flags Workspace Isolation - Update flags properly isolated +9. Empty Workspace Standardization - Empty workspace handling +10. JsonKVStorage Workspace Isolation - Integration test for KV storage +11. LightRAG End-to-End Workspace Isolation - Complete E2E test with two instances + +Total: 11 test scenarios +""" + +import asyncio +import time +import os +import shutil +import numpy as np +import pytest +from pathlib import Path +from typing import List, Tuple, Dict +from lightrag.kg.shared_storage import ( + get_final_namespace, + get_namespace_lock, + get_default_workspace, + set_default_workspace, + initialize_share_data, + finalize_share_data, + initialize_pipeline_status, + get_namespace_data, + set_all_update_flags, + clear_all_update_flags, + get_all_update_flags_status, + get_update_flag, +) + + +# ============================================================================= +# Test Configuration +# ============================================================================= + +# Test configuration is handled via pytest fixtures in conftest.py +# - Use CLI options: --keep-artifacts, --stress-test, --test-workers=N +# - Or environment variables: LIGHTRAG_KEEP_ARTIFACTS, LIGHTRAG_STRESS_TEST, LIGHTRAG_TEST_WORKERS +# Priority: CLI options > Environment variables > Default values + + +# ============================================================================= +# Pytest Fixtures +# ============================================================================= + + +@pytest.fixture(autouse=True) +def setup_shared_data(): + """Initialize shared data before each test""" + initialize_share_data() + yield + finalize_share_data() + + +async def _measure_lock_parallelism( + workload: List[Tuple[str, str, str]], hold_time: float = 0.05 +) -> Tuple[int, List[Tuple[str, str]], Dict[str, float]]: + """Run lock acquisition workload and capture peak concurrency and timeline. + + Args: + workload: List of (name, workspace, namespace) tuples + hold_time: How long each worker holds the lock (seconds) + + Returns: + Tuple of (max_parallel, timeline, metrics) where: + - max_parallel: Peak number of concurrent lock holders + - timeline: List of (name, event) tuples tracking execution order + - metrics: Dict with performance metrics (total_duration, max_concurrency, etc.) + """ + + running = 0 + max_parallel = 0 + timeline: List[Tuple[str, str]] = [] + start_time = time.time() + + async def worker(name: str, workspace: str, namespace: str) -> None: + nonlocal running, max_parallel + lock = get_namespace_lock(namespace, workspace) + async with lock: + running += 1 + max_parallel = max(max_parallel, running) + timeline.append((name, "start")) + await asyncio.sleep(hold_time) + timeline.append((name, "end")) + running -= 1 + + await asyncio.gather(*(worker(*args) for args in workload)) + + metrics = { + "total_duration": time.time() - start_time, + "max_concurrency": max_parallel, + "avg_hold_time": hold_time, + "num_workers": len(workload), + } + + return max_parallel, timeline, metrics + + +def _assert_no_timeline_overlap(timeline: List[Tuple[str, str]]) -> None: + """Ensure that timeline events never overlap for sequential execution. + + This function implements a finite state machine that validates: + - No overlapping lock acquisitions (only one task active at a time) + - Proper lock release order (task releases its own lock) + - All locks are properly released + + Args: + timeline: List of (name, event) tuples where event is "start" or "end" + + Raises: + AssertionError: If timeline shows overlapping execution or improper locking + """ + + active_task = None + for name, event in timeline: + if event == "start": + if active_task is not None: + raise AssertionError( + f"Task '{name}' started before '{active_task}' released the lock" + ) + active_task = name + else: + if active_task != name: + raise AssertionError( + f"Task '{name}' finished while '{active_task}' was expected to hold the lock" + ) + active_task = None + + if active_task is not None: + raise AssertionError(f"Task '{active_task}' did not release the lock properly") + + +# ============================================================================= +# Test 1: Pipeline Status Isolation Test +# ============================================================================= + + +@pytest.mark.offline +@pytest.mark.asyncio +async def test_pipeline_status_isolation(): + """ + Test that pipeline status is isolated between different workspaces. + """ + # Purpose: Ensure pipeline_status shared data remains unique per workspace. + # Scope: initialize_pipeline_status and get_namespace_data interactions. + print("\n" + "=" * 60) + print("TEST 1: Pipeline Status Isolation") + print("=" * 60) + + # Initialize shared storage + initialize_share_data() + + # Initialize pipeline status for two different workspaces + workspace1 = "test_workspace_1" + workspace2 = "test_workspace_2" + + await initialize_pipeline_status(workspace1) + await initialize_pipeline_status(workspace2) + + # Get pipeline status data for both workspaces + data1 = await get_namespace_data("pipeline_status", workspace=workspace1) + data2 = await get_namespace_data("pipeline_status", workspace=workspace2) + + # Verify they are independent objects + assert ( + data1 is not data2 + ), "Pipeline status data objects are the same (should be different)" + + # Modify workspace1's data and verify workspace2 is not affected + data1["test_key"] = "workspace1_value" + + # Re-fetch to ensure we get the latest data + data1_check = await get_namespace_data("pipeline_status", workspace=workspace1) + data2_check = await get_namespace_data("pipeline_status", workspace=workspace2) + + assert "test_key" in data1_check, "test_key not found in workspace1" + assert ( + data1_check["test_key"] == "workspace1_value" + ), f"workspace1 test_key value incorrect: {data1_check.get('test_key')}" + assert ( + "test_key" not in data2_check + ), f"test_key leaked to workspace2: {data2_check.get('test_key')}" + + print("✅ PASSED: Pipeline Status Isolation") + print(" Different workspaces have isolated pipeline status") + + +# ============================================================================= +# Test 2: Lock Mechanism Test (No Deadlocks) +# ============================================================================= + + +@pytest.mark.offline +@pytest.mark.asyncio +async def test_lock_mechanism(stress_test_mode, parallel_workers): + """ + Test that the new keyed lock mechanism works correctly without deadlocks. + Tests both parallel execution for different workspaces and serialization + for the same workspace. + """ + # Purpose: Validate that keyed locks isolate workspaces while serializing + # requests within the same workspace. Scope: get_namespace_lock scheduling + # semantics for both cross-workspace and single-workspace cases. + print("\n" + "=" * 60) + print("TEST 2: Lock Mechanism (No Deadlocks)") + print("=" * 60) + + # Test 2.1: Different workspaces should run in parallel + print("\nTest 2.1: Different workspaces locks should be parallel") + + # Support stress testing with configurable number of workers + num_workers = parallel_workers if stress_test_mode else 3 + parallel_workload = [ + (f"ws_{chr(97+i)}", f"ws_{chr(97+i)}", "test_namespace") + for i in range(num_workers) + ] + + max_parallel, timeline_parallel, metrics = await _measure_lock_parallelism( + parallel_workload + ) + assert max_parallel >= 2, ( + "Locks for distinct workspaces should overlap; " + f"observed max concurrency: {max_parallel}, timeline={timeline_parallel}" + ) + + print("✅ PASSED: Lock Mechanism - Parallel (Different Workspaces)") + print( + f" Locks overlapped for different workspaces (max concurrency={max_parallel})" + ) + print( + f" Performance: {metrics['total_duration']:.3f}s for {metrics['num_workers']} workers" + ) + + # Test 2.2: Same workspace should serialize + print("\nTest 2.2: Same workspace locks should serialize") + serial_workload = [ + ("serial_run_1", "ws_same", "test_namespace"), + ("serial_run_2", "ws_same", "test_namespace"), + ] + ( + max_parallel_serial, + timeline_serial, + metrics_serial, + ) = await _measure_lock_parallelism(serial_workload) + assert max_parallel_serial == 1, ( + "Same workspace locks should not overlap; " + f"observed {max_parallel_serial} with timeline {timeline_serial}" + ) + _assert_no_timeline_overlap(timeline_serial) + + print("✅ PASSED: Lock Mechanism - Serial (Same Workspace)") + print(" Same workspace operations executed sequentially with no overlap") + print( + f" Performance: {metrics_serial['total_duration']:.3f}s for {metrics_serial['num_workers']} tasks" + ) + + +# ============================================================================= +# Test 3: Backward Compatibility Test +# ============================================================================= + + +@pytest.mark.offline +@pytest.mark.asyncio +async def test_backward_compatibility(): + """ + Test that legacy code without workspace parameter still works correctly. + """ + # Purpose: Validate backward-compatible defaults when workspace arguments + # are omitted. Scope: get_final_namespace, set/get_default_workspace and + # initialize_pipeline_status fallback behavior. + print("\n" + "=" * 60) + print("TEST 3: Backward Compatibility") + print("=" * 60) + + # Test 3.1: get_final_namespace with None should use default workspace + print("\nTest 3.1: get_final_namespace with workspace=None") + + set_default_workspace("my_default_workspace") + final_ns = get_final_namespace("pipeline_status") + expected = "my_default_workspace:pipeline_status" + + assert final_ns == expected, f"Expected {expected}, got {final_ns}" + + print("✅ PASSED: Backward Compatibility - get_final_namespace") + print(f" Correctly uses default workspace: {final_ns}") + + # Test 3.2: get_default_workspace + print("\nTest 3.2: get/set default workspace") + + set_default_workspace("test_default") + retrieved = get_default_workspace() + + assert retrieved == "test_default", f"Expected 'test_default', got {retrieved}" + + print("✅ PASSED: Backward Compatibility - default workspace") + print(f" Default workspace set/get correctly: {retrieved}") + + # Test 3.3: Empty workspace handling + print("\nTest 3.3: Empty workspace handling") + + set_default_workspace("") + final_ns_empty = get_final_namespace("pipeline_status", workspace=None) + expected_empty = "pipeline_status" # Should be just the namespace without ':' + + assert ( + final_ns_empty == expected_empty + ), f"Expected '{expected_empty}', got '{final_ns_empty}'" + + print("✅ PASSED: Backward Compatibility - empty workspace") + print(f" Empty workspace handled correctly: '{final_ns_empty}'") + + # Test 3.4: None workspace with default set + print("\nTest 3.4: initialize_pipeline_status with workspace=None") + set_default_workspace("compat_test_workspace") + initialize_share_data() + await initialize_pipeline_status(workspace=None) # Should use default + + # Try to get data using the default workspace explicitly + data = await get_namespace_data( + "pipeline_status", workspace="compat_test_workspace" + ) + + assert ( + data is not None + ), "Failed to initialize pipeline status with default workspace" + + print("✅ PASSED: Backward Compatibility - pipeline init with None") + print(" Pipeline status initialized with default workspace") + + +# ============================================================================= +# Test 4: Multi-Workspace Concurrency Test +# ============================================================================= + + +@pytest.mark.offline +@pytest.mark.asyncio +async def test_multi_workspace_concurrency(): + """ + Test that multiple workspaces can operate concurrently without interference. + Simulates concurrent operations on different workspaces. + """ + # Purpose: Simulate concurrent workloads touching pipeline_status across + # workspaces. Scope: initialize_pipeline_status, get_namespace_lock, and + # shared dictionary mutation while ensuring isolation. + print("\n" + "=" * 60) + print("TEST 4: Multi-Workspace Concurrency") + print("=" * 60) + + initialize_share_data() + + async def workspace_operations(workspace_id): + """Simulate operations on a specific workspace""" + print(f"\n [{workspace_id}] Starting operations") + + # Initialize pipeline status + await initialize_pipeline_status(workspace_id) + + # Get lock and perform operations + lock = get_namespace_lock("test_operations", workspace_id) + async with lock: + # Get workspace data + data = await get_namespace_data("pipeline_status", workspace=workspace_id) + + # Modify data + data[f"{workspace_id}_key"] = f"{workspace_id}_value" + data["timestamp"] = time.time() + + # Simulate some work + await asyncio.sleep(0.1) + + print(f" [{workspace_id}] Completed operations") + + return workspace_id + + # Run multiple workspaces concurrently + workspaces = ["concurrent_ws_1", "concurrent_ws_2", "concurrent_ws_3"] + + start = time.time() + results_list = await asyncio.gather( + *[workspace_operations(ws) for ws in workspaces] + ) + elapsed = time.time() - start + + print(f"\n All workspaces completed in {elapsed:.2f}s") + + # Verify all workspaces completed + assert set(results_list) == set(workspaces), "Not all workspaces completed" + + print("✅ PASSED: Multi-Workspace Concurrency - Execution") + print( + f" All {len(workspaces)} workspaces completed successfully in {elapsed:.2f}s" + ) + + # Verify data isolation - each workspace should have its own data + print("\n Verifying data isolation...") + + for ws in workspaces: + data = await get_namespace_data("pipeline_status", workspace=ws) + expected_key = f"{ws}_key" + expected_value = f"{ws}_value" + + assert ( + expected_key in data + ), f"Data not properly isolated for {ws}: missing {expected_key}" + assert ( + data[expected_key] == expected_value + ), f"Data not properly isolated for {ws}: {expected_key}={data[expected_key]} (expected {expected_value})" + print(f" [{ws}] Data correctly isolated: {expected_key}={data[expected_key]}") + + print("✅ PASSED: Multi-Workspace Concurrency - Data Isolation") + print(" All workspaces have properly isolated data") + + +# ============================================================================= +# Test 5: NamespaceLock Re-entrance Protection +# ============================================================================= + + +@pytest.mark.offline +@pytest.mark.asyncio +async def test_namespace_lock_reentrance(): + """ + Test that NamespaceLock prevents re-entrance in the same coroutine + and allows concurrent use in different coroutines. + """ + # Purpose: Ensure NamespaceLock enforces single entry per coroutine while + # allowing concurrent reuse through ContextVar isolation. Scope: lock + # re-entrance checks and concurrent gather semantics. + print("\n" + "=" * 60) + print("TEST 5: NamespaceLock Re-entrance Protection") + print("=" * 60) + + # Test 5.1: Same coroutine re-entrance should fail + print("\nTest 5.1: Same coroutine re-entrance should raise RuntimeError") + + lock = get_namespace_lock("test_reentrance", "test_ws") + + reentrance_failed_correctly = False + try: + async with lock: + print(" Acquired lock first time") + # Try to acquire the same lock again in the same coroutine + async with lock: + print(" ERROR: Should not reach here - re-entrance succeeded!") + except RuntimeError as e: + if "already acquired" in str(e).lower(): + print(f" ✓ Re-entrance correctly blocked: {e}") + reentrance_failed_correctly = True + else: + raise + + assert reentrance_failed_correctly, "Re-entrance protection not working" + + print("✅ PASSED: NamespaceLock Re-entrance Protection") + print(" Re-entrance correctly raises RuntimeError") + + # Test 5.2: Same NamespaceLock instance in different coroutines should succeed + print("\nTest 5.2: Same NamespaceLock instance in different coroutines") + + shared_lock = get_namespace_lock("test_concurrent", "test_ws") + concurrent_results = [] + + async def use_shared_lock(coroutine_id): + """Use the same NamespaceLock instance""" + async with shared_lock: + concurrent_results.append(f"coroutine_{coroutine_id}_start") + await asyncio.sleep(0.1) + concurrent_results.append(f"coroutine_{coroutine_id}_end") + + # This should work because each coroutine gets its own ContextVar + await asyncio.gather( + use_shared_lock(1), + use_shared_lock(2), + ) + + # Both coroutines should have completed + expected_entries = 4 # 2 starts + 2 ends + assert ( + len(concurrent_results) == expected_entries + ), f"Expected {expected_entries} entries, got {len(concurrent_results)}" + + print("✅ PASSED: NamespaceLock Concurrent Reuse") + print( + f" Same NamespaceLock instance used successfully in {expected_entries//2} concurrent coroutines" + ) + + +# ============================================================================= +# Test 6: Different Namespace Lock Isolation +# ============================================================================= + + +@pytest.mark.offline +@pytest.mark.asyncio +async def test_different_namespace_lock_isolation(): + """ + Test that locks for different namespaces (same workspace) are independent. + """ + # Purpose: Confirm that namespace isolation is enforced even when workspace + # is the same. Scope: get_namespace_lock behavior when namespaces differ. + print("\n" + "=" * 60) + print("TEST 6: Different Namespace Lock Isolation") + print("=" * 60) + + print("\nTesting locks with same workspace but different namespaces") + + workload = [ + ("ns_a", "same_ws", "namespace_a"), + ("ns_b", "same_ws", "namespace_b"), + ("ns_c", "same_ws", "namespace_c"), + ] + max_parallel, timeline, metrics = await _measure_lock_parallelism(workload) + + assert max_parallel >= 2, ( + "Different namespaces within the same workspace should run concurrently; " + f"observed max concurrency {max_parallel} with timeline {timeline}" + ) + + print("✅ PASSED: Different Namespace Lock Isolation") + print( + f" Different namespace locks ran in parallel (max concurrency={max_parallel})" + ) + print( + f" Performance: {metrics['total_duration']:.3f}s for {metrics['num_workers']} namespaces" + ) + + +# ============================================================================= +# Test 7: Error Handling +# ============================================================================= + + +@pytest.mark.offline +@pytest.mark.asyncio +async def test_error_handling(): + """ + Test error handling for invalid workspace configurations. + """ + # Purpose: Validate guardrails for workspace normalization and namespace + # derivation. Scope: set_default_workspace conversions and get_final_namespace + # failure paths when configuration is invalid. + print("\n" + "=" * 60) + print("TEST 7: Error Handling") + print("=" * 60) + + # Test 7.0: Missing default workspace should raise ValueError + print("\nTest 7.0: Missing workspace raises ValueError") + with pytest.raises(ValueError): + get_final_namespace("test_namespace", workspace=None) + + # Test 7.1: set_default_workspace(None) converts to empty string + print("\nTest 7.1: set_default_workspace(None) converts to empty string") + + set_default_workspace(None) + default_ws = get_default_workspace() + + # Should convert None to "" automatically + assert default_ws == "", f"Expected empty string, got: '{default_ws}'" + + print("✅ PASSED: Error Handling - None to Empty String") + print( + f" set_default_workspace(None) correctly converts to empty string: '{default_ws}'" + ) + + # Test 7.2: Empty string workspace behavior + print("\nTest 7.2: Empty string workspace creates valid namespace") + + # With empty workspace, should create namespace without colon + final_ns = get_final_namespace("test_namespace", workspace="") + assert final_ns == "test_namespace", f"Unexpected namespace: '{final_ns}'" + + print("✅ PASSED: Error Handling - Empty Workspace Namespace") + print(f" Empty workspace creates valid namespace: '{final_ns}'") + + # Restore default workspace for other tests + set_default_workspace("") + + +# ============================================================================= +# Test 8: Update Flags Workspace Isolation +# ============================================================================= + + +@pytest.mark.offline +@pytest.mark.asyncio +async def test_update_flags_workspace_isolation(): + """ + Test that update flags are properly isolated between workspaces. + """ + # Purpose: Confirm update flag setters/readers respect workspace scoping. + # Scope: set_all_update_flags, clear_all_update_flags, get_all_update_flags_status, + # and get_update_flag interactions across namespaces. + print("\n" + "=" * 60) + print("TEST 8: Update Flags Workspace Isolation") + print("=" * 60) + + initialize_share_data() + + workspace1 = "update_flags_ws1" + workspace2 = "update_flags_ws2" + test_namespace = "test_update_flags_ns" + + # Initialize namespaces for both workspaces + await initialize_pipeline_status(workspace1) + await initialize_pipeline_status(workspace2) + + # Test 8.1: set_all_update_flags isolation + print("\nTest 8.1: set_all_update_flags workspace isolation") + + # Create flags for both workspaces (simulating workers) + flag1_obj = await get_update_flag(test_namespace, workspace=workspace1) + flag2_obj = await get_update_flag(test_namespace, workspace=workspace2) + + # Initial state should be False + assert flag1_obj.value is False, "Flag1 initial value should be False" + assert flag2_obj.value is False, "Flag2 initial value should be False" + + # Set all flags for workspace1 + await set_all_update_flags(test_namespace, workspace=workspace1) + + # Check that only workspace1's flags are set + assert ( + flag1_obj.value is True + ), f"Flag1 should be True after set_all_update_flags, got {flag1_obj.value}" + assert ( + flag2_obj.value is False + ), f"Flag2 should still be False, got {flag2_obj.value}" + + print("✅ PASSED: Update Flags - set_all_update_flags Isolation") + print( + f" set_all_update_flags isolated: ws1={flag1_obj.value}, ws2={flag2_obj.value}" + ) + + # Test 8.2: clear_all_update_flags isolation + print("\nTest 8.2: clear_all_update_flags workspace isolation") + + # Set flags for both workspaces + await set_all_update_flags(test_namespace, workspace=workspace1) + await set_all_update_flags(test_namespace, workspace=workspace2) + + # Verify both are set + assert flag1_obj.value is True, "Flag1 should be True" + assert flag2_obj.value is True, "Flag2 should be True" + + # Clear only workspace1 + await clear_all_update_flags(test_namespace, workspace=workspace1) + + # Check that only workspace1's flags are cleared + assert ( + flag1_obj.value is False + ), f"Flag1 should be False after clear, got {flag1_obj.value}" + assert flag2_obj.value is True, f"Flag2 should still be True, got {flag2_obj.value}" + + print("✅ PASSED: Update Flags - clear_all_update_flags Isolation") + print( + f" clear_all_update_flags isolated: ws1={flag1_obj.value}, ws2={flag2_obj.value}" + ) + + # Test 8.3: get_all_update_flags_status workspace filtering + print("\nTest 8.3: get_all_update_flags_status workspace filtering") + + # Initialize more namespaces for testing + await get_update_flag("ns_a", workspace=workspace1) + await get_update_flag("ns_b", workspace=workspace1) + await get_update_flag("ns_c", workspace=workspace2) + + # Set flags for workspace1 + await set_all_update_flags("ns_a", workspace=workspace1) + await set_all_update_flags("ns_b", workspace=workspace1) + + # Set flags for workspace2 + await set_all_update_flags("ns_c", workspace=workspace2) + + # Get status for workspace1 only + status1 = await get_all_update_flags_status(workspace=workspace1) + + # Check that workspace1's namespaces are present + # The keys should include workspace1's namespaces but not workspace2's + workspace1_keys = [k for k in status1.keys() if workspace1 in k] + workspace2_keys = [k for k in status1.keys() if workspace2 in k] + + assert ( + len(workspace1_keys) > 0 + ), f"workspace1 keys should be present, got {len(workspace1_keys)}" + assert ( + len(workspace2_keys) == 0 + ), f"workspace2 keys should not be present, got {len(workspace2_keys)}" + for key, values in status1.items(): + assert all(values), f"All flags in {key} should be True, got {values}" + + # Workspace2 query should only surface workspace2 namespaces + status2 = await get_all_update_flags_status(workspace=workspace2) + expected_ws2_keys = { + f"{workspace2}:{test_namespace}", + f"{workspace2}:ns_c", + } + assert ( + set(status2.keys()) == expected_ws2_keys + ), f"Unexpected namespaces for workspace2: {status2.keys()}" + for key, values in status2.items(): + assert all(values), f"All flags in {key} should be True, got {values}" + + print("✅ PASSED: Update Flags - get_all_update_flags_status Filtering") + print( + f" Status correctly filtered: ws1 keys={len(workspace1_keys)}, ws2 keys={len(workspace2_keys)}" + ) + + +# ============================================================================= +# Test 9: Empty Workspace Standardization +# ============================================================================= + + +@pytest.mark.offline +@pytest.mark.asyncio +async def test_empty_workspace_standardization(): + """ + Test that empty workspace is properly standardized to "" instead of "_". + """ + # Purpose: Verify namespace formatting when workspace is an empty string. + # Scope: get_final_namespace output and initialize_pipeline_status behavior + # between empty and non-empty workspaces. + print("\n" + "=" * 60) + print("TEST 9: Empty Workspace Standardization") + print("=" * 60) + + # Test 9.1: Empty string workspace creates namespace without colon + print("\nTest 9.1: Empty string workspace namespace format") + + set_default_workspace("") + final_ns = get_final_namespace("test_namespace", workspace=None) + + # Should be just "test_namespace" without colon prefix + assert ( + final_ns == "test_namespace" + ), f"Unexpected namespace format: '{final_ns}' (expected 'test_namespace')" + + print("✅ PASSED: Empty Workspace Standardization - Format") + print(f" Empty workspace creates correct namespace: '{final_ns}'") + + # Test 9.2: Empty workspace vs non-empty workspace behavior + print("\nTest 9.2: Empty vs non-empty workspace behavior") + + initialize_share_data() + + # Initialize with empty workspace + await initialize_pipeline_status(workspace="") + data_empty = await get_namespace_data("pipeline_status", workspace="") + + # Initialize with non-empty workspace + await initialize_pipeline_status(workspace="test_ws") + data_nonempty = await get_namespace_data("pipeline_status", workspace="test_ws") + + # They should be different objects + assert ( + data_empty is not data_nonempty + ), "Empty and non-empty workspaces share data (should be independent)" + + print("✅ PASSED: Empty Workspace Standardization - Behavior") + print(" Empty and non-empty workspaces have independent data") + + +# ============================================================================= +# Test 10: JsonKVStorage Workspace Isolation (Integration Test) +# ============================================================================= + + +@pytest.mark.offline +@pytest.mark.asyncio +async def test_json_kv_storage_workspace_isolation(keep_test_artifacts): + """ + Integration test: Verify JsonKVStorage properly isolates data between workspaces. + Creates two JsonKVStorage instances with different workspaces, writes different data, + and verifies they don't mix. + """ + # Purpose: Ensure JsonKVStorage respects workspace-specific directories and data. + # Scope: storage initialization, upsert/get_by_id operations, and filesystem layout + # inside the temporary working directory. + print("\n" + "=" * 60) + print("TEST 10: JsonKVStorage Workspace Isolation (Integration)") + print("=" * 60) + + # Create temporary test directory under project temp/ + test_dir = str( + Path(__file__).parent.parent / "temp/test_json_kv_storage_workspace_isolation" + ) + if os.path.exists(test_dir): + shutil.rmtree(test_dir) + os.makedirs(test_dir, exist_ok=True) + print(f"\n Using test directory: {test_dir}") + + try: + initialize_share_data() + + # Mock embedding function + async def mock_embedding_func(texts: list[str]) -> np.ndarray: + return np.random.rand(len(texts), 384) # 384-dimensional vectors + + # Global config + global_config = { + "working_dir": test_dir, + "embedding_batch_num": 10, + } + + # Test 10.1: Create two JsonKVStorage instances with different workspaces + print( + "\nTest 10.1: Create two JsonKVStorage instances with different workspaces" + ) + + from lightrag.kg.json_kv_impl import JsonKVStorage + + storage1 = JsonKVStorage( + namespace="entities", + workspace="workspace1", + global_config=global_config, + embedding_func=mock_embedding_func, + ) + + storage2 = JsonKVStorage( + namespace="entities", + workspace="workspace2", + global_config=global_config, + embedding_func=mock_embedding_func, + ) + + # Initialize both storages + await storage1.initialize() + await storage2.initialize() + + print(" Storage1 created: workspace=workspace1, namespace=entities") + print(" Storage2 created: workspace=workspace2, namespace=entities") + + # Test 10.2: Write different data to each storage + print("\nTest 10.2: Write different data to each storage") + + # Write to storage1 (upsert expects dict[str, dict]) + await storage1.upsert( + { + "entity1": { + "content": "Data from workspace1 - AI Research", + "type": "entity", + }, + "entity2": { + "content": "Data from workspace1 - Machine Learning", + "type": "entity", + }, + } + ) + print(" Written to storage1: entity1, entity2") + # Persist data to disk + await storage1.index_done_callback() + print(" Persisted storage1 data to disk") + + # Write to storage2 + await storage2.upsert( + { + "entity1": { + "content": "Data from workspace2 - Deep Learning", + "type": "entity", + }, + "entity2": { + "content": "Data from workspace2 - Neural Networks", + "type": "entity", + }, + } + ) + print(" Written to storage2: entity1, entity2") + # Persist data to disk + await storage2.index_done_callback() + print(" Persisted storage2 data to disk") + + # Test 10.3: Read data from each storage and verify isolation + print("\nTest 10.3: Read data and verify isolation") + + # Read from storage1 + result1_entity1 = await storage1.get_by_id("entity1") + result1_entity2 = await storage1.get_by_id("entity2") + + # Read from storage2 + result2_entity1 = await storage2.get_by_id("entity1") + result2_entity2 = await storage2.get_by_id("entity2") + + print(f" Storage1 entity1: {result1_entity1}") + print(f" Storage1 entity2: {result1_entity2}") + print(f" Storage2 entity1: {result2_entity1}") + print(f" Storage2 entity2: {result2_entity2}") + + # Verify isolation (get_by_id returns dict) + assert result1_entity1 is not None, "Storage1 entity1 should not be None" + assert result1_entity2 is not None, "Storage1 entity2 should not be None" + assert result2_entity1 is not None, "Storage2 entity1 should not be None" + assert result2_entity2 is not None, "Storage2 entity2 should not be None" + assert ( + result1_entity1.get("content") == "Data from workspace1 - AI Research" + ), "Storage1 entity1 content mismatch" + assert ( + result1_entity2.get("content") == "Data from workspace1 - Machine Learning" + ), "Storage1 entity2 content mismatch" + assert ( + result2_entity1.get("content") == "Data from workspace2 - Deep Learning" + ), "Storage2 entity1 content mismatch" + assert ( + result2_entity2.get("content") == "Data from workspace2 - Neural Networks" + ), "Storage2 entity2 content mismatch" + assert result1_entity1.get("content") != result2_entity1.get( + "content" + ), "Storage1 and Storage2 entity1 should have different content" + assert result1_entity2.get("content") != result2_entity2.get( + "content" + ), "Storage1 and Storage2 entity2 should have different content" + + print("✅ PASSED: JsonKVStorage - Data Isolation") + print( + " Two storage instances correctly isolated: ws1 and ws2 have different data" + ) + + # Test 10.4: Verify file structure + print("\nTest 10.4: Verify file structure") + ws1_dir = Path(test_dir) / "workspace1" + ws2_dir = Path(test_dir) / "workspace2" + + ws1_exists = ws1_dir.exists() + ws2_exists = ws2_dir.exists() + + print(f" workspace1 directory exists: {ws1_exists}") + print(f" workspace2 directory exists: {ws2_exists}") + + assert ws1_exists, "workspace1 directory should exist" + assert ws2_exists, "workspace2 directory should exist" + + print("✅ PASSED: JsonKVStorage - File Structure") + print(f" Workspace directories correctly created: {ws1_dir} and {ws2_dir}") + + finally: + # Cleanup test directory (unless keep_test_artifacts is set) + if os.path.exists(test_dir) and not keep_test_artifacts: + shutil.rmtree(test_dir) + print(f"\n Cleaned up test directory: {test_dir}") + elif keep_test_artifacts: + print(f"\n Kept test directory for inspection: {test_dir}") + + +# ============================================================================= +# Test 11: LightRAG End-to-End Integration Test +# ============================================================================= + + +@pytest.mark.offline +@pytest.mark.asyncio +async def test_lightrag_end_to_end_workspace_isolation(keep_test_artifacts): + """ + End-to-end test: Create two LightRAG instances with different workspaces, + insert different data, and verify file separation. + Uses mock LLM and embedding functions to avoid external API calls. + """ + # Purpose: Validate that full LightRAG flows keep artifacts scoped per workspace. + # Scope: LightRAG.initialize_storages + ainsert side effects plus filesystem + # verification for generated storage files. + print("\n" + "=" * 60) + print("TEST 11: LightRAG End-to-End Workspace Isolation") + print("=" * 60) + + # Create temporary test directory under project temp/ + test_dir = str( + Path(__file__).parent.parent + / "temp/test_lightrag_end_to_end_workspace_isolation" + ) + if os.path.exists(test_dir): + shutil.rmtree(test_dir) + os.makedirs(test_dir, exist_ok=True) + print(f"\n Using test directory: {test_dir}") + + try: + # Factory function to create different mock LLM functions for each workspace + def create_mock_llm_func(workspace_name): + """Create a mock LLM function that returns different content based on workspace""" + + async def mock_llm_func( + prompt, system_prompt=None, history_messages=[], **kwargs + ) -> str: + # Add coroutine switching to simulate async I/O and allow concurrent execution + await asyncio.sleep(0) + + # Return different responses based on workspace + # Format: entity<|#|>entity_name<|#|>entity_type<|#|>entity_description + # Format: relation<|#|>source_entity<|#|>target_entity<|#|>keywords<|#|>description + if workspace_name == "project_a": + return """entity<|#|>Artificial Intelligence<|#|>concept<|#|>AI is a field of computer science focused on creating intelligent machines. +entity<|#|>Machine Learning<|#|>concept<|#|>Machine Learning is a subset of AI that enables systems to learn from data. +relation<|#|>Machine Learning<|#|>Artificial Intelligence<|#|>subset, related field<|#|>Machine Learning is a key component and subset of Artificial Intelligence. +<|COMPLETE|>""" + else: # project_b + return """entity<|#|>Deep Learning<|#|>concept<|#|>Deep Learning is a subset of machine learning using neural networks with multiple layers. +entity<|#|>Neural Networks<|#|>concept<|#|>Neural Networks are computing systems inspired by biological neural networks. +relation<|#|>Deep Learning<|#|>Neural Networks<|#|>uses, composed of<|#|>Deep Learning uses multiple layers of Neural Networks to learn representations. +<|COMPLETE|>""" + + return mock_llm_func + + # Mock embedding function + async def mock_embedding_func(texts: list[str]) -> np.ndarray: + # Add coroutine switching to simulate async I/O and allow concurrent execution + await asyncio.sleep(0) + return np.random.rand(len(texts), 384) # 384-dimensional vectors + + # Test 11.1: Create two LightRAG instances with different workspaces + print("\nTest 11.1: Create two LightRAG instances with different workspaces") + + from lightrag import LightRAG + from lightrag.utils import EmbeddingFunc, Tokenizer + + # Create different mock LLM functions for each workspace + mock_llm_func_a = create_mock_llm_func("project_a") + mock_llm_func_b = create_mock_llm_func("project_b") + + class _SimpleTokenizerImpl: + def encode(self, content: str) -> list[int]: + return [ord(ch) for ch in content] + + def decode(self, tokens: list[int]) -> str: + return "".join(chr(t) for t in tokens) + + tokenizer = Tokenizer("mock-tokenizer", _SimpleTokenizerImpl()) + + rag1 = LightRAG( + working_dir=test_dir, + workspace="project_a", + llm_model_func=mock_llm_func_a, + embedding_func=EmbeddingFunc( + embedding_dim=384, + max_token_size=8192, + func=mock_embedding_func, + ), + tokenizer=tokenizer, + ) + + rag2 = LightRAG( + working_dir=test_dir, + workspace="project_b", + llm_model_func=mock_llm_func_b, + embedding_func=EmbeddingFunc( + embedding_dim=384, + max_token_size=8192, + func=mock_embedding_func, + ), + tokenizer=tokenizer, + ) + + # Initialize storages + await rag1.initialize_storages() + await rag2.initialize_storages() + + print(" RAG1 created: workspace=project_a") + print(" RAG2 created: workspace=project_b") + + # Test 11.2: Insert different data to each RAG instance (CONCURRENTLY) + print("\nTest 11.2: Insert different data to each RAG instance (concurrently)") + + text_for_project_a = "This document is about Artificial Intelligence and Machine Learning. AI is transforming the world." + text_for_project_b = "This document is about Deep Learning and Neural Networks. Deep learning uses multiple layers." + + # Insert to both projects concurrently to test workspace isolation under concurrent load + print(" Starting concurrent insert operations...") + start_time = time.time() + await asyncio.gather( + rag1.ainsert(text_for_project_a), rag2.ainsert(text_for_project_b) + ) + elapsed_time = time.time() - start_time + + print(f" Inserted to project_a: {len(text_for_project_a)} chars (concurrent)") + print(f" Inserted to project_b: {len(text_for_project_b)} chars (concurrent)") + print(f" Total concurrent execution time: {elapsed_time:.3f}s") + + # Test 11.3: Verify file structure + print("\nTest 11.3: Verify workspace directory structure") + + project_a_dir = Path(test_dir) / "project_a" + project_b_dir = Path(test_dir) / "project_b" + + project_a_exists = project_a_dir.exists() + project_b_exists = project_b_dir.exists() + + print(f" project_a directory: {project_a_dir}") + print(f" project_a exists: {project_a_exists}") + print(f" project_b directory: {project_b_dir}") + print(f" project_b exists: {project_b_exists}") + + assert project_a_exists, "project_a directory should exist" + assert project_b_exists, "project_b directory should exist" + + # List files in each directory + print("\n Files in project_a/:") + for file in sorted(project_a_dir.glob("*")): + if file.is_file(): + size = file.stat().st_size + print(f" - {file.name} ({size} bytes)") + + print("\n Files in project_b/:") + for file in sorted(project_b_dir.glob("*")): + if file.is_file(): + size = file.stat().st_size + print(f" - {file.name} ({size} bytes)") + + print("✅ PASSED: LightRAG E2E - File Structure") + print(" Workspace directories correctly created and separated") + + # Test 11.4: Verify data isolation by checking file contents + print("\nTest 11.4: Verify data isolation (check file contents)") + + # Check if full_docs storage files exist and contain different content + docs_a_file = project_a_dir / "kv_store_full_docs.json" + docs_b_file = project_b_dir / "kv_store_full_docs.json" + + if docs_a_file.exists() and docs_b_file.exists(): + import json + + with open(docs_a_file, "r") as f: + docs_a_content = json.load(f) + + with open(docs_b_file, "r") as f: + docs_b_content = json.load(f) + + print(f" project_a doc count: {len(docs_a_content)}") + print(f" project_b doc count: {len(docs_b_content)}") + + # Verify they contain different data + assert ( + docs_a_content != docs_b_content + ), "Document storage not properly isolated" + + # Verify each workspace contains its own text content + docs_a_str = json.dumps(docs_a_content) + docs_b_str = json.dumps(docs_b_content) + + # Check project_a contains its text and NOT project_b's text + assert ( + "Artificial Intelligence" in docs_a_str + ), "project_a should contain 'Artificial Intelligence'" + assert ( + "Machine Learning" in docs_a_str + ), "project_a should contain 'Machine Learning'" + assert ( + "Deep Learning" not in docs_a_str + ), "project_a should NOT contain 'Deep Learning' from project_b" + assert ( + "Neural Networks" not in docs_a_str + ), "project_a should NOT contain 'Neural Networks' from project_b" + + # Check project_b contains its text and NOT project_a's text + assert ( + "Deep Learning" in docs_b_str + ), "project_b should contain 'Deep Learning'" + assert ( + "Neural Networks" in docs_b_str + ), "project_b should contain 'Neural Networks'" + assert ( + "Artificial Intelligence" not in docs_b_str + ), "project_b should NOT contain 'Artificial Intelligence' from project_a" + # Note: "Machine Learning" might appear in project_b's text, so we skip that check + + print("✅ PASSED: LightRAG E2E - Data Isolation") + print(" Document storage correctly isolated between workspaces") + print(" project_a contains only its own data") + print(" project_b contains only its own data") + else: + print(" Document storage files not found (may not be created yet)") + print("✅ PASSED: LightRAG E2E - Data Isolation") + print(" Skipped file content check (files not created)") + + print("\n ✓ Test complete - workspace isolation verified at E2E level") + + finally: + # Cleanup test directory (unless keep_test_artifacts is set) + if os.path.exists(test_dir) and not keep_test_artifacts: + shutil.rmtree(test_dir) + print(f"\n Cleaned up test directory: {test_dir}") + elif keep_test_artifacts: + print(f"\n Kept test directory for inspection: {test_dir}") diff --git a/tests/test_write_json_optimization.py b/tests/test_write_json_optimization.py index 0a92904f..32dcfb5e 100644 --- a/tests/test_write_json_optimization.py +++ b/tests/test_write_json_optimization.py @@ -11,9 +11,11 @@ This test verifies: import os import json import tempfile +import pytest from lightrag.utils import write_json, load_json, SanitizingJSONEncoder +@pytest.mark.offline class TestWriteJsonOptimization: """Test write_json optimization with two-stage approach""" diff --git a/uv.lock b/uv.lock index 6408bd92..97703af0 100644 --- a/uv.lock +++ b/uv.lock @@ -2611,6 +2611,7 @@ docling = [ evaluation = [ { name = "datasets" }, { name = "httpx" }, + { name = "pre-commit" }, { name = "pytest" }, { name = "pytest-asyncio" }, { name = "ragas" }, @@ -2695,6 +2696,11 @@ offline-storage = [ { name = "qdrant-client", version = "1.15.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.13'" }, { name = "redis" }, ] +pytest = [ + { name = "pre-commit" }, + { name = "pytest" }, + { name = "pytest-asyncio" }, +] [package.metadata] requires-dist = [ @@ -2729,6 +2735,7 @@ requires-dist = [ { name = "json-repair", marker = "extra == 'api'" }, { name = "langfuse", marker = "extra == 'observability'", specifier = ">=3.8.1" }, { name = "lightrag-hku", extras = ["api", "offline-llm", "offline-storage"], marker = "extra == 'offline'" }, + { name = "lightrag-hku", extras = ["pytest"], marker = "extra == 'evaluation'" }, { name = "llama-index", marker = "extra == 'offline-llm'", specifier = ">=0.9.0,<1.0.0" }, { name = "nano-vectordb" }, { name = "nano-vectordb", marker = "extra == 'api'" }, @@ -2746,6 +2753,7 @@ requires-dist = [ { name = "passlib", extras = ["bcrypt"], marker = "extra == 'api'" }, { name = "pipmaster" }, { name = "pipmaster", marker = "extra == 'api'" }, + { name = "pre-commit", marker = "extra == 'pytest'" }, { name = "psutil", marker = "extra == 'api'" }, { name = "pycryptodome", marker = "extra == 'api'", specifier = ">=3.0.0,<4.0.0" }, { name = "pydantic" }, @@ -2756,8 +2764,8 @@ requires-dist = [ { name = "pypdf", marker = "extra == 'api'", specifier = ">=6.1.0" }, { name = "pypinyin" }, { name = "pypinyin", marker = "extra == 'api'" }, - { name = "pytest", marker = "extra == 'evaluation'", specifier = ">=8.4.2" }, - { name = "pytest-asyncio", marker = "extra == 'evaluation'", specifier = ">=1.2.0" }, + { name = "pytest", marker = "extra == 'pytest'", specifier = ">=8.4.2" }, + { name = "pytest-asyncio", marker = "extra == 'pytest'", specifier = ">=1.2.0" }, { name = "python-docx", marker = "extra == 'api'", specifier = ">=0.8.11,<2.0.0" }, { name = "python-dotenv" }, { name = "python-dotenv", marker = "extra == 'api'" }, @@ -2780,7 +2788,7 @@ requires-dist = [ { name = "xlsxwriter", marker = "extra == 'api'", specifier = ">=3.1.0" }, { name = "zhipuai", marker = "extra == 'offline-llm'", specifier = ">=2.0.0,<3.0.0" }, ] -provides-extras = ["api", "docling", "offline-storage", "offline-llm", "offline", "evaluation", "observability"] +provides-extras = ["pytest", "api", "docling", "offline-storage", "offline-llm", "offline", "evaluation", "observability"] [[package]] name = "llama-cloud"