Fix: add muti-process lock for initialize and drop method for all storage

This commit is contained in:
yangdx 2025-08-12 04:25:09 +08:00
parent ca00b9c8ee
commit fc8ca1a706
8 changed files with 679 additions and 595 deletions

View file

@ -9,6 +9,7 @@ from ..utils import logger
from ..base import BaseGraphStorage from ..base import BaseGraphStorage
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
from ..constants import GRAPH_FIELD_SEP from ..constants import GRAPH_FIELD_SEP
from ..kg.shared_storage import get_graph_db_lock
import pipmaster as pm import pipmaster as pm
if not pm.is_installed("neo4j"): if not pm.is_installed("neo4j"):
@ -55,53 +56,56 @@ class MemgraphStorage(BaseGraphStorage):
return self.workspace return self.workspace
async def initialize(self): async def initialize(self):
URI = os.environ.get( async with get_graph_db_lock(enable_logging=True):
"MEMGRAPH_URI", URI = os.environ.get(
config.get("memgraph", "uri", fallback="bolt://localhost:7687"), "MEMGRAPH_URI",
) config.get("memgraph", "uri", fallback="bolt://localhost:7687"),
USERNAME = os.environ.get(
"MEMGRAPH_USERNAME", config.get("memgraph", "username", fallback="")
)
PASSWORD = os.environ.get(
"MEMGRAPH_PASSWORD", config.get("memgraph", "password", fallback="")
)
DATABASE = os.environ.get(
"MEMGRAPH_DATABASE", config.get("memgraph", "database", fallback="memgraph")
)
self._driver = AsyncGraphDatabase.driver(
URI,
auth=(USERNAME, PASSWORD),
)
self._DATABASE = DATABASE
try:
async with self._driver.session(database=DATABASE) as session:
# Create index for base nodes on entity_id if it doesn't exist
try:
workspace_label = self._get_workspace_label()
await session.run(
f"""CREATE INDEX ON :{workspace_label}(entity_id)"""
)
logger.info(
f"[{self.workspace}] Created index on :{workspace_label}(entity_id) in Memgraph."
)
except Exception as e:
# Index may already exist, which is not an error
logger.warning(
f"[{self.workspace}] Index creation on :{workspace_label}(entity_id) may have failed or already exists: {e}"
)
await session.run("RETURN 1")
logger.info(f"[{self.workspace}] Connected to Memgraph at {URI}")
except Exception as e:
logger.error(
f"[{self.workspace}] Failed to connect to Memgraph at {URI}: {e}"
) )
raise USERNAME = os.environ.get(
"MEMGRAPH_USERNAME", config.get("memgraph", "username", fallback="")
)
PASSWORD = os.environ.get(
"MEMGRAPH_PASSWORD", config.get("memgraph", "password", fallback="")
)
DATABASE = os.environ.get(
"MEMGRAPH_DATABASE",
config.get("memgraph", "database", fallback="memgraph"),
)
self._driver = AsyncGraphDatabase.driver(
URI,
auth=(USERNAME, PASSWORD),
)
self._DATABASE = DATABASE
try:
async with self._driver.session(database=DATABASE) as session:
# Create index for base nodes on entity_id if it doesn't exist
try:
workspace_label = self._get_workspace_label()
await session.run(
f"""CREATE INDEX ON :{workspace_label}(entity_id)"""
)
logger.info(
f"[{self.workspace}] Created index on :{workspace_label}(entity_id) in Memgraph."
)
except Exception as e:
# Index may already exist, which is not an error
logger.warning(
f"[{self.workspace}] Index creation on :{workspace_label}(entity_id) may have failed or already exists: {e}"
)
await session.run("RETURN 1")
logger.info(f"[{self.workspace}] Connected to Memgraph at {URI}")
except Exception as e:
logger.error(
f"[{self.workspace}] Failed to connect to Memgraph at {URI}: {e}"
)
raise
async def finalize(self): async def finalize(self):
if self._driver is not None: async with get_graph_db_lock(enable_logging=True):
await self._driver.close() if self._driver is not None:
self._driver = None await self._driver.close()
self._driver = None
async def __aexit__(self, exc_type, exc, tb): async def __aexit__(self, exc_type, exc, tb):
await self.finalize() await self.finalize()
@ -739,21 +743,22 @@ class MemgraphStorage(BaseGraphStorage):
raise RuntimeError( raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first." "Memgraph driver is not initialized. Call 'await initialize()' first."
) )
try: async with get_graph_db_lock(enable_logging=True):
async with self._driver.session(database=self._DATABASE) as session: try:
workspace_label = self._get_workspace_label() async with self._driver.session(database=self._DATABASE) as session:
query = f"MATCH (n:`{workspace_label}`) DETACH DELETE n" workspace_label = self._get_workspace_label()
result = await session.run(query) query = f"MATCH (n:`{workspace_label}`) DETACH DELETE n"
await result.consume() result = await session.run(query)
logger.info( await result.consume()
f"[{self.workspace}] Dropped workspace {workspace_label} from Memgraph database {self._DATABASE}" logger.info(
f"[{self.workspace}] Dropped workspace {workspace_label} from Memgraph database {self._DATABASE}"
)
return {"status": "success", "message": "workspace data dropped"}
except Exception as e:
logger.error(
f"[{self.workspace}] Error dropping workspace {workspace_label} from Memgraph database {self._DATABASE}: {e}"
) )
return {"status": "success", "message": "workspace data dropped"} return {"status": "error", "message": str(e)}
except Exception as e:
logger.error(
f"[{self.workspace}] Error dropping workspace {workspace_label} from Memgraph database {self._DATABASE}: {e}"
)
return {"status": "error", "message": str(e)}
async def edge_degree(self, src_id: str, tgt_id: str) -> int: async def edge_degree(self, src_id: str, tgt_id: str) -> int:
"""Get the total degree (sum of relationships) of two nodes. """Get the total degree (sum of relationships) of two nodes.

View file

@ -6,6 +6,7 @@ import numpy as np
from lightrag.utils import logger, compute_mdhash_id from lightrag.utils import logger, compute_mdhash_id
from ..base import BaseVectorStorage from ..base import BaseVectorStorage
from ..constants import DEFAULT_MAX_FILE_PATH_LENGTH from ..constants import DEFAULT_MAX_FILE_PATH_LENGTH
from ..kg.shared_storage import get_storage_lock
import pipmaster as pm import pipmaster as pm
if not pm.is_installed("pymilvus"): if not pm.is_installed("pymilvus"):
@ -752,9 +753,26 @@ class MilvusVectorDBStorage(BaseVectorStorage):
), ),
) )
self._max_batch_size = self.global_config["embedding_batch_num"] self._max_batch_size = self.global_config["embedding_batch_num"]
self._initialized = False
# Create collection and check compatibility async def initialize(self):
self._create_collection_if_not_exist() """Initialize Milvus collection"""
async with get_storage_lock(enable_logging=True):
if self._initialized:
return
try:
# Create collection and check compatibility
self._create_collection_if_not_exist()
self._initialized = True
logger.info(
f"[{self.workspace}] Milvus collection '{self.namespace}' initialized successfully"
)
except Exception as e:
logger.error(
f"[{self.workspace}] Failed to initialize Milvus collection '{self.namespace}': {e}"
)
raise
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.debug(f"[{self.workspace}] Inserting {len(data)} to {self.namespace}") logger.debug(f"[{self.workspace}] Inserting {len(data)} to {self.namespace}")
@ -1012,20 +1030,21 @@ class MilvusVectorDBStorage(BaseVectorStorage):
- On success: {"status": "success", "message": "data dropped"} - On success: {"status": "success", "message": "data dropped"}
- On failure: {"status": "error", "message": "<error details>"} - On failure: {"status": "error", "message": "<error details>"}
""" """
try: async with get_storage_lock(enable_logging=True):
# Drop the collection and recreate it try:
if self._client.has_collection(self.final_namespace): # Drop the collection and recreate it
self._client.drop_collection(self.final_namespace) if self._client.has_collection(self.final_namespace):
self._client.drop_collection(self.final_namespace)
# Recreate the collection # Recreate the collection
self._create_collection_if_not_exist() self._create_collection_if_not_exist()
logger.info( logger.info(
f"[{self.workspace}] Process {os.getpid()} drop Milvus collection {self.namespace}" f"[{self.workspace}] Process {os.getpid()} drop Milvus collection {self.namespace}"
) )
return {"status": "success", "message": "data dropped"} return {"status": "success", "message": "data dropped"}
except Exception as e: except Exception as e:
logger.error( logger.error(
f"[{self.workspace}] Error dropping Milvus collection {self.namespace}: {e}" f"[{self.workspace}] Error dropping Milvus collection {self.namespace}: {e}"
) )
return {"status": "error", "message": str(e)} return {"status": "error", "message": str(e)}

View file

@ -18,6 +18,7 @@ from ..base import (
from ..utils import logger, compute_mdhash_id from ..utils import logger, compute_mdhash_id
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
from ..constants import GRAPH_FIELD_SEP from ..constants import GRAPH_FIELD_SEP
from ..kg.shared_storage import get_storage_lock, get_graph_db_lock
import pipmaster as pm import pipmaster as pm
@ -39,39 +40,36 @@ GRAPH_BFS_MODE = os.getenv("MONGO_GRAPH_BFS_MODE", "bidirectional")
class ClientManager: class ClientManager:
_instances = {"db": None, "ref_count": 0} _instances = {"db": None, "ref_count": 0}
_lock = asyncio.Lock()
@classmethod @classmethod
async def get_client(cls) -> AsyncMongoClient: async def get_client(cls) -> AsyncMongoClient:
async with cls._lock: if cls._instances["db"] is None:
if cls._instances["db"] is None: uri = os.environ.get(
uri = os.environ.get( "MONGO_URI",
"MONGO_URI", config.get(
config.get( "mongodb",
"mongodb", "uri",
"uri", fallback="mongodb://root:root@localhost:27017/",
fallback="mongodb://root:root@localhost:27017/", ),
), )
) database_name = os.environ.get(
database_name = os.environ.get( "MONGO_DATABASE",
"MONGO_DATABASE", config.get("mongodb", "database", fallback="LightRAG"),
config.get("mongodb", "database", fallback="LightRAG"), )
) client = AsyncMongoClient(uri)
client = AsyncMongoClient(uri) db = client.get_database(database_name)
db = client.get_database(database_name) cls._instances["db"] = db
cls._instances["db"] = db cls._instances["ref_count"] = 0
cls._instances["ref_count"] = 0 cls._instances["ref_count"] += 1
cls._instances["ref_count"] += 1 return cls._instances["db"]
return cls._instances["db"]
@classmethod @classmethod
async def release_client(cls, db: AsyncDatabase): async def release_client(cls, db: AsyncDatabase):
async with cls._lock: if db is not None:
if db is not None: if db is cls._instances["db"]:
if db is cls._instances["db"]: cls._instances["ref_count"] -= 1
cls._instances["ref_count"] -= 1 if cls._instances["ref_count"] == 0:
if cls._instances["ref_count"] == 0: cls._instances["db"] = None
cls._instances["db"] = None
@final @final
@ -125,18 +123,21 @@ class MongoKVStorage(BaseKVStorage):
self._collection_name = self.final_namespace self._collection_name = self.final_namespace
async def initialize(self): async def initialize(self):
if self.db is None: async with get_storage_lock(enable_logging=True):
self.db = await ClientManager.get_client() if self.db is None:
self.db = await ClientManager.get_client()
self._data = await get_or_create_collection(self.db, self._collection_name) self._data = await get_or_create_collection(self.db, self._collection_name)
logger.debug( logger.debug(
f"[{self.workspace}] Use MongoDB as KV {self._collection_name}" f"[{self.workspace}] Use MongoDB as KV {self._collection_name}"
) )
async def finalize(self): async def finalize(self):
if self.db is not None: async with get_storage_lock(enable_logging=True):
await ClientManager.release_client(self.db) if self.db is not None:
self.db = None await ClientManager.release_client(self.db)
self._data = None self.db = None
self._data = None
async def get_by_id(self, id: str) -> dict[str, Any] | None: async def get_by_id(self, id: str) -> dict[str, Any] | None:
# Unified handling for flattened keys # Unified handling for flattened keys
@ -251,22 +252,23 @@ class MongoKVStorage(BaseKVStorage):
Returns: Returns:
dict[str, str]: Status of the operation with keys 'status' and 'message' dict[str, str]: Status of the operation with keys 'status' and 'message'
""" """
try: async with get_storage_lock(enable_logging=True):
result = await self._data.delete_many({}) try:
deleted_count = result.deleted_count result = await self._data.delete_many({})
deleted_count = result.deleted_count
logger.info( logger.info(
f"[{self.workspace}] Dropped {deleted_count} documents from doc status {self._collection_name}" f"[{self.workspace}] Dropped {deleted_count} documents from doc status {self._collection_name}"
) )
return { return {
"status": "success", "status": "success",
"message": f"{deleted_count} documents dropped", "message": f"{deleted_count} documents dropped",
} }
except PyMongoError as e: except PyMongoError as e:
logger.error( logger.error(
f"[{self.workspace}] Error dropping doc status {self._collection_name}: {e}" f"[{self.workspace}] Error dropping doc status {self._collection_name}: {e}"
) )
return {"status": "error", "message": str(e)} return {"status": "error", "message": str(e)}
@final @final
@ -318,8 +320,10 @@ class MongoDocStatusStorage(DocStatusStorage):
self._collection_name = self.final_namespace self._collection_name = self.final_namespace
async def initialize(self): async def initialize(self):
if self.db is None: async with get_storage_lock(enable_logging=True):
self.db = await ClientManager.get_client() if self.db is None:
self.db = await ClientManager.get_client()
self._data = await get_or_create_collection(self.db, self._collection_name) self._data = await get_or_create_collection(self.db, self._collection_name)
# Create track_id index for better query performance # Create track_id index for better query performance
@ -333,10 +337,11 @@ class MongoDocStatusStorage(DocStatusStorage):
) )
async def finalize(self): async def finalize(self):
if self.db is not None: async with get_storage_lock(enable_logging=True):
await ClientManager.release_client(self.db) if self.db is not None:
self.db = None await ClientManager.release_client(self.db)
self._data = None self.db = None
self._data = None
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
return await self._data.find_one({"_id": id}) return await self._data.find_one({"_id": id})
@ -447,22 +452,23 @@ class MongoDocStatusStorage(DocStatusStorage):
Returns: Returns:
dict[str, str]: Status of the operation with keys 'status' and 'message' dict[str, str]: Status of the operation with keys 'status' and 'message'
""" """
try: async with get_storage_lock(enable_logging=True):
result = await self._data.delete_many({}) try:
deleted_count = result.deleted_count result = await self._data.delete_many({})
deleted_count = result.deleted_count
logger.info( logger.info(
f"[{self.workspace}] Dropped {deleted_count} documents from doc status {self._collection_name}" f"[{self.workspace}] Dropped {deleted_count} documents from doc status {self._collection_name}"
) )
return { return {
"status": "success", "status": "success",
"message": f"{deleted_count} documents dropped", "message": f"{deleted_count} documents dropped",
} }
except PyMongoError as e: except PyMongoError as e:
logger.error( logger.error(
f"[{self.workspace}] Error dropping doc status {self._collection_name}: {e}" f"[{self.workspace}] Error dropping doc status {self._collection_name}: {e}"
) )
return {"status": "error", "message": str(e)} return {"status": "error", "message": str(e)}
async def delete(self, ids: list[str]) -> None: async def delete(self, ids: list[str]) -> None:
await self._data.delete_many({"_id": {"$in": ids}}) await self._data.delete_many({"_id": {"$in": ids}})
@ -699,8 +705,10 @@ class MongoGraphStorage(BaseGraphStorage):
self._edge_collection_name = f"{self._collection_name}_edges" self._edge_collection_name = f"{self._collection_name}_edges"
async def initialize(self): async def initialize(self):
if self.db is None: async with get_graph_db_lock(enable_logging=True):
self.db = await ClientManager.get_client() if self.db is None:
self.db = await ClientManager.get_client()
self.collection = await get_or_create_collection( self.collection = await get_or_create_collection(
self.db, self._collection_name self.db, self._collection_name
) )
@ -712,11 +720,12 @@ class MongoGraphStorage(BaseGraphStorage):
) )
async def finalize(self): async def finalize(self):
if self.db is not None: async with get_graph_db_lock(enable_logging=True):
await ClientManager.release_client(self.db) if self.db is not None:
self.db = None await ClientManager.release_client(self.db)
self.collection = None self.db = None
self.edge_collection = None self.collection = None
self.edge_collection = None
# Sample entity document # Sample entity document
# "source_ids" is Array representation of "source_id" split by GRAPH_FIELD_SEP # "source_ids" is Array representation of "source_id" split by GRAPH_FIELD_SEP
@ -1567,29 +1576,30 @@ class MongoGraphStorage(BaseGraphStorage):
Returns: Returns:
dict[str, str]: Status of the operation with keys 'status' and 'message' dict[str, str]: Status of the operation with keys 'status' and 'message'
""" """
try: async with get_graph_db_lock(enable_logging=True):
result = await self.collection.delete_many({}) try:
deleted_count = result.deleted_count result = await self.collection.delete_many({})
deleted_count = result.deleted_count
logger.info( logger.info(
f"[{self.workspace}] Dropped {deleted_count} documents from graph {self._collection_name}" f"[{self.workspace}] Dropped {deleted_count} documents from graph {self._collection_name}"
) )
result = await self.edge_collection.delete_many({}) result = await self.edge_collection.delete_many({})
edge_count = result.deleted_count edge_count = result.deleted_count
logger.info( logger.info(
f"[{self.workspace}] Dropped {edge_count} edges from graph {self._edge_collection_name}" f"[{self.workspace}] Dropped {edge_count} edges from graph {self._edge_collection_name}"
) )
return { return {
"status": "success", "status": "success",
"message": f"{deleted_count} documents and {edge_count} edges dropped", "message": f"{deleted_count} documents and {edge_count} edges dropped",
} }
except PyMongoError as e: except PyMongoError as e:
logger.error( logger.error(
f"[{self.workspace}] Error dropping graph {self._collection_name}: {e}" f"[{self.workspace}] Error dropping graph {self._collection_name}: {e}"
) )
return {"status": "error", "message": str(e)} return {"status": "error", "message": str(e)}
@final @final
@ -1661,8 +1671,10 @@ class MongoVectorDBStorage(BaseVectorStorage):
self._max_batch_size = self.global_config["embedding_batch_num"] self._max_batch_size = self.global_config["embedding_batch_num"]
async def initialize(self): async def initialize(self):
if self.db is None: async with get_storage_lock(enable_logging=True):
self.db = await ClientManager.get_client() if self.db is None:
self.db = await ClientManager.get_client()
self._data = await get_or_create_collection(self.db, self._collection_name) self._data = await get_or_create_collection(self.db, self._collection_name)
# Ensure vector index exists # Ensure vector index exists
@ -1673,10 +1685,11 @@ class MongoVectorDBStorage(BaseVectorStorage):
) )
async def finalize(self): async def finalize(self):
if self.db is not None: async with get_storage_lock(enable_logging=True):
await ClientManager.release_client(self.db) if self.db is not None:
self.db = None await ClientManager.release_client(self.db)
self._data = None self.db = None
self._data = None
async def create_vector_index_if_not_exists(self): async def create_vector_index_if_not_exists(self):
"""Creates an Atlas Vector Search index.""" """Creates an Atlas Vector Search index."""
@ -1957,26 +1970,27 @@ class MongoVectorDBStorage(BaseVectorStorage):
Returns: Returns:
dict[str, str]: Status of the operation with keys 'status' and 'message' dict[str, str]: Status of the operation with keys 'status' and 'message'
""" """
try: async with get_storage_lock(enable_logging=True):
# Delete all documents try:
result = await self._data.delete_many({}) # Delete all documents
deleted_count = result.deleted_count result = await self._data.delete_many({})
deleted_count = result.deleted_count
# Recreate vector index # Recreate vector index
await self.create_vector_index_if_not_exists() await self.create_vector_index_if_not_exists()
logger.info( logger.info(
f"[{self.workspace}] Dropped {deleted_count} documents from vector storage {self._collection_name} and recreated vector index" f"[{self.workspace}] Dropped {deleted_count} documents from vector storage {self._collection_name} and recreated vector index"
) )
return { return {
"status": "success", "status": "success",
"message": f"{deleted_count} documents dropped and vector index recreated", "message": f"{deleted_count} documents dropped and vector index recreated",
} }
except PyMongoError as e: except PyMongoError as e:
logger.error( logger.error(
f"[{self.workspace}] Error dropping vector storage {self._collection_name}: {e}" f"[{self.workspace}] Error dropping vector storage {self._collection_name}: {e}"
) )
return {"status": "error", "message": str(e)} return {"status": "error", "message": str(e)}
async def get_or_create_collection(db: AsyncDatabase, collection_name: str): async def get_or_create_collection(db: AsyncDatabase, collection_name: str):

View file

@ -17,6 +17,7 @@ from ..utils import logger
from ..base import BaseGraphStorage from ..base import BaseGraphStorage
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
from ..constants import GRAPH_FIELD_SEP from ..constants import GRAPH_FIELD_SEP
from ..kg.shared_storage import get_graph_db_lock
import pipmaster as pm import pipmaster as pm
if not pm.is_installed("neo4j"): if not pm.is_installed("neo4j"):
@ -70,174 +71,180 @@ class Neo4JStorage(BaseGraphStorage):
return self.workspace return self.workspace
async def initialize(self): async def initialize(self):
URI = os.environ.get("NEO4J_URI", config.get("neo4j", "uri", fallback=None)) async with get_graph_db_lock(enable_logging=True):
USERNAME = os.environ.get( URI = os.environ.get("NEO4J_URI", config.get("neo4j", "uri", fallback=None))
"NEO4J_USERNAME", config.get("neo4j", "username", fallback=None) USERNAME = os.environ.get(
) "NEO4J_USERNAME", config.get("neo4j", "username", fallback=None)
PASSWORD = os.environ.get(
"NEO4J_PASSWORD", config.get("neo4j", "password", fallback=None)
)
MAX_CONNECTION_POOL_SIZE = int(
os.environ.get(
"NEO4J_MAX_CONNECTION_POOL_SIZE",
config.get("neo4j", "connection_pool_size", fallback=100),
) )
) PASSWORD = os.environ.get(
CONNECTION_TIMEOUT = float( "NEO4J_PASSWORD", config.get("neo4j", "password", fallback=None)
os.environ.get( )
"NEO4J_CONNECTION_TIMEOUT", MAX_CONNECTION_POOL_SIZE = int(
config.get("neo4j", "connection_timeout", fallback=30.0), os.environ.get(
), "NEO4J_MAX_CONNECTION_POOL_SIZE",
) config.get("neo4j", "connection_pool_size", fallback=100),
CONNECTION_ACQUISITION_TIMEOUT = float(
os.environ.get(
"NEO4J_CONNECTION_ACQUISITION_TIMEOUT",
config.get("neo4j", "connection_acquisition_timeout", fallback=30.0),
),
)
MAX_TRANSACTION_RETRY_TIME = float(
os.environ.get(
"NEO4J_MAX_TRANSACTION_RETRY_TIME",
config.get("neo4j", "max_transaction_retry_time", fallback=30.0),
),
)
MAX_CONNECTION_LIFETIME = float(
os.environ.get(
"NEO4J_MAX_CONNECTION_LIFETIME",
config.get("neo4j", "max_connection_lifetime", fallback=300.0),
),
)
LIVENESS_CHECK_TIMEOUT = float(
os.environ.get(
"NEO4J_LIVENESS_CHECK_TIMEOUT",
config.get("neo4j", "liveness_check_timeout", fallback=30.0),
),
)
KEEP_ALIVE = os.environ.get(
"NEO4J_KEEP_ALIVE",
config.get("neo4j", "keep_alive", fallback="true"),
).lower() in ("true", "1", "yes", "on")
DATABASE = os.environ.get(
"NEO4J_DATABASE", re.sub(r"[^a-zA-Z0-9-]", "-", self.namespace)
)
"""The default value approach for the DATABASE is only intended to maintain compatibility with legacy practices."""
self._driver: AsyncDriver = AsyncGraphDatabase.driver(
URI,
auth=(USERNAME, PASSWORD),
max_connection_pool_size=MAX_CONNECTION_POOL_SIZE,
connection_timeout=CONNECTION_TIMEOUT,
connection_acquisition_timeout=CONNECTION_ACQUISITION_TIMEOUT,
max_transaction_retry_time=MAX_TRANSACTION_RETRY_TIME,
max_connection_lifetime=MAX_CONNECTION_LIFETIME,
liveness_check_timeout=LIVENESS_CHECK_TIMEOUT,
keep_alive=KEEP_ALIVE,
)
# Try to connect to the database and create it if it doesn't exist
for database in (DATABASE, None):
self._DATABASE = database
connected = False
try:
async with self._driver.session(database=database) as session:
try:
result = await session.run("MATCH (n) RETURN n LIMIT 0")
await result.consume() # Ensure result is consumed
logger.info(
f"[{self.workspace}] Connected to {database} at {URI}"
)
connected = True
except neo4jExceptions.ServiceUnavailable as e:
logger.error(
f"[{self.workspace}] "
+ f"{database} at {URI} is not available".capitalize()
)
raise e
except neo4jExceptions.AuthError as e:
logger.error(
f"[{self.workspace}] Authentication failed for {database} at {URI}"
) )
raise e )
except neo4jExceptions.ClientError as e: CONNECTION_TIMEOUT = float(
if e.code == "Neo.ClientError.Database.DatabaseNotFound": os.environ.get(
logger.info( "NEO4J_CONNECTION_TIMEOUT",
f"[{self.workspace}] " config.get("neo4j", "connection_timeout", fallback=30.0),
+ f"{database} at {URI} not found. Try to create specified database.".capitalize() ),
) )
try: CONNECTION_ACQUISITION_TIMEOUT = float(
async with self._driver.session() as session: os.environ.get(
result = await session.run( "NEO4J_CONNECTION_ACQUISITION_TIMEOUT",
f"CREATE DATABASE `{database}` IF NOT EXISTS" config.get(
) "neo4j", "connection_acquisition_timeout", fallback=30.0
await result.consume() # Ensure result is consumed ),
logger.info( ),
f"[{self.workspace}] " )
+ f"{database} at {URI} created".capitalize() MAX_TRANSACTION_RETRY_TIME = float(
) os.environ.get(
connected = True "NEO4J_MAX_TRANSACTION_RETRY_TIME",
except ( config.get("neo4j", "max_transaction_retry_time", fallback=30.0),
neo4jExceptions.ClientError, ),
neo4jExceptions.DatabaseError, )
) as e: MAX_CONNECTION_LIFETIME = float(
if ( os.environ.get(
e.code "NEO4J_MAX_CONNECTION_LIFETIME",
== "Neo.ClientError.Statement.UnsupportedAdministrationCommand" config.get("neo4j", "max_connection_lifetime", fallback=300.0),
) or (e.code == "Neo.DatabaseError.Statement.ExecutionFailed"): ),
if database is not None: )
logger.warning( LIVENESS_CHECK_TIMEOUT = float(
f"[{self.workspace}] This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead. Fallback to use the default database." os.environ.get(
) "NEO4J_LIVENESS_CHECK_TIMEOUT",
if database is None: config.get("neo4j", "liveness_check_timeout", fallback=30.0),
logger.error( ),
f"[{self.workspace}] Failed to create {database} at {URI}" )
) KEEP_ALIVE = os.environ.get(
raise e "NEO4J_KEEP_ALIVE",
config.get("neo4j", "keep_alive", fallback="true"),
).lower() in ("true", "1", "yes", "on")
DATABASE = os.environ.get(
"NEO4J_DATABASE", re.sub(r"[^a-zA-Z0-9-]", "-", self.namespace)
)
"""The default value approach for the DATABASE is only intended to maintain compatibility with legacy practices."""
self._driver: AsyncDriver = AsyncGraphDatabase.driver(
URI,
auth=(USERNAME, PASSWORD),
max_connection_pool_size=MAX_CONNECTION_POOL_SIZE,
connection_timeout=CONNECTION_TIMEOUT,
connection_acquisition_timeout=CONNECTION_ACQUISITION_TIMEOUT,
max_transaction_retry_time=MAX_TRANSACTION_RETRY_TIME,
max_connection_lifetime=MAX_CONNECTION_LIFETIME,
liveness_check_timeout=LIVENESS_CHECK_TIMEOUT,
keep_alive=KEEP_ALIVE,
)
# Try to connect to the database and create it if it doesn't exist
for database in (DATABASE, None):
self._DATABASE = database
connected = False
if connected:
# Create index for workspace nodes on entity_id if it doesn't exist
workspace_label = self._get_workspace_label()
try: try:
async with self._driver.session(database=database) as session: async with self._driver.session(database=database) as session:
# Check if index exists first
check_query = f"""
CALL db.indexes() YIELD name, labelsOrTypes, properties
WHERE labelsOrTypes = ['{workspace_label}'] AND properties = ['entity_id']
RETURN count(*) > 0 AS exists
"""
try: try:
check_result = await session.run(check_query) result = await session.run("MATCH (n) RETURN n LIMIT 0")
record = await check_result.single() await result.consume() # Ensure result is consumed
await check_result.consume() logger.info(
f"[{self.workspace}] Connected to {database} at {URI}"
index_exists = record and record.get("exists", False) )
connected = True
if not index_exists: except neo4jExceptions.ServiceUnavailable as e:
# Create index only if it doesn't exist logger.error(
f"[{self.workspace}] "
+ f"{database} at {URI} is not available".capitalize()
)
raise e
except neo4jExceptions.AuthError as e:
logger.error(
f"[{self.workspace}] Authentication failed for {database} at {URI}"
)
raise e
except neo4jExceptions.ClientError as e:
if e.code == "Neo.ClientError.Database.DatabaseNotFound":
logger.info(
f"[{self.workspace}] "
+ f"{database} at {URI} not found. Try to create specified database.".capitalize()
)
try:
async with self._driver.session() as session:
result = await session.run( result = await session.run(
f"CREATE INDEX FOR (n:`{workspace_label}`) ON (n.entity_id)" f"CREATE DATABASE `{database}` IF NOT EXISTS"
)
await result.consume() # Ensure result is consumed
logger.info(
f"[{self.workspace}] "
+ f"{database} at {URI} created".capitalize()
)
connected = True
except (
neo4jExceptions.ClientError,
neo4jExceptions.DatabaseError,
) as e:
if (
e.code
== "Neo.ClientError.Statement.UnsupportedAdministrationCommand"
) or (
e.code == "Neo.DatabaseError.Statement.ExecutionFailed"
):
if database is not None:
logger.warning(
f"[{self.workspace}] This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead. Fallback to use the default database."
)
if database is None:
logger.error(
f"[{self.workspace}] Failed to create {database} at {URI}"
)
raise e
if connected:
# Create index for workspace nodes on entity_id if it doesn't exist
workspace_label = self._get_workspace_label()
try:
async with self._driver.session(database=database) as session:
# Check if index exists first
check_query = f"""
CALL db.indexes() YIELD name, labelsOrTypes, properties
WHERE labelsOrTypes = ['{workspace_label}'] AND properties = ['entity_id']
RETURN count(*) > 0 AS exists
"""
try:
check_result = await session.run(check_query)
record = await check_result.single()
await check_result.consume()
index_exists = record and record.get("exists", False)
if not index_exists:
# Create index only if it doesn't exist
result = await session.run(
f"CREATE INDEX FOR (n:`{workspace_label}`) ON (n.entity_id)"
)
await result.consume()
logger.info(
f"[{self.workspace}] Created index for {workspace_label} nodes on entity_id in {database}"
)
except Exception:
# Fallback if db.indexes() is not supported in this Neo4j version
result = await session.run(
f"CREATE INDEX IF NOT EXISTS FOR (n:`{workspace_label}`) ON (n.entity_id)"
) )
await result.consume() await result.consume()
logger.info( except Exception as e:
f"[{self.workspace}] Created index for {workspace_label} nodes on entity_id in {database}" logger.warning(
) f"[{self.workspace}] Failed to create index: {str(e)}"
except Exception: )
# Fallback if db.indexes() is not supported in this Neo4j version break
result = await session.run(
f"CREATE INDEX IF NOT EXISTS FOR (n:`{workspace_label}`) ON (n.entity_id)"
)
await result.consume()
except Exception as e:
logger.warning(
f"[{self.workspace}] Failed to create index: {str(e)}"
)
break
async def finalize(self): async def finalize(self):
"""Close the Neo4j driver and release all resources""" """Close the Neo4j driver and release all resources"""
if self._driver: async with get_graph_db_lock(enable_logging=True):
await self._driver.close() if self._driver:
self._driver = None await self._driver.close()
self._driver = None
async def __aexit__(self, exc_type, exc, tb): async def __aexit__(self, exc_type, exc, tb):
"""Ensure driver is closed when context manager exits""" """Ensure driver is closed when context manager exits"""
@ -1526,23 +1533,24 @@ class Neo4JStorage(BaseGraphStorage):
- On success: {"status": "success", "message": "workspace data dropped"} - On success: {"status": "success", "message": "workspace data dropped"}
- On failure: {"status": "error", "message": "<error details>"} - On failure: {"status": "error", "message": "<error details>"}
""" """
workspace_label = self._get_workspace_label() async with get_graph_db_lock(enable_logging=True):
try: workspace_label = self._get_workspace_label()
async with self._driver.session(database=self._DATABASE) as session: try:
# Delete all nodes and relationships in current workspace only async with self._driver.session(database=self._DATABASE) as session:
query = f"MATCH (n:`{workspace_label}`) DETACH DELETE n" # Delete all nodes and relationships in current workspace only
result = await session.run(query) query = f"MATCH (n:`{workspace_label}`) DETACH DELETE n"
await result.consume() # Ensure result is fully consumed result = await session.run(query)
await result.consume() # Ensure result is fully consumed
# logger.debug( # logger.debug(
# f"[{self.workspace}] Process {os.getpid()} drop Neo4j workspace '{workspace_label}' in database {self._DATABASE}" # f"[{self.workspace}] Process {os.getpid()} drop Neo4j workspace '{workspace_label}' in database {self._DATABASE}"
# ) # )
return { return {
"status": "success", "status": "success",
"message": f"workspace '{workspace_label}' data dropped", "message": f"workspace '{workspace_label}' data dropped",
} }
except Exception as e: except Exception as e:
logger.error( logger.error(
f"[{self.workspace}] Error dropping Neo4j workspace '{workspace_label}' in database {self._DATABASE}: {e}" f"[{self.workspace}] Error dropping Neo4j workspace '{workspace_label}' in database {self._DATABASE}: {e}"
) )
return {"status": "error", "message": str(e)} return {"status": "error", "message": str(e)}

View file

@ -30,7 +30,7 @@ from ..base import (
from ..namespace import NameSpace, is_namespace from ..namespace import NameSpace, is_namespace
from ..utils import logger from ..utils import logger
from ..constants import GRAPH_FIELD_SEP from ..constants import GRAPH_FIELD_SEP
from ..kg.shared_storage import get_graph_db_lock from ..kg.shared_storage import get_graph_db_lock, get_storage_lock
import pipmaster as pm import pipmaster as pm
@ -1406,23 +1406,26 @@ class PGKVStorage(BaseKVStorage):
self._max_batch_size = self.global_config["embedding_batch_num"] self._max_batch_size = self.global_config["embedding_batch_num"]
async def initialize(self): async def initialize(self):
if self.db is None: async with get_storage_lock(enable_logging=True):
self.db = await ClientManager.get_client() if self.db is None:
# Implement workspace priority: PostgreSQLDB.workspace > self.workspace > "default" self.db = await ClientManager.get_client()
if self.db.workspace:
# Use PostgreSQLDB's workspace (highest priority) # Implement workspace priority: PostgreSQLDB.workspace > self.workspace > "default"
self.workspace = self.db.workspace if self.db.workspace:
elif hasattr(self, "workspace") and self.workspace: # Use PostgreSQLDB's workspace (highest priority)
# Use storage class's workspace (medium priority) self.workspace = self.db.workspace
pass elif hasattr(self, "workspace") and self.workspace:
else: # Use storage class's workspace (medium priority)
# Use "default" for compatibility (lowest priority) pass
self.workspace = "default" else:
# Use "default" for compatibility (lowest priority)
self.workspace = "default"
async def finalize(self): async def finalize(self):
if self.db is not None: async with get_storage_lock(enable_logging=True):
await ClientManager.release_client(self.db) if self.db is not None:
self.db = None await ClientManager.release_client(self.db)
self.db = None
################ QUERY METHODS ################ ################ QUERY METHODS ################
async def get_all(self) -> dict[str, Any]: async def get_all(self) -> dict[str, Any]:
@ -1834,21 +1837,22 @@ class PGKVStorage(BaseKVStorage):
async def drop(self) -> dict[str, str]: async def drop(self) -> dict[str, str]:
"""Drop the storage""" """Drop the storage"""
try: async with get_storage_lock(enable_logging=True):
table_name = namespace_to_table_name(self.namespace) try:
if not table_name: table_name = namespace_to_table_name(self.namespace)
return { if not table_name:
"status": "error", return {
"message": f"Unknown namespace: {self.namespace}", "status": "error",
} "message": f"Unknown namespace: {self.namespace}",
}
drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format( drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format(
table_name=table_name table_name=table_name
) )
await self.db.execute(drop_sql, {"workspace": self.workspace}) await self.db.execute(drop_sql, {"workspace": self.workspace})
return {"status": "success", "message": "data dropped"} return {"status": "success", "message": "data dropped"}
except Exception as e: except Exception as e:
return {"status": "error", "message": str(e)} return {"status": "error", "message": str(e)}
@final @final
@ -1867,23 +1871,26 @@ class PGVectorStorage(BaseVectorStorage):
self.cosine_better_than_threshold = cosine_threshold self.cosine_better_than_threshold = cosine_threshold
async def initialize(self): async def initialize(self):
if self.db is None: async with get_storage_lock(enable_logging=True):
self.db = await ClientManager.get_client() if self.db is None:
# Implement workspace priority: PostgreSQLDB.workspace > self.workspace > "default" self.db = await ClientManager.get_client()
if self.db.workspace:
# Use PostgreSQLDB's workspace (highest priority) # Implement workspace priority: PostgreSQLDB.workspace > self.workspace > "default"
self.workspace = self.db.workspace if self.db.workspace:
elif hasattr(self, "workspace") and self.workspace: # Use PostgreSQLDB's workspace (highest priority)
# Use storage class's workspace (medium priority) self.workspace = self.db.workspace
pass elif hasattr(self, "workspace") and self.workspace:
else: # Use storage class's workspace (medium priority)
# Use "default" for compatibility (lowest priority) pass
self.workspace = "default" else:
# Use "default" for compatibility (lowest priority)
self.workspace = "default"
async def finalize(self): async def finalize(self):
if self.db is not None: async with get_storage_lock(enable_logging=True):
await ClientManager.release_client(self.db) if self.db is not None:
self.db = None await ClientManager.release_client(self.db)
self.db = None
def _upsert_chunks( def _upsert_chunks(
self, item: dict[str, Any], current_time: datetime.datetime self, item: dict[str, Any], current_time: datetime.datetime
@ -2153,21 +2160,22 @@ class PGVectorStorage(BaseVectorStorage):
async def drop(self) -> dict[str, str]: async def drop(self) -> dict[str, str]:
"""Drop the storage""" """Drop the storage"""
try: async with get_storage_lock(enable_logging=True):
table_name = namespace_to_table_name(self.namespace) try:
if not table_name: table_name = namespace_to_table_name(self.namespace)
return { if not table_name:
"status": "error", return {
"message": f"Unknown namespace: {self.namespace}", "status": "error",
} "message": f"Unknown namespace: {self.namespace}",
}
drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format( drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format(
table_name=table_name table_name=table_name
) )
await self.db.execute(drop_sql, {"workspace": self.workspace}) await self.db.execute(drop_sql, {"workspace": self.workspace})
return {"status": "success", "message": "data dropped"} return {"status": "success", "message": "data dropped"}
except Exception as e: except Exception as e:
return {"status": "error", "message": str(e)} return {"status": "error", "message": str(e)}
@final @final
@ -2186,23 +2194,26 @@ class PGDocStatusStorage(DocStatusStorage):
return dt.isoformat() return dt.isoformat()
async def initialize(self): async def initialize(self):
if self.db is None: async with get_storage_lock(enable_logging=True):
self.db = await ClientManager.get_client() if self.db is None:
# Implement workspace priority: PostgreSQLDB.workspace > self.workspace > "default" self.db = await ClientManager.get_client()
if self.db.workspace:
# Use PostgreSQLDB's workspace (highest priority) # Implement workspace priority: PostgreSQLDB.workspace > self.workspace > "default"
self.workspace = self.db.workspace if self.db.workspace:
elif hasattr(self, "workspace") and self.workspace: # Use PostgreSQLDB's workspace (highest priority)
# Use storage class's workspace (medium priority) self.workspace = self.db.workspace
pass elif hasattr(self, "workspace") and self.workspace:
else: # Use storage class's workspace (medium priority)
# Use "default" for compatibility (lowest priority) pass
self.workspace = "default" else:
# Use "default" for compatibility (lowest priority)
self.workspace = "default"
async def finalize(self): async def finalize(self):
if self.db is not None: async with get_storage_lock(enable_logging=True):
await ClientManager.release_client(self.db) if self.db is not None:
self.db = None await ClientManager.release_client(self.db)
self.db = None
async def filter_keys(self, keys: set[str]) -> set[str]: async def filter_keys(self, keys: set[str]) -> set[str]:
"""Filter out duplicated content""" """Filter out duplicated content"""
@ -2778,30 +2789,29 @@ class PGGraphStorage(BaseGraphStorage):
return normalized_id return normalized_id
async def initialize(self): async def initialize(self):
if self.db is None: async with get_graph_db_lock(enable_logging=True):
self.db = await ClientManager.get_client() if self.db is None:
# Implement workspace priority: PostgreSQLDB.workspace > self.workspace > "default" self.db = await ClientManager.get_client()
if self.db.workspace:
# Use PostgreSQLDB's workspace (highest priority)
self.workspace = self.db.workspace
elif hasattr(self, "workspace") and self.workspace:
# Use storage class's workspace (medium priority)
pass
else:
# Use "default" for compatibility (lowest priority)
self.workspace = "default"
# Dynamically generate graph name based on workspace # Implement workspace priority: PostgreSQLDB.workspace > self.workspace > "default"
self.graph_name = self._get_workspace_graph_name() if self.db.workspace:
# Use PostgreSQLDB's workspace (highest priority)
self.workspace = self.db.workspace
elif hasattr(self, "workspace") and self.workspace:
# Use storage class's workspace (medium priority)
pass
else:
# Use "default" for compatibility (lowest priority)
self.workspace = "default"
# Log the graph initialization for debugging # Dynamically generate graph name based on workspace
logger.info( self.graph_name = self._get_workspace_graph_name()
f"[{self.workspace}] PostgreSQL Graph initialized: graph_name='{self.graph_name}'"
) # Log the graph initialization for debugging
logger.info(
f"[{self.workspace}] PostgreSQL Graph initialized: graph_name='{self.graph_name}'"
)
# Use graph database lock to ensure atomic operations and prevent deadlocks
graph_db_lock = get_graph_db_lock(enable_logging=False)
async with graph_db_lock:
# Create AGE extension and configure graph environment once at initialization # Create AGE extension and configure graph environment once at initialization
async with self.db.pool.acquire() as connection: async with self.db.pool.acquire() as connection:
# First ensure AGE extension is created # First ensure AGE extension is created
@ -2840,9 +2850,10 @@ class PGGraphStorage(BaseGraphStorage):
) )
async def finalize(self): async def finalize(self):
if self.db is not None: async with get_graph_db_lock(enable_logging=True):
await ClientManager.release_client(self.db) if self.db is not None:
self.db = None await ClientManager.release_client(self.db)
self.db = None
async def index_done_callback(self) -> None: async def index_done_callback(self) -> None:
# PG handles persistence automatically # PG handles persistence automatically

View file

@ -7,6 +7,7 @@ import hashlib
import uuid import uuid
from ..utils import logger from ..utils import logger
from ..base import BaseVectorStorage from ..base import BaseVectorStorage
from ..kg.shared_storage import get_storage_lock
import configparser import configparser
import pipmaster as pm import pipmaster as pm
@ -118,13 +119,33 @@ class QdrantVectorDBStorage(BaseVectorStorage):
), ),
) )
self._max_batch_size = self.global_config["embedding_batch_num"] self._max_batch_size = self.global_config["embedding_batch_num"]
QdrantVectorDBStorage.create_collection_if_not_exist( self._initialized = False
self._client,
self.final_namespace, async def initialize(self):
vectors_config=models.VectorParams( """Initialize Qdrant collection"""
size=self.embedding_func.embedding_dim, distance=models.Distance.COSINE async with get_storage_lock(enable_logging=True):
), if self._initialized:
) return
try:
# Create collection if not exists
QdrantVectorDBStorage.create_collection_if_not_exist(
self._client,
self.final_namespace,
vectors_config=models.VectorParams(
size=self.embedding_func.embedding_dim,
distance=models.Distance.COSINE,
),
)
self._initialized = True
logger.info(
f"[{self.workspace}] Qdrant collection '{self.namespace}' initialized successfully"
)
except Exception as e:
logger.error(
f"[{self.workspace}] Failed to initialize Qdrant collection '{self.namespace}': {e}"
)
raise
async def upsert(self, data: dict[str, dict[str, Any]]) -> None: async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.debug(f"[{self.workspace}] Inserting {len(data)} to {self.namespace}") logger.debug(f"[{self.workspace}] Inserting {len(data)} to {self.namespace}")
@ -382,27 +403,28 @@ class QdrantVectorDBStorage(BaseVectorStorage):
- On success: {"status": "success", "message": "data dropped"} - On success: {"status": "success", "message": "data dropped"}
- On failure: {"status": "error", "message": "<error details>"} - On failure: {"status": "error", "message": "<error details>"}
""" """
try: async with get_storage_lock(enable_logging=True):
# Delete the collection and recreate it try:
if self._client.collection_exists(self.final_namespace): # Delete the collection and recreate it
self._client.delete_collection(self.final_namespace) if self._client.collection_exists(self.final_namespace):
self._client.delete_collection(self.final_namespace)
# Recreate the collection # Recreate the collection
QdrantVectorDBStorage.create_collection_if_not_exist( QdrantVectorDBStorage.create_collection_if_not_exist(
self._client, self._client,
self.final_namespace, self.final_namespace,
vectors_config=models.VectorParams( vectors_config=models.VectorParams(
size=self.embedding_func.embedding_dim, size=self.embedding_func.embedding_dim,
distance=models.Distance.COSINE, distance=models.Distance.COSINE,
), ),
) )
logger.info( logger.info(
f"[{self.workspace}] Process {os.getpid()} drop Qdrant collection {self.namespace}" f"[{self.workspace}] Process {os.getpid()} drop Qdrant collection {self.namespace}"
) )
return {"status": "success", "message": "data dropped"} return {"status": "success", "message": "data dropped"}
except Exception as e: except Exception as e:
logger.error( logger.error(
f"[{self.workspace}] Error dropping Qdrant collection {self.namespace}: {e}" f"[{self.workspace}] Error dropping Qdrant collection {self.namespace}: {e}"
) )
return {"status": "error", "message": str(e)} return {"status": "error", "message": str(e)}

View file

@ -20,6 +20,7 @@ from lightrag.base import (
DocStatus, DocStatus,
DocProcessingStatus, DocProcessingStatus,
) )
from ..kg.shared_storage import get_storage_lock
import json import json
# Import tenacity for retry logic # Import tenacity for retry logic
@ -179,32 +180,33 @@ class RedisKVStorage(BaseKVStorage):
async def initialize(self): async def initialize(self):
"""Initialize Redis connection and migrate legacy cache structure if needed""" """Initialize Redis connection and migrate legacy cache structure if needed"""
if self._initialized: async with get_storage_lock(enable_logging=True):
return if self._initialized:
return
# Test connection # Test connection
try:
async with self._get_redis_connection() as redis:
await redis.ping()
logger.info(
f"[{self.workspace}] Connected to Redis for namespace {self.namespace}"
)
self._initialized = True
except Exception as e:
logger.error(f"[{self.workspace}] Failed to connect to Redis: {e}")
# Clean up on connection failure
await self.close()
raise
# Migrate legacy cache structure if this is a cache namespace
if self.namespace.endswith("_cache"):
try: try:
await self._migrate_legacy_cache_structure() async with self._get_redis_connection() as redis:
await redis.ping()
logger.info(
f"[{self.workspace}] Connected to Redis for namespace {self.namespace}"
)
self._initialized = True
except Exception as e: except Exception as e:
logger.error( logger.error(f"[{self.workspace}] Failed to connect to Redis: {e}")
f"[{self.workspace}] Failed to migrate legacy cache structure: {e}" # Clean up on connection failure
) await self.close()
# Don't fail initialization for migration errors, just log them raise
# Migrate legacy cache structure if this is a cache namespace
if self.namespace.endswith("_cache"):
try:
await self._migrate_legacy_cache_structure()
except Exception as e:
logger.error(
f"[{self.workspace}] Failed to migrate legacy cache structure: {e}"
)
# Don't fail initialization for migration errors, just log them
@asynccontextmanager @asynccontextmanager
async def _get_redis_connection(self): async def _get_redis_connection(self):
@ -426,39 +428,42 @@ class RedisKVStorage(BaseKVStorage):
Returns: Returns:
dict[str, str]: Status of the operation with keys 'status' and 'message' dict[str, str]: Status of the operation with keys 'status' and 'message'
""" """
async with self._get_redis_connection() as redis: async with get_storage_lock(enable_logging=True):
try: async with self._get_redis_connection() as redis:
# Use SCAN to find all keys with the namespace prefix try:
pattern = f"{self.final_namespace}:*" # Use SCAN to find all keys with the namespace prefix
cursor = 0 pattern = f"{self.final_namespace}:*"
deleted_count = 0 cursor = 0
deleted_count = 0
while True: while True:
cursor, keys = await redis.scan(cursor, match=pattern, count=1000) cursor, keys = await redis.scan(
if keys: cursor, match=pattern, count=1000
# Delete keys in batches )
pipe = redis.pipeline() if keys:
for key in keys: # Delete keys in batches
pipe.delete(key) pipe = redis.pipeline()
results = await pipe.execute() for key in keys:
deleted_count += sum(results) pipe.delete(key)
results = await pipe.execute()
deleted_count += sum(results)
if cursor == 0: if cursor == 0:
break break
logger.info( logger.info(
f"[{self.workspace}] Dropped {deleted_count} keys from {self.namespace}" f"[{self.workspace}] Dropped {deleted_count} keys from {self.namespace}"
) )
return { return {
"status": "success", "status": "success",
"message": f"{deleted_count} keys dropped", "message": f"{deleted_count} keys dropped",
} }
except Exception as e: except Exception as e:
logger.error( logger.error(
f"[{self.workspace}] Error dropping keys from {self.namespace}: {e}" f"[{self.workspace}] Error dropping keys from {self.namespace}: {e}"
) )
return {"status": "error", "message": str(e)} return {"status": "error", "message": str(e)}
async def _migrate_legacy_cache_structure(self): async def _migrate_legacy_cache_structure(self):
"""Migrate legacy nested cache structure to flattened structure for Redis """Migrate legacy nested cache structure to flattened structure for Redis
@ -601,23 +606,24 @@ class RedisDocStatusStorage(DocStatusStorage):
async def initialize(self): async def initialize(self):
"""Initialize Redis connection""" """Initialize Redis connection"""
if self._initialized: async with get_storage_lock(enable_logging=True):
return if self._initialized:
return
try: try:
async with self._get_redis_connection() as redis: async with self._get_redis_connection() as redis:
await redis.ping() await redis.ping()
logger.info( logger.info(
f"[{self.workspace}] Connected to Redis for doc status namespace {self.namespace}" f"[{self.workspace}] Connected to Redis for doc status namespace {self.namespace}"
)
self._initialized = True
except Exception as e:
logger.error(
f"[{self.workspace}] Failed to connect to Redis for doc status: {e}"
) )
self._initialized = True # Clean up on connection failure
except Exception as e: await self.close()
logger.error( raise
f"[{self.workspace}] Failed to connect to Redis for doc status: {e}"
)
# Clean up on connection failure
await self.close()
raise
@asynccontextmanager @asynccontextmanager
async def _get_redis_connection(self): async def _get_redis_connection(self):
@ -1043,32 +1049,35 @@ class RedisDocStatusStorage(DocStatusStorage):
async def drop(self) -> dict[str, str]: async def drop(self) -> dict[str, str]:
"""Drop all document status data from storage and clean up resources""" """Drop all document status data from storage and clean up resources"""
try: async with get_storage_lock(enable_logging=True):
async with self._get_redis_connection() as redis: try:
# Use SCAN to find all keys with the namespace prefix async with self._get_redis_connection() as redis:
pattern = f"{self.final_namespace}:*" # Use SCAN to find all keys with the namespace prefix
cursor = 0 pattern = f"{self.final_namespace}:*"
deleted_count = 0 cursor = 0
deleted_count = 0
while True: while True:
cursor, keys = await redis.scan(cursor, match=pattern, count=1000) cursor, keys = await redis.scan(
if keys: cursor, match=pattern, count=1000
# Delete keys in batches )
pipe = redis.pipeline() if keys:
for key in keys: # Delete keys in batches
pipe.delete(key) pipe = redis.pipeline()
results = await pipe.execute() for key in keys:
deleted_count += sum(results) pipe.delete(key)
results = await pipe.execute()
deleted_count += sum(results)
if cursor == 0: if cursor == 0:
break break
logger.info( logger.info(
f"[{self.workspace}] Dropped {deleted_count} doc status keys from {self.namespace}" f"[{self.workspace}] Dropped {deleted_count} doc status keys from {self.namespace}"
)
return {"status": "success", "message": "data dropped"}
except Exception as e:
logger.error(
f"[{self.workspace}] Error dropping doc status {self.namespace}: {e}"
) )
return {"status": "success", "message": "data dropped"} return {"status": "error", "message": str(e)}
except Exception as e:
logger.error(
f"[{self.workspace}] Error dropping doc status {self.namespace}: {e}"
)
return {"status": "error", "message": str(e)}

View file

@ -530,10 +530,8 @@ class LightRAG:
self._storages_status = StoragesStatus.CREATED self._storages_status = StoragesStatus.CREATED
async def initialize_storages(self): async def initialize_storages(self):
"""Asynchronously initialize the storages""" """Storage initialization must be called one by one to prevent deadlock"""
if self._storages_status == StoragesStatus.CREATED: if self._storages_status == StoragesStatus.CREATED:
tasks = []
for storage in ( for storage in (
self.full_docs, self.full_docs,
self.text_chunks, self.text_chunks,
@ -547,9 +545,7 @@ class LightRAG:
self.doc_status, self.doc_status,
): ):
if storage: if storage:
tasks.append(storage.initialize()) await storage.initialize()
await asyncio.gather(*tasks)
self._storages_status = StoragesStatus.INITIALIZED self._storages_status = StoragesStatus.INITIALIZED
logger.debug("All storage types initialized") logger.debug("All storage types initialized")