Fix: add muti-process lock for initialize and drop method for all storage

This commit is contained in:
yangdx 2025-08-12 04:25:09 +08:00
parent ca00b9c8ee
commit fc8ca1a706
8 changed files with 679 additions and 595 deletions

View file

@ -9,6 +9,7 @@ from ..utils import logger
from ..base import BaseGraphStorage from ..base import BaseGraphStorage
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
from ..constants import GRAPH_FIELD_SEP from ..constants import GRAPH_FIELD_SEP
from ..kg.shared_storage import get_graph_db_lock
import pipmaster as pm import pipmaster as pm
if not pm.is_installed("neo4j"): if not pm.is_installed("neo4j"):
@ -55,6 +56,7 @@ class MemgraphStorage(BaseGraphStorage):
return self.workspace return self.workspace
async def initialize(self): async def initialize(self):
async with get_graph_db_lock(enable_logging=True):
URI = os.environ.get( URI = os.environ.get(
"MEMGRAPH_URI", "MEMGRAPH_URI",
config.get("memgraph", "uri", fallback="bolt://localhost:7687"), config.get("memgraph", "uri", fallback="bolt://localhost:7687"),
@ -66,7 +68,8 @@ class MemgraphStorage(BaseGraphStorage):
"MEMGRAPH_PASSWORD", config.get("memgraph", "password", fallback="") "MEMGRAPH_PASSWORD", config.get("memgraph", "password", fallback="")
) )
DATABASE = os.environ.get( DATABASE = os.environ.get(
"MEMGRAPH_DATABASE", config.get("memgraph", "database", fallback="memgraph") "MEMGRAPH_DATABASE",
config.get("memgraph", "database", fallback="memgraph"),
) )
self._driver = AsyncGraphDatabase.driver( self._driver = AsyncGraphDatabase.driver(
@ -99,6 +102,7 @@ class MemgraphStorage(BaseGraphStorage):
raise raise
async def finalize(self): async def finalize(self):
async with get_graph_db_lock(enable_logging=True):
if self._driver is not None: if self._driver is not None:
await self._driver.close() await self._driver.close()
self._driver = None self._driver = None
@ -739,6 +743,7 @@ class MemgraphStorage(BaseGraphStorage):
raise RuntimeError( raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first." "Memgraph driver is not initialized. Call 'await initialize()' first."
) )
async with get_graph_db_lock(enable_logging=True):
try: try:
async with self._driver.session(database=self._DATABASE) as session: async with self._driver.session(database=self._DATABASE) as session:
workspace_label = self._get_workspace_label() workspace_label = self._get_workspace_label()

View file

@ -6,6 +6,7 @@ import numpy as np
from lightrag.utils import logger, compute_mdhash_id from lightrag.utils import logger, compute_mdhash_id
from ..base import BaseVectorStorage from ..base import BaseVectorStorage
from ..constants import DEFAULT_MAX_FILE_PATH_LENGTH from ..constants import DEFAULT_MAX_FILE_PATH_LENGTH
from ..kg.shared_storage import get_storage_lock
import pipmaster as pm import pipmaster as pm
if not pm.is_installed("pymilvus"): if not pm.is_installed("pymilvus"):
@ -752,9 +753,26 @@ class MilvusVectorDBStorage(BaseVectorStorage):
), ),
) )
self._max_batch_size = self.global_config["embedding_batch_num"] self._max_batch_size = self.global_config["embedding_batch_num"]
self._initialized = False
async def initialize(self):
"""Initialize Milvus collection"""
async with get_storage_lock(enable_logging=True):
if self._initialized:
return
try:
# Create collection and check compatibility # Create collection and check compatibility
self._create_collection_if_not_exist() self._create_collection_if_not_exist()
self._initialized = True
logger.info(
f"[{self.workspace}] Milvus collection '{self.namespace}' initialized successfully"
)
except Exception as e:
logger.error(
f"[{self.workspace}] Failed to initialize Milvus collection '{self.namespace}': {e}"
)
raise
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.debug(f"[{self.workspace}] Inserting {len(data)} to {self.namespace}") logger.debug(f"[{self.workspace}] Inserting {len(data)} to {self.namespace}")
@ -1012,6 +1030,7 @@ class MilvusVectorDBStorage(BaseVectorStorage):
- On success: {"status": "success", "message": "data dropped"} - On success: {"status": "success", "message": "data dropped"}
- On failure: {"status": "error", "message": "<error details>"} - On failure: {"status": "error", "message": "<error details>"}
""" """
async with get_storage_lock(enable_logging=True):
try: try:
# Drop the collection and recreate it # Drop the collection and recreate it
if self._client.has_collection(self.final_namespace): if self._client.has_collection(self.final_namespace):

View file

@ -18,6 +18,7 @@ from ..base import (
from ..utils import logger, compute_mdhash_id from ..utils import logger, compute_mdhash_id
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
from ..constants import GRAPH_FIELD_SEP from ..constants import GRAPH_FIELD_SEP
from ..kg.shared_storage import get_storage_lock, get_graph_db_lock
import pipmaster as pm import pipmaster as pm
@ -39,11 +40,9 @@ GRAPH_BFS_MODE = os.getenv("MONGO_GRAPH_BFS_MODE", "bidirectional")
class ClientManager: class ClientManager:
_instances = {"db": None, "ref_count": 0} _instances = {"db": None, "ref_count": 0}
_lock = asyncio.Lock()
@classmethod @classmethod
async def get_client(cls) -> AsyncMongoClient: async def get_client(cls) -> AsyncMongoClient:
async with cls._lock:
if cls._instances["db"] is None: if cls._instances["db"] is None:
uri = os.environ.get( uri = os.environ.get(
"MONGO_URI", "MONGO_URI",
@ -66,7 +65,6 @@ class ClientManager:
@classmethod @classmethod
async def release_client(cls, db: AsyncDatabase): async def release_client(cls, db: AsyncDatabase):
async with cls._lock:
if db is not None: if db is not None:
if db is cls._instances["db"]: if db is cls._instances["db"]:
cls._instances["ref_count"] -= 1 cls._instances["ref_count"] -= 1
@ -125,14 +123,17 @@ class MongoKVStorage(BaseKVStorage):
self._collection_name = self.final_namespace self._collection_name = self.final_namespace
async def initialize(self): async def initialize(self):
async with get_storage_lock(enable_logging=True):
if self.db is None: if self.db is None:
self.db = await ClientManager.get_client() self.db = await ClientManager.get_client()
self._data = await get_or_create_collection(self.db, self._collection_name) self._data = await get_or_create_collection(self.db, self._collection_name)
logger.debug( logger.debug(
f"[{self.workspace}] Use MongoDB as KV {self._collection_name}" f"[{self.workspace}] Use MongoDB as KV {self._collection_name}"
) )
async def finalize(self): async def finalize(self):
async with get_storage_lock(enable_logging=True):
if self.db is not None: if self.db is not None:
await ClientManager.release_client(self.db) await ClientManager.release_client(self.db)
self.db = None self.db = None
@ -251,6 +252,7 @@ class MongoKVStorage(BaseKVStorage):
Returns: Returns:
dict[str, str]: Status of the operation with keys 'status' and 'message' dict[str, str]: Status of the operation with keys 'status' and 'message'
""" """
async with get_storage_lock(enable_logging=True):
try: try:
result = await self._data.delete_many({}) result = await self._data.delete_many({})
deleted_count = result.deleted_count deleted_count = result.deleted_count
@ -318,8 +320,10 @@ class MongoDocStatusStorage(DocStatusStorage):
self._collection_name = self.final_namespace self._collection_name = self.final_namespace
async def initialize(self): async def initialize(self):
async with get_storage_lock(enable_logging=True):
if self.db is None: if self.db is None:
self.db = await ClientManager.get_client() self.db = await ClientManager.get_client()
self._data = await get_or_create_collection(self.db, self._collection_name) self._data = await get_or_create_collection(self.db, self._collection_name)
# Create track_id index for better query performance # Create track_id index for better query performance
@ -333,6 +337,7 @@ class MongoDocStatusStorage(DocStatusStorage):
) )
async def finalize(self): async def finalize(self):
async with get_storage_lock(enable_logging=True):
if self.db is not None: if self.db is not None:
await ClientManager.release_client(self.db) await ClientManager.release_client(self.db)
self.db = None self.db = None
@ -447,6 +452,7 @@ class MongoDocStatusStorage(DocStatusStorage):
Returns: Returns:
dict[str, str]: Status of the operation with keys 'status' and 'message' dict[str, str]: Status of the operation with keys 'status' and 'message'
""" """
async with get_storage_lock(enable_logging=True):
try: try:
result = await self._data.delete_many({}) result = await self._data.delete_many({})
deleted_count = result.deleted_count deleted_count = result.deleted_count
@ -699,8 +705,10 @@ class MongoGraphStorage(BaseGraphStorage):
self._edge_collection_name = f"{self._collection_name}_edges" self._edge_collection_name = f"{self._collection_name}_edges"
async def initialize(self): async def initialize(self):
async with get_graph_db_lock(enable_logging=True):
if self.db is None: if self.db is None:
self.db = await ClientManager.get_client() self.db = await ClientManager.get_client()
self.collection = await get_or_create_collection( self.collection = await get_or_create_collection(
self.db, self._collection_name self.db, self._collection_name
) )
@ -712,6 +720,7 @@ class MongoGraphStorage(BaseGraphStorage):
) )
async def finalize(self): async def finalize(self):
async with get_graph_db_lock(enable_logging=True):
if self.db is not None: if self.db is not None:
await ClientManager.release_client(self.db) await ClientManager.release_client(self.db)
self.db = None self.db = None
@ -1567,6 +1576,7 @@ class MongoGraphStorage(BaseGraphStorage):
Returns: Returns:
dict[str, str]: Status of the operation with keys 'status' and 'message' dict[str, str]: Status of the operation with keys 'status' and 'message'
""" """
async with get_graph_db_lock(enable_logging=True):
try: try:
result = await self.collection.delete_many({}) result = await self.collection.delete_many({})
deleted_count = result.deleted_count deleted_count = result.deleted_count
@ -1661,8 +1671,10 @@ class MongoVectorDBStorage(BaseVectorStorage):
self._max_batch_size = self.global_config["embedding_batch_num"] self._max_batch_size = self.global_config["embedding_batch_num"]
async def initialize(self): async def initialize(self):
async with get_storage_lock(enable_logging=True):
if self.db is None: if self.db is None:
self.db = await ClientManager.get_client() self.db = await ClientManager.get_client()
self._data = await get_or_create_collection(self.db, self._collection_name) self._data = await get_or_create_collection(self.db, self._collection_name)
# Ensure vector index exists # Ensure vector index exists
@ -1673,6 +1685,7 @@ class MongoVectorDBStorage(BaseVectorStorage):
) )
async def finalize(self): async def finalize(self):
async with get_storage_lock(enable_logging=True):
if self.db is not None: if self.db is not None:
await ClientManager.release_client(self.db) await ClientManager.release_client(self.db)
self.db = None self.db = None
@ -1957,6 +1970,7 @@ class MongoVectorDBStorage(BaseVectorStorage):
Returns: Returns:
dict[str, str]: Status of the operation with keys 'status' and 'message' dict[str, str]: Status of the operation with keys 'status' and 'message'
""" """
async with get_storage_lock(enable_logging=True):
try: try:
# Delete all documents # Delete all documents
result = await self._data.delete_many({}) result = await self._data.delete_many({})

View file

@ -17,6 +17,7 @@ from ..utils import logger
from ..base import BaseGraphStorage from ..base import BaseGraphStorage
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
from ..constants import GRAPH_FIELD_SEP from ..constants import GRAPH_FIELD_SEP
from ..kg.shared_storage import get_graph_db_lock
import pipmaster as pm import pipmaster as pm
if not pm.is_installed("neo4j"): if not pm.is_installed("neo4j"):
@ -70,6 +71,7 @@ class Neo4JStorage(BaseGraphStorage):
return self.workspace return self.workspace
async def initialize(self): async def initialize(self):
async with get_graph_db_lock(enable_logging=True):
URI = os.environ.get("NEO4J_URI", config.get("neo4j", "uri", fallback=None)) URI = os.environ.get("NEO4J_URI", config.get("neo4j", "uri", fallback=None))
USERNAME = os.environ.get( USERNAME = os.environ.get(
"NEO4J_USERNAME", config.get("neo4j", "username", fallback=None) "NEO4J_USERNAME", config.get("neo4j", "username", fallback=None)
@ -92,7 +94,9 @@ class Neo4JStorage(BaseGraphStorage):
CONNECTION_ACQUISITION_TIMEOUT = float( CONNECTION_ACQUISITION_TIMEOUT = float(
os.environ.get( os.environ.get(
"NEO4J_CONNECTION_ACQUISITION_TIMEOUT", "NEO4J_CONNECTION_ACQUISITION_TIMEOUT",
config.get("neo4j", "connection_acquisition_timeout", fallback=30.0), config.get(
"neo4j", "connection_acquisition_timeout", fallback=30.0
),
), ),
) )
MAX_TRANSACTION_RETRY_TIME = float( MAX_TRANSACTION_RETRY_TIME = float(
@ -183,7 +187,9 @@ class Neo4JStorage(BaseGraphStorage):
if ( if (
e.code e.code
== "Neo.ClientError.Statement.UnsupportedAdministrationCommand" == "Neo.ClientError.Statement.UnsupportedAdministrationCommand"
) or (e.code == "Neo.DatabaseError.Statement.ExecutionFailed"): ) or (
e.code == "Neo.DatabaseError.Statement.ExecutionFailed"
):
if database is not None: if database is not None:
logger.warning( logger.warning(
f"[{self.workspace}] This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead. Fallback to use the default database." f"[{self.workspace}] This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead. Fallback to use the default database."
@ -235,6 +241,7 @@ class Neo4JStorage(BaseGraphStorage):
async def finalize(self): async def finalize(self):
"""Close the Neo4j driver and release all resources""" """Close the Neo4j driver and release all resources"""
async with get_graph_db_lock(enable_logging=True):
if self._driver: if self._driver:
await self._driver.close() await self._driver.close()
self._driver = None self._driver = None
@ -1526,6 +1533,7 @@ class Neo4JStorage(BaseGraphStorage):
- On success: {"status": "success", "message": "workspace data dropped"} - On success: {"status": "success", "message": "workspace data dropped"}
- On failure: {"status": "error", "message": "<error details>"} - On failure: {"status": "error", "message": "<error details>"}
""" """
async with get_graph_db_lock(enable_logging=True):
workspace_label = self._get_workspace_label() workspace_label = self._get_workspace_label()
try: try:
async with self._driver.session(database=self._DATABASE) as session: async with self._driver.session(database=self._DATABASE) as session:

View file

@ -30,7 +30,7 @@ from ..base import (
from ..namespace import NameSpace, is_namespace from ..namespace import NameSpace, is_namespace
from ..utils import logger from ..utils import logger
from ..constants import GRAPH_FIELD_SEP from ..constants import GRAPH_FIELD_SEP
from ..kg.shared_storage import get_graph_db_lock from ..kg.shared_storage import get_graph_db_lock, get_storage_lock
import pipmaster as pm import pipmaster as pm
@ -1406,8 +1406,10 @@ class PGKVStorage(BaseKVStorage):
self._max_batch_size = self.global_config["embedding_batch_num"] self._max_batch_size = self.global_config["embedding_batch_num"]
async def initialize(self): async def initialize(self):
async with get_storage_lock(enable_logging=True):
if self.db is None: if self.db is None:
self.db = await ClientManager.get_client() self.db = await ClientManager.get_client()
# Implement workspace priority: PostgreSQLDB.workspace > self.workspace > "default" # Implement workspace priority: PostgreSQLDB.workspace > self.workspace > "default"
if self.db.workspace: if self.db.workspace:
# Use PostgreSQLDB's workspace (highest priority) # Use PostgreSQLDB's workspace (highest priority)
@ -1420,6 +1422,7 @@ class PGKVStorage(BaseKVStorage):
self.workspace = "default" self.workspace = "default"
async def finalize(self): async def finalize(self):
async with get_storage_lock(enable_logging=True):
if self.db is not None: if self.db is not None:
await ClientManager.release_client(self.db) await ClientManager.release_client(self.db)
self.db = None self.db = None
@ -1834,6 +1837,7 @@ class PGKVStorage(BaseKVStorage):
async def drop(self) -> dict[str, str]: async def drop(self) -> dict[str, str]:
"""Drop the storage""" """Drop the storage"""
async with get_storage_lock(enable_logging=True):
try: try:
table_name = namespace_to_table_name(self.namespace) table_name = namespace_to_table_name(self.namespace)
if not table_name: if not table_name:
@ -1867,8 +1871,10 @@ class PGVectorStorage(BaseVectorStorage):
self.cosine_better_than_threshold = cosine_threshold self.cosine_better_than_threshold = cosine_threshold
async def initialize(self): async def initialize(self):
async with get_storage_lock(enable_logging=True):
if self.db is None: if self.db is None:
self.db = await ClientManager.get_client() self.db = await ClientManager.get_client()
# Implement workspace priority: PostgreSQLDB.workspace > self.workspace > "default" # Implement workspace priority: PostgreSQLDB.workspace > self.workspace > "default"
if self.db.workspace: if self.db.workspace:
# Use PostgreSQLDB's workspace (highest priority) # Use PostgreSQLDB's workspace (highest priority)
@ -1881,6 +1887,7 @@ class PGVectorStorage(BaseVectorStorage):
self.workspace = "default" self.workspace = "default"
async def finalize(self): async def finalize(self):
async with get_storage_lock(enable_logging=True):
if self.db is not None: if self.db is not None:
await ClientManager.release_client(self.db) await ClientManager.release_client(self.db)
self.db = None self.db = None
@ -2153,6 +2160,7 @@ class PGVectorStorage(BaseVectorStorage):
async def drop(self) -> dict[str, str]: async def drop(self) -> dict[str, str]:
"""Drop the storage""" """Drop the storage"""
async with get_storage_lock(enable_logging=True):
try: try:
table_name = namespace_to_table_name(self.namespace) table_name = namespace_to_table_name(self.namespace)
if not table_name: if not table_name:
@ -2186,8 +2194,10 @@ class PGDocStatusStorage(DocStatusStorage):
return dt.isoformat() return dt.isoformat()
async def initialize(self): async def initialize(self):
async with get_storage_lock(enable_logging=True):
if self.db is None: if self.db is None:
self.db = await ClientManager.get_client() self.db = await ClientManager.get_client()
# Implement workspace priority: PostgreSQLDB.workspace > self.workspace > "default" # Implement workspace priority: PostgreSQLDB.workspace > self.workspace > "default"
if self.db.workspace: if self.db.workspace:
# Use PostgreSQLDB's workspace (highest priority) # Use PostgreSQLDB's workspace (highest priority)
@ -2200,6 +2210,7 @@ class PGDocStatusStorage(DocStatusStorage):
self.workspace = "default" self.workspace = "default"
async def finalize(self): async def finalize(self):
async with get_storage_lock(enable_logging=True):
if self.db is not None: if self.db is not None:
await ClientManager.release_client(self.db) await ClientManager.release_client(self.db)
self.db = None self.db = None
@ -2778,8 +2789,10 @@ class PGGraphStorage(BaseGraphStorage):
return normalized_id return normalized_id
async def initialize(self): async def initialize(self):
async with get_graph_db_lock(enable_logging=True):
if self.db is None: if self.db is None:
self.db = await ClientManager.get_client() self.db = await ClientManager.get_client()
# Implement workspace priority: PostgreSQLDB.workspace > self.workspace > "default" # Implement workspace priority: PostgreSQLDB.workspace > self.workspace > "default"
if self.db.workspace: if self.db.workspace:
# Use PostgreSQLDB's workspace (highest priority) # Use PostgreSQLDB's workspace (highest priority)
@ -2799,9 +2812,6 @@ class PGGraphStorage(BaseGraphStorage):
f"[{self.workspace}] PostgreSQL Graph initialized: graph_name='{self.graph_name}'" f"[{self.workspace}] PostgreSQL Graph initialized: graph_name='{self.graph_name}'"
) )
# Use graph database lock to ensure atomic operations and prevent deadlocks
graph_db_lock = get_graph_db_lock(enable_logging=False)
async with graph_db_lock:
# Create AGE extension and configure graph environment once at initialization # Create AGE extension and configure graph environment once at initialization
async with self.db.pool.acquire() as connection: async with self.db.pool.acquire() as connection:
# First ensure AGE extension is created # First ensure AGE extension is created
@ -2840,6 +2850,7 @@ class PGGraphStorage(BaseGraphStorage):
) )
async def finalize(self): async def finalize(self):
async with get_graph_db_lock(enable_logging=True):
if self.db is not None: if self.db is not None:
await ClientManager.release_client(self.db) await ClientManager.release_client(self.db)
self.db = None self.db = None

View file

@ -7,6 +7,7 @@ import hashlib
import uuid import uuid
from ..utils import logger from ..utils import logger
from ..base import BaseVectorStorage from ..base import BaseVectorStorage
from ..kg.shared_storage import get_storage_lock
import configparser import configparser
import pipmaster as pm import pipmaster as pm
@ -118,13 +119,33 @@ class QdrantVectorDBStorage(BaseVectorStorage):
), ),
) )
self._max_batch_size = self.global_config["embedding_batch_num"] self._max_batch_size = self.global_config["embedding_batch_num"]
self._initialized = False
async def initialize(self):
"""Initialize Qdrant collection"""
async with get_storage_lock(enable_logging=True):
if self._initialized:
return
try:
# Create collection if not exists
QdrantVectorDBStorage.create_collection_if_not_exist( QdrantVectorDBStorage.create_collection_if_not_exist(
self._client, self._client,
self.final_namespace, self.final_namespace,
vectors_config=models.VectorParams( vectors_config=models.VectorParams(
size=self.embedding_func.embedding_dim, distance=models.Distance.COSINE size=self.embedding_func.embedding_dim,
distance=models.Distance.COSINE,
), ),
) )
self._initialized = True
logger.info(
f"[{self.workspace}] Qdrant collection '{self.namespace}' initialized successfully"
)
except Exception as e:
logger.error(
f"[{self.workspace}] Failed to initialize Qdrant collection '{self.namespace}': {e}"
)
raise
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.debug(f"[{self.workspace}] Inserting {len(data)} to {self.namespace}") logger.debug(f"[{self.workspace}] Inserting {len(data)} to {self.namespace}")
@ -382,6 +403,7 @@ class QdrantVectorDBStorage(BaseVectorStorage):
- On success: {"status": "success", "message": "data dropped"} - On success: {"status": "success", "message": "data dropped"}
- On failure: {"status": "error", "message": "<error details>"} - On failure: {"status": "error", "message": "<error details>"}
""" """
async with get_storage_lock(enable_logging=True):
try: try:
# Delete the collection and recreate it # Delete the collection and recreate it
if self._client.collection_exists(self.final_namespace): if self._client.collection_exists(self.final_namespace):

View file

@ -20,6 +20,7 @@ from lightrag.base import (
DocStatus, DocStatus,
DocProcessingStatus, DocProcessingStatus,
) )
from ..kg.shared_storage import get_storage_lock
import json import json
# Import tenacity for retry logic # Import tenacity for retry logic
@ -179,6 +180,7 @@ class RedisKVStorage(BaseKVStorage):
async def initialize(self): async def initialize(self):
"""Initialize Redis connection and migrate legacy cache structure if needed""" """Initialize Redis connection and migrate legacy cache structure if needed"""
async with get_storage_lock(enable_logging=True):
if self._initialized: if self._initialized:
return return
@ -426,6 +428,7 @@ class RedisKVStorage(BaseKVStorage):
Returns: Returns:
dict[str, str]: Status of the operation with keys 'status' and 'message' dict[str, str]: Status of the operation with keys 'status' and 'message'
""" """
async with get_storage_lock(enable_logging=True):
async with self._get_redis_connection() as redis: async with self._get_redis_connection() as redis:
try: try:
# Use SCAN to find all keys with the namespace prefix # Use SCAN to find all keys with the namespace prefix
@ -434,7 +437,9 @@ class RedisKVStorage(BaseKVStorage):
deleted_count = 0 deleted_count = 0
while True: while True:
cursor, keys = await redis.scan(cursor, match=pattern, count=1000) cursor, keys = await redis.scan(
cursor, match=pattern, count=1000
)
if keys: if keys:
# Delete keys in batches # Delete keys in batches
pipe = redis.pipeline() pipe = redis.pipeline()
@ -601,6 +606,7 @@ class RedisDocStatusStorage(DocStatusStorage):
async def initialize(self): async def initialize(self):
"""Initialize Redis connection""" """Initialize Redis connection"""
async with get_storage_lock(enable_logging=True):
if self._initialized: if self._initialized:
return return
@ -1043,6 +1049,7 @@ class RedisDocStatusStorage(DocStatusStorage):
async def drop(self) -> dict[str, str]: async def drop(self) -> dict[str, str]:
"""Drop all document status data from storage and clean up resources""" """Drop all document status data from storage and clean up resources"""
async with get_storage_lock(enable_logging=True):
try: try:
async with self._get_redis_connection() as redis: async with self._get_redis_connection() as redis:
# Use SCAN to find all keys with the namespace prefix # Use SCAN to find all keys with the namespace prefix
@ -1051,7 +1058,9 @@ class RedisDocStatusStorage(DocStatusStorage):
deleted_count = 0 deleted_count = 0
while True: while True:
cursor, keys = await redis.scan(cursor, match=pattern, count=1000) cursor, keys = await redis.scan(
cursor, match=pattern, count=1000
)
if keys: if keys:
# Delete keys in batches # Delete keys in batches
pipe = redis.pipeline() pipe = redis.pipeline()

View file

@ -530,10 +530,8 @@ class LightRAG:
self._storages_status = StoragesStatus.CREATED self._storages_status = StoragesStatus.CREATED
async def initialize_storages(self): async def initialize_storages(self):
"""Asynchronously initialize the storages""" """Storage initialization must be called one by one to prevent deadlock"""
if self._storages_status == StoragesStatus.CREATED: if self._storages_status == StoragesStatus.CREATED:
tasks = []
for storage in ( for storage in (
self.full_docs, self.full_docs,
self.text_chunks, self.text_chunks,
@ -547,9 +545,7 @@ class LightRAG:
self.doc_status, self.doc_status,
): ):
if storage: if storage:
tasks.append(storage.initialize()) await storage.initialize()
await asyncio.gather(*tasks)
self._storages_status = StoragesStatus.INITIALIZED self._storages_status = StoragesStatus.INITIALIZED
logger.debug("All storage types initialized") logger.debug("All storage types initialized")