diff --git a/lightrag/kg/memgraph_impl.py b/lightrag/kg/memgraph_impl.py index bd69678a..a0148357 100644 --- a/lightrag/kg/memgraph_impl.py +++ b/lightrag/kg/memgraph_impl.py @@ -9,6 +9,7 @@ from ..utils import logger from ..base import BaseGraphStorage from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from ..constants import GRAPH_FIELD_SEP +from ..kg.shared_storage import get_graph_db_lock import pipmaster as pm if not pm.is_installed("neo4j"): @@ -55,53 +56,56 @@ class MemgraphStorage(BaseGraphStorage): return self.workspace async def initialize(self): - URI = os.environ.get( - "MEMGRAPH_URI", - config.get("memgraph", "uri", fallback="bolt://localhost:7687"), - ) - USERNAME = os.environ.get( - "MEMGRAPH_USERNAME", config.get("memgraph", "username", fallback="") - ) - PASSWORD = os.environ.get( - "MEMGRAPH_PASSWORD", config.get("memgraph", "password", fallback="") - ) - DATABASE = os.environ.get( - "MEMGRAPH_DATABASE", config.get("memgraph", "database", fallback="memgraph") - ) - - self._driver = AsyncGraphDatabase.driver( - URI, - auth=(USERNAME, PASSWORD), - ) - self._DATABASE = DATABASE - try: - async with self._driver.session(database=DATABASE) as session: - # Create index for base nodes on entity_id if it doesn't exist - try: - workspace_label = self._get_workspace_label() - await session.run( - f"""CREATE INDEX ON :{workspace_label}(entity_id)""" - ) - logger.info( - f"[{self.workspace}] Created index on :{workspace_label}(entity_id) in Memgraph." - ) - except Exception as e: - # Index may already exist, which is not an error - logger.warning( - f"[{self.workspace}] Index creation on :{workspace_label}(entity_id) may have failed or already exists: {e}" - ) - await session.run("RETURN 1") - logger.info(f"[{self.workspace}] Connected to Memgraph at {URI}") - except Exception as e: - logger.error( - f"[{self.workspace}] Failed to connect to Memgraph at {URI}: {e}" + async with get_graph_db_lock(enable_logging=True): + URI = os.environ.get( + "MEMGRAPH_URI", + config.get("memgraph", "uri", fallback="bolt://localhost:7687"), ) - raise + USERNAME = os.environ.get( + "MEMGRAPH_USERNAME", config.get("memgraph", "username", fallback="") + ) + PASSWORD = os.environ.get( + "MEMGRAPH_PASSWORD", config.get("memgraph", "password", fallback="") + ) + DATABASE = os.environ.get( + "MEMGRAPH_DATABASE", + config.get("memgraph", "database", fallback="memgraph"), + ) + + self._driver = AsyncGraphDatabase.driver( + URI, + auth=(USERNAME, PASSWORD), + ) + self._DATABASE = DATABASE + try: + async with self._driver.session(database=DATABASE) as session: + # Create index for base nodes on entity_id if it doesn't exist + try: + workspace_label = self._get_workspace_label() + await session.run( + f"""CREATE INDEX ON :{workspace_label}(entity_id)""" + ) + logger.info( + f"[{self.workspace}] Created index on :{workspace_label}(entity_id) in Memgraph." + ) + except Exception as e: + # Index may already exist, which is not an error + logger.warning( + f"[{self.workspace}] Index creation on :{workspace_label}(entity_id) may have failed or already exists: {e}" + ) + await session.run("RETURN 1") + logger.info(f"[{self.workspace}] Connected to Memgraph at {URI}") + except Exception as e: + logger.error( + f"[{self.workspace}] Failed to connect to Memgraph at {URI}: {e}" + ) + raise async def finalize(self): - if self._driver is not None: - await self._driver.close() - self._driver = None + async with get_graph_db_lock(enable_logging=True): + if self._driver is not None: + await self._driver.close() + self._driver = None async def __aexit__(self, exc_type, exc, tb): await self.finalize() @@ -739,21 +743,22 @@ class MemgraphStorage(BaseGraphStorage): raise RuntimeError( "Memgraph driver is not initialized. Call 'await initialize()' first." ) - try: - async with self._driver.session(database=self._DATABASE) as session: - workspace_label = self._get_workspace_label() - query = f"MATCH (n:`{workspace_label}`) DETACH DELETE n" - result = await session.run(query) - await result.consume() - logger.info( - f"[{self.workspace}] Dropped workspace {workspace_label} from Memgraph database {self._DATABASE}" + async with get_graph_db_lock(enable_logging=True): + try: + async with self._driver.session(database=self._DATABASE) as session: + workspace_label = self._get_workspace_label() + query = f"MATCH (n:`{workspace_label}`) DETACH DELETE n" + result = await session.run(query) + await result.consume() + logger.info( + f"[{self.workspace}] Dropped workspace {workspace_label} from Memgraph database {self._DATABASE}" + ) + return {"status": "success", "message": "workspace data dropped"} + except Exception as e: + logger.error( + f"[{self.workspace}] Error dropping workspace {workspace_label} from Memgraph database {self._DATABASE}: {e}" ) - return {"status": "success", "message": "workspace data dropped"} - except Exception as e: - logger.error( - f"[{self.workspace}] Error dropping workspace {workspace_label} from Memgraph database {self._DATABASE}: {e}" - ) - return {"status": "error", "message": str(e)} + return {"status": "error", "message": str(e)} async def edge_degree(self, src_id: str, tgt_id: str) -> int: """Get the total degree (sum of relationships) of two nodes. diff --git a/lightrag/kg/milvus_impl.py b/lightrag/kg/milvus_impl.py index f3333bb7..7c4edcd3 100644 --- a/lightrag/kg/milvus_impl.py +++ b/lightrag/kg/milvus_impl.py @@ -6,6 +6,7 @@ import numpy as np from lightrag.utils import logger, compute_mdhash_id from ..base import BaseVectorStorage from ..constants import DEFAULT_MAX_FILE_PATH_LENGTH +from ..kg.shared_storage import get_storage_lock import pipmaster as pm if not pm.is_installed("pymilvus"): @@ -752,9 +753,26 @@ class MilvusVectorDBStorage(BaseVectorStorage): ), ) self._max_batch_size = self.global_config["embedding_batch_num"] + self._initialized = False - # Create collection and check compatibility - self._create_collection_if_not_exist() + async def initialize(self): + """Initialize Milvus collection""" + async with get_storage_lock(enable_logging=True): + if self._initialized: + return + + try: + # Create collection and check compatibility + self._create_collection_if_not_exist() + self._initialized = True + logger.info( + f"[{self.workspace}] Milvus collection '{self.namespace}' initialized successfully" + ) + except Exception as e: + logger.error( + f"[{self.workspace}] Failed to initialize Milvus collection '{self.namespace}': {e}" + ) + raise async def upsert(self, data: dict[str, dict[str, Any]]) -> None: logger.debug(f"[{self.workspace}] Inserting {len(data)} to {self.namespace}") @@ -1012,20 +1030,21 @@ class MilvusVectorDBStorage(BaseVectorStorage): - On success: {"status": "success", "message": "data dropped"} - On failure: {"status": "error", "message": ""} """ - try: - # Drop the collection and recreate it - if self._client.has_collection(self.final_namespace): - self._client.drop_collection(self.final_namespace) + async with get_storage_lock(enable_logging=True): + try: + # Drop the collection and recreate it + if self._client.has_collection(self.final_namespace): + self._client.drop_collection(self.final_namespace) - # Recreate the collection - self._create_collection_if_not_exist() + # Recreate the collection + self._create_collection_if_not_exist() - logger.info( - f"[{self.workspace}] Process {os.getpid()} drop Milvus collection {self.namespace}" - ) - return {"status": "success", "message": "data dropped"} - except Exception as e: - logger.error( - f"[{self.workspace}] Error dropping Milvus collection {self.namespace}: {e}" - ) - return {"status": "error", "message": str(e)} + logger.info( + f"[{self.workspace}] Process {os.getpid()} drop Milvus collection {self.namespace}" + ) + return {"status": "success", "message": "data dropped"} + except Exception as e: + logger.error( + f"[{self.workspace}] Error dropping Milvus collection {self.namespace}: {e}" + ) + return {"status": "error", "message": str(e)} diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index b4550c1b..6b249b94 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -18,6 +18,7 @@ from ..base import ( from ..utils import logger, compute_mdhash_id from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from ..constants import GRAPH_FIELD_SEP +from ..kg.shared_storage import get_storage_lock, get_graph_db_lock import pipmaster as pm @@ -39,39 +40,36 @@ GRAPH_BFS_MODE = os.getenv("MONGO_GRAPH_BFS_MODE", "bidirectional") class ClientManager: _instances = {"db": None, "ref_count": 0} - _lock = asyncio.Lock() @classmethod async def get_client(cls) -> AsyncMongoClient: - async with cls._lock: - if cls._instances["db"] is None: - uri = os.environ.get( - "MONGO_URI", - config.get( - "mongodb", - "uri", - fallback="mongodb://root:root@localhost:27017/", - ), - ) - database_name = os.environ.get( - "MONGO_DATABASE", - config.get("mongodb", "database", fallback="LightRAG"), - ) - client = AsyncMongoClient(uri) - db = client.get_database(database_name) - cls._instances["db"] = db - cls._instances["ref_count"] = 0 - cls._instances["ref_count"] += 1 - return cls._instances["db"] + if cls._instances["db"] is None: + uri = os.environ.get( + "MONGO_URI", + config.get( + "mongodb", + "uri", + fallback="mongodb://root:root@localhost:27017/", + ), + ) + database_name = os.environ.get( + "MONGO_DATABASE", + config.get("mongodb", "database", fallback="LightRAG"), + ) + client = AsyncMongoClient(uri) + db = client.get_database(database_name) + cls._instances["db"] = db + cls._instances["ref_count"] = 0 + cls._instances["ref_count"] += 1 + return cls._instances["db"] @classmethod async def release_client(cls, db: AsyncDatabase): - async with cls._lock: - if db is not None: - if db is cls._instances["db"]: - cls._instances["ref_count"] -= 1 - if cls._instances["ref_count"] == 0: - cls._instances["db"] = None + if db is not None: + if db is cls._instances["db"]: + cls._instances["ref_count"] -= 1 + if cls._instances["ref_count"] == 0: + cls._instances["db"] = None @final @@ -125,18 +123,21 @@ class MongoKVStorage(BaseKVStorage): self._collection_name = self.final_namespace async def initialize(self): - if self.db is None: - self.db = await ClientManager.get_client() + async with get_storage_lock(enable_logging=True): + if self.db is None: + self.db = await ClientManager.get_client() + self._data = await get_or_create_collection(self.db, self._collection_name) logger.debug( f"[{self.workspace}] Use MongoDB as KV {self._collection_name}" ) async def finalize(self): - if self.db is not None: - await ClientManager.release_client(self.db) - self.db = None - self._data = None + async with get_storage_lock(enable_logging=True): + if self.db is not None: + await ClientManager.release_client(self.db) + self.db = None + self._data = None async def get_by_id(self, id: str) -> dict[str, Any] | None: # Unified handling for flattened keys @@ -251,22 +252,23 @@ class MongoKVStorage(BaseKVStorage): Returns: dict[str, str]: Status of the operation with keys 'status' and 'message' """ - try: - result = await self._data.delete_many({}) - deleted_count = result.deleted_count + async with get_storage_lock(enable_logging=True): + try: + result = await self._data.delete_many({}) + deleted_count = result.deleted_count - logger.info( - f"[{self.workspace}] Dropped {deleted_count} documents from doc status {self._collection_name}" - ) - return { - "status": "success", - "message": f"{deleted_count} documents dropped", - } - except PyMongoError as e: - logger.error( - f"[{self.workspace}] Error dropping doc status {self._collection_name}: {e}" - ) - return {"status": "error", "message": str(e)} + logger.info( + f"[{self.workspace}] Dropped {deleted_count} documents from doc status {self._collection_name}" + ) + return { + "status": "success", + "message": f"{deleted_count} documents dropped", + } + except PyMongoError as e: + logger.error( + f"[{self.workspace}] Error dropping doc status {self._collection_name}: {e}" + ) + return {"status": "error", "message": str(e)} @final @@ -318,8 +320,10 @@ class MongoDocStatusStorage(DocStatusStorage): self._collection_name = self.final_namespace async def initialize(self): - if self.db is None: - self.db = await ClientManager.get_client() + async with get_storage_lock(enable_logging=True): + if self.db is None: + self.db = await ClientManager.get_client() + self._data = await get_or_create_collection(self.db, self._collection_name) # Create track_id index for better query performance @@ -333,10 +337,11 @@ class MongoDocStatusStorage(DocStatusStorage): ) async def finalize(self): - if self.db is not None: - await ClientManager.release_client(self.db) - self.db = None - self._data = None + async with get_storage_lock(enable_logging=True): + if self.db is not None: + await ClientManager.release_client(self.db) + self.db = None + self._data = None async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: return await self._data.find_one({"_id": id}) @@ -447,22 +452,23 @@ class MongoDocStatusStorage(DocStatusStorage): Returns: dict[str, str]: Status of the operation with keys 'status' and 'message' """ - try: - result = await self._data.delete_many({}) - deleted_count = result.deleted_count + async with get_storage_lock(enable_logging=True): + try: + result = await self._data.delete_many({}) + deleted_count = result.deleted_count - logger.info( - f"[{self.workspace}] Dropped {deleted_count} documents from doc status {self._collection_name}" - ) - return { - "status": "success", - "message": f"{deleted_count} documents dropped", - } - except PyMongoError as e: - logger.error( - f"[{self.workspace}] Error dropping doc status {self._collection_name}: {e}" - ) - return {"status": "error", "message": str(e)} + logger.info( + f"[{self.workspace}] Dropped {deleted_count} documents from doc status {self._collection_name}" + ) + return { + "status": "success", + "message": f"{deleted_count} documents dropped", + } + except PyMongoError as e: + logger.error( + f"[{self.workspace}] Error dropping doc status {self._collection_name}: {e}" + ) + return {"status": "error", "message": str(e)} async def delete(self, ids: list[str]) -> None: await self._data.delete_many({"_id": {"$in": ids}}) @@ -699,8 +705,10 @@ class MongoGraphStorage(BaseGraphStorage): self._edge_collection_name = f"{self._collection_name}_edges" async def initialize(self): - if self.db is None: - self.db = await ClientManager.get_client() + async with get_graph_db_lock(enable_logging=True): + if self.db is None: + self.db = await ClientManager.get_client() + self.collection = await get_or_create_collection( self.db, self._collection_name ) @@ -712,11 +720,12 @@ class MongoGraphStorage(BaseGraphStorage): ) async def finalize(self): - if self.db is not None: - await ClientManager.release_client(self.db) - self.db = None - self.collection = None - self.edge_collection = None + async with get_graph_db_lock(enable_logging=True): + if self.db is not None: + await ClientManager.release_client(self.db) + self.db = None + self.collection = None + self.edge_collection = None # Sample entity document # "source_ids" is Array representation of "source_id" split by GRAPH_FIELD_SEP @@ -1567,29 +1576,30 @@ class MongoGraphStorage(BaseGraphStorage): Returns: dict[str, str]: Status of the operation with keys 'status' and 'message' """ - try: - result = await self.collection.delete_many({}) - deleted_count = result.deleted_count + async with get_graph_db_lock(enable_logging=True): + try: + result = await self.collection.delete_many({}) + deleted_count = result.deleted_count - logger.info( - f"[{self.workspace}] Dropped {deleted_count} documents from graph {self._collection_name}" - ) + logger.info( + f"[{self.workspace}] Dropped {deleted_count} documents from graph {self._collection_name}" + ) - result = await self.edge_collection.delete_many({}) - edge_count = result.deleted_count - logger.info( - f"[{self.workspace}] Dropped {edge_count} edges from graph {self._edge_collection_name}" - ) + result = await self.edge_collection.delete_many({}) + edge_count = result.deleted_count + logger.info( + f"[{self.workspace}] Dropped {edge_count} edges from graph {self._edge_collection_name}" + ) - return { - "status": "success", - "message": f"{deleted_count} documents and {edge_count} edges dropped", - } - except PyMongoError as e: - logger.error( - f"[{self.workspace}] Error dropping graph {self._collection_name}: {e}" - ) - return {"status": "error", "message": str(e)} + return { + "status": "success", + "message": f"{deleted_count} documents and {edge_count} edges dropped", + } + except PyMongoError as e: + logger.error( + f"[{self.workspace}] Error dropping graph {self._collection_name}: {e}" + ) + return {"status": "error", "message": str(e)} @final @@ -1661,8 +1671,10 @@ class MongoVectorDBStorage(BaseVectorStorage): self._max_batch_size = self.global_config["embedding_batch_num"] async def initialize(self): - if self.db is None: - self.db = await ClientManager.get_client() + async with get_storage_lock(enable_logging=True): + if self.db is None: + self.db = await ClientManager.get_client() + self._data = await get_or_create_collection(self.db, self._collection_name) # Ensure vector index exists @@ -1673,10 +1685,11 @@ class MongoVectorDBStorage(BaseVectorStorage): ) async def finalize(self): - if self.db is not None: - await ClientManager.release_client(self.db) - self.db = None - self._data = None + async with get_storage_lock(enable_logging=True): + if self.db is not None: + await ClientManager.release_client(self.db) + self.db = None + self._data = None async def create_vector_index_if_not_exists(self): """Creates an Atlas Vector Search index.""" @@ -1957,26 +1970,27 @@ class MongoVectorDBStorage(BaseVectorStorage): Returns: dict[str, str]: Status of the operation with keys 'status' and 'message' """ - try: - # Delete all documents - result = await self._data.delete_many({}) - deleted_count = result.deleted_count + async with get_storage_lock(enable_logging=True): + try: + # Delete all documents + result = await self._data.delete_many({}) + deleted_count = result.deleted_count - # Recreate vector index - await self.create_vector_index_if_not_exists() + # Recreate vector index + await self.create_vector_index_if_not_exists() - logger.info( - f"[{self.workspace}] Dropped {deleted_count} documents from vector storage {self._collection_name} and recreated vector index" - ) - return { - "status": "success", - "message": f"{deleted_count} documents dropped and vector index recreated", - } - except PyMongoError as e: - logger.error( - f"[{self.workspace}] Error dropping vector storage {self._collection_name}: {e}" - ) - return {"status": "error", "message": str(e)} + logger.info( + f"[{self.workspace}] Dropped {deleted_count} documents from vector storage {self._collection_name} and recreated vector index" + ) + return { + "status": "success", + "message": f"{deleted_count} documents dropped and vector index recreated", + } + except PyMongoError as e: + logger.error( + f"[{self.workspace}] Error dropping vector storage {self._collection_name}: {e}" + ) + return {"status": "error", "message": str(e)} async def get_or_create_collection(db: AsyncDatabase, collection_name: str): diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index 953946a1..dc3471ca 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -17,6 +17,7 @@ from ..utils import logger from ..base import BaseGraphStorage from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from ..constants import GRAPH_FIELD_SEP +from ..kg.shared_storage import get_graph_db_lock import pipmaster as pm if not pm.is_installed("neo4j"): @@ -70,174 +71,180 @@ class Neo4JStorage(BaseGraphStorage): return self.workspace async def initialize(self): - URI = os.environ.get("NEO4J_URI", config.get("neo4j", "uri", fallback=None)) - USERNAME = os.environ.get( - "NEO4J_USERNAME", config.get("neo4j", "username", fallback=None) - ) - PASSWORD = os.environ.get( - "NEO4J_PASSWORD", config.get("neo4j", "password", fallback=None) - ) - MAX_CONNECTION_POOL_SIZE = int( - os.environ.get( - "NEO4J_MAX_CONNECTION_POOL_SIZE", - config.get("neo4j", "connection_pool_size", fallback=100), + async with get_graph_db_lock(enable_logging=True): + URI = os.environ.get("NEO4J_URI", config.get("neo4j", "uri", fallback=None)) + USERNAME = os.environ.get( + "NEO4J_USERNAME", config.get("neo4j", "username", fallback=None) ) - ) - CONNECTION_TIMEOUT = float( - os.environ.get( - "NEO4J_CONNECTION_TIMEOUT", - config.get("neo4j", "connection_timeout", fallback=30.0), - ), - ) - CONNECTION_ACQUISITION_TIMEOUT = float( - os.environ.get( - "NEO4J_CONNECTION_ACQUISITION_TIMEOUT", - config.get("neo4j", "connection_acquisition_timeout", fallback=30.0), - ), - ) - MAX_TRANSACTION_RETRY_TIME = float( - os.environ.get( - "NEO4J_MAX_TRANSACTION_RETRY_TIME", - config.get("neo4j", "max_transaction_retry_time", fallback=30.0), - ), - ) - MAX_CONNECTION_LIFETIME = float( - os.environ.get( - "NEO4J_MAX_CONNECTION_LIFETIME", - config.get("neo4j", "max_connection_lifetime", fallback=300.0), - ), - ) - LIVENESS_CHECK_TIMEOUT = float( - os.environ.get( - "NEO4J_LIVENESS_CHECK_TIMEOUT", - config.get("neo4j", "liveness_check_timeout", fallback=30.0), - ), - ) - KEEP_ALIVE = os.environ.get( - "NEO4J_KEEP_ALIVE", - config.get("neo4j", "keep_alive", fallback="true"), - ).lower() in ("true", "1", "yes", "on") - DATABASE = os.environ.get( - "NEO4J_DATABASE", re.sub(r"[^a-zA-Z0-9-]", "-", self.namespace) - ) - """The default value approach for the DATABASE is only intended to maintain compatibility with legacy practices.""" - - self._driver: AsyncDriver = AsyncGraphDatabase.driver( - URI, - auth=(USERNAME, PASSWORD), - max_connection_pool_size=MAX_CONNECTION_POOL_SIZE, - connection_timeout=CONNECTION_TIMEOUT, - connection_acquisition_timeout=CONNECTION_ACQUISITION_TIMEOUT, - max_transaction_retry_time=MAX_TRANSACTION_RETRY_TIME, - max_connection_lifetime=MAX_CONNECTION_LIFETIME, - liveness_check_timeout=LIVENESS_CHECK_TIMEOUT, - keep_alive=KEEP_ALIVE, - ) - - # Try to connect to the database and create it if it doesn't exist - for database in (DATABASE, None): - self._DATABASE = database - connected = False - - try: - async with self._driver.session(database=database) as session: - try: - result = await session.run("MATCH (n) RETURN n LIMIT 0") - await result.consume() # Ensure result is consumed - logger.info( - f"[{self.workspace}] Connected to {database} at {URI}" - ) - connected = True - except neo4jExceptions.ServiceUnavailable as e: - logger.error( - f"[{self.workspace}] " - + f"{database} at {URI} is not available".capitalize() - ) - raise e - except neo4jExceptions.AuthError as e: - logger.error( - f"[{self.workspace}] Authentication failed for {database} at {URI}" + PASSWORD = os.environ.get( + "NEO4J_PASSWORD", config.get("neo4j", "password", fallback=None) + ) + MAX_CONNECTION_POOL_SIZE = int( + os.environ.get( + "NEO4J_MAX_CONNECTION_POOL_SIZE", + config.get("neo4j", "connection_pool_size", fallback=100), ) - raise e - except neo4jExceptions.ClientError as e: - if e.code == "Neo.ClientError.Database.DatabaseNotFound": - logger.info( - f"[{self.workspace}] " - + f"{database} at {URI} not found. Try to create specified database.".capitalize() - ) - try: - async with self._driver.session() as session: - result = await session.run( - f"CREATE DATABASE `{database}` IF NOT EXISTS" - ) - await result.consume() # Ensure result is consumed - logger.info( - f"[{self.workspace}] " - + f"{database} at {URI} created".capitalize() - ) - connected = True - except ( - neo4jExceptions.ClientError, - neo4jExceptions.DatabaseError, - ) as e: - if ( - e.code - == "Neo.ClientError.Statement.UnsupportedAdministrationCommand" - ) or (e.code == "Neo.DatabaseError.Statement.ExecutionFailed"): - if database is not None: - logger.warning( - f"[{self.workspace}] This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead. Fallback to use the default database." - ) - if database is None: - logger.error( - f"[{self.workspace}] Failed to create {database} at {URI}" - ) - raise e + ) + CONNECTION_TIMEOUT = float( + os.environ.get( + "NEO4J_CONNECTION_TIMEOUT", + config.get("neo4j", "connection_timeout", fallback=30.0), + ), + ) + CONNECTION_ACQUISITION_TIMEOUT = float( + os.environ.get( + "NEO4J_CONNECTION_ACQUISITION_TIMEOUT", + config.get( + "neo4j", "connection_acquisition_timeout", fallback=30.0 + ), + ), + ) + MAX_TRANSACTION_RETRY_TIME = float( + os.environ.get( + "NEO4J_MAX_TRANSACTION_RETRY_TIME", + config.get("neo4j", "max_transaction_retry_time", fallback=30.0), + ), + ) + MAX_CONNECTION_LIFETIME = float( + os.environ.get( + "NEO4J_MAX_CONNECTION_LIFETIME", + config.get("neo4j", "max_connection_lifetime", fallback=300.0), + ), + ) + LIVENESS_CHECK_TIMEOUT = float( + os.environ.get( + "NEO4J_LIVENESS_CHECK_TIMEOUT", + config.get("neo4j", "liveness_check_timeout", fallback=30.0), + ), + ) + KEEP_ALIVE = os.environ.get( + "NEO4J_KEEP_ALIVE", + config.get("neo4j", "keep_alive", fallback="true"), + ).lower() in ("true", "1", "yes", "on") + DATABASE = os.environ.get( + "NEO4J_DATABASE", re.sub(r"[^a-zA-Z0-9-]", "-", self.namespace) + ) + """The default value approach for the DATABASE is only intended to maintain compatibility with legacy practices.""" + + self._driver: AsyncDriver = AsyncGraphDatabase.driver( + URI, + auth=(USERNAME, PASSWORD), + max_connection_pool_size=MAX_CONNECTION_POOL_SIZE, + connection_timeout=CONNECTION_TIMEOUT, + connection_acquisition_timeout=CONNECTION_ACQUISITION_TIMEOUT, + max_transaction_retry_time=MAX_TRANSACTION_RETRY_TIME, + max_connection_lifetime=MAX_CONNECTION_LIFETIME, + liveness_check_timeout=LIVENESS_CHECK_TIMEOUT, + keep_alive=KEEP_ALIVE, + ) + + # Try to connect to the database and create it if it doesn't exist + for database in (DATABASE, None): + self._DATABASE = database + connected = False - if connected: - # Create index for workspace nodes on entity_id if it doesn't exist - workspace_label = self._get_workspace_label() try: async with self._driver.session(database=database) as session: - # Check if index exists first - check_query = f""" - CALL db.indexes() YIELD name, labelsOrTypes, properties - WHERE labelsOrTypes = ['{workspace_label}'] AND properties = ['entity_id'] - RETURN count(*) > 0 AS exists - """ try: - check_result = await session.run(check_query) - record = await check_result.single() - await check_result.consume() - - index_exists = record and record.get("exists", False) - - if not index_exists: - # Create index only if it doesn't exist + result = await session.run("MATCH (n) RETURN n LIMIT 0") + await result.consume() # Ensure result is consumed + logger.info( + f"[{self.workspace}] Connected to {database} at {URI}" + ) + connected = True + except neo4jExceptions.ServiceUnavailable as e: + logger.error( + f"[{self.workspace}] " + + f"{database} at {URI} is not available".capitalize() + ) + raise e + except neo4jExceptions.AuthError as e: + logger.error( + f"[{self.workspace}] Authentication failed for {database} at {URI}" + ) + raise e + except neo4jExceptions.ClientError as e: + if e.code == "Neo.ClientError.Database.DatabaseNotFound": + logger.info( + f"[{self.workspace}] " + + f"{database} at {URI} not found. Try to create specified database.".capitalize() + ) + try: + async with self._driver.session() as session: result = await session.run( - f"CREATE INDEX FOR (n:`{workspace_label}`) ON (n.entity_id)" + f"CREATE DATABASE `{database}` IF NOT EXISTS" + ) + await result.consume() # Ensure result is consumed + logger.info( + f"[{self.workspace}] " + + f"{database} at {URI} created".capitalize() + ) + connected = True + except ( + neo4jExceptions.ClientError, + neo4jExceptions.DatabaseError, + ) as e: + if ( + e.code + == "Neo.ClientError.Statement.UnsupportedAdministrationCommand" + ) or ( + e.code == "Neo.DatabaseError.Statement.ExecutionFailed" + ): + if database is not None: + logger.warning( + f"[{self.workspace}] This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead. Fallback to use the default database." + ) + if database is None: + logger.error( + f"[{self.workspace}] Failed to create {database} at {URI}" + ) + raise e + + if connected: + # Create index for workspace nodes on entity_id if it doesn't exist + workspace_label = self._get_workspace_label() + try: + async with self._driver.session(database=database) as session: + # Check if index exists first + check_query = f""" + CALL db.indexes() YIELD name, labelsOrTypes, properties + WHERE labelsOrTypes = ['{workspace_label}'] AND properties = ['entity_id'] + RETURN count(*) > 0 AS exists + """ + try: + check_result = await session.run(check_query) + record = await check_result.single() + await check_result.consume() + + index_exists = record and record.get("exists", False) + + if not index_exists: + # Create index only if it doesn't exist + result = await session.run( + f"CREATE INDEX FOR (n:`{workspace_label}`) ON (n.entity_id)" + ) + await result.consume() + logger.info( + f"[{self.workspace}] Created index for {workspace_label} nodes on entity_id in {database}" + ) + except Exception: + # Fallback if db.indexes() is not supported in this Neo4j version + result = await session.run( + f"CREATE INDEX IF NOT EXISTS FOR (n:`{workspace_label}`) ON (n.entity_id)" ) await result.consume() - logger.info( - f"[{self.workspace}] Created index for {workspace_label} nodes on entity_id in {database}" - ) - except Exception: - # Fallback if db.indexes() is not supported in this Neo4j version - result = await session.run( - f"CREATE INDEX IF NOT EXISTS FOR (n:`{workspace_label}`) ON (n.entity_id)" - ) - await result.consume() - except Exception as e: - logger.warning( - f"[{self.workspace}] Failed to create index: {str(e)}" - ) - break + except Exception as e: + logger.warning( + f"[{self.workspace}] Failed to create index: {str(e)}" + ) + break async def finalize(self): """Close the Neo4j driver and release all resources""" - if self._driver: - await self._driver.close() - self._driver = None + async with get_graph_db_lock(enable_logging=True): + if self._driver: + await self._driver.close() + self._driver = None async def __aexit__(self, exc_type, exc, tb): """Ensure driver is closed when context manager exits""" @@ -1526,23 +1533,24 @@ class Neo4JStorage(BaseGraphStorage): - On success: {"status": "success", "message": "workspace data dropped"} - On failure: {"status": "error", "message": ""} """ - workspace_label = self._get_workspace_label() - try: - async with self._driver.session(database=self._DATABASE) as session: - # Delete all nodes and relationships in current workspace only - query = f"MATCH (n:`{workspace_label}`) DETACH DELETE n" - result = await session.run(query) - await result.consume() # Ensure result is fully consumed + async with get_graph_db_lock(enable_logging=True): + workspace_label = self._get_workspace_label() + try: + async with self._driver.session(database=self._DATABASE) as session: + # Delete all nodes and relationships in current workspace only + query = f"MATCH (n:`{workspace_label}`) DETACH DELETE n" + result = await session.run(query) + await result.consume() # Ensure result is fully consumed - # logger.debug( - # f"[{self.workspace}] Process {os.getpid()} drop Neo4j workspace '{workspace_label}' in database {self._DATABASE}" - # ) - return { - "status": "success", - "message": f"workspace '{workspace_label}' data dropped", - } - except Exception as e: - logger.error( - f"[{self.workspace}] Error dropping Neo4j workspace '{workspace_label}' in database {self._DATABASE}: {e}" - ) - return {"status": "error", "message": str(e)} + # logger.debug( + # f"[{self.workspace}] Process {os.getpid()} drop Neo4j workspace '{workspace_label}' in database {self._DATABASE}" + # ) + return { + "status": "success", + "message": f"workspace '{workspace_label}' data dropped", + } + except Exception as e: + logger.error( + f"[{self.workspace}] Error dropping Neo4j workspace '{workspace_label}' in database {self._DATABASE}: {e}" + ) + return {"status": "error", "message": str(e)} diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index e7119df0..c5a58b6d 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -30,7 +30,7 @@ from ..base import ( from ..namespace import NameSpace, is_namespace from ..utils import logger from ..constants import GRAPH_FIELD_SEP -from ..kg.shared_storage import get_graph_db_lock +from ..kg.shared_storage import get_graph_db_lock, get_storage_lock import pipmaster as pm @@ -1406,23 +1406,26 @@ class PGKVStorage(BaseKVStorage): self._max_batch_size = self.global_config["embedding_batch_num"] async def initialize(self): - if self.db is None: - self.db = await ClientManager.get_client() - # Implement workspace priority: PostgreSQLDB.workspace > self.workspace > "default" - if self.db.workspace: - # Use PostgreSQLDB's workspace (highest priority) - self.workspace = self.db.workspace - elif hasattr(self, "workspace") and self.workspace: - # Use storage class's workspace (medium priority) - pass - else: - # Use "default" for compatibility (lowest priority) - self.workspace = "default" + async with get_storage_lock(enable_logging=True): + if self.db is None: + self.db = await ClientManager.get_client() + + # Implement workspace priority: PostgreSQLDB.workspace > self.workspace > "default" + if self.db.workspace: + # Use PostgreSQLDB's workspace (highest priority) + self.workspace = self.db.workspace + elif hasattr(self, "workspace") and self.workspace: + # Use storage class's workspace (medium priority) + pass + else: + # Use "default" for compatibility (lowest priority) + self.workspace = "default" async def finalize(self): - if self.db is not None: - await ClientManager.release_client(self.db) - self.db = None + async with get_storage_lock(enable_logging=True): + if self.db is not None: + await ClientManager.release_client(self.db) + self.db = None ################ QUERY METHODS ################ async def get_all(self) -> dict[str, Any]: @@ -1834,21 +1837,22 @@ class PGKVStorage(BaseKVStorage): async def drop(self) -> dict[str, str]: """Drop the storage""" - try: - table_name = namespace_to_table_name(self.namespace) - if not table_name: - return { - "status": "error", - "message": f"Unknown namespace: {self.namespace}", - } + async with get_storage_lock(enable_logging=True): + try: + table_name = namespace_to_table_name(self.namespace) + if not table_name: + return { + "status": "error", + "message": f"Unknown namespace: {self.namespace}", + } - drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format( - table_name=table_name - ) - await self.db.execute(drop_sql, {"workspace": self.workspace}) - return {"status": "success", "message": "data dropped"} - except Exception as e: - return {"status": "error", "message": str(e)} + drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format( + table_name=table_name + ) + await self.db.execute(drop_sql, {"workspace": self.workspace}) + return {"status": "success", "message": "data dropped"} + except Exception as e: + return {"status": "error", "message": str(e)} @final @@ -1867,23 +1871,26 @@ class PGVectorStorage(BaseVectorStorage): self.cosine_better_than_threshold = cosine_threshold async def initialize(self): - if self.db is None: - self.db = await ClientManager.get_client() - # Implement workspace priority: PostgreSQLDB.workspace > self.workspace > "default" - if self.db.workspace: - # Use PostgreSQLDB's workspace (highest priority) - self.workspace = self.db.workspace - elif hasattr(self, "workspace") and self.workspace: - # Use storage class's workspace (medium priority) - pass - else: - # Use "default" for compatibility (lowest priority) - self.workspace = "default" + async with get_storage_lock(enable_logging=True): + if self.db is None: + self.db = await ClientManager.get_client() + + # Implement workspace priority: PostgreSQLDB.workspace > self.workspace > "default" + if self.db.workspace: + # Use PostgreSQLDB's workspace (highest priority) + self.workspace = self.db.workspace + elif hasattr(self, "workspace") and self.workspace: + # Use storage class's workspace (medium priority) + pass + else: + # Use "default" for compatibility (lowest priority) + self.workspace = "default" async def finalize(self): - if self.db is not None: - await ClientManager.release_client(self.db) - self.db = None + async with get_storage_lock(enable_logging=True): + if self.db is not None: + await ClientManager.release_client(self.db) + self.db = None def _upsert_chunks( self, item: dict[str, Any], current_time: datetime.datetime @@ -2153,21 +2160,22 @@ class PGVectorStorage(BaseVectorStorage): async def drop(self) -> dict[str, str]: """Drop the storage""" - try: - table_name = namespace_to_table_name(self.namespace) - if not table_name: - return { - "status": "error", - "message": f"Unknown namespace: {self.namespace}", - } + async with get_storage_lock(enable_logging=True): + try: + table_name = namespace_to_table_name(self.namespace) + if not table_name: + return { + "status": "error", + "message": f"Unknown namespace: {self.namespace}", + } - drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format( - table_name=table_name - ) - await self.db.execute(drop_sql, {"workspace": self.workspace}) - return {"status": "success", "message": "data dropped"} - except Exception as e: - return {"status": "error", "message": str(e)} + drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format( + table_name=table_name + ) + await self.db.execute(drop_sql, {"workspace": self.workspace}) + return {"status": "success", "message": "data dropped"} + except Exception as e: + return {"status": "error", "message": str(e)} @final @@ -2186,23 +2194,26 @@ class PGDocStatusStorage(DocStatusStorage): return dt.isoformat() async def initialize(self): - if self.db is None: - self.db = await ClientManager.get_client() - # Implement workspace priority: PostgreSQLDB.workspace > self.workspace > "default" - if self.db.workspace: - # Use PostgreSQLDB's workspace (highest priority) - self.workspace = self.db.workspace - elif hasattr(self, "workspace") and self.workspace: - # Use storage class's workspace (medium priority) - pass - else: - # Use "default" for compatibility (lowest priority) - self.workspace = "default" + async with get_storage_lock(enable_logging=True): + if self.db is None: + self.db = await ClientManager.get_client() + + # Implement workspace priority: PostgreSQLDB.workspace > self.workspace > "default" + if self.db.workspace: + # Use PostgreSQLDB's workspace (highest priority) + self.workspace = self.db.workspace + elif hasattr(self, "workspace") and self.workspace: + # Use storage class's workspace (medium priority) + pass + else: + # Use "default" for compatibility (lowest priority) + self.workspace = "default" async def finalize(self): - if self.db is not None: - await ClientManager.release_client(self.db) - self.db = None + async with get_storage_lock(enable_logging=True): + if self.db is not None: + await ClientManager.release_client(self.db) + self.db = None async def filter_keys(self, keys: set[str]) -> set[str]: """Filter out duplicated content""" @@ -2778,30 +2789,29 @@ class PGGraphStorage(BaseGraphStorage): return normalized_id async def initialize(self): - if self.db is None: - self.db = await ClientManager.get_client() - # Implement workspace priority: PostgreSQLDB.workspace > self.workspace > "default" - if self.db.workspace: - # Use PostgreSQLDB's workspace (highest priority) - self.workspace = self.db.workspace - elif hasattr(self, "workspace") and self.workspace: - # Use storage class's workspace (medium priority) - pass - else: - # Use "default" for compatibility (lowest priority) - self.workspace = "default" + async with get_graph_db_lock(enable_logging=True): + if self.db is None: + self.db = await ClientManager.get_client() - # Dynamically generate graph name based on workspace - self.graph_name = self._get_workspace_graph_name() + # Implement workspace priority: PostgreSQLDB.workspace > self.workspace > "default" + if self.db.workspace: + # Use PostgreSQLDB's workspace (highest priority) + self.workspace = self.db.workspace + elif hasattr(self, "workspace") and self.workspace: + # Use storage class's workspace (medium priority) + pass + else: + # Use "default" for compatibility (lowest priority) + self.workspace = "default" - # Log the graph initialization for debugging - logger.info( - f"[{self.workspace}] PostgreSQL Graph initialized: graph_name='{self.graph_name}'" - ) + # Dynamically generate graph name based on workspace + self.graph_name = self._get_workspace_graph_name() + + # Log the graph initialization for debugging + logger.info( + f"[{self.workspace}] PostgreSQL Graph initialized: graph_name='{self.graph_name}'" + ) - # Use graph database lock to ensure atomic operations and prevent deadlocks - graph_db_lock = get_graph_db_lock(enable_logging=False) - async with graph_db_lock: # Create AGE extension and configure graph environment once at initialization async with self.db.pool.acquire() as connection: # First ensure AGE extension is created @@ -2840,9 +2850,10 @@ class PGGraphStorage(BaseGraphStorage): ) async def finalize(self): - if self.db is not None: - await ClientManager.release_client(self.db) - self.db = None + async with get_graph_db_lock(enable_logging=True): + if self.db is not None: + await ClientManager.release_client(self.db) + self.db = None async def index_done_callback(self) -> None: # PG handles persistence automatically diff --git a/lightrag/kg/qdrant_impl.py b/lightrag/kg/qdrant_impl.py index f5e4d1c2..96eb0c2f 100644 --- a/lightrag/kg/qdrant_impl.py +++ b/lightrag/kg/qdrant_impl.py @@ -7,6 +7,7 @@ import hashlib import uuid from ..utils import logger from ..base import BaseVectorStorage +from ..kg.shared_storage import get_storage_lock import configparser import pipmaster as pm @@ -118,13 +119,33 @@ class QdrantVectorDBStorage(BaseVectorStorage): ), ) self._max_batch_size = self.global_config["embedding_batch_num"] - QdrantVectorDBStorage.create_collection_if_not_exist( - self._client, - self.final_namespace, - vectors_config=models.VectorParams( - size=self.embedding_func.embedding_dim, distance=models.Distance.COSINE - ), - ) + self._initialized = False + + async def initialize(self): + """Initialize Qdrant collection""" + async with get_storage_lock(enable_logging=True): + if self._initialized: + return + + try: + # Create collection if not exists + QdrantVectorDBStorage.create_collection_if_not_exist( + self._client, + self.final_namespace, + vectors_config=models.VectorParams( + size=self.embedding_func.embedding_dim, + distance=models.Distance.COSINE, + ), + ) + self._initialized = True + logger.info( + f"[{self.workspace}] Qdrant collection '{self.namespace}' initialized successfully" + ) + except Exception as e: + logger.error( + f"[{self.workspace}] Failed to initialize Qdrant collection '{self.namespace}': {e}" + ) + raise async def upsert(self, data: dict[str, dict[str, Any]]) -> None: logger.debug(f"[{self.workspace}] Inserting {len(data)} to {self.namespace}") @@ -382,27 +403,28 @@ class QdrantVectorDBStorage(BaseVectorStorage): - On success: {"status": "success", "message": "data dropped"} - On failure: {"status": "error", "message": ""} """ - try: - # Delete the collection and recreate it - if self._client.collection_exists(self.final_namespace): - self._client.delete_collection(self.final_namespace) + async with get_storage_lock(enable_logging=True): + try: + # Delete the collection and recreate it + if self._client.collection_exists(self.final_namespace): + self._client.delete_collection(self.final_namespace) - # Recreate the collection - QdrantVectorDBStorage.create_collection_if_not_exist( - self._client, - self.final_namespace, - vectors_config=models.VectorParams( - size=self.embedding_func.embedding_dim, - distance=models.Distance.COSINE, - ), - ) + # Recreate the collection + QdrantVectorDBStorage.create_collection_if_not_exist( + self._client, + self.final_namespace, + vectors_config=models.VectorParams( + size=self.embedding_func.embedding_dim, + distance=models.Distance.COSINE, + ), + ) - logger.info( - f"[{self.workspace}] Process {os.getpid()} drop Qdrant collection {self.namespace}" - ) - return {"status": "success", "message": "data dropped"} - except Exception as e: - logger.error( - f"[{self.workspace}] Error dropping Qdrant collection {self.namespace}: {e}" - ) - return {"status": "error", "message": str(e)} + logger.info( + f"[{self.workspace}] Process {os.getpid()} drop Qdrant collection {self.namespace}" + ) + return {"status": "success", "message": "data dropped"} + except Exception as e: + logger.error( + f"[{self.workspace}] Error dropping Qdrant collection {self.namespace}: {e}" + ) + return {"status": "error", "message": str(e)} diff --git a/lightrag/kg/redis_impl.py b/lightrag/kg/redis_impl.py index cdb4793c..733e07a5 100644 --- a/lightrag/kg/redis_impl.py +++ b/lightrag/kg/redis_impl.py @@ -20,6 +20,7 @@ from lightrag.base import ( DocStatus, DocProcessingStatus, ) +from ..kg.shared_storage import get_storage_lock import json # Import tenacity for retry logic @@ -179,32 +180,33 @@ class RedisKVStorage(BaseKVStorage): async def initialize(self): """Initialize Redis connection and migrate legacy cache structure if needed""" - if self._initialized: - return + async with get_storage_lock(enable_logging=True): + if self._initialized: + return - # Test connection - try: - async with self._get_redis_connection() as redis: - await redis.ping() - logger.info( - f"[{self.workspace}] Connected to Redis for namespace {self.namespace}" - ) - self._initialized = True - except Exception as e: - logger.error(f"[{self.workspace}] Failed to connect to Redis: {e}") - # Clean up on connection failure - await self.close() - raise - - # Migrate legacy cache structure if this is a cache namespace - if self.namespace.endswith("_cache"): + # Test connection try: - await self._migrate_legacy_cache_structure() + async with self._get_redis_connection() as redis: + await redis.ping() + logger.info( + f"[{self.workspace}] Connected to Redis for namespace {self.namespace}" + ) + self._initialized = True except Exception as e: - logger.error( - f"[{self.workspace}] Failed to migrate legacy cache structure: {e}" - ) - # Don't fail initialization for migration errors, just log them + logger.error(f"[{self.workspace}] Failed to connect to Redis: {e}") + # Clean up on connection failure + await self.close() + raise + + # Migrate legacy cache structure if this is a cache namespace + if self.namespace.endswith("_cache"): + try: + await self._migrate_legacy_cache_structure() + except Exception as e: + logger.error( + f"[{self.workspace}] Failed to migrate legacy cache structure: {e}" + ) + # Don't fail initialization for migration errors, just log them @asynccontextmanager async def _get_redis_connection(self): @@ -426,39 +428,42 @@ class RedisKVStorage(BaseKVStorage): Returns: dict[str, str]: Status of the operation with keys 'status' and 'message' """ - async with self._get_redis_connection() as redis: - try: - # Use SCAN to find all keys with the namespace prefix - pattern = f"{self.final_namespace}:*" - cursor = 0 - deleted_count = 0 + async with get_storage_lock(enable_logging=True): + async with self._get_redis_connection() as redis: + try: + # Use SCAN to find all keys with the namespace prefix + pattern = f"{self.final_namespace}:*" + cursor = 0 + deleted_count = 0 - while True: - cursor, keys = await redis.scan(cursor, match=pattern, count=1000) - if keys: - # Delete keys in batches - pipe = redis.pipeline() - for key in keys: - pipe.delete(key) - results = await pipe.execute() - deleted_count += sum(results) + while True: + cursor, keys = await redis.scan( + cursor, match=pattern, count=1000 + ) + if keys: + # Delete keys in batches + pipe = redis.pipeline() + for key in keys: + pipe.delete(key) + results = await pipe.execute() + deleted_count += sum(results) - if cursor == 0: - break + if cursor == 0: + break - logger.info( - f"[{self.workspace}] Dropped {deleted_count} keys from {self.namespace}" - ) - return { - "status": "success", - "message": f"{deleted_count} keys dropped", - } + logger.info( + f"[{self.workspace}] Dropped {deleted_count} keys from {self.namespace}" + ) + return { + "status": "success", + "message": f"{deleted_count} keys dropped", + } - except Exception as e: - logger.error( - f"[{self.workspace}] Error dropping keys from {self.namespace}: {e}" - ) - return {"status": "error", "message": str(e)} + except Exception as e: + logger.error( + f"[{self.workspace}] Error dropping keys from {self.namespace}: {e}" + ) + return {"status": "error", "message": str(e)} async def _migrate_legacy_cache_structure(self): """Migrate legacy nested cache structure to flattened structure for Redis @@ -601,23 +606,24 @@ class RedisDocStatusStorage(DocStatusStorage): async def initialize(self): """Initialize Redis connection""" - if self._initialized: - return + async with get_storage_lock(enable_logging=True): + if self._initialized: + return - try: - async with self._get_redis_connection() as redis: - await redis.ping() - logger.info( - f"[{self.workspace}] Connected to Redis for doc status namespace {self.namespace}" + try: + async with self._get_redis_connection() as redis: + await redis.ping() + logger.info( + f"[{self.workspace}] Connected to Redis for doc status namespace {self.namespace}" + ) + self._initialized = True + except Exception as e: + logger.error( + f"[{self.workspace}] Failed to connect to Redis for doc status: {e}" ) - self._initialized = True - except Exception as e: - logger.error( - f"[{self.workspace}] Failed to connect to Redis for doc status: {e}" - ) - # Clean up on connection failure - await self.close() - raise + # Clean up on connection failure + await self.close() + raise @asynccontextmanager async def _get_redis_connection(self): @@ -1043,32 +1049,35 @@ class RedisDocStatusStorage(DocStatusStorage): async def drop(self) -> dict[str, str]: """Drop all document status data from storage and clean up resources""" - try: - async with self._get_redis_connection() as redis: - # Use SCAN to find all keys with the namespace prefix - pattern = f"{self.final_namespace}:*" - cursor = 0 - deleted_count = 0 + async with get_storage_lock(enable_logging=True): + try: + async with self._get_redis_connection() as redis: + # Use SCAN to find all keys with the namespace prefix + pattern = f"{self.final_namespace}:*" + cursor = 0 + deleted_count = 0 - while True: - cursor, keys = await redis.scan(cursor, match=pattern, count=1000) - if keys: - # Delete keys in batches - pipe = redis.pipeline() - for key in keys: - pipe.delete(key) - results = await pipe.execute() - deleted_count += sum(results) + while True: + cursor, keys = await redis.scan( + cursor, match=pattern, count=1000 + ) + if keys: + # Delete keys in batches + pipe = redis.pipeline() + for key in keys: + pipe.delete(key) + results = await pipe.execute() + deleted_count += sum(results) - if cursor == 0: - break + if cursor == 0: + break - logger.info( - f"[{self.workspace}] Dropped {deleted_count} doc status keys from {self.namespace}" + logger.info( + f"[{self.workspace}] Dropped {deleted_count} doc status keys from {self.namespace}" + ) + return {"status": "success", "message": "data dropped"} + except Exception as e: + logger.error( + f"[{self.workspace}] Error dropping doc status {self.namespace}: {e}" ) - return {"status": "success", "message": "data dropped"} - except Exception as e: - logger.error( - f"[{self.workspace}] Error dropping doc status {self.namespace}: {e}" - ) - return {"status": "error", "message": str(e)} + return {"status": "error", "message": str(e)} diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 47e6bec0..7421a4fd 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -530,10 +530,8 @@ class LightRAG: self._storages_status = StoragesStatus.CREATED async def initialize_storages(self): - """Asynchronously initialize the storages""" + """Storage initialization must be called one by one to prevent deadlock""" if self._storages_status == StoragesStatus.CREATED: - tasks = [] - for storage in ( self.full_docs, self.text_chunks, @@ -547,9 +545,7 @@ class LightRAG: self.doc_status, ): if storage: - tasks.append(storage.initialize()) - - await asyncio.gather(*tasks) + await storage.initialize() self._storages_status = StoragesStatus.INITIALIZED logger.debug("All storage types initialized")