Refactor workspace handling to use default workspace and namespace locks

- Remove DB-specific workspace configs
- Add default workspace auto-setting
- Replace global locks with namespace locks
- Simplify pipeline status management
- Remove redundant graph DB locking

(cherry picked from commit 926960e957)
This commit is contained in:
yangdx 2025-11-17 02:32:00 +08:00 committed by Raphaël MANSUY
parent c01cfc3649
commit 94ae13a037
16 changed files with 591 additions and 449 deletions

View file

@ -29,7 +29,7 @@ WEBUI_DESCRIPTION="Simple and Fast Graph Based RAG System"
# OLLAMA_EMULATING_MODEL_NAME=lightrag # OLLAMA_EMULATING_MODEL_NAME=lightrag
OLLAMA_EMULATING_MODEL_TAG=latest OLLAMA_EMULATING_MODEL_TAG=latest
### Max nodes return from graph retrieval in webui ### Max nodes for graph retrieval (Ensure WebUI local settings are also updated, which is limited to this value)
# MAX_GRAPH_NODES=1000 # MAX_GRAPH_NODES=1000
### Logging level ### Logging level
@ -50,6 +50,8 @@ OLLAMA_EMULATING_MODEL_TAG=latest
# JWT_ALGORITHM=HS256 # JWT_ALGORITHM=HS256
### API-Key to access LightRAG Server API ### API-Key to access LightRAG Server API
### Use this key in HTTP requests with the 'X-API-Key' header
### Example: curl -H "X-API-Key: your-secure-api-key-here" http://localhost:9621/query
# LIGHTRAG_API_KEY=your-secure-api-key-here # LIGHTRAG_API_KEY=your-secure-api-key-here
# WHITELIST_PATHS=/health,/api/* # WHITELIST_PATHS=/health,/api/*
@ -119,6 +121,9 @@ ENABLE_LLM_CACHE_FOR_EXTRACT=true
### Document processing output language: English, Chinese, French, German ... ### Document processing output language: English, Chinese, French, German ...
SUMMARY_LANGUAGE=English SUMMARY_LANGUAGE=English
### PDF decryption password for protected PDF files
# PDF_DECRYPT_PASSWORD=your_pdf_password_here
### Entity types that the LLM will attempt to recognize ### Entity types that the LLM will attempt to recognize
# ENTITY_TYPES='["Person", "Creature", "Organization", "Location", "Event", "Concept", "Method", "Content", "Data", "Artifact", "NaturalObject"]' # ENTITY_TYPES='["Person", "Creature", "Organization", "Location", "Event", "Concept", "Method", "Content", "Data", "Artifact", "NaturalObject"]'
@ -138,10 +143,13 @@ SUMMARY_LANGUAGE=English
### control the maximum chunk_ids stored in vector and graph db ### control the maximum chunk_ids stored in vector and graph db
# MAX_SOURCE_IDS_PER_ENTITY=300 # MAX_SOURCE_IDS_PER_ENTITY=300
# MAX_SOURCE_IDS_PER_RELATION=300 # MAX_SOURCE_IDS_PER_RELATION=300
### control chunk_ids limitation method: KEEP, FIFO (KEEP: Keep oldest, FIFO: First in first out) ### control chunk_ids limitation method: FIFO, KEEP
# SOURCE_IDS_LIMIT_METHOD=KEEP ### FIFO: First in first out
### Maximum number of file paths stored in entity/relation file_path field ### KEEP: Keep oldest (less merge action and faster)
# MAX_FILE_PATHS=30 # SOURCE_IDS_LIMIT_METHOD=FIFO
# Maximum number of file paths stored in entity/relation file_path field (For displayed only, does not affect query performance)
# MAX_FILE_PATHS=100
### maximum number of related chunks per source entity or relation ### maximum number of related chunks per source entity or relation
### The chunk picker uses this value to determine the total number of chunks selected from KG(knowledge graph) ### The chunk picker uses this value to determine the total number of chunks selected from KG(knowledge graph)
@ -160,10 +168,13 @@ MAX_PARALLEL_INSERT=2
### Num of chunks send to Embedding in single request ### Num of chunks send to Embedding in single request
# EMBEDDING_BATCH_NUM=10 # EMBEDDING_BATCH_NUM=10
########################################################### ###########################################################################
### LLM Configuration ### LLM Configuration
### LLM_BINDING type: openai, ollama, lollms, azure_openai, aws_bedrock ### LLM_BINDING type: openai, ollama, lollms, azure_openai, aws_bedrock, gemini
########################################################### ### LLM_BINDING_HOST: host only for Ollama, endpoint for other LLM service
### If LightRAG deployed in Docker:
### uses host.docker.internal instead of localhost in LLM_BINDING_HOST
###########################################################################
### LLM request timeout setting for all llm (0 means no timeout for Ollma) ### LLM request timeout setting for all llm (0 means no timeout for Ollma)
# LLM_TIMEOUT=180 # LLM_TIMEOUT=180
@ -172,7 +183,7 @@ LLM_MODEL=gpt-4o
LLM_BINDING_HOST=https://api.openai.com/v1 LLM_BINDING_HOST=https://api.openai.com/v1
LLM_BINDING_API_KEY=your_api_key LLM_BINDING_API_KEY=your_api_key
### Optional for Azure ### Env vars for Azure openai
# AZURE_OPENAI_API_VERSION=2024-08-01-preview # AZURE_OPENAI_API_VERSION=2024-08-01-preview
# AZURE_OPENAI_DEPLOYMENT=gpt-4o # AZURE_OPENAI_DEPLOYMENT=gpt-4o
@ -182,18 +193,21 @@ LLM_BINDING_API_KEY=your_api_key
# LLM_BINDING_API_KEY=your_api_key # LLM_BINDING_API_KEY=your_api_key
# LLM_BINDING=openai # LLM_BINDING=openai
### OpenAI Compatible API Specific Parameters ### Gemini example
### Increased temperature values may mitigate infinite inference loops in certain LLM, such as Qwen3-30B. # LLM_BINDING=gemini
# OPENAI_LLM_TEMPERATURE=0.9 # LLM_MODEL=gemini-flash-latest
### Set the max_tokens to mitigate endless output of some LLM (less than LLM_TIMEOUT * llm_output_tokens/second, i.e. 9000 = 180s * 50 tokens/s) # LLM_BINDING_API_KEY=your_gemini_api_key
### Typically, max_tokens does not include prompt content, though some models, such as Gemini Models, are exceptions # LLM_BINDING_HOST=https://generativelanguage.googleapis.com
### For vLLM/SGLang deployed models, or most of OpenAI compatible API provider
# OPENAI_LLM_MAX_TOKENS=9000
### For OpenAI o1-mini or newer modles
OPENAI_LLM_MAX_COMPLETION_TOKENS=9000
#### OpenAI's new API utilizes max_completion_tokens instead of max_tokens ### use the following command to see all support options for OpenAI, azure_openai or OpenRouter
# OPENAI_LLM_MAX_COMPLETION_TOKENS=9000 ### lightrag-server --llm-binding gemini --help
### Gemini Specific Parameters
# GEMINI_LLM_MAX_OUTPUT_TOKENS=9000
# GEMINI_LLM_TEMPERATURE=0.7
### Enable Thinking
# GEMINI_LLM_THINKING_CONFIG='{"thinking_budget": -1, "include_thoughts": true}'
### Disable Thinking
# GEMINI_LLM_THINKING_CONFIG='{"thinking_budget": 0, "include_thoughts": false}'
### use the following command to see all support options for OpenAI, azure_openai or OpenRouter ### use the following command to see all support options for OpenAI, azure_openai or OpenRouter
### lightrag-server --llm-binding openai --help ### lightrag-server --llm-binding openai --help
@ -204,6 +218,16 @@ OPENAI_LLM_MAX_COMPLETION_TOKENS=9000
### Qwen3 Specific Parameters deploy by vLLM ### Qwen3 Specific Parameters deploy by vLLM
# OPENAI_LLM_EXTRA_BODY='{"chat_template_kwargs": {"enable_thinking": false}}' # OPENAI_LLM_EXTRA_BODY='{"chat_template_kwargs": {"enable_thinking": false}}'
### OpenAI Compatible API Specific Parameters
### Increased temperature values may mitigate infinite inference loops in certain LLM, such as Qwen3-30B.
# OPENAI_LLM_TEMPERATURE=0.9
### Set the max_tokens to mitigate endless output of some LLM (less than LLM_TIMEOUT * llm_output_tokens/second, i.e. 9000 = 180s * 50 tokens/s)
### Typically, max_tokens does not include prompt content
### For vLLM/SGLang deployed models, or most of OpenAI compatible API provider
# OPENAI_LLM_MAX_TOKENS=9000
### For OpenAI o1-mini or newer modles utilizes max_completion_tokens instead of max_tokens
OPENAI_LLM_MAX_COMPLETION_TOKENS=9000
### use the following command to see all support options for Ollama LLM ### use the following command to see all support options for Ollama LLM
### lightrag-server --llm-binding ollama --help ### lightrag-server --llm-binding ollama --help
### Ollama Server Specific Parameters ### Ollama Server Specific Parameters
@ -217,24 +241,37 @@ OLLAMA_LLM_NUM_CTX=32768
### Bedrock Specific Parameters ### Bedrock Specific Parameters
# BEDROCK_LLM_TEMPERATURE=1.0 # BEDROCK_LLM_TEMPERATURE=1.0
#################################################################################### #######################################################################################
### Embedding Configuration (Should not be changed after the first file processed) ### Embedding Configuration (Should not be changed after the first file processed)
### EMBEDDING_BINDING: ollama, openai, azure_openai, jina, lollms, aws_bedrock ### EMBEDDING_BINDING: ollama, openai, azure_openai, jina, lollms, aws_bedrock
#################################################################################### ### EMBEDDING_BINDING_HOST: host only for Ollama, endpoint for other Embedding service
### If LightRAG deployed in Docker:
### uses host.docker.internal instead of localhost in EMBEDDING_BINDING_HOST
#######################################################################################
# EMBEDDING_TIMEOUT=30 # EMBEDDING_TIMEOUT=30
EMBEDDING_BINDING=ollama
EMBEDDING_MODEL=bge-m3:latest
EMBEDDING_DIM=1024
EMBEDDING_BINDING_API_KEY=your_api_key
# If the embedding service is deployed within the same Docker stack, use host.docker.internal instead of localhost
EMBEDDING_BINDING_HOST=http://localhost:11434
### OpenAI compatible (VoyageAI embedding openai compatible) ### Control whether to send embedding_dim parameter to embedding API
# EMBEDDING_BINDING=openai ### IMPORTANT: Jina ALWAYS sends dimension parameter (API requirement) - this setting is ignored for Jina
# EMBEDDING_MODEL=text-embedding-3-large ### For OpenAI: Set to 'true' to enable dynamic dimension adjustment
# EMBEDDING_DIM=3072 ### For OpenAI: Set to 'false' (default) to disable sending dimension parameter
# EMBEDDING_BINDING_HOST=https://api.openai.com/v1 ### Note: Automatically ignored for backends that don't support dimension parameter (e.g., Ollama)
# Ollama embedding
# EMBEDDING_BINDING=ollama
# EMBEDDING_MODEL=bge-m3:latest
# EMBEDDING_DIM=1024
# EMBEDDING_BINDING_API_KEY=your_api_key # EMBEDDING_BINDING_API_KEY=your_api_key
### If LightRAG deployed in Docker uses host.docker.internal instead of localhost
# EMBEDDING_BINDING_HOST=http://localhost:11434
### OpenAI compatible embedding
EMBEDDING_BINDING=openai
EMBEDDING_MODEL=text-embedding-3-large
EMBEDDING_DIM=3072
EMBEDDING_SEND_DIM=false
EMBEDDING_TOKEN_LIMIT=8192
EMBEDDING_BINDING_HOST=https://api.openai.com/v1
EMBEDDING_BINDING_API_KEY=your_api_key
### Optional for Azure ### Optional for Azure
# AZURE_EMBEDDING_DEPLOYMENT=text-embedding-3-large # AZURE_EMBEDDING_DEPLOYMENT=text-embedding-3-large
@ -242,6 +279,16 @@ EMBEDDING_BINDING_HOST=http://localhost:11434
# AZURE_EMBEDDING_ENDPOINT=your_endpoint # AZURE_EMBEDDING_ENDPOINT=your_endpoint
# AZURE_EMBEDDING_API_KEY=your_api_key # AZURE_EMBEDDING_API_KEY=your_api_key
### Gemini embedding
# EMBEDDING_BINDING=gemini
# EMBEDDING_MODEL=gemini-embedding-001
# EMBEDDING_DIM=1536
# EMBEDDING_TOKEN_LIMIT=2048
# EMBEDDING_BINDING_HOST=https://generativelanguage.googleapis.com
# EMBEDDING_BINDING_API_KEY=your_api_key
### Gemini embedding requires sending dimension to server
# EMBEDDING_SEND_DIM=true
### Jina AI Embedding ### Jina AI Embedding
# EMBEDDING_BINDING=jina # EMBEDDING_BINDING=jina
# EMBEDDING_BINDING_HOST=https://api.jina.ai/v1/embeddings # EMBEDDING_BINDING_HOST=https://api.jina.ai/v1/embeddings
@ -302,17 +349,15 @@ POSTGRES_USER=your_username
POSTGRES_PASSWORD='your_password' POSTGRES_PASSWORD='your_password'
POSTGRES_DATABASE=your_database POSTGRES_DATABASE=your_database
POSTGRES_MAX_CONNECTIONS=12 POSTGRES_MAX_CONNECTIONS=12
# POSTGRES_WORKSPACE=forced_workspace_name ### DB specific workspace should not be set, keep for compatible only
### POSTGRES_WORKSPACE=forced_workspace_name
### PostgreSQL Vector Storage Configuration ### PostgreSQL Vector Storage Configuration
### Vector storage type: HNSW, IVFFlat, VCHORDRQ ### Vector storage type: HNSW, IVFFlat
POSTGRES_VECTOR_INDEX_TYPE=HNSW POSTGRES_VECTOR_INDEX_TYPE=HNSW
POSTGRES_HNSW_M=16 POSTGRES_HNSW_M=16
POSTGRES_HNSW_EF=200 POSTGRES_HNSW_EF=200
POSTGRES_IVFFLAT_LISTS=100 POSTGRES_IVFFLAT_LISTS=100
POSTGRES_VCHORDRQ_BUILD_OPTIONS=
POSTGRES_VCHORDRQ_PROBES=
POSTGRES_VCHORDRQ_EPSILON=1.9
### PostgreSQL Connection Retry Configuration (Network Robustness) ### PostgreSQL Connection Retry Configuration (Network Robustness)
### Number of retry attempts (1-10, default: 3) ### Number of retry attempts (1-10, default: 3)
@ -351,7 +396,8 @@ NEO4J_MAX_TRANSACTION_RETRY_TIME=30
NEO4J_MAX_CONNECTION_LIFETIME=300 NEO4J_MAX_CONNECTION_LIFETIME=300
NEO4J_LIVENESS_CHECK_TIMEOUT=30 NEO4J_LIVENESS_CHECK_TIMEOUT=30
NEO4J_KEEP_ALIVE=true NEO4J_KEEP_ALIVE=true
# NEO4J_WORKSPACE=forced_workspace_name ### DB specific workspace should not be set, keep for compatible only
### NEO4J_WORKSPACE=forced_workspace_name
### MongoDB Configuration ### MongoDB Configuration
MONGO_URI=mongodb://root:root@localhost:27017/ MONGO_URI=mongodb://root:root@localhost:27017/
@ -365,12 +411,14 @@ MILVUS_DB_NAME=lightrag
# MILVUS_USER=root # MILVUS_USER=root
# MILVUS_PASSWORD=your_password # MILVUS_PASSWORD=your_password
# MILVUS_TOKEN=your_token # MILVUS_TOKEN=your_token
# MILVUS_WORKSPACE=forced_workspace_name ### DB specific workspace should not be set, keep for compatible only
### MILVUS_WORKSPACE=forced_workspace_name
### Qdrant ### Qdrant
QDRANT_URL=http://localhost:6333 QDRANT_URL=http://localhost:6333
# QDRANT_API_KEY=your-api-key # QDRANT_API_KEY=your-api-key
# QDRANT_WORKSPACE=forced_workspace_name ### DB specific workspace should not be set, keep for compatible only
### QDRANT_WORKSPACE=forced_workspace_name
### Redis ### Redis
REDIS_URI=redis://localhost:6379 REDIS_URI=redis://localhost:6379
@ -378,11 +426,45 @@ REDIS_SOCKET_TIMEOUT=30
REDIS_CONNECT_TIMEOUT=10 REDIS_CONNECT_TIMEOUT=10
REDIS_MAX_CONNECTIONS=100 REDIS_MAX_CONNECTIONS=100
REDIS_RETRY_ATTEMPTS=3 REDIS_RETRY_ATTEMPTS=3
# REDIS_WORKSPACE=forced_workspace_name ### DB specific workspace should not be set, keep for compatible only
### REDIS_WORKSPACE=forced_workspace_name
### Memgraph Configuration ### Memgraph Configuration
MEMGRAPH_URI=bolt://localhost:7687 MEMGRAPH_URI=bolt://localhost:7687
MEMGRAPH_USERNAME= MEMGRAPH_USERNAME=
MEMGRAPH_PASSWORD= MEMGRAPH_PASSWORD=
MEMGRAPH_DATABASE=memgraph MEMGRAPH_DATABASE=memgraph
# MEMGRAPH_WORKSPACE=forced_workspace_name ### DB specific workspace should not be set, keep for compatible only
### MEMGRAPH_WORKSPACE=forced_workspace_name
############################
### Evaluation Configuration
############################
### RAGAS evaluation models (used for RAG quality assessment)
### ⚠️ IMPORTANT: Both LLM and Embedding endpoints MUST be OpenAI-compatible
### Default uses OpenAI models for evaluation
### LLM Configuration for Evaluation
# EVAL_LLM_MODEL=gpt-4o-mini
### API key for LLM evaluation (fallback to OPENAI_API_KEY if not set)
# EVAL_LLM_BINDING_API_KEY=your_api_key
### Custom OpenAI-compatible endpoint for LLM evaluation (optional)
# EVAL_LLM_BINDING_HOST=https://api.openai.com/v1
### Embedding Configuration for Evaluation
# EVAL_EMBEDDING_MODEL=text-embedding-3-large
### API key for embeddings (fallback: EVAL_LLM_BINDING_API_KEY -> OPENAI_API_KEY)
# EVAL_EMBEDDING_BINDING_API_KEY=your_embedding_api_key
### Custom OpenAI-compatible endpoint for embeddings (fallback: EVAL_LLM_BINDING_HOST)
# EVAL_EMBEDDING_BINDING_HOST=https://api.openai.com/v1
### Performance Tuning
### Number of concurrent test case evaluations
### Lower values reduce API rate limit issues but increase evaluation time
# EVAL_MAX_CONCURRENT=2
### TOP_K query parameter of LightRAG (default: 10)
### Number of entities or relations retrieved from KG
# EVAL_QUERY_TOP_K=10
### LLM request retry and timeout settings for evaluation
# EVAL_LLM_MAX_RETRIES=5
# EVAL_LLM_TIMEOUT=180

View file

@ -455,7 +455,7 @@ def create_app(args):
# Create combined auth dependency for all endpoints # Create combined auth dependency for all endpoints
combined_auth = get_combined_auth_dependency(api_key) combined_auth = get_combined_auth_dependency(api_key)
def get_workspace_from_request(request: Request) -> str | None: def get_workspace_from_request(request: Request) -> str:
""" """
Extract workspace from HTTP request header or use default. Extract workspace from HTTP request header or use default.
@ -472,8 +472,9 @@ def create_app(args):
# Check custom header first # Check custom header first
workspace = request.headers.get("LIGHTRAG-WORKSPACE", "").strip() workspace = request.headers.get("LIGHTRAG-WORKSPACE", "").strip()
# Fall back to server default if header not provided
if not workspace: if not workspace:
workspace = None workspace = args.workspace
return workspace return workspace
@ -1141,13 +1142,8 @@ def create_app(args):
async def get_status(request: Request): async def get_status(request: Request):
"""Get current system status""" """Get current system status"""
try: try:
workspace = get_workspace_from_request(request)
default_workspace = get_default_workspace() default_workspace = get_default_workspace()
if workspace is None: pipeline_status = await get_namespace_data("pipeline_status")
workspace = default_workspace
pipeline_status = await get_namespace_data(
"pipeline_status", workspace=workspace
)
if not auth_configured: if not auth_configured:
auth_mode = "disabled" auth_mode = "disabled"

View file

@ -1644,12 +1644,8 @@ async def background_delete_documents(
get_namespace_lock, get_namespace_lock,
) )
pipeline_status = await get_namespace_data( pipeline_status = await get_namespace_data("pipeline_status")
"pipeline_status", workspace=rag.workspace pipeline_status_lock = get_namespace_lock("pipeline_status")
)
pipeline_status_lock = get_namespace_lock(
"pipeline_status", workspace=rag.workspace
)
total_docs = len(doc_ids) total_docs = len(doc_ids)
successful_deletions = [] successful_deletions = []
@ -2142,12 +2138,8 @@ def create_document_routes(
) )
# Get pipeline status and lock # Get pipeline status and lock
pipeline_status = await get_namespace_data( pipeline_status = await get_namespace_data("pipeline_status")
"pipeline_status", workspace=rag.workspace pipeline_status_lock = get_namespace_lock("pipeline_status")
)
pipeline_status_lock = get_namespace_lock(
"pipeline_status", workspace=rag.workspace
)
# Check and set status with lock # Check and set status with lock
async with pipeline_status_lock: async with pipeline_status_lock:
@ -2342,12 +2334,8 @@ def create_document_routes(
get_all_update_flags_status, get_all_update_flags_status,
) )
pipeline_status = await get_namespace_data( pipeline_status = await get_namespace_data("pipeline_status")
"pipeline_status", workspace=rag.workspace pipeline_status_lock = get_namespace_lock("pipeline_status")
)
pipeline_status_lock = get_namespace_lock(
"pipeline_status", workspace=rag.workspace
)
# Get update flags status for all namespaces # Get update flags status for all namespaces
update_status = await get_all_update_flags_status() update_status = await get_all_update_flags_status()
@ -2558,12 +2546,8 @@ def create_document_routes(
get_namespace_lock, get_namespace_lock,
) )
pipeline_status = await get_namespace_data( pipeline_status = await get_namespace_data("pipeline_status")
"pipeline_status", workspace=rag.workspace pipeline_status_lock = get_namespace_lock("pipeline_status")
)
pipeline_status_lock = get_namespace_lock(
"pipeline_status", workspace=rag.workspace
)
# Check if pipeline is busy with proper lock # Check if pipeline is busy with proper lock
async with pipeline_status_lock: async with pipeline_status_lock:
@ -2971,12 +2955,8 @@ def create_document_routes(
get_namespace_lock, get_namespace_lock,
) )
pipeline_status = await get_namespace_data( pipeline_status = await get_namespace_data("pipeline_status")
"pipeline_status", workspace=rag.workspace pipeline_status_lock = get_namespace_lock("pipeline_status")
)
pipeline_status_lock = get_namespace_lock(
"pipeline_status", workspace=rag.workspace
)
async with pipeline_status_lock: async with pipeline_status_lock:
if not pipeline_status.get("busy", False): if not pipeline_status.get("busy", False):

View file

@ -38,9 +38,9 @@ class JsonDocStatusStorage(DocStatusStorage):
self.final_namespace = f"{self.workspace}_{self.namespace}" self.final_namespace = f"{self.workspace}_{self.namespace}"
else: else:
# Default behavior when workspace is empty # Default behavior when workspace is empty
workspace_dir = working_dir
self.final_namespace = self.namespace self.final_namespace = self.namespace
self.workspace = "" self.workspace = "_"
workspace_dir = working_dir
os.makedirs(workspace_dir, exist_ok=True) os.makedirs(workspace_dir, exist_ok=True)
self._file_name = os.path.join(workspace_dir, f"kv_store_{self.namespace}.json") self._file_name = os.path.join(workspace_dir, f"kv_store_{self.namespace}.json")
@ -51,18 +51,18 @@ class JsonDocStatusStorage(DocStatusStorage):
async def initialize(self): async def initialize(self):
"""Initialize storage data""" """Initialize storage data"""
self._storage_lock = get_namespace_lock( self._storage_lock = get_namespace_lock(
self.namespace, workspace=self.workspace self.final_namespace, workspace=self.workspace
) )
self.storage_updated = await get_update_flag( self.storage_updated = await get_update_flag(
self.namespace, workspace=self.workspace self.final_namespace, workspace=self.workspace
) )
async with get_data_init_lock(): async with get_data_init_lock():
# check need_init must before get_namespace_data # check need_init must before get_namespace_data
need_init = await try_initialize_namespace( need_init = await try_initialize_namespace(
self.namespace, workspace=self.workspace self.final_namespace, workspace=self.workspace
) )
self._data = await get_namespace_data( self._data = await get_namespace_data(
self.namespace, workspace=self.workspace self.final_namespace, workspace=self.workspace
) )
if need_init: if need_init:
loaded_data = load_json(self._file_name) or {} loaded_data = load_json(self._file_name) or {}
@ -183,7 +183,9 @@ class JsonDocStatusStorage(DocStatusStorage):
self._data.clear() self._data.clear()
self._data.update(cleaned_data) self._data.update(cleaned_data)
await clear_all_update_flags(self.namespace, workspace=self.workspace) await clear_all_update_flags(
self.final_namespace, workspace=self.workspace
)
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
""" """
@ -204,7 +206,7 @@ class JsonDocStatusStorage(DocStatusStorage):
if "chunks_list" not in doc_data: if "chunks_list" not in doc_data:
doc_data["chunks_list"] = [] doc_data["chunks_list"] = []
self._data.update(data) self._data.update(data)
await set_all_update_flags(self.namespace, workspace=self.workspace) await set_all_update_flags(self.final_namespace, workspace=self.workspace)
await self.index_done_callback() await self.index_done_callback()
@ -358,7 +360,9 @@ class JsonDocStatusStorage(DocStatusStorage):
any_deleted = True any_deleted = True
if any_deleted: if any_deleted:
await set_all_update_flags(self.namespace, workspace=self.workspace) await set_all_update_flags(
self.final_namespace, workspace=self.workspace
)
async def get_doc_by_file_path(self, file_path: str) -> Union[dict[str, Any], None]: async def get_doc_by_file_path(self, file_path: str) -> Union[dict[str, Any], None]:
"""Get document by file path """Get document by file path
@ -397,7 +401,9 @@ class JsonDocStatusStorage(DocStatusStorage):
try: try:
async with self._storage_lock: async with self._storage_lock:
self._data.clear() self._data.clear()
await set_all_update_flags(self.namespace, workspace=self.workspace) await set_all_update_flags(
self.final_namespace, workspace=self.workspace
)
await self.index_done_callback() await self.index_done_callback()
logger.info( logger.info(

View file

@ -35,7 +35,7 @@ class JsonKVStorage(BaseKVStorage):
# Default behavior when workspace is empty # Default behavior when workspace is empty
workspace_dir = working_dir workspace_dir = working_dir
self.final_namespace = self.namespace self.final_namespace = self.namespace
self.workspace = "" self.workspace = "_"
os.makedirs(workspace_dir, exist_ok=True) os.makedirs(workspace_dir, exist_ok=True)
self._file_name = os.path.join(workspace_dir, f"kv_store_{self.namespace}.json") self._file_name = os.path.join(workspace_dir, f"kv_store_{self.namespace}.json")
@ -47,18 +47,18 @@ class JsonKVStorage(BaseKVStorage):
async def initialize(self): async def initialize(self):
"""Initialize storage data""" """Initialize storage data"""
self._storage_lock = get_namespace_lock( self._storage_lock = get_namespace_lock(
self.namespace, workspace=self.workspace self.final_namespace, workspace=self.workspace
) )
self.storage_updated = await get_update_flag( self.storage_updated = await get_update_flag(
self.namespace, workspace=self.workspace self.final_namespace, workspace=self.workspace
) )
async with get_data_init_lock(): async with get_data_init_lock():
# check need_init must before get_namespace_data # check need_init must before get_namespace_data
need_init = await try_initialize_namespace( need_init = await try_initialize_namespace(
self.namespace, workspace=self.workspace self.final_namespace, workspace=self.workspace
) )
self._data = await get_namespace_data( self._data = await get_namespace_data(
self.namespace, workspace=self.workspace self.final_namespace, workspace=self.workspace
) )
if need_init: if need_init:
loaded_data = load_json(self._file_name) or {} loaded_data = load_json(self._file_name) or {}
@ -103,7 +103,9 @@ class JsonKVStorage(BaseKVStorage):
self._data.clear() self._data.clear()
self._data.update(cleaned_data) self._data.update(cleaned_data)
await clear_all_update_flags(self.namespace, workspace=self.workspace) await clear_all_update_flags(
self.final_namespace, workspace=self.workspace
)
async def get_by_id(self, id: str) -> dict[str, Any] | None: async def get_by_id(self, id: str) -> dict[str, Any] | None:
async with self._storage_lock: async with self._storage_lock:
@ -176,7 +178,7 @@ class JsonKVStorage(BaseKVStorage):
v["_id"] = k v["_id"] = k
self._data.update(data) self._data.update(data)
await set_all_update_flags(self.namespace, workspace=self.workspace) await set_all_update_flags(self.final_namespace, workspace=self.workspace)
async def delete(self, ids: list[str]) -> None: async def delete(self, ids: list[str]) -> None:
"""Delete specific records from storage by their IDs """Delete specific records from storage by their IDs
@ -199,7 +201,9 @@ class JsonKVStorage(BaseKVStorage):
any_deleted = True any_deleted = True
if any_deleted: if any_deleted:
await set_all_update_flags(self.namespace, workspace=self.workspace) await set_all_update_flags(
self.final_namespace, workspace=self.workspace
)
async def is_empty(self) -> bool: async def is_empty(self) -> bool:
"""Check if the storage is empty """Check if the storage is empty
@ -227,7 +231,9 @@ class JsonKVStorage(BaseKVStorage):
try: try:
async with self._storage_lock: async with self._storage_lock:
self._data.clear() self._data.clear()
await set_all_update_flags(self.namespace, workspace=self.workspace) await set_all_update_flags(
self.final_namespace, workspace=self.workspace
)
await self.index_done_callback() await self.index_done_callback()
logger.info( logger.info(

View file

@ -8,7 +8,7 @@ import configparser
from ..utils import logger from ..utils import logger
from ..base import BaseGraphStorage from ..base import BaseGraphStorage
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
from ..kg.shared_storage import get_data_init_lock, get_graph_db_lock from ..kg.shared_storage import get_data_init_lock
import pipmaster as pm import pipmaster as pm
if not pm.is_installed("neo4j"): if not pm.is_installed("neo4j"):
@ -101,10 +101,9 @@ class MemgraphStorage(BaseGraphStorage):
raise raise
async def finalize(self): async def finalize(self):
async with get_graph_db_lock(): if self._driver is not None:
if self._driver is not None: await self._driver.close()
await self._driver.close() self._driver = None
self._driver = None
async def __aexit__(self, exc_type, exc, tb): async def __aexit__(self, exc_type, exc, tb):
await self.finalize() await self.finalize()
@ -762,22 +761,21 @@ class MemgraphStorage(BaseGraphStorage):
raise RuntimeError( raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first." "Memgraph driver is not initialized. Call 'await initialize()' first."
) )
async with get_graph_db_lock(): try:
try: async with self._driver.session(database=self._DATABASE) as session:
async with self._driver.session(database=self._DATABASE) as session: workspace_label = self._get_workspace_label()
workspace_label = self._get_workspace_label() query = f"MATCH (n:`{workspace_label}`) DETACH DELETE n"
query = f"MATCH (n:`{workspace_label}`) DETACH DELETE n" result = await session.run(query)
result = await session.run(query) await result.consume()
await result.consume() logger.info(
logger.info( f"[{self.workspace}] Dropped workspace {workspace_label} from Memgraph database {self._DATABASE}"
f"[{self.workspace}] Dropped workspace {workspace_label} from Memgraph database {self._DATABASE}"
)
return {"status": "success", "message": "workspace data dropped"}
except Exception as e:
logger.error(
f"[{self.workspace}] Error dropping workspace {workspace_label} from Memgraph database {self._DATABASE}: {e}"
) )
return {"status": "error", "message": str(e)} return {"status": "success", "message": "workspace data dropped"}
except Exception as e:
logger.error(
f"[{self.workspace}] Error dropping workspace {workspace_label} from Memgraph database {self._DATABASE}: {e}"
)
return {"status": "error", "message": str(e)}
async def edge_degree(self, src_id: str, tgt_id: str) -> int: async def edge_degree(self, src_id: str, tgt_id: str) -> int:
"""Get the total degree (sum of relationships) of two nodes. """Get the total degree (sum of relationships) of two nodes.

View file

@ -6,7 +6,7 @@ import numpy as np
from lightrag.utils import logger, compute_mdhash_id from lightrag.utils import logger, compute_mdhash_id
from ..base import BaseVectorStorage from ..base import BaseVectorStorage
from ..constants import DEFAULT_MAX_FILE_PATH_LENGTH from ..constants import DEFAULT_MAX_FILE_PATH_LENGTH
from ..kg.shared_storage import get_data_init_lock, get_storage_lock from ..kg.shared_storage import get_data_init_lock
import pipmaster as pm import pipmaster as pm
if not pm.is_installed("pymilvus"): if not pm.is_installed("pymilvus"):
@ -1357,21 +1357,20 @@ class MilvusVectorDBStorage(BaseVectorStorage):
- On success: {"status": "success", "message": "data dropped"} - On success: {"status": "success", "message": "data dropped"}
- On failure: {"status": "error", "message": "<error details>"} - On failure: {"status": "error", "message": "<error details>"}
""" """
async with get_storage_lock(): try:
try: # Drop the collection and recreate it
# Drop the collection and recreate it if self._client.has_collection(self.final_namespace):
if self._client.has_collection(self.final_namespace): self._client.drop_collection(self.final_namespace)
self._client.drop_collection(self.final_namespace)
# Recreate the collection # Recreate the collection
self._create_collection_if_not_exist() self._create_collection_if_not_exist()
logger.info( logger.info(
f"[{self.workspace}] Process {os.getpid()} drop Milvus collection {self.namespace}" f"[{self.workspace}] Process {os.getpid()} drop Milvus collection {self.namespace}"
) )
return {"status": "success", "message": "data dropped"} return {"status": "success", "message": "data dropped"}
except Exception as e: except Exception as e:
logger.error( logger.error(
f"[{self.workspace}] Error dropping Milvus collection {self.namespace}: {e}" f"[{self.workspace}] Error dropping Milvus collection {self.namespace}: {e}"
) )
return {"status": "error", "message": str(e)} return {"status": "error", "message": str(e)}

View file

@ -19,7 +19,7 @@ from ..base import (
from ..utils import logger, compute_mdhash_id from ..utils import logger, compute_mdhash_id
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
from ..constants import GRAPH_FIELD_SEP from ..constants import GRAPH_FIELD_SEP
from ..kg.shared_storage import get_data_init_lock, get_storage_lock, get_graph_db_lock from ..kg.shared_storage import get_data_init_lock
import pipmaster as pm import pipmaster as pm
@ -138,11 +138,10 @@ class MongoKVStorage(BaseKVStorage):
) )
async def finalize(self): async def finalize(self):
async with get_storage_lock(): if self.db is not None:
if self.db is not None: await ClientManager.release_client(self.db)
await ClientManager.release_client(self.db) self.db = None
self.db = None self._data = None
self._data = None
async def get_by_id(self, id: str) -> dict[str, Any] | None: async def get_by_id(self, id: str) -> dict[str, Any] | None:
# Unified handling for flattened keys # Unified handling for flattened keys
@ -265,23 +264,22 @@ class MongoKVStorage(BaseKVStorage):
Returns: Returns:
dict[str, str]: Status of the operation with keys 'status' and 'message' dict[str, str]: Status of the operation with keys 'status' and 'message'
""" """
async with get_storage_lock(): try:
try: result = await self._data.delete_many({})
result = await self._data.delete_many({}) deleted_count = result.deleted_count
deleted_count = result.deleted_count
logger.info( logger.info(
f"[{self.workspace}] Dropped {deleted_count} documents from doc status {self._collection_name}" f"[{self.workspace}] Dropped {deleted_count} documents from doc status {self._collection_name}"
) )
return { return {
"status": "success", "status": "success",
"message": f"{deleted_count} documents dropped", "message": f"{deleted_count} documents dropped",
} }
except PyMongoError as e: except PyMongoError as e:
logger.error( logger.error(
f"[{self.workspace}] Error dropping doc status {self._collection_name}: {e}" f"[{self.workspace}] Error dropping doc status {self._collection_name}: {e}"
) )
return {"status": "error", "message": str(e)} return {"status": "error", "message": str(e)}
@final @final
@ -372,11 +370,10 @@ class MongoDocStatusStorage(DocStatusStorage):
) )
async def finalize(self): async def finalize(self):
async with get_storage_lock(): if self.db is not None:
if self.db is not None: await ClientManager.release_client(self.db)
await ClientManager.release_client(self.db) self.db = None
self.db = None self._data = None
self._data = None
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
return await self._data.find_one({"_id": id}) return await self._data.find_one({"_id": id})
@ -472,23 +469,22 @@ class MongoDocStatusStorage(DocStatusStorage):
Returns: Returns:
dict[str, str]: Status of the operation with keys 'status' and 'message' dict[str, str]: Status of the operation with keys 'status' and 'message'
""" """
async with get_storage_lock(): try:
try: result = await self._data.delete_many({})
result = await self._data.delete_many({}) deleted_count = result.deleted_count
deleted_count = result.deleted_count
logger.info( logger.info(
f"[{self.workspace}] Dropped {deleted_count} documents from doc status {self._collection_name}" f"[{self.workspace}] Dropped {deleted_count} documents from doc status {self._collection_name}"
) )
return { return {
"status": "success", "status": "success",
"message": f"{deleted_count} documents dropped", "message": f"{deleted_count} documents dropped",
} }
except PyMongoError as e: except PyMongoError as e:
logger.error( logger.error(
f"[{self.workspace}] Error dropping doc status {self._collection_name}: {e}" f"[{self.workspace}] Error dropping doc status {self._collection_name}: {e}"
) )
return {"status": "error", "message": str(e)} return {"status": "error", "message": str(e)}
async def delete(self, ids: list[str]) -> None: async def delete(self, ids: list[str]) -> None:
await self._data.delete_many({"_id": {"$in": ids}}) await self._data.delete_many({"_id": {"$in": ids}})
@ -789,12 +785,11 @@ class MongoGraphStorage(BaseGraphStorage):
) )
async def finalize(self): async def finalize(self):
async with get_graph_db_lock(): if self.db is not None:
if self.db is not None: await ClientManager.release_client(self.db)
await ClientManager.release_client(self.db) self.db = None
self.db = None self.collection = None
self.collection = None self.edge_collection = None
self.edge_collection = None
# Sample entity document # Sample entity document
# "source_ids" is Array representation of "source_id" split by GRAPH_FIELD_SEP # "source_ids" is Array representation of "source_id" split by GRAPH_FIELD_SEP
@ -2042,30 +2037,29 @@ class MongoGraphStorage(BaseGraphStorage):
Returns: Returns:
dict[str, str]: Status of the operation with keys 'status' and 'message' dict[str, str]: Status of the operation with keys 'status' and 'message'
""" """
async with get_graph_db_lock(): try:
try: result = await self.collection.delete_many({})
result = await self.collection.delete_many({}) deleted_count = result.deleted_count
deleted_count = result.deleted_count
logger.info( logger.info(
f"[{self.workspace}] Dropped {deleted_count} documents from graph {self._collection_name}" f"[{self.workspace}] Dropped {deleted_count} documents from graph {self._collection_name}"
) )
result = await self.edge_collection.delete_many({}) result = await self.edge_collection.delete_many({})
edge_count = result.deleted_count edge_count = result.deleted_count
logger.info( logger.info(
f"[{self.workspace}] Dropped {edge_count} edges from graph {self._edge_collection_name}" f"[{self.workspace}] Dropped {edge_count} edges from graph {self._edge_collection_name}"
) )
return { return {
"status": "success", "status": "success",
"message": f"{deleted_count} documents and {edge_count} edges dropped", "message": f"{deleted_count} documents and {edge_count} edges dropped",
} }
except PyMongoError as e: except PyMongoError as e:
logger.error( logger.error(
f"[{self.workspace}] Error dropping graph {self._collection_name}: {e}" f"[{self.workspace}] Error dropping graph {self._collection_name}: {e}"
) )
return {"status": "error", "message": str(e)} return {"status": "error", "message": str(e)}
@final @final
@ -2152,11 +2146,10 @@ class MongoVectorDBStorage(BaseVectorStorage):
) )
async def finalize(self): async def finalize(self):
async with get_storage_lock(): if self.db is not None:
if self.db is not None: await ClientManager.release_client(self.db)
await ClientManager.release_client(self.db) self.db = None
self.db = None self._data = None
self._data = None
async def create_vector_index_if_not_exists(self): async def create_vector_index_if_not_exists(self):
"""Creates an Atlas Vector Search index.""" """Creates an Atlas Vector Search index."""
@ -2479,27 +2472,26 @@ class MongoVectorDBStorage(BaseVectorStorage):
Returns: Returns:
dict[str, str]: Status of the operation with keys 'status' and 'message' dict[str, str]: Status of the operation with keys 'status' and 'message'
""" """
async with get_storage_lock(): try:
try: # Delete all documents
# Delete all documents result = await self._data.delete_many({})
result = await self._data.delete_many({}) deleted_count = result.deleted_count
deleted_count = result.deleted_count
# Recreate vector index # Recreate vector index
await self.create_vector_index_if_not_exists() await self.create_vector_index_if_not_exists()
logger.info( logger.info(
f"[{self.workspace}] Dropped {deleted_count} documents from vector storage {self._collection_name} and recreated vector index" f"[{self.workspace}] Dropped {deleted_count} documents from vector storage {self._collection_name} and recreated vector index"
) )
return { return {
"status": "success", "status": "success",
"message": f"{deleted_count} documents dropped and vector index recreated", "message": f"{deleted_count} documents dropped and vector index recreated",
} }
except PyMongoError as e: except PyMongoError as e:
logger.error( logger.error(
f"[{self.workspace}] Error dropping vector storage {self._collection_name}: {e}" f"[{self.workspace}] Error dropping vector storage {self._collection_name}: {e}"
) )
return {"status": "error", "message": str(e)} return {"status": "error", "message": str(e)}
async def get_or_create_collection(db: AsyncDatabase, collection_name: str): async def get_or_create_collection(db: AsyncDatabase, collection_name: str):

View file

@ -66,11 +66,11 @@ class NanoVectorDBStorage(BaseVectorStorage):
"""Initialize storage data""" """Initialize storage data"""
# Get the update flag for cross-process update notification # Get the update flag for cross-process update notification
self.storage_updated = await get_update_flag( self.storage_updated = await get_update_flag(
self.namespace, workspace=self.workspace self.final_namespace, workspace=self.workspace
) )
# Get the storage lock for use in other methods # Get the storage lock for use in other methods
self._storage_lock = get_namespace_lock( self._storage_lock = get_namespace_lock(
self.namespace, workspace=self.workspace self.final_namespace, workspace=self.workspace
) )
async def _get_client(self): async def _get_client(self):
@ -292,7 +292,9 @@ class NanoVectorDBStorage(BaseVectorStorage):
# Save data to disk # Save data to disk
self._client.save() self._client.save()
# Notify other processes that data has been updated # Notify other processes that data has been updated
await set_all_update_flags(self.namespace, workspace=self.workspace) await set_all_update_flags(
self.final_namespace, workspace=self.workspace
)
# Reset own update flag to avoid self-reloading # Reset own update flag to avoid self-reloading
self.storage_updated.value = False self.storage_updated.value = False
return True # Return success return True # Return success
@ -414,7 +416,9 @@ class NanoVectorDBStorage(BaseVectorStorage):
) )
# Notify other processes that data has been updated # Notify other processes that data has been updated
await set_all_update_flags(self.namespace, workspace=self.workspace) await set_all_update_flags(
self.final_namespace, workspace=self.workspace
)
# Reset own update flag to avoid self-reloading # Reset own update flag to avoid self-reloading
self.storage_updated.value = False self.storage_updated.value = False

View file

@ -72,11 +72,11 @@ class NetworkXStorage(BaseGraphStorage):
"""Initialize storage data""" """Initialize storage data"""
# Get the update flag for cross-process update notification # Get the update flag for cross-process update notification
self.storage_updated = await get_update_flag( self.storage_updated = await get_update_flag(
self.namespace, workspace=self.workspace self.final_namespace, workspace=self.workspace
) )
# Get the storage lock for use in other methods # Get the storage lock for use in other methods
self._storage_lock = get_namespace_lock( self._storage_lock = get_namespace_lock(
self.namespace, workspace=self.workspace self.final_namespace, workspace=self.workspace
) )
async def _get_graph(self): async def _get_graph(self):
@ -526,7 +526,9 @@ class NetworkXStorage(BaseGraphStorage):
self._graph, self._graphml_xml_file, self.workspace self._graph, self._graphml_xml_file, self.workspace
) )
# Notify other processes that data has been updated # Notify other processes that data has been updated
await set_all_update_flags(self.namespace, workspace=self.workspace) await set_all_update_flags(
self.final_namespace, workspace=self.workspace
)
# Reset own update flag to avoid self-reloading # Reset own update flag to avoid self-reloading
self.storage_updated.value = False self.storage_updated.value = False
return True # Return success return True # Return success
@ -557,7 +559,9 @@ class NetworkXStorage(BaseGraphStorage):
os.remove(self._graphml_xml_file) os.remove(self._graphml_xml_file)
self._graph = nx.Graph() self._graph = nx.Graph()
# Notify other processes that data has been updated # Notify other processes that data has been updated
await set_all_update_flags(self.namespace, workspace=self.workspace) await set_all_update_flags(
self.final_namespace, workspace=self.workspace
)
# Reset own update flag to avoid self-reloading # Reset own update flag to avoid self-reloading
self.storage_updated.value = False self.storage_updated.value = False
logger.info( logger.info(

View file

@ -33,7 +33,7 @@ from ..base import (
) )
from ..namespace import NameSpace, is_namespace from ..namespace import NameSpace, is_namespace
from ..utils import logger from ..utils import logger
from ..kg.shared_storage import get_data_init_lock, get_graph_db_lock, get_storage_lock from ..kg.shared_storage import get_data_init_lock
import pipmaster as pm import pipmaster as pm
@ -1699,10 +1699,9 @@ class PGKVStorage(BaseKVStorage):
self.workspace = "default" self.workspace = "default"
async def finalize(self): async def finalize(self):
async with get_storage_lock(): if self.db is not None:
if self.db is not None: await ClientManager.release_client(self.db)
await ClientManager.release_client(self.db) self.db = None
self.db = None
################ QUERY METHODS ################ ################ QUERY METHODS ################
async def get_by_id(self, id: str) -> dict[str, Any] | None: async def get_by_id(self, id: str) -> dict[str, Any] | None:
@ -2144,22 +2143,21 @@ class PGKVStorage(BaseKVStorage):
async def drop(self) -> dict[str, str]: async def drop(self) -> dict[str, str]:
"""Drop the storage""" """Drop the storage"""
async with get_storage_lock(): try:
try: table_name = namespace_to_table_name(self.namespace)
table_name = namespace_to_table_name(self.namespace) if not table_name:
if not table_name: return {
return { "status": "error",
"status": "error", "message": f"Unknown namespace: {self.namespace}",
"message": f"Unknown namespace: {self.namespace}", }
}
drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format( drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format(
table_name=table_name table_name=table_name
) )
await self.db.execute(drop_sql, {"workspace": self.workspace}) await self.db.execute(drop_sql, {"workspace": self.workspace})
return {"status": "success", "message": "data dropped"} return {"status": "success", "message": "data dropped"}
except Exception as e: except Exception as e:
return {"status": "error", "message": str(e)} return {"status": "error", "message": str(e)}
@final @final
@ -2194,10 +2192,9 @@ class PGVectorStorage(BaseVectorStorage):
self.workspace = "default" self.workspace = "default"
async def finalize(self): async def finalize(self):
async with get_storage_lock(): if self.db is not None:
if self.db is not None: await ClientManager.release_client(self.db)
await ClientManager.release_client(self.db) self.db = None
self.db = None
def _upsert_chunks( def _upsert_chunks(
self, item: dict[str, Any], current_time: datetime.datetime self, item: dict[str, Any], current_time: datetime.datetime
@ -2533,22 +2530,21 @@ class PGVectorStorage(BaseVectorStorage):
async def drop(self) -> dict[str, str]: async def drop(self) -> dict[str, str]:
"""Drop the storage""" """Drop the storage"""
async with get_storage_lock(): try:
try: table_name = namespace_to_table_name(self.namespace)
table_name = namespace_to_table_name(self.namespace) if not table_name:
if not table_name: return {
return { "status": "error",
"status": "error", "message": f"Unknown namespace: {self.namespace}",
"message": f"Unknown namespace: {self.namespace}", }
}
drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format( drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format(
table_name=table_name table_name=table_name
) )
await self.db.execute(drop_sql, {"workspace": self.workspace}) await self.db.execute(drop_sql, {"workspace": self.workspace})
return {"status": "success", "message": "data dropped"} return {"status": "success", "message": "data dropped"}
except Exception as e: except Exception as e:
return {"status": "error", "message": str(e)} return {"status": "error", "message": str(e)}
@final @final
@ -2583,10 +2579,9 @@ class PGDocStatusStorage(DocStatusStorage):
self.workspace = "default" self.workspace = "default"
async def finalize(self): async def finalize(self):
async with get_storage_lock(): if self.db is not None:
if self.db is not None: await ClientManager.release_client(self.db)
await ClientManager.release_client(self.db) self.db = None
self.db = None
async def filter_keys(self, keys: set[str]) -> set[str]: async def filter_keys(self, keys: set[str]) -> set[str]:
"""Filter out duplicated content""" """Filter out duplicated content"""
@ -3161,22 +3156,21 @@ class PGDocStatusStorage(DocStatusStorage):
async def drop(self) -> dict[str, str]: async def drop(self) -> dict[str, str]:
"""Drop the storage""" """Drop the storage"""
async with get_storage_lock(): try:
try: table_name = namespace_to_table_name(self.namespace)
table_name = namespace_to_table_name(self.namespace) if not table_name:
if not table_name: return {
return { "status": "error",
"status": "error", "message": f"Unknown namespace: {self.namespace}",
"message": f"Unknown namespace: {self.namespace}", }
}
drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format( drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format(
table_name=table_name table_name=table_name
) )
await self.db.execute(drop_sql, {"workspace": self.workspace}) await self.db.execute(drop_sql, {"workspace": self.workspace})
return {"status": "success", "message": "data dropped"} return {"status": "success", "message": "data dropped"}
except Exception as e: except Exception as e:
return {"status": "error", "message": str(e)} return {"status": "error", "message": str(e)}
class PGGraphQueryException(Exception): class PGGraphQueryException(Exception):
@ -3308,10 +3302,9 @@ class PGGraphStorage(BaseGraphStorage):
) )
async def finalize(self): async def finalize(self):
async with get_graph_db_lock(): if self.db is not None:
if self.db is not None: await ClientManager.release_client(self.db)
await ClientManager.release_client(self.db) self.db = None
self.db = None
async def index_done_callback(self) -> None: async def index_done_callback(self) -> None:
# PG handles persistence automatically # PG handles persistence automatically
@ -4711,21 +4704,20 @@ class PGGraphStorage(BaseGraphStorage):
async def drop(self) -> dict[str, str]: async def drop(self) -> dict[str, str]:
"""Drop the storage""" """Drop the storage"""
async with get_graph_db_lock(): try:
try: drop_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
drop_query = f"""SELECT * FROM cypher('{self.graph_name}', $$ MATCH (n)
MATCH (n) DETACH DELETE n
DETACH DELETE n $$) AS (result agtype)"""
$$) AS (result agtype)"""
await self._query(drop_query, readonly=False) await self._query(drop_query, readonly=False)
return { return {
"status": "success", "status": "success",
"message": f"workspace '{self.workspace}' graph data dropped", "message": f"workspace '{self.workspace}' graph data dropped",
} }
except Exception as e: except Exception as e:
logger.error(f"[{self.workspace}] Error dropping graph: {e}") logger.error(f"[{self.workspace}] Error dropping graph: {e}")
return {"status": "error", "message": str(e)} return {"status": "error", "message": str(e)}
# Note: Order matters! More specific namespaces (e.g., "full_entities") must come before # Note: Order matters! More specific namespaces (e.g., "full_entities") must come before

View file

@ -11,7 +11,7 @@ import pipmaster as pm
from ..base import BaseVectorStorage from ..base import BaseVectorStorage
from ..exceptions import QdrantMigrationError from ..exceptions import QdrantMigrationError
from ..kg.shared_storage import get_data_init_lock, get_storage_lock from ..kg.shared_storage import get_data_init_lock
from ..utils import compute_mdhash_id, logger from ..utils import compute_mdhash_id, logger
if not pm.is_installed("qdrant-client"): if not pm.is_installed("qdrant-client"):
@ -698,25 +698,25 @@ class QdrantVectorDBStorage(BaseVectorStorage):
- On success: {"status": "success", "message": "data dropped"} - On success: {"status": "success", "message": "data dropped"}
- On failure: {"status": "error", "message": "<error details>"} - On failure: {"status": "error", "message": "<error details>"}
""" """
async with get_storage_lock(): # No need to lock: data integrity is ensured by allowing only one process to hold pipeline at a time
try: try:
# Delete all points for the current workspace # Delete all points for the current workspace
self._client.delete( self._client.delete(
collection_name=self.final_namespace, collection_name=self.final_namespace,
points_selector=models.FilterSelector( points_selector=models.FilterSelector(
filter=models.Filter( filter=models.Filter(
must=[workspace_filter_condition(self.effective_workspace)] must=[workspace_filter_condition(self.effective_workspace)]
) )
), ),
wait=True, wait=True,
) )
logger.info( logger.info(
f"[{self.workspace}] Process {os.getpid()} dropped workspace data from Qdrant collection {self.namespace}" f"[{self.workspace}] Process {os.getpid()} dropped workspace data from Qdrant collection {self.namespace}"
) )
return {"status": "success", "message": "data dropped"} return {"status": "success", "message": "data dropped"}
except Exception as e: except Exception as e:
logger.error( logger.error(
f"[{self.workspace}] Error dropping workspace data from Qdrant collection {self.namespace}: {e}" f"[{self.workspace}] Error dropping workspace data from Qdrant collection {self.namespace}: {e}"
) )
return {"status": "error", "message": str(e)} return {"status": "error", "message": str(e)}

View file

@ -1370,19 +1370,11 @@ async def get_all_update_flags_status(workspace: str | None = None) -> Dict[str,
result = {} result = {}
async with get_internal_lock(): async with get_internal_lock():
for namespace, flags in _update_flags.items(): for namespace, flags in _update_flags.items():
# Check if namespace has a workspace prefix (contains ':') namespace_split = namespace.split(":")
if ":" in namespace: if workspace and not namespace_split[0] == workspace:
# Namespace has workspace prefix like "space1:pipeline_status" continue
# Only include if workspace matches the prefix if not workspace and namespace_split[0]:
namespace_split = namespace.split(":", 1) continue
if not workspace or namespace_split[0] != workspace:
continue
else:
# Namespace has no workspace prefix like "pipeline_status"
# Only include if we're querying the default (empty) workspace
if workspace:
continue
worker_statuses = [] worker_statuses = []
for flag in flags: for flag in flags:
if _is_multiprocess: if _is_multiprocess:
@ -1446,21 +1438,18 @@ async def get_namespace_data(
async with get_internal_lock(): async with get_internal_lock():
if final_namespace not in _shared_dicts: if final_namespace not in _shared_dicts:
# Special handling for pipeline_status namespace # Special handling for pipeline_status namespace
if ( if final_namespace.endswith(":pipeline_status") and not first_init:
final_namespace.endswith(":pipeline_status")
or final_namespace == "pipeline_status"
) and not first_init:
# Check if pipeline_status should have been initialized but wasn't # Check if pipeline_status should have been initialized but wasn't
# This helps users to call initialize_pipeline_status() before get_namespace_data() # This helps users to call initialize_pipeline_status() before get_namespace_data()
raise PipelineNotInitializedError(final_namespace) raise PipelineNotInitializedError(namespace)
# For other namespaces or when allow_create=True, create them dynamically # For other namespaces or when allow_create=True, create them dynamically
if _is_multiprocess and _manager is not None: if _is_multiprocess and _manager is not None:
_shared_dicts[final_namespace] = _manager.dict() _shared_dicts[namespace] = _manager.dict()
else: else:
_shared_dicts[final_namespace] = {} _shared_dicts[namespace] = {}
return _shared_dicts[final_namespace] return _shared_dicts[namespace]
def get_namespace_lock( def get_namespace_lock(

View file

@ -1599,12 +1599,8 @@ class LightRAG:
""" """
# Get pipeline status shared data and lock # Get pipeline status shared data and lock
pipeline_status = await get_namespace_data( pipeline_status = await get_namespace_data("pipeline_status")
"pipeline_status", workspace=self.workspace pipeline_status_lock = get_namespace_lock("pipeline_status")
)
pipeline_status_lock = get_namespace_lock(
"pipeline_status", workspace=self.workspace
)
# Check if another process is already processing the queue # Check if another process is already processing the queue
async with pipeline_status_lock: async with pipeline_status_lock:
@ -2956,12 +2952,8 @@ class LightRAG:
doc_llm_cache_ids: list[str] = [] doc_llm_cache_ids: list[str] = []
# Get pipeline status shared data and lock for status updates # Get pipeline status shared data and lock for status updates
pipeline_status = await get_namespace_data( pipeline_status = await get_namespace_data("pipeline_status")
"pipeline_status", workspace=self.workspace pipeline_status_lock = get_namespace_lock("pipeline_status")
)
pipeline_status_lock = get_namespace_lock(
"pipeline_status", workspace=self.workspace
)
async with pipeline_status_lock: async with pipeline_status_lock:
log_message = f"Starting deletion process for document {doc_id}" log_message = f"Starting deletion process for document {doc_id}"

View file

@ -463,7 +463,7 @@ class CleanupTool:
# CRITICAL: Set update flag so changes persist to disk # CRITICAL: Set update flag so changes persist to disk
# Without this, deletions remain in-memory only and are lost on exit # Without this, deletions remain in-memory only and are lost on exit
await set_all_update_flags(storage.final_namespace) await set_all_update_flags(storage.final_namespace, storage.workspace)
# Success # Success
stats.successful_batches += 1 stats.successful_batches += 1

View file

@ -18,7 +18,6 @@ import os
import sys import sys
import importlib import importlib
import numpy as np import numpy as np
import pytest
from dotenv import load_dotenv from dotenv import load_dotenv
from ascii_colors import ASCIIColors from ascii_colors import ASCIIColors
@ -106,12 +105,13 @@ async def initialize_graph_storage():
"vector_db_storage_cls_kwargs": { "vector_db_storage_cls_kwargs": {
"cosine_better_than_threshold": 0.5 # Cosine similarity threshold "cosine_better_than_threshold": 0.5 # Cosine similarity threshold
}, },
"working_dir": os.environ.get("WORKING_DIR", "./rag_storage"), # Working directory "working_dir": os.environ.get(
"WORKING_DIR", "./rag_storage"
), # Working directory
} }
# If using NetworkXStorage, initialize shared_storage first # Initialize shared_storage for all storage types (required for locks)
if graph_storage_type == "NetworkXStorage": initialize_share_data() # Use single-process mode (workers=1)
initialize_share_data() # Use single-process mode
try: try:
storage = storage_class( storage = storage_class(
@ -129,8 +129,6 @@ async def initialize_graph_storage():
return None return None
@pytest.mark.integration
@pytest.mark.requires_db
async def test_graph_basic(storage): async def test_graph_basic(storage):
""" """
Test basic graph database operations: Test basic graph database operations:
@ -176,7 +174,9 @@ async def test_graph_basic(storage):
node1_props = await storage.get_node(node1_id) node1_props = await storage.get_node(node1_id)
if node1_props: if node1_props:
print(f"Successfully read node properties: {node1_id}") print(f"Successfully read node properties: {node1_id}")
print(f"Node description: {node1_props.get('description', 'No description')}") print(
f"Node description: {node1_props.get('description', 'No description')}"
)
print(f"Node type: {node1_props.get('entity_type', 'No type')}") print(f"Node type: {node1_props.get('entity_type', 'No type')}")
print(f"Node keywords: {node1_props.get('keywords', 'No keywords')}") print(f"Node keywords: {node1_props.get('keywords', 'No keywords')}")
# Verify that the returned properties are correct # Verify that the returned properties are correct
@ -198,8 +198,12 @@ async def test_graph_basic(storage):
edge_props = await storage.get_edge(node1_id, node2_id) edge_props = await storage.get_edge(node1_id, node2_id)
if edge_props: if edge_props:
print(f"Successfully read edge properties: {node1_id} -> {node2_id}") print(f"Successfully read edge properties: {node1_id} -> {node2_id}")
print(f"Edge relationship: {edge_props.get('relationship', 'No relationship')}") print(
print(f"Edge description: {edge_props.get('description', 'No description')}") f"Edge relationship: {edge_props.get('relationship', 'No relationship')}"
)
print(
f"Edge description: {edge_props.get('description', 'No description')}"
)
print(f"Edge weight: {edge_props.get('weight', 'No weight')}") print(f"Edge weight: {edge_props.get('weight', 'No weight')}")
# Verify that the returned properties are correct # Verify that the returned properties are correct
assert ( assert (
@ -208,7 +212,9 @@ async def test_graph_basic(storage):
assert ( assert (
edge_props.get("description") == edge_data["description"] edge_props.get("description") == edge_data["description"]
), "Edge description mismatch" ), "Edge description mismatch"
assert edge_props.get("weight") == edge_data["weight"], "Edge weight mismatch" assert (
edge_props.get("weight") == edge_data["weight"]
), "Edge weight mismatch"
else: else:
print(f"Failed to read edge properties: {node1_id} -> {node2_id}") print(f"Failed to read edge properties: {node1_id} -> {node2_id}")
assert False, f"Failed to read edge properties: {node1_id} -> {node2_id}" assert False, f"Failed to read edge properties: {node1_id} -> {node2_id}"
@ -217,20 +223,28 @@ async def test_graph_basic(storage):
print(f"Reading reverse edge properties: {node2_id} -> {node1_id}") print(f"Reading reverse edge properties: {node2_id} -> {node1_id}")
reverse_edge_props = await storage.get_edge(node2_id, node1_id) reverse_edge_props = await storage.get_edge(node2_id, node1_id)
if reverse_edge_props: if reverse_edge_props:
print(f"Successfully read reverse edge properties: {node2_id} -> {node1_id}") print(
print(f"Reverse edge relationship: {reverse_edge_props.get('relationship', 'No relationship')}") f"Successfully read reverse edge properties: {node2_id} -> {node1_id}"
print(f"Reverse edge description: {reverse_edge_props.get('description', 'No description')}") )
print(f"Reverse edge weight: {reverse_edge_props.get('weight', 'No weight')}") print(
f"Reverse edge relationship: {reverse_edge_props.get('relationship', 'No relationship')}"
)
print(
f"Reverse edge description: {reverse_edge_props.get('description', 'No description')}"
)
print(
f"Reverse edge weight: {reverse_edge_props.get('weight', 'No weight')}"
)
# Verify that forward and reverse edge properties are the same # Verify that forward and reverse edge properties are the same
assert ( assert (
edge_props == reverse_edge_props edge_props == reverse_edge_props
), "Forward and reverse edge properties are not consistent, undirected graph property verification failed" ), "Forward and reverse edge properties are not consistent, undirected graph property verification failed"
print("Undirected graph property verification successful: forward and reverse edge properties are consistent") print(
"Undirected graph property verification successful: forward and reverse edge properties are consistent"
)
else: else:
print(f"Failed to read reverse edge properties: {node2_id} -> {node1_id}") print(f"Failed to read reverse edge properties: {node2_id} -> {node1_id}")
assert ( assert False, f"Failed to read reverse edge properties: {node2_id} -> {node1_id}, undirected graph property verification failed"
False
), f"Failed to read reverse edge properties: {node2_id} -> {node1_id}, undirected graph property verification failed"
print("Basic tests completed, data is preserved in the database.") print("Basic tests completed, data is preserved in the database.")
return True return True
@ -240,8 +254,6 @@ async def test_graph_basic(storage):
return False return False
@pytest.mark.integration
@pytest.mark.requires_db
async def test_graph_advanced(storage): async def test_graph_advanced(storage):
""" """
Test advanced graph database operations: Test advanced graph database operations:
@ -312,7 +324,9 @@ async def test_graph_advanced(storage):
print(f"== Testing node_degree: {node1_id}") print(f"== Testing node_degree: {node1_id}")
node1_degree = await storage.node_degree(node1_id) node1_degree = await storage.node_degree(node1_id)
print(f"Degree of node {node1_id}: {node1_degree}") print(f"Degree of node {node1_id}: {node1_degree}")
assert node1_degree == 1, f"Degree of node {node1_id} should be 1, but got {node1_degree}" assert (
node1_degree == 1
), f"Degree of node {node1_id} should be 1, but got {node1_degree}"
# 2.1 Test degrees of all nodes # 2.1 Test degrees of all nodes
print("== Testing degrees of all nodes") print("== Testing degrees of all nodes")
@ -320,8 +334,12 @@ async def test_graph_advanced(storage):
node3_degree = await storage.node_degree(node3_id) node3_degree = await storage.node_degree(node3_id)
print(f"Degree of node {node2_id}: {node2_degree}") print(f"Degree of node {node2_id}: {node2_degree}")
print(f"Degree of node {node3_id}: {node3_degree}") print(f"Degree of node {node3_id}: {node3_degree}")
assert node2_degree == 2, f"Degree of node {node2_id} should be 2, but got {node2_degree}" assert (
assert node3_degree == 1, f"Degree of node {node3_id} should be 1, but got {node3_degree}" node2_degree == 2
), f"Degree of node {node2_id} should be 2, but got {node2_degree}"
assert (
node3_degree == 1
), f"Degree of node {node3_id} should be 1, but got {node3_degree}"
# 3. Test edge_degree - get the degree of an edge # 3. Test edge_degree - get the degree of an edge
print(f"== Testing edge_degree: {node1_id} -> {node2_id}") print(f"== Testing edge_degree: {node1_id} -> {node2_id}")
@ -338,7 +356,9 @@ async def test_graph_advanced(storage):
assert ( assert (
edge_degree == reverse_edge_degree edge_degree == reverse_edge_degree
), "Degrees of forward and reverse edges are not consistent, undirected graph property verification failed" ), "Degrees of forward and reverse edges are not consistent, undirected graph property verification failed"
print("Undirected graph property verification successful: degrees of forward and reverse edges are consistent") print(
"Undirected graph property verification successful: degrees of forward and reverse edges are consistent"
)
# 4. Test get_node_edges - get all edges of a node # 4. Test get_node_edges - get all edges of a node
print(f"== Testing get_node_edges: {node2_id}") print(f"== Testing get_node_edges: {node2_id}")
@ -371,7 +391,9 @@ async def test_graph_advanced(storage):
assert ( assert (
has_connection_with_node3 has_connection_with_node3
), f"Edge list of node {node2_id} should include a connection with {node3_id}" ), f"Edge list of node {node2_id} should include a connection with {node3_id}"
print(f"Undirected graph property verification successful: edge list of node {node2_id} contains all relevant edges") print(
f"Undirected graph property verification successful: edge list of node {node2_id} contains all relevant edges"
)
# 5. Test get_all_labels - get all labels # 5. Test get_all_labels - get all labels
print("== Testing get_all_labels") print("== Testing get_all_labels")
@ -387,9 +409,15 @@ async def test_graph_advanced(storage):
kg = await storage.get_knowledge_graph("*", max_depth=2, max_nodes=10) kg = await storage.get_knowledge_graph("*", max_depth=2, max_nodes=10)
print(f"Number of nodes in knowledge graph: {len(kg.nodes)}") print(f"Number of nodes in knowledge graph: {len(kg.nodes)}")
print(f"Number of edges in knowledge graph: {len(kg.edges)}") print(f"Number of edges in knowledge graph: {len(kg.edges)}")
assert isinstance(kg, KnowledgeGraph), "The returned result should be of type KnowledgeGraph" assert isinstance(
assert len(kg.nodes) == 3, f"The knowledge graph should have 3 nodes, but got {len(kg.nodes)}" kg, KnowledgeGraph
assert len(kg.edges) == 2, f"The knowledge graph should have 2 edges, but got {len(kg.edges)}" ), "The returned result should be of type KnowledgeGraph"
assert (
len(kg.nodes) == 3
), f"The knowledge graph should have 3 nodes, but got {len(kg.nodes)}"
assert (
len(kg.edges) == 2
), f"The knowledge graph should have 2 edges, but got {len(kg.edges)}"
# 7. Test delete_node - delete a node # 7. Test delete_node - delete a node
print(f"== Testing delete_node: {node3_id}") print(f"== Testing delete_node: {node3_id}")
@ -406,17 +434,27 @@ async def test_graph_advanced(storage):
print(f"== Testing remove_edges: {node2_id} -> {node3_id}") print(f"== Testing remove_edges: {node2_id} -> {node3_id}")
await storage.remove_edges([(node2_id, node3_id)]) await storage.remove_edges([(node2_id, node3_id)])
edge_props = await storage.get_edge(node2_id, node3_id) edge_props = await storage.get_edge(node2_id, node3_id)
print(f"Querying edge properties after deletion {node2_id} -> {node3_id}: {edge_props}") print(
assert edge_props is None, f"Edge {node2_id} -> {node3_id} should have been deleted" f"Querying edge properties after deletion {node2_id} -> {node3_id}: {edge_props}"
)
assert (
edge_props is None
), f"Edge {node2_id} -> {node3_id} should have been deleted"
# 8.1 Verify undirected graph property of edge deletion # 8.1 Verify undirected graph property of edge deletion
print(f"== Verifying undirected graph property of edge deletion: {node3_id} -> {node2_id}") print(
f"== Verifying undirected graph property of edge deletion: {node3_id} -> {node2_id}"
)
reverse_edge_props = await storage.get_edge(node3_id, node2_id) reverse_edge_props = await storage.get_edge(node3_id, node2_id)
print(f"Querying reverse edge properties after deletion {node3_id} -> {node2_id}: {reverse_edge_props}") print(
f"Querying reverse edge properties after deletion {node3_id} -> {node2_id}: {reverse_edge_props}"
)
assert ( assert (
reverse_edge_props is None reverse_edge_props is None
), f"Reverse edge {node3_id} -> {node2_id} should also be deleted, undirected graph property verification failed" ), f"Reverse edge {node3_id} -> {node2_id} should also be deleted, undirected graph property verification failed"
print("Undirected graph property verification successful: deleting an edge in one direction also deletes the reverse edge") print(
"Undirected graph property verification successful: deleting an edge in one direction also deletes the reverse edge"
)
# 9. Test remove_nodes - delete multiple nodes # 9. Test remove_nodes - delete multiple nodes
print(f"== Testing remove_nodes: [{node2_id}, {node3_id}]") print(f"== Testing remove_nodes: [{node2_id}, {node3_id}]")
@ -436,8 +474,6 @@ async def test_graph_advanced(storage):
return False return False
@pytest.mark.integration
@pytest.mark.requires_db
async def test_graph_batch_operations(storage): async def test_graph_batch_operations(storage):
""" """
Test batch operations of the graph database: Test batch operations of the graph database:
@ -643,7 +679,9 @@ async def test_graph_batch_operations(storage):
edge_dicts = [{"src": src, "tgt": tgt} for src, tgt in edges] edge_dicts = [{"src": src, "tgt": tgt} for src, tgt in edges]
edges_dict = await storage.get_edges_batch(edge_dicts) edges_dict = await storage.get_edges_batch(edge_dicts)
print(f"Batch get edge properties result: {edges_dict.keys()}") print(f"Batch get edge properties result: {edges_dict.keys()}")
assert len(edges_dict) == 3, f"Should return properties of 3 edges, but got {len(edges_dict)}" assert (
len(edges_dict) == 3
), f"Should return properties of 3 edges, but got {len(edges_dict)}"
assert ( assert (
node1_id, node1_id,
node2_id, node2_id,
@ -682,14 +720,19 @@ async def test_graph_batch_operations(storage):
# Verify that properties of forward and reverse edges are consistent # Verify that properties of forward and reverse edges are consistent
for (src, tgt), props in edges_dict.items(): for (src, tgt), props in edges_dict.items():
assert ( assert (
tgt, (
src, tgt,
) in reverse_edges_dict, f"Reverse edge {tgt} -> {src} should be in the result" src,
)
in reverse_edges_dict
), f"Reverse edge {tgt} -> {src} should be in the result"
assert ( assert (
props == reverse_edges_dict[(tgt, src)] props == reverse_edges_dict[(tgt, src)]
), f"Properties of edge {src} -> {tgt} and reverse edge {tgt} -> {src} are inconsistent" ), f"Properties of edge {src} -> {tgt} and reverse edge {tgt} -> {src} are inconsistent"
print("Undirected graph property verification successful: properties of batch-retrieved forward and reverse edges are consistent") print(
"Undirected graph property verification successful: properties of batch-retrieved forward and reverse edges are consistent"
)
# 6. Test get_nodes_edges_batch - batch get all edges of multiple nodes # 6. Test get_nodes_edges_batch - batch get all edges of multiple nodes
print("== Testing get_nodes_edges_batch") print("== Testing get_nodes_edges_batch")
@ -725,9 +768,15 @@ async def test_graph_batch_operations(storage):
has_edge_to_node4 = any(tgt == node4_id for _, tgt in node1_outgoing_edges) has_edge_to_node4 = any(tgt == node4_id for _, tgt in node1_outgoing_edges)
has_edge_to_node5 = any(tgt == node5_id for _, tgt in node1_outgoing_edges) has_edge_to_node5 = any(tgt == node5_id for _, tgt in node1_outgoing_edges)
assert has_edge_to_node2, f"Edge list of node {node1_id} should include an edge to {node2_id}" assert (
assert has_edge_to_node4, f"Edge list of node {node1_id} should include an edge to {node4_id}" has_edge_to_node2
assert has_edge_to_node5, f"Edge list of node {node1_id} should include an edge to {node5_id}" ), f"Edge list of node {node1_id} should include an edge to {node2_id}"
assert (
has_edge_to_node4
), f"Edge list of node {node1_id} should include an edge to {node4_id}"
assert (
has_edge_to_node5
), f"Edge list of node {node1_id} should include an edge to {node5_id}"
# Check if node 3's edges include all relevant edges (regardless of direction) # Check if node 3's edges include all relevant edges (regardless of direction)
node3_outgoing_edges = [ node3_outgoing_edges = [
@ -766,7 +815,9 @@ async def test_graph_batch_operations(storage):
has_connection_with_node5 has_connection_with_node5
), f"Edge list of node {node3_id} should include a connection with {node5_id}" ), f"Edge list of node {node3_id} should include a connection with {node5_id}"
print("Undirected graph property verification successful: batch-retrieved node edges include all relevant edges (regardless of direction)") print(
"Undirected graph property verification successful: batch-retrieved node edges include all relevant edges (regardless of direction)"
)
print("\nBatch operations tests completed.") print("\nBatch operations tests completed.")
return True return True
@ -776,8 +827,6 @@ async def test_graph_batch_operations(storage):
return False return False
@pytest.mark.integration
@pytest.mark.requires_db
async def test_graph_special_characters(storage): async def test_graph_special_characters(storage):
""" """
Test the graph database's handling of special characters: Test the graph database's handling of special characters:
@ -834,7 +883,9 @@ async def test_graph_special_characters(storage):
"weight": 0.8, "weight": 0.8,
"description": "Contains SQL injection attempt: SELECT * FROM users WHERE name='admin'--", "description": "Contains SQL injection attempt: SELECT * FROM users WHERE name='admin'--",
} }
print(f"Inserting edge with complex special characters: {node2_id} -> {node3_id}") print(
f"Inserting edge with complex special characters: {node2_id} -> {node3_id}"
)
await storage.upsert_edge(node2_id, node3_id, edge2_data) await storage.upsert_edge(node2_id, node3_id, edge2_data)
# 6. Verify that node special characters are saved correctly # 6. Verify that node special characters are saved correctly
@ -847,7 +898,9 @@ async def test_graph_special_characters(storage):
node_props = await storage.get_node(node_id) node_props = await storage.get_node(node_id)
if node_props: if node_props:
print(f"Successfully read node: {node_id}") print(f"Successfully read node: {node_id}")
print(f"Node description: {node_props.get('description', 'No description')}") print(
f"Node description: {node_props.get('description', 'No description')}"
)
# Verify node ID is saved correctly # Verify node ID is saved correctly
assert ( assert (
@ -869,8 +922,12 @@ async def test_graph_special_characters(storage):
edge1_props = await storage.get_edge(node1_id, node2_id) edge1_props = await storage.get_edge(node1_id, node2_id)
if edge1_props: if edge1_props:
print(f"Successfully read edge: {node1_id} -> {node2_id}") print(f"Successfully read edge: {node1_id} -> {node2_id}")
print(f"Edge relationship: {edge1_props.get('relationship', 'No relationship')}") print(
print(f"Edge description: {edge1_props.get('description', 'No description')}") f"Edge relationship: {edge1_props.get('relationship', 'No relationship')}"
)
print(
f"Edge description: {edge1_props.get('description', 'No description')}"
)
# Verify edge relationship is saved correctly # Verify edge relationship is saved correctly
assert ( assert (
@ -882,7 +939,9 @@ async def test_graph_special_characters(storage):
edge1_props.get("description") == edge1_data["description"] edge1_props.get("description") == edge1_data["description"]
), f"Edge description mismatch: expected {edge1_data['description']}, got {edge1_props.get('description')}" ), f"Edge description mismatch: expected {edge1_data['description']}, got {edge1_props.get('description')}"
print(f"Edge {node1_id} -> {node2_id} special character verification successful") print(
f"Edge {node1_id} -> {node2_id} special character verification successful"
)
else: else:
print(f"Failed to read edge properties: {node1_id} -> {node2_id}") print(f"Failed to read edge properties: {node1_id} -> {node2_id}")
assert False, f"Failed to read edge properties: {node1_id} -> {node2_id}" assert False, f"Failed to read edge properties: {node1_id} -> {node2_id}"
@ -890,8 +949,12 @@ async def test_graph_special_characters(storage):
edge2_props = await storage.get_edge(node2_id, node3_id) edge2_props = await storage.get_edge(node2_id, node3_id)
if edge2_props: if edge2_props:
print(f"Successfully read edge: {node2_id} -> {node3_id}") print(f"Successfully read edge: {node2_id} -> {node3_id}")
print(f"Edge relationship: {edge2_props.get('relationship', 'No relationship')}") print(
print(f"Edge description: {edge2_props.get('description', 'No description')}") f"Edge relationship: {edge2_props.get('relationship', 'No relationship')}"
)
print(
f"Edge description: {edge2_props.get('description', 'No description')}"
)
# Verify edge relationship is saved correctly # Verify edge relationship is saved correctly
assert ( assert (
@ -903,7 +966,9 @@ async def test_graph_special_characters(storage):
edge2_props.get("description") == edge2_data["description"] edge2_props.get("description") == edge2_data["description"]
), f"Edge description mismatch: expected {edge2_data['description']}, got {edge2_props.get('description')}" ), f"Edge description mismatch: expected {edge2_data['description']}, got {edge2_props.get('description')}"
print(f"Edge {node2_id} -> {node3_id} special character verification successful") print(
f"Edge {node2_id} -> {node3_id} special character verification successful"
)
else: else:
print(f"Failed to read edge properties: {node2_id} -> {node3_id}") print(f"Failed to read edge properties: {node2_id} -> {node3_id}")
assert False, f"Failed to read edge properties: {node2_id} -> {node3_id}" assert False, f"Failed to read edge properties: {node2_id} -> {node3_id}"
@ -916,8 +981,6 @@ async def test_graph_special_characters(storage):
return False return False
@pytest.mark.integration
@pytest.mark.requires_db
async def test_graph_undirected_property(storage): async def test_graph_undirected_property(storage):
""" """
Specifically test the undirected graph property of the storage: Specifically test the undirected graph property of the storage:
@ -976,18 +1039,24 @@ async def test_graph_undirected_property(storage):
# Verify forward query # Verify forward query
forward_edge = await storage.get_edge(node1_id, node2_id) forward_edge = await storage.get_edge(node1_id, node2_id)
print(f"Forward edge properties: {forward_edge}") print(f"Forward edge properties: {forward_edge}")
assert forward_edge is not None, f"Failed to read forward edge properties: {node1_id} -> {node2_id}" assert (
forward_edge is not None
), f"Failed to read forward edge properties: {node1_id} -> {node2_id}"
# Verify reverse query # Verify reverse query
reverse_edge = await storage.get_edge(node2_id, node1_id) reverse_edge = await storage.get_edge(node2_id, node1_id)
print(f"Reverse edge properties: {reverse_edge}") print(f"Reverse edge properties: {reverse_edge}")
assert reverse_edge is not None, f"Failed to read reverse edge properties: {node2_id} -> {node1_id}" assert (
reverse_edge is not None
), f"Failed to read reverse edge properties: {node2_id} -> {node1_id}"
# Verify that forward and reverse edge properties are consistent # Verify that forward and reverse edge properties are consistent
assert ( assert (
forward_edge == reverse_edge forward_edge == reverse_edge
), "Forward and reverse edge properties are inconsistent, undirected property verification failed" ), "Forward and reverse edge properties are inconsistent, undirected property verification failed"
print("Undirected property verification successful: forward and reverse edge properties are consistent") print(
"Undirected property verification successful: forward and reverse edge properties are consistent"
)
# 3. Test undirected property of edge degree # 3. Test undirected property of edge degree
print("\n== Testing undirected property of edge degree") print("\n== Testing undirected property of edge degree")
@ -1009,7 +1078,9 @@ async def test_graph_undirected_property(storage):
assert ( assert (
forward_degree == reverse_degree forward_degree == reverse_degree
), "Degrees of forward and reverse edges are inconsistent, undirected property verification failed" ), "Degrees of forward and reverse edges are inconsistent, undirected property verification failed"
print("Undirected property verification successful: degrees of forward and reverse edges are consistent") print(
"Undirected property verification successful: degrees of forward and reverse edges are consistent"
)
# 4. Test undirected property of edge deletion # 4. Test undirected property of edge deletion
print("\n== Testing undirected property of edge deletion") print("\n== Testing undirected property of edge deletion")
@ -1020,16 +1091,24 @@ async def test_graph_undirected_property(storage):
# Verify forward edge is deleted # Verify forward edge is deleted
forward_edge = await storage.get_edge(node1_id, node2_id) forward_edge = await storage.get_edge(node1_id, node2_id)
print(f"Querying forward edge properties after deletion {node1_id} -> {node2_id}: {forward_edge}") print(
assert forward_edge is None, f"Edge {node1_id} -> {node2_id} should have been deleted" f"Querying forward edge properties after deletion {node1_id} -> {node2_id}: {forward_edge}"
)
assert (
forward_edge is None
), f"Edge {node1_id} -> {node2_id} should have been deleted"
# Verify reverse edge is also deleted # Verify reverse edge is also deleted
reverse_edge = await storage.get_edge(node2_id, node1_id) reverse_edge = await storage.get_edge(node2_id, node1_id)
print(f"Querying reverse edge properties after deletion {node2_id} -> {node1_id}: {reverse_edge}") print(
f"Querying reverse edge properties after deletion {node2_id} -> {node1_id}: {reverse_edge}"
)
assert ( assert (
reverse_edge is None reverse_edge is None
), f"Reverse edge {node2_id} -> {node1_id} should also be deleted, undirected property verification failed" ), f"Reverse edge {node2_id} -> {node1_id} should also be deleted, undirected property verification failed"
print("Undirected property verification successful: deleting an edge in one direction also deletes the reverse edge") print(
"Undirected property verification successful: deleting an edge in one direction also deletes the reverse edge"
)
# 5. Test undirected property in batch operations # 5. Test undirected property in batch operations
print("\n== Testing undirected property in batch operations") print("\n== Testing undirected property in batch operations")
@ -1056,14 +1135,19 @@ async def test_graph_undirected_property(storage):
# Verify that properties of forward and reverse edges are consistent # Verify that properties of forward and reverse edges are consistent
for (src, tgt), props in edges_dict.items(): for (src, tgt), props in edges_dict.items():
assert ( assert (
tgt, (
src, tgt,
) in reverse_edges_dict, f"Reverse edge {tgt} -> {src} should be in the result" src,
)
in reverse_edges_dict
), f"Reverse edge {tgt} -> {src} should be in the result"
assert ( assert (
props == reverse_edges_dict[(tgt, src)] props == reverse_edges_dict[(tgt, src)]
), f"Properties of edge {src} -> {tgt} and reverse edge {tgt} -> {src} are inconsistent" ), f"Properties of edge {src} -> {tgt} and reverse edge {tgt} -> {src} are inconsistent"
print("Undirected property verification successful: properties of batch-retrieved forward and reverse edges are consistent") print(
"Undirected property verification successful: properties of batch-retrieved forward and reverse edges are consistent"
)
# 6. Test undirected property of batch-retrieved node edges # 6. Test undirected property of batch-retrieved node edges
print("\n== Testing undirected property of batch-retrieved node edges") print("\n== Testing undirected property of batch-retrieved node edges")
@ -1083,8 +1167,12 @@ async def test_graph_undirected_property(storage):
(src == node1_id and tgt == node3_id) for src, tgt in node1_edges (src == node1_id and tgt == node3_id) for src, tgt in node1_edges
) )
assert has_edge_to_node2, f"Edge list of node {node1_id} should include an edge to {node2_id}" assert (
assert has_edge_to_node3, f"Edge list of node {node1_id} should include an edge to {node3_id}" has_edge_to_node2
), f"Edge list of node {node1_id} should include an edge to {node2_id}"
assert (
has_edge_to_node3
), f"Edge list of node {node1_id} should include an edge to {node3_id}"
# Check if node 2 has a connection with node 1 # Check if node 2 has a connection with node 1
has_edge_to_node1 = any( has_edge_to_node1 = any(
@ -1096,7 +1184,9 @@ async def test_graph_undirected_property(storage):
has_edge_to_node1 has_edge_to_node1
), f"Edge list of node {node2_id} should include a connection with {node1_id}" ), f"Edge list of node {node2_id} should include a connection with {node1_id}"
print("Undirected property verification successful: batch-retrieved node edges include all relevant edges (regardless of direction)") print(
"Undirected property verification successful: batch-retrieved node edges include all relevant edges (regardless of direction)"
)
print("\nUndirected property tests completed.") print("\nUndirected property tests completed.")
return True return True
@ -1124,7 +1214,9 @@ async def main():
# Get graph storage type # Get graph storage type
graph_storage_type = os.getenv("LIGHTRAG_GRAPH_STORAGE", "NetworkXStorage") graph_storage_type = os.getenv("LIGHTRAG_GRAPH_STORAGE", "NetworkXStorage")
ASCIIColors.magenta(f"\nCurrently configured graph storage type: {graph_storage_type}") ASCIIColors.magenta(
f"\nCurrently configured graph storage type: {graph_storage_type}"
)
ASCIIColors.white( ASCIIColors.white(
f"Supported graph storage types: {', '.join(STORAGE_IMPLEMENTATIONS['GRAPH_STORAGE']['implementations'])}" f"Supported graph storage types: {', '.join(STORAGE_IMPLEMENTATIONS['GRAPH_STORAGE']['implementations'])}"
) )
@ -1139,10 +1231,18 @@ async def main():
# Display test options # Display test options
ASCIIColors.yellow("\nPlease select a test type:") ASCIIColors.yellow("\nPlease select a test type:")
ASCIIColors.white("1. Basic Test (Node and edge insertion, reading)") ASCIIColors.white("1. Basic Test (Node and edge insertion, reading)")
ASCIIColors.white("2. Advanced Test (Degree, labels, knowledge graph, deletion, etc.)") ASCIIColors.white(
ASCIIColors.white("3. Batch Operations Test (Batch get node/edge properties, degrees, etc.)") "2. Advanced Test (Degree, labels, knowledge graph, deletion, etc.)"
ASCIIColors.white("4. Undirected Property Test (Verify undirected properties of the storage)") )
ASCIIColors.white("5. Special Characters Test (Verify handling of single/double quotes, backslashes, etc.)") ASCIIColors.white(
"3. Batch Operations Test (Batch get node/edge properties, degrees, etc.)"
)
ASCIIColors.white(
"4. Undirected Property Test (Verify undirected properties of the storage)"
)
ASCIIColors.white(
"5. Special Characters Test (Verify handling of single/double quotes, backslashes, etc.)"
)
ASCIIColors.white("6. All Tests") ASCIIColors.white("6. All Tests")
choice = input("\nEnter your choice (1/2/3/4/5/6): ") choice = input("\nEnter your choice (1/2/3/4/5/6): ")
@ -1182,7 +1282,9 @@ async def main():
) )
if undirected_result: if undirected_result:
ASCIIColors.cyan("\n=== Starting Special Characters Test ===") ASCIIColors.cyan(
"\n=== Starting Special Characters Test ==="
)
await test_graph_special_characters(storage) await test_graph_special_characters(storage)
else: else:
ASCIIColors.red("Invalid choice") ASCIIColors.red("Invalid choice")