diff --git a/lightrag/kg/memgraph_impl.py b/lightrag/kg/memgraph_impl.py index a0148357..2892f591 100644 --- a/lightrag/kg/memgraph_impl.py +++ b/lightrag/kg/memgraph_impl.py @@ -9,7 +9,7 @@ from ..utils import logger from ..base import BaseGraphStorage from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from ..constants import GRAPH_FIELD_SEP -from ..kg.shared_storage import get_graph_db_lock +from ..kg.shared_storage import get_data_init_lock, get_graph_db_lock import pipmaster as pm if not pm.is_installed("neo4j"): @@ -56,7 +56,7 @@ class MemgraphStorage(BaseGraphStorage): return self.workspace async def initialize(self): - async with get_graph_db_lock(enable_logging=True): + async with get_data_init_lock(): URI = os.environ.get( "MEMGRAPH_URI", config.get("memgraph", "uri", fallback="bolt://localhost:7687"), @@ -102,7 +102,7 @@ class MemgraphStorage(BaseGraphStorage): raise async def finalize(self): - async with get_graph_db_lock(enable_logging=True): + async with get_graph_db_lock(): if self._driver is not None: await self._driver.close() self._driver = None @@ -743,7 +743,7 @@ class MemgraphStorage(BaseGraphStorage): raise RuntimeError( "Memgraph driver is not initialized. Call 'await initialize()' first." ) - async with get_graph_db_lock(enable_logging=True): + async with get_graph_db_lock(): try: async with self._driver.session(database=self._DATABASE) as session: workspace_label = self._get_workspace_label() diff --git a/lightrag/kg/milvus_impl.py b/lightrag/kg/milvus_impl.py index 04cc137e..1aa41021 100644 --- a/lightrag/kg/milvus_impl.py +++ b/lightrag/kg/milvus_impl.py @@ -6,7 +6,7 @@ import numpy as np from lightrag.utils import logger, compute_mdhash_id from ..base import BaseVectorStorage from ..constants import DEFAULT_MAX_FILE_PATH_LENGTH -from ..kg.shared_storage import get_storage_lock +from ..kg.shared_storage import get_data_init_lock, get_storage_lock import pipmaster as pm if not pm.is_installed("pymilvus"): @@ -723,7 +723,7 @@ class MilvusVectorDBStorage(BaseVectorStorage): async def initialize(self): """Initialize Milvus collection""" - async with get_storage_lock(enable_logging=True): + async with get_data_init_lock(enable_logging=True): if self._initialized: return @@ -1028,7 +1028,7 @@ class MilvusVectorDBStorage(BaseVectorStorage): - On success: {"status": "success", "message": "data dropped"} - On failure: {"status": "error", "message": ""} """ - async with get_storage_lock(enable_logging=True): + async with get_storage_lock(): try: # Drop the collection and recreate it if self._client.has_collection(self.final_namespace): diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index cf4822e5..ff671c56 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -18,7 +18,7 @@ from ..base import ( from ..utils import logger, compute_mdhash_id from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from ..constants import GRAPH_FIELD_SEP -from ..kg.shared_storage import get_storage_lock, get_graph_db_lock +from ..kg.shared_storage import get_data_init_lock, get_storage_lock, get_graph_db_lock import pipmaster as pm @@ -126,7 +126,7 @@ class MongoKVStorage(BaseKVStorage): self._collection_name = self.final_namespace async def initialize(self): - async with get_storage_lock(enable_logging=True): + async with get_data_init_lock(): if self.db is None: self.db = await ClientManager.get_client() @@ -136,7 +136,7 @@ class MongoKVStorage(BaseKVStorage): ) async def finalize(self): - async with get_storage_lock(enable_logging=True): + async with get_storage_lock(): if self.db is not None: await ClientManager.release_client(self.db) self.db = None @@ -255,7 +255,7 @@ class MongoKVStorage(BaseKVStorage): Returns: dict[str, str]: Status of the operation with keys 'status' and 'message' """ - async with get_storage_lock(enable_logging=True): + async with get_storage_lock(): try: result = await self._data.delete_many({}) deleted_count = result.deleted_count @@ -323,7 +323,7 @@ class MongoDocStatusStorage(DocStatusStorage): self._collection_name = self.final_namespace async def initialize(self): - async with get_storage_lock(enable_logging=True): + async with get_data_init_lock(): if self.db is None: self.db = await ClientManager.get_client() @@ -340,7 +340,7 @@ class MongoDocStatusStorage(DocStatusStorage): ) async def finalize(self): - async with get_storage_lock(enable_logging=True): + async with get_storage_lock(): if self.db is not None: await ClientManager.release_client(self.db) self.db = None @@ -455,7 +455,7 @@ class MongoDocStatusStorage(DocStatusStorage): Returns: dict[str, str]: Status of the operation with keys 'status' and 'message' """ - async with get_storage_lock(enable_logging=True): + async with get_storage_lock(): try: result = await self._data.delete_many({}) deleted_count = result.deleted_count @@ -708,7 +708,7 @@ class MongoGraphStorage(BaseGraphStorage): self._edge_collection_name = f"{self._collection_name}_edges" async def initialize(self): - async with get_graph_db_lock(enable_logging=True): + async with get_data_init_lock(): if self.db is None: self.db = await ClientManager.get_client() @@ -723,7 +723,7 @@ class MongoGraphStorage(BaseGraphStorage): ) async def finalize(self): - async with get_graph_db_lock(enable_logging=True): + async with get_graph_db_lock(): if self.db is not None: await ClientManager.release_client(self.db) self.db = None @@ -1579,7 +1579,7 @@ class MongoGraphStorage(BaseGraphStorage): Returns: dict[str, str]: Status of the operation with keys 'status' and 'message' """ - async with get_graph_db_lock(enable_logging=True): + async with get_graph_db_lock(): try: result = await self.collection.delete_many({}) deleted_count = result.deleted_count @@ -1674,7 +1674,7 @@ class MongoVectorDBStorage(BaseVectorStorage): self._max_batch_size = self.global_config["embedding_batch_num"] async def initialize(self): - async with get_storage_lock(enable_logging=True): + async with get_data_init_lock(): if self.db is None: self.db = await ClientManager.get_client() @@ -1688,7 +1688,7 @@ class MongoVectorDBStorage(BaseVectorStorage): ) async def finalize(self): - async with get_storage_lock(enable_logging=True): + async with get_storage_lock(): if self.db is not None: await ClientManager.release_client(self.db) self.db = None @@ -1973,7 +1973,7 @@ class MongoVectorDBStorage(BaseVectorStorage): Returns: dict[str, str]: Status of the operation with keys 'status' and 'message' """ - async with get_storage_lock(enable_logging=True): + async with get_storage_lock(): try: # Delete all documents result = await self._data.delete_many({}) diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index dc3471ca..febd07e5 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -17,7 +17,7 @@ from ..utils import logger from ..base import BaseGraphStorage from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from ..constants import GRAPH_FIELD_SEP -from ..kg.shared_storage import get_graph_db_lock +from ..kg.shared_storage import get_data_init_lock, get_graph_db_lock import pipmaster as pm if not pm.is_installed("neo4j"): @@ -71,7 +71,7 @@ class Neo4JStorage(BaseGraphStorage): return self.workspace async def initialize(self): - async with get_graph_db_lock(enable_logging=True): + async with get_data_init_lock(): URI = os.environ.get("NEO4J_URI", config.get("neo4j", "uri", fallback=None)) USERNAME = os.environ.get( "NEO4J_USERNAME", config.get("neo4j", "username", fallback=None) @@ -241,7 +241,7 @@ class Neo4JStorage(BaseGraphStorage): async def finalize(self): """Close the Neo4j driver and release all resources""" - async with get_graph_db_lock(enable_logging=True): + async with get_graph_db_lock(): if self._driver: await self._driver.close() self._driver = None @@ -1533,7 +1533,7 @@ class Neo4JStorage(BaseGraphStorage): - On success: {"status": "success", "message": "workspace data dropped"} - On failure: {"status": "error", "message": ""} """ - async with get_graph_db_lock(enable_logging=True): + async with get_graph_db_lock(): workspace_label = self._get_workspace_label() try: async with self._driver.session(database=self._DATABASE) as session: diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index c5a58b6d..88292db5 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -30,7 +30,7 @@ from ..base import ( from ..namespace import NameSpace, is_namespace from ..utils import logger from ..constants import GRAPH_FIELD_SEP -from ..kg.shared_storage import get_graph_db_lock, get_storage_lock +from ..kg.shared_storage import get_data_init_lock, get_graph_db_lock, get_storage_lock import pipmaster as pm @@ -1406,7 +1406,7 @@ class PGKVStorage(BaseKVStorage): self._max_batch_size = self.global_config["embedding_batch_num"] async def initialize(self): - async with get_storage_lock(enable_logging=True): + async with get_data_init_lock(): if self.db is None: self.db = await ClientManager.get_client() @@ -1422,7 +1422,7 @@ class PGKVStorage(BaseKVStorage): self.workspace = "default" async def finalize(self): - async with get_storage_lock(enable_logging=True): + async with get_storage_lock(): if self.db is not None: await ClientManager.release_client(self.db) self.db = None @@ -1837,7 +1837,7 @@ class PGKVStorage(BaseKVStorage): async def drop(self) -> dict[str, str]: """Drop the storage""" - async with get_storage_lock(enable_logging=True): + async with get_storage_lock(): try: table_name = namespace_to_table_name(self.namespace) if not table_name: @@ -1871,7 +1871,7 @@ class PGVectorStorage(BaseVectorStorage): self.cosine_better_than_threshold = cosine_threshold async def initialize(self): - async with get_storage_lock(enable_logging=True): + async with get_data_init_lock(): if self.db is None: self.db = await ClientManager.get_client() @@ -1887,7 +1887,7 @@ class PGVectorStorage(BaseVectorStorage): self.workspace = "default" async def finalize(self): - async with get_storage_lock(enable_logging=True): + async with get_storage_lock(): if self.db is not None: await ClientManager.release_client(self.db) self.db = None @@ -2160,7 +2160,7 @@ class PGVectorStorage(BaseVectorStorage): async def drop(self) -> dict[str, str]: """Drop the storage""" - async with get_storage_lock(enable_logging=True): + async with get_storage_lock(): try: table_name = namespace_to_table_name(self.namespace) if not table_name: @@ -2194,7 +2194,7 @@ class PGDocStatusStorage(DocStatusStorage): return dt.isoformat() async def initialize(self): - async with get_storage_lock(enable_logging=True): + async with get_data_init_lock(): if self.db is None: self.db = await ClientManager.get_client() @@ -2210,7 +2210,7 @@ class PGDocStatusStorage(DocStatusStorage): self.workspace = "default" async def finalize(self): - async with get_storage_lock(enable_logging=True): + async with get_storage_lock(): if self.db is not None: await ClientManager.release_client(self.db) self.db = None @@ -2704,21 +2704,22 @@ class PGDocStatusStorage(DocStatusStorage): async def drop(self) -> dict[str, str]: """Drop the storage""" - try: - table_name = namespace_to_table_name(self.namespace) - if not table_name: - return { - "status": "error", - "message": f"Unknown namespace: {self.namespace}", - } + async with get_storage_lock(): + try: + table_name = namespace_to_table_name(self.namespace) + if not table_name: + return { + "status": "error", + "message": f"Unknown namespace: {self.namespace}", + } - drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format( - table_name=table_name - ) - await self.db.execute(drop_sql, {"workspace": self.workspace}) - return {"status": "success", "message": "data dropped"} - except Exception as e: - return {"status": "error", "message": str(e)} + drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format( + table_name=table_name + ) + await self.db.execute(drop_sql, {"workspace": self.workspace}) + return {"status": "success", "message": "data dropped"} + except Exception as e: + return {"status": "error", "message": str(e)} class PGGraphQueryException(Exception): @@ -2789,7 +2790,7 @@ class PGGraphStorage(BaseGraphStorage): return normalized_id async def initialize(self): - async with get_graph_db_lock(enable_logging=True): + async with get_data_init_lock(): if self.db is None: self.db = await ClientManager.get_client() @@ -2850,7 +2851,7 @@ class PGGraphStorage(BaseGraphStorage): ) async def finalize(self): - async with get_graph_db_lock(enable_logging=True): + async with get_graph_db_lock(): if self.db is not None: await ClientManager.release_client(self.db) self.db = None @@ -4141,20 +4142,21 @@ class PGGraphStorage(BaseGraphStorage): async def drop(self) -> dict[str, str]: """Drop the storage""" - try: - drop_query = f"""SELECT * FROM cypher('{self.graph_name}', $$ - MATCH (n) - DETACH DELETE n - $$) AS (result agtype)""" + async with get_graph_db_lock(): + try: + drop_query = f"""SELECT * FROM cypher('{self.graph_name}', $$ + MATCH (n) + DETACH DELETE n + $$) AS (result agtype)""" - await self._query(drop_query, readonly=False) - return { - "status": "success", - "message": f"workspace '{self.workspace}' graph data dropped", - } - except Exception as e: - logger.error(f"[{self.workspace}] Error dropping graph: {e}") - return {"status": "error", "message": str(e)} + await self._query(drop_query, readonly=False) + return { + "status": "success", + "message": f"workspace '{self.workspace}' graph data dropped", + } + except Exception as e: + logger.error(f"[{self.workspace}] Error dropping graph: {e}") + return {"status": "error", "message": str(e)} # Note: Order matters! More specific namespaces (e.g., "full_entities") must come before diff --git a/lightrag/kg/qdrant_impl.py b/lightrag/kg/qdrant_impl.py index ab0ee02b..1686cba6 100644 --- a/lightrag/kg/qdrant_impl.py +++ b/lightrag/kg/qdrant_impl.py @@ -7,7 +7,7 @@ import hashlib import uuid from ..utils import logger from ..base import BaseVectorStorage -from ..kg.shared_storage import get_storage_lock +from ..kg.shared_storage import get_data_init_lock, get_storage_lock import configparser import pipmaster as pm @@ -117,7 +117,7 @@ class QdrantVectorDBStorage(BaseVectorStorage): async def initialize(self): """Initialize Qdrant collection""" - async with get_storage_lock(enable_logging=True): + async with get_data_init_lock(): if self._initialized: return @@ -412,7 +412,7 @@ class QdrantVectorDBStorage(BaseVectorStorage): - On success: {"status": "success", "message": "data dropped"} - On failure: {"status": "error", "message": ""} """ - async with get_storage_lock(enable_logging=True): + async with get_storage_lock(): try: # Delete the collection and recreate it if self._client.collection_exists(self.final_namespace): diff --git a/lightrag/kg/redis_impl.py b/lightrag/kg/redis_impl.py index 733e07a5..8fc1ec4b 100644 --- a/lightrag/kg/redis_impl.py +++ b/lightrag/kg/redis_impl.py @@ -20,7 +20,7 @@ from lightrag.base import ( DocStatus, DocProcessingStatus, ) -from ..kg.shared_storage import get_storage_lock +from ..kg.shared_storage import get_data_init_lock, get_storage_lock import json # Import tenacity for retry logic @@ -180,7 +180,7 @@ class RedisKVStorage(BaseKVStorage): async def initialize(self): """Initialize Redis connection and migrate legacy cache structure if needed""" - async with get_storage_lock(enable_logging=True): + async with get_data_init_lock(): if self._initialized: return @@ -428,7 +428,7 @@ class RedisKVStorage(BaseKVStorage): Returns: dict[str, str]: Status of the operation with keys 'status' and 'message' """ - async with get_storage_lock(enable_logging=True): + async with get_storage_lock(): async with self._get_redis_connection() as redis: try: # Use SCAN to find all keys with the namespace prefix @@ -606,7 +606,7 @@ class RedisDocStatusStorage(DocStatusStorage): async def initialize(self): """Initialize Redis connection""" - async with get_storage_lock(enable_logging=True): + async with get_data_init_lock(): if self._initialized: return @@ -1049,7 +1049,7 @@ class RedisDocStatusStorage(DocStatusStorage): async def drop(self) -> dict[str, str]: """Drop all document status data from storage and clean up resources""" - async with get_storage_lock(enable_logging=True): + async with get_storage_lock(): try: async with self._get_redis_connection() as redis: # Use SCAN to find all keys with the namespace prefix