Fix: add muti-process lock for initialize and drop method for all storage
This commit is contained in:
parent
ca00b9c8ee
commit
fc8ca1a706
8 changed files with 679 additions and 595 deletions
|
|
@ -9,6 +9,7 @@ from ..utils import logger
|
||||||
from ..base import BaseGraphStorage
|
from ..base import BaseGraphStorage
|
||||||
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
||||||
from ..constants import GRAPH_FIELD_SEP
|
from ..constants import GRAPH_FIELD_SEP
|
||||||
|
from ..kg.shared_storage import get_graph_db_lock
|
||||||
import pipmaster as pm
|
import pipmaster as pm
|
||||||
|
|
||||||
if not pm.is_installed("neo4j"):
|
if not pm.is_installed("neo4j"):
|
||||||
|
|
@ -55,53 +56,56 @@ class MemgraphStorage(BaseGraphStorage):
|
||||||
return self.workspace
|
return self.workspace
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
URI = os.environ.get(
|
async with get_graph_db_lock(enable_logging=True):
|
||||||
"MEMGRAPH_URI",
|
URI = os.environ.get(
|
||||||
config.get("memgraph", "uri", fallback="bolt://localhost:7687"),
|
"MEMGRAPH_URI",
|
||||||
)
|
config.get("memgraph", "uri", fallback="bolt://localhost:7687"),
|
||||||
USERNAME = os.environ.get(
|
|
||||||
"MEMGRAPH_USERNAME", config.get("memgraph", "username", fallback="")
|
|
||||||
)
|
|
||||||
PASSWORD = os.environ.get(
|
|
||||||
"MEMGRAPH_PASSWORD", config.get("memgraph", "password", fallback="")
|
|
||||||
)
|
|
||||||
DATABASE = os.environ.get(
|
|
||||||
"MEMGRAPH_DATABASE", config.get("memgraph", "database", fallback="memgraph")
|
|
||||||
)
|
|
||||||
|
|
||||||
self._driver = AsyncGraphDatabase.driver(
|
|
||||||
URI,
|
|
||||||
auth=(USERNAME, PASSWORD),
|
|
||||||
)
|
|
||||||
self._DATABASE = DATABASE
|
|
||||||
try:
|
|
||||||
async with self._driver.session(database=DATABASE) as session:
|
|
||||||
# Create index for base nodes on entity_id if it doesn't exist
|
|
||||||
try:
|
|
||||||
workspace_label = self._get_workspace_label()
|
|
||||||
await session.run(
|
|
||||||
f"""CREATE INDEX ON :{workspace_label}(entity_id)"""
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
f"[{self.workspace}] Created index on :{workspace_label}(entity_id) in Memgraph."
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
# Index may already exist, which is not an error
|
|
||||||
logger.warning(
|
|
||||||
f"[{self.workspace}] Index creation on :{workspace_label}(entity_id) may have failed or already exists: {e}"
|
|
||||||
)
|
|
||||||
await session.run("RETURN 1")
|
|
||||||
logger.info(f"[{self.workspace}] Connected to Memgraph at {URI}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"[{self.workspace}] Failed to connect to Memgraph at {URI}: {e}"
|
|
||||||
)
|
)
|
||||||
raise
|
USERNAME = os.environ.get(
|
||||||
|
"MEMGRAPH_USERNAME", config.get("memgraph", "username", fallback="")
|
||||||
|
)
|
||||||
|
PASSWORD = os.environ.get(
|
||||||
|
"MEMGRAPH_PASSWORD", config.get("memgraph", "password", fallback="")
|
||||||
|
)
|
||||||
|
DATABASE = os.environ.get(
|
||||||
|
"MEMGRAPH_DATABASE",
|
||||||
|
config.get("memgraph", "database", fallback="memgraph"),
|
||||||
|
)
|
||||||
|
|
||||||
|
self._driver = AsyncGraphDatabase.driver(
|
||||||
|
URI,
|
||||||
|
auth=(USERNAME, PASSWORD),
|
||||||
|
)
|
||||||
|
self._DATABASE = DATABASE
|
||||||
|
try:
|
||||||
|
async with self._driver.session(database=DATABASE) as session:
|
||||||
|
# Create index for base nodes on entity_id if it doesn't exist
|
||||||
|
try:
|
||||||
|
workspace_label = self._get_workspace_label()
|
||||||
|
await session.run(
|
||||||
|
f"""CREATE INDEX ON :{workspace_label}(entity_id)"""
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"[{self.workspace}] Created index on :{workspace_label}(entity_id) in Memgraph."
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
# Index may already exist, which is not an error
|
||||||
|
logger.warning(
|
||||||
|
f"[{self.workspace}] Index creation on :{workspace_label}(entity_id) may have failed or already exists: {e}"
|
||||||
|
)
|
||||||
|
await session.run("RETURN 1")
|
||||||
|
logger.info(f"[{self.workspace}] Connected to Memgraph at {URI}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"[{self.workspace}] Failed to connect to Memgraph at {URI}: {e}"
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
async def finalize(self):
|
async def finalize(self):
|
||||||
if self._driver is not None:
|
async with get_graph_db_lock(enable_logging=True):
|
||||||
await self._driver.close()
|
if self._driver is not None:
|
||||||
self._driver = None
|
await self._driver.close()
|
||||||
|
self._driver = None
|
||||||
|
|
||||||
async def __aexit__(self, exc_type, exc, tb):
|
async def __aexit__(self, exc_type, exc, tb):
|
||||||
await self.finalize()
|
await self.finalize()
|
||||||
|
|
@ -739,21 +743,22 @@ class MemgraphStorage(BaseGraphStorage):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Memgraph driver is not initialized. Call 'await initialize()' first."
|
"Memgraph driver is not initialized. Call 'await initialize()' first."
|
||||||
)
|
)
|
||||||
try:
|
async with get_graph_db_lock(enable_logging=True):
|
||||||
async with self._driver.session(database=self._DATABASE) as session:
|
try:
|
||||||
workspace_label = self._get_workspace_label()
|
async with self._driver.session(database=self._DATABASE) as session:
|
||||||
query = f"MATCH (n:`{workspace_label}`) DETACH DELETE n"
|
workspace_label = self._get_workspace_label()
|
||||||
result = await session.run(query)
|
query = f"MATCH (n:`{workspace_label}`) DETACH DELETE n"
|
||||||
await result.consume()
|
result = await session.run(query)
|
||||||
logger.info(
|
await result.consume()
|
||||||
f"[{self.workspace}] Dropped workspace {workspace_label} from Memgraph database {self._DATABASE}"
|
logger.info(
|
||||||
|
f"[{self.workspace}] Dropped workspace {workspace_label} from Memgraph database {self._DATABASE}"
|
||||||
|
)
|
||||||
|
return {"status": "success", "message": "workspace data dropped"}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"[{self.workspace}] Error dropping workspace {workspace_label} from Memgraph database {self._DATABASE}: {e}"
|
||||||
)
|
)
|
||||||
return {"status": "success", "message": "workspace data dropped"}
|
return {"status": "error", "message": str(e)}
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"[{self.workspace}] Error dropping workspace {workspace_label} from Memgraph database {self._DATABASE}: {e}"
|
|
||||||
)
|
|
||||||
return {"status": "error", "message": str(e)}
|
|
||||||
|
|
||||||
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
||||||
"""Get the total degree (sum of relationships) of two nodes.
|
"""Get the total degree (sum of relationships) of two nodes.
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ import numpy as np
|
||||||
from lightrag.utils import logger, compute_mdhash_id
|
from lightrag.utils import logger, compute_mdhash_id
|
||||||
from ..base import BaseVectorStorage
|
from ..base import BaseVectorStorage
|
||||||
from ..constants import DEFAULT_MAX_FILE_PATH_LENGTH
|
from ..constants import DEFAULT_MAX_FILE_PATH_LENGTH
|
||||||
|
from ..kg.shared_storage import get_storage_lock
|
||||||
import pipmaster as pm
|
import pipmaster as pm
|
||||||
|
|
||||||
if not pm.is_installed("pymilvus"):
|
if not pm.is_installed("pymilvus"):
|
||||||
|
|
@ -752,9 +753,26 @@ class MilvusVectorDBStorage(BaseVectorStorage):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||||
|
self._initialized = False
|
||||||
|
|
||||||
# Create collection and check compatibility
|
async def initialize(self):
|
||||||
self._create_collection_if_not_exist()
|
"""Initialize Milvus collection"""
|
||||||
|
async with get_storage_lock(enable_logging=True):
|
||||||
|
if self._initialized:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create collection and check compatibility
|
||||||
|
self._create_collection_if_not_exist()
|
||||||
|
self._initialized = True
|
||||||
|
logger.info(
|
||||||
|
f"[{self.workspace}] Milvus collection '{self.namespace}' initialized successfully"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"[{self.workspace}] Failed to initialize Milvus collection '{self.namespace}': {e}"
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||||
logger.debug(f"[{self.workspace}] Inserting {len(data)} to {self.namespace}")
|
logger.debug(f"[{self.workspace}] Inserting {len(data)} to {self.namespace}")
|
||||||
|
|
@ -1012,20 +1030,21 @@ class MilvusVectorDBStorage(BaseVectorStorage):
|
||||||
- On success: {"status": "success", "message": "data dropped"}
|
- On success: {"status": "success", "message": "data dropped"}
|
||||||
- On failure: {"status": "error", "message": "<error details>"}
|
- On failure: {"status": "error", "message": "<error details>"}
|
||||||
"""
|
"""
|
||||||
try:
|
async with get_storage_lock(enable_logging=True):
|
||||||
# Drop the collection and recreate it
|
try:
|
||||||
if self._client.has_collection(self.final_namespace):
|
# Drop the collection and recreate it
|
||||||
self._client.drop_collection(self.final_namespace)
|
if self._client.has_collection(self.final_namespace):
|
||||||
|
self._client.drop_collection(self.final_namespace)
|
||||||
|
|
||||||
# Recreate the collection
|
# Recreate the collection
|
||||||
self._create_collection_if_not_exist()
|
self._create_collection_if_not_exist()
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[{self.workspace}] Process {os.getpid()} drop Milvus collection {self.namespace}"
|
f"[{self.workspace}] Process {os.getpid()} drop Milvus collection {self.namespace}"
|
||||||
)
|
)
|
||||||
return {"status": "success", "message": "data dropped"}
|
return {"status": "success", "message": "data dropped"}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"[{self.workspace}] Error dropping Milvus collection {self.namespace}: {e}"
|
f"[{self.workspace}] Error dropping Milvus collection {self.namespace}: {e}"
|
||||||
)
|
)
|
||||||
return {"status": "error", "message": str(e)}
|
return {"status": "error", "message": str(e)}
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,7 @@ from ..base import (
|
||||||
from ..utils import logger, compute_mdhash_id
|
from ..utils import logger, compute_mdhash_id
|
||||||
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
||||||
from ..constants import GRAPH_FIELD_SEP
|
from ..constants import GRAPH_FIELD_SEP
|
||||||
|
from ..kg.shared_storage import get_storage_lock, get_graph_db_lock
|
||||||
|
|
||||||
import pipmaster as pm
|
import pipmaster as pm
|
||||||
|
|
||||||
|
|
@ -39,39 +40,36 @@ GRAPH_BFS_MODE = os.getenv("MONGO_GRAPH_BFS_MODE", "bidirectional")
|
||||||
|
|
||||||
class ClientManager:
|
class ClientManager:
|
||||||
_instances = {"db": None, "ref_count": 0}
|
_instances = {"db": None, "ref_count": 0}
|
||||||
_lock = asyncio.Lock()
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_client(cls) -> AsyncMongoClient:
|
async def get_client(cls) -> AsyncMongoClient:
|
||||||
async with cls._lock:
|
if cls._instances["db"] is None:
|
||||||
if cls._instances["db"] is None:
|
uri = os.environ.get(
|
||||||
uri = os.environ.get(
|
"MONGO_URI",
|
||||||
"MONGO_URI",
|
config.get(
|
||||||
config.get(
|
"mongodb",
|
||||||
"mongodb",
|
"uri",
|
||||||
"uri",
|
fallback="mongodb://root:root@localhost:27017/",
|
||||||
fallback="mongodb://root:root@localhost:27017/",
|
),
|
||||||
),
|
)
|
||||||
)
|
database_name = os.environ.get(
|
||||||
database_name = os.environ.get(
|
"MONGO_DATABASE",
|
||||||
"MONGO_DATABASE",
|
config.get("mongodb", "database", fallback="LightRAG"),
|
||||||
config.get("mongodb", "database", fallback="LightRAG"),
|
)
|
||||||
)
|
client = AsyncMongoClient(uri)
|
||||||
client = AsyncMongoClient(uri)
|
db = client.get_database(database_name)
|
||||||
db = client.get_database(database_name)
|
cls._instances["db"] = db
|
||||||
cls._instances["db"] = db
|
cls._instances["ref_count"] = 0
|
||||||
cls._instances["ref_count"] = 0
|
cls._instances["ref_count"] += 1
|
||||||
cls._instances["ref_count"] += 1
|
return cls._instances["db"]
|
||||||
return cls._instances["db"]
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def release_client(cls, db: AsyncDatabase):
|
async def release_client(cls, db: AsyncDatabase):
|
||||||
async with cls._lock:
|
if db is not None:
|
||||||
if db is not None:
|
if db is cls._instances["db"]:
|
||||||
if db is cls._instances["db"]:
|
cls._instances["ref_count"] -= 1
|
||||||
cls._instances["ref_count"] -= 1
|
if cls._instances["ref_count"] == 0:
|
||||||
if cls._instances["ref_count"] == 0:
|
cls._instances["db"] = None
|
||||||
cls._instances["db"] = None
|
|
||||||
|
|
||||||
|
|
||||||
@final
|
@final
|
||||||
|
|
@ -125,18 +123,21 @@ class MongoKVStorage(BaseKVStorage):
|
||||||
self._collection_name = self.final_namespace
|
self._collection_name = self.final_namespace
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
if self.db is None:
|
async with get_storage_lock(enable_logging=True):
|
||||||
self.db = await ClientManager.get_client()
|
if self.db is None:
|
||||||
|
self.db = await ClientManager.get_client()
|
||||||
|
|
||||||
self._data = await get_or_create_collection(self.db, self._collection_name)
|
self._data = await get_or_create_collection(self.db, self._collection_name)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"[{self.workspace}] Use MongoDB as KV {self._collection_name}"
|
f"[{self.workspace}] Use MongoDB as KV {self._collection_name}"
|
||||||
)
|
)
|
||||||
|
|
||||||
async def finalize(self):
|
async def finalize(self):
|
||||||
if self.db is not None:
|
async with get_storage_lock(enable_logging=True):
|
||||||
await ClientManager.release_client(self.db)
|
if self.db is not None:
|
||||||
self.db = None
|
await ClientManager.release_client(self.db)
|
||||||
self._data = None
|
self.db = None
|
||||||
|
self._data = None
|
||||||
|
|
||||||
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||||
# Unified handling for flattened keys
|
# Unified handling for flattened keys
|
||||||
|
|
@ -251,22 +252,23 @@ class MongoKVStorage(BaseKVStorage):
|
||||||
Returns:
|
Returns:
|
||||||
dict[str, str]: Status of the operation with keys 'status' and 'message'
|
dict[str, str]: Status of the operation with keys 'status' and 'message'
|
||||||
"""
|
"""
|
||||||
try:
|
async with get_storage_lock(enable_logging=True):
|
||||||
result = await self._data.delete_many({})
|
try:
|
||||||
deleted_count = result.deleted_count
|
result = await self._data.delete_many({})
|
||||||
|
deleted_count = result.deleted_count
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[{self.workspace}] Dropped {deleted_count} documents from doc status {self._collection_name}"
|
f"[{self.workspace}] Dropped {deleted_count} documents from doc status {self._collection_name}"
|
||||||
)
|
)
|
||||||
return {
|
return {
|
||||||
"status": "success",
|
"status": "success",
|
||||||
"message": f"{deleted_count} documents dropped",
|
"message": f"{deleted_count} documents dropped",
|
||||||
}
|
}
|
||||||
except PyMongoError as e:
|
except PyMongoError as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"[{self.workspace}] Error dropping doc status {self._collection_name}: {e}"
|
f"[{self.workspace}] Error dropping doc status {self._collection_name}: {e}"
|
||||||
)
|
)
|
||||||
return {"status": "error", "message": str(e)}
|
return {"status": "error", "message": str(e)}
|
||||||
|
|
||||||
|
|
||||||
@final
|
@final
|
||||||
|
|
@ -318,8 +320,10 @@ class MongoDocStatusStorage(DocStatusStorage):
|
||||||
self._collection_name = self.final_namespace
|
self._collection_name = self.final_namespace
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
if self.db is None:
|
async with get_storage_lock(enable_logging=True):
|
||||||
self.db = await ClientManager.get_client()
|
if self.db is None:
|
||||||
|
self.db = await ClientManager.get_client()
|
||||||
|
|
||||||
self._data = await get_or_create_collection(self.db, self._collection_name)
|
self._data = await get_or_create_collection(self.db, self._collection_name)
|
||||||
|
|
||||||
# Create track_id index for better query performance
|
# Create track_id index for better query performance
|
||||||
|
|
@ -333,10 +337,11 @@ class MongoDocStatusStorage(DocStatusStorage):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def finalize(self):
|
async def finalize(self):
|
||||||
if self.db is not None:
|
async with get_storage_lock(enable_logging=True):
|
||||||
await ClientManager.release_client(self.db)
|
if self.db is not None:
|
||||||
self.db = None
|
await ClientManager.release_client(self.db)
|
||||||
self._data = None
|
self.db = None
|
||||||
|
self._data = None
|
||||||
|
|
||||||
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
|
||||||
return await self._data.find_one({"_id": id})
|
return await self._data.find_one({"_id": id})
|
||||||
|
|
@ -447,22 +452,23 @@ class MongoDocStatusStorage(DocStatusStorage):
|
||||||
Returns:
|
Returns:
|
||||||
dict[str, str]: Status of the operation with keys 'status' and 'message'
|
dict[str, str]: Status of the operation with keys 'status' and 'message'
|
||||||
"""
|
"""
|
||||||
try:
|
async with get_storage_lock(enable_logging=True):
|
||||||
result = await self._data.delete_many({})
|
try:
|
||||||
deleted_count = result.deleted_count
|
result = await self._data.delete_many({})
|
||||||
|
deleted_count = result.deleted_count
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[{self.workspace}] Dropped {deleted_count} documents from doc status {self._collection_name}"
|
f"[{self.workspace}] Dropped {deleted_count} documents from doc status {self._collection_name}"
|
||||||
)
|
)
|
||||||
return {
|
return {
|
||||||
"status": "success",
|
"status": "success",
|
||||||
"message": f"{deleted_count} documents dropped",
|
"message": f"{deleted_count} documents dropped",
|
||||||
}
|
}
|
||||||
except PyMongoError as e:
|
except PyMongoError as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"[{self.workspace}] Error dropping doc status {self._collection_name}: {e}"
|
f"[{self.workspace}] Error dropping doc status {self._collection_name}: {e}"
|
||||||
)
|
)
|
||||||
return {"status": "error", "message": str(e)}
|
return {"status": "error", "message": str(e)}
|
||||||
|
|
||||||
async def delete(self, ids: list[str]) -> None:
|
async def delete(self, ids: list[str]) -> None:
|
||||||
await self._data.delete_many({"_id": {"$in": ids}})
|
await self._data.delete_many({"_id": {"$in": ids}})
|
||||||
|
|
@ -699,8 +705,10 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||||
self._edge_collection_name = f"{self._collection_name}_edges"
|
self._edge_collection_name = f"{self._collection_name}_edges"
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
if self.db is None:
|
async with get_graph_db_lock(enable_logging=True):
|
||||||
self.db = await ClientManager.get_client()
|
if self.db is None:
|
||||||
|
self.db = await ClientManager.get_client()
|
||||||
|
|
||||||
self.collection = await get_or_create_collection(
|
self.collection = await get_or_create_collection(
|
||||||
self.db, self._collection_name
|
self.db, self._collection_name
|
||||||
)
|
)
|
||||||
|
|
@ -712,11 +720,12 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def finalize(self):
|
async def finalize(self):
|
||||||
if self.db is not None:
|
async with get_graph_db_lock(enable_logging=True):
|
||||||
await ClientManager.release_client(self.db)
|
if self.db is not None:
|
||||||
self.db = None
|
await ClientManager.release_client(self.db)
|
||||||
self.collection = None
|
self.db = None
|
||||||
self.edge_collection = None
|
self.collection = None
|
||||||
|
self.edge_collection = None
|
||||||
|
|
||||||
# Sample entity document
|
# Sample entity document
|
||||||
# "source_ids" is Array representation of "source_id" split by GRAPH_FIELD_SEP
|
# "source_ids" is Array representation of "source_id" split by GRAPH_FIELD_SEP
|
||||||
|
|
@ -1567,29 +1576,30 @@ class MongoGraphStorage(BaseGraphStorage):
|
||||||
Returns:
|
Returns:
|
||||||
dict[str, str]: Status of the operation with keys 'status' and 'message'
|
dict[str, str]: Status of the operation with keys 'status' and 'message'
|
||||||
"""
|
"""
|
||||||
try:
|
async with get_graph_db_lock(enable_logging=True):
|
||||||
result = await self.collection.delete_many({})
|
try:
|
||||||
deleted_count = result.deleted_count
|
result = await self.collection.delete_many({})
|
||||||
|
deleted_count = result.deleted_count
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[{self.workspace}] Dropped {deleted_count} documents from graph {self._collection_name}"
|
f"[{self.workspace}] Dropped {deleted_count} documents from graph {self._collection_name}"
|
||||||
)
|
)
|
||||||
|
|
||||||
result = await self.edge_collection.delete_many({})
|
result = await self.edge_collection.delete_many({})
|
||||||
edge_count = result.deleted_count
|
edge_count = result.deleted_count
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[{self.workspace}] Dropped {edge_count} edges from graph {self._edge_collection_name}"
|
f"[{self.workspace}] Dropped {edge_count} edges from graph {self._edge_collection_name}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": "success",
|
"status": "success",
|
||||||
"message": f"{deleted_count} documents and {edge_count} edges dropped",
|
"message": f"{deleted_count} documents and {edge_count} edges dropped",
|
||||||
}
|
}
|
||||||
except PyMongoError as e:
|
except PyMongoError as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"[{self.workspace}] Error dropping graph {self._collection_name}: {e}"
|
f"[{self.workspace}] Error dropping graph {self._collection_name}: {e}"
|
||||||
)
|
)
|
||||||
return {"status": "error", "message": str(e)}
|
return {"status": "error", "message": str(e)}
|
||||||
|
|
||||||
|
|
||||||
@final
|
@final
|
||||||
|
|
@ -1661,8 +1671,10 @@ class MongoVectorDBStorage(BaseVectorStorage):
|
||||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
if self.db is None:
|
async with get_storage_lock(enable_logging=True):
|
||||||
self.db = await ClientManager.get_client()
|
if self.db is None:
|
||||||
|
self.db = await ClientManager.get_client()
|
||||||
|
|
||||||
self._data = await get_or_create_collection(self.db, self._collection_name)
|
self._data = await get_or_create_collection(self.db, self._collection_name)
|
||||||
|
|
||||||
# Ensure vector index exists
|
# Ensure vector index exists
|
||||||
|
|
@ -1673,10 +1685,11 @@ class MongoVectorDBStorage(BaseVectorStorage):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def finalize(self):
|
async def finalize(self):
|
||||||
if self.db is not None:
|
async with get_storage_lock(enable_logging=True):
|
||||||
await ClientManager.release_client(self.db)
|
if self.db is not None:
|
||||||
self.db = None
|
await ClientManager.release_client(self.db)
|
||||||
self._data = None
|
self.db = None
|
||||||
|
self._data = None
|
||||||
|
|
||||||
async def create_vector_index_if_not_exists(self):
|
async def create_vector_index_if_not_exists(self):
|
||||||
"""Creates an Atlas Vector Search index."""
|
"""Creates an Atlas Vector Search index."""
|
||||||
|
|
@ -1957,26 +1970,27 @@ class MongoVectorDBStorage(BaseVectorStorage):
|
||||||
Returns:
|
Returns:
|
||||||
dict[str, str]: Status of the operation with keys 'status' and 'message'
|
dict[str, str]: Status of the operation with keys 'status' and 'message'
|
||||||
"""
|
"""
|
||||||
try:
|
async with get_storage_lock(enable_logging=True):
|
||||||
# Delete all documents
|
try:
|
||||||
result = await self._data.delete_many({})
|
# Delete all documents
|
||||||
deleted_count = result.deleted_count
|
result = await self._data.delete_many({})
|
||||||
|
deleted_count = result.deleted_count
|
||||||
|
|
||||||
# Recreate vector index
|
# Recreate vector index
|
||||||
await self.create_vector_index_if_not_exists()
|
await self.create_vector_index_if_not_exists()
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[{self.workspace}] Dropped {deleted_count} documents from vector storage {self._collection_name} and recreated vector index"
|
f"[{self.workspace}] Dropped {deleted_count} documents from vector storage {self._collection_name} and recreated vector index"
|
||||||
)
|
)
|
||||||
return {
|
return {
|
||||||
"status": "success",
|
"status": "success",
|
||||||
"message": f"{deleted_count} documents dropped and vector index recreated",
|
"message": f"{deleted_count} documents dropped and vector index recreated",
|
||||||
}
|
}
|
||||||
except PyMongoError as e:
|
except PyMongoError as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"[{self.workspace}] Error dropping vector storage {self._collection_name}: {e}"
|
f"[{self.workspace}] Error dropping vector storage {self._collection_name}: {e}"
|
||||||
)
|
)
|
||||||
return {"status": "error", "message": str(e)}
|
return {"status": "error", "message": str(e)}
|
||||||
|
|
||||||
|
|
||||||
async def get_or_create_collection(db: AsyncDatabase, collection_name: str):
|
async def get_or_create_collection(db: AsyncDatabase, collection_name: str):
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,7 @@ from ..utils import logger
|
||||||
from ..base import BaseGraphStorage
|
from ..base import BaseGraphStorage
|
||||||
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
||||||
from ..constants import GRAPH_FIELD_SEP
|
from ..constants import GRAPH_FIELD_SEP
|
||||||
|
from ..kg.shared_storage import get_graph_db_lock
|
||||||
import pipmaster as pm
|
import pipmaster as pm
|
||||||
|
|
||||||
if not pm.is_installed("neo4j"):
|
if not pm.is_installed("neo4j"):
|
||||||
|
|
@ -70,174 +71,180 @@ class Neo4JStorage(BaseGraphStorage):
|
||||||
return self.workspace
|
return self.workspace
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
URI = os.environ.get("NEO4J_URI", config.get("neo4j", "uri", fallback=None))
|
async with get_graph_db_lock(enable_logging=True):
|
||||||
USERNAME = os.environ.get(
|
URI = os.environ.get("NEO4J_URI", config.get("neo4j", "uri", fallback=None))
|
||||||
"NEO4J_USERNAME", config.get("neo4j", "username", fallback=None)
|
USERNAME = os.environ.get(
|
||||||
)
|
"NEO4J_USERNAME", config.get("neo4j", "username", fallback=None)
|
||||||
PASSWORD = os.environ.get(
|
|
||||||
"NEO4J_PASSWORD", config.get("neo4j", "password", fallback=None)
|
|
||||||
)
|
|
||||||
MAX_CONNECTION_POOL_SIZE = int(
|
|
||||||
os.environ.get(
|
|
||||||
"NEO4J_MAX_CONNECTION_POOL_SIZE",
|
|
||||||
config.get("neo4j", "connection_pool_size", fallback=100),
|
|
||||||
)
|
)
|
||||||
)
|
PASSWORD = os.environ.get(
|
||||||
CONNECTION_TIMEOUT = float(
|
"NEO4J_PASSWORD", config.get("neo4j", "password", fallback=None)
|
||||||
os.environ.get(
|
)
|
||||||
"NEO4J_CONNECTION_TIMEOUT",
|
MAX_CONNECTION_POOL_SIZE = int(
|
||||||
config.get("neo4j", "connection_timeout", fallback=30.0),
|
os.environ.get(
|
||||||
),
|
"NEO4J_MAX_CONNECTION_POOL_SIZE",
|
||||||
)
|
config.get("neo4j", "connection_pool_size", fallback=100),
|
||||||
CONNECTION_ACQUISITION_TIMEOUT = float(
|
|
||||||
os.environ.get(
|
|
||||||
"NEO4J_CONNECTION_ACQUISITION_TIMEOUT",
|
|
||||||
config.get("neo4j", "connection_acquisition_timeout", fallback=30.0),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
MAX_TRANSACTION_RETRY_TIME = float(
|
|
||||||
os.environ.get(
|
|
||||||
"NEO4J_MAX_TRANSACTION_RETRY_TIME",
|
|
||||||
config.get("neo4j", "max_transaction_retry_time", fallback=30.0),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
MAX_CONNECTION_LIFETIME = float(
|
|
||||||
os.environ.get(
|
|
||||||
"NEO4J_MAX_CONNECTION_LIFETIME",
|
|
||||||
config.get("neo4j", "max_connection_lifetime", fallback=300.0),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
LIVENESS_CHECK_TIMEOUT = float(
|
|
||||||
os.environ.get(
|
|
||||||
"NEO4J_LIVENESS_CHECK_TIMEOUT",
|
|
||||||
config.get("neo4j", "liveness_check_timeout", fallback=30.0),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
KEEP_ALIVE = os.environ.get(
|
|
||||||
"NEO4J_KEEP_ALIVE",
|
|
||||||
config.get("neo4j", "keep_alive", fallback="true"),
|
|
||||||
).lower() in ("true", "1", "yes", "on")
|
|
||||||
DATABASE = os.environ.get(
|
|
||||||
"NEO4J_DATABASE", re.sub(r"[^a-zA-Z0-9-]", "-", self.namespace)
|
|
||||||
)
|
|
||||||
"""The default value approach for the DATABASE is only intended to maintain compatibility with legacy practices."""
|
|
||||||
|
|
||||||
self._driver: AsyncDriver = AsyncGraphDatabase.driver(
|
|
||||||
URI,
|
|
||||||
auth=(USERNAME, PASSWORD),
|
|
||||||
max_connection_pool_size=MAX_CONNECTION_POOL_SIZE,
|
|
||||||
connection_timeout=CONNECTION_TIMEOUT,
|
|
||||||
connection_acquisition_timeout=CONNECTION_ACQUISITION_TIMEOUT,
|
|
||||||
max_transaction_retry_time=MAX_TRANSACTION_RETRY_TIME,
|
|
||||||
max_connection_lifetime=MAX_CONNECTION_LIFETIME,
|
|
||||||
liveness_check_timeout=LIVENESS_CHECK_TIMEOUT,
|
|
||||||
keep_alive=KEEP_ALIVE,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Try to connect to the database and create it if it doesn't exist
|
|
||||||
for database in (DATABASE, None):
|
|
||||||
self._DATABASE = database
|
|
||||||
connected = False
|
|
||||||
|
|
||||||
try:
|
|
||||||
async with self._driver.session(database=database) as session:
|
|
||||||
try:
|
|
||||||
result = await session.run("MATCH (n) RETURN n LIMIT 0")
|
|
||||||
await result.consume() # Ensure result is consumed
|
|
||||||
logger.info(
|
|
||||||
f"[{self.workspace}] Connected to {database} at {URI}"
|
|
||||||
)
|
|
||||||
connected = True
|
|
||||||
except neo4jExceptions.ServiceUnavailable as e:
|
|
||||||
logger.error(
|
|
||||||
f"[{self.workspace}] "
|
|
||||||
+ f"{database} at {URI} is not available".capitalize()
|
|
||||||
)
|
|
||||||
raise e
|
|
||||||
except neo4jExceptions.AuthError as e:
|
|
||||||
logger.error(
|
|
||||||
f"[{self.workspace}] Authentication failed for {database} at {URI}"
|
|
||||||
)
|
)
|
||||||
raise e
|
)
|
||||||
except neo4jExceptions.ClientError as e:
|
CONNECTION_TIMEOUT = float(
|
||||||
if e.code == "Neo.ClientError.Database.DatabaseNotFound":
|
os.environ.get(
|
||||||
logger.info(
|
"NEO4J_CONNECTION_TIMEOUT",
|
||||||
f"[{self.workspace}] "
|
config.get("neo4j", "connection_timeout", fallback=30.0),
|
||||||
+ f"{database} at {URI} not found. Try to create specified database.".capitalize()
|
),
|
||||||
)
|
)
|
||||||
try:
|
CONNECTION_ACQUISITION_TIMEOUT = float(
|
||||||
async with self._driver.session() as session:
|
os.environ.get(
|
||||||
result = await session.run(
|
"NEO4J_CONNECTION_ACQUISITION_TIMEOUT",
|
||||||
f"CREATE DATABASE `{database}` IF NOT EXISTS"
|
config.get(
|
||||||
)
|
"neo4j", "connection_acquisition_timeout", fallback=30.0
|
||||||
await result.consume() # Ensure result is consumed
|
),
|
||||||
logger.info(
|
),
|
||||||
f"[{self.workspace}] "
|
)
|
||||||
+ f"{database} at {URI} created".capitalize()
|
MAX_TRANSACTION_RETRY_TIME = float(
|
||||||
)
|
os.environ.get(
|
||||||
connected = True
|
"NEO4J_MAX_TRANSACTION_RETRY_TIME",
|
||||||
except (
|
config.get("neo4j", "max_transaction_retry_time", fallback=30.0),
|
||||||
neo4jExceptions.ClientError,
|
),
|
||||||
neo4jExceptions.DatabaseError,
|
)
|
||||||
) as e:
|
MAX_CONNECTION_LIFETIME = float(
|
||||||
if (
|
os.environ.get(
|
||||||
e.code
|
"NEO4J_MAX_CONNECTION_LIFETIME",
|
||||||
== "Neo.ClientError.Statement.UnsupportedAdministrationCommand"
|
config.get("neo4j", "max_connection_lifetime", fallback=300.0),
|
||||||
) or (e.code == "Neo.DatabaseError.Statement.ExecutionFailed"):
|
),
|
||||||
if database is not None:
|
)
|
||||||
logger.warning(
|
LIVENESS_CHECK_TIMEOUT = float(
|
||||||
f"[{self.workspace}] This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead. Fallback to use the default database."
|
os.environ.get(
|
||||||
)
|
"NEO4J_LIVENESS_CHECK_TIMEOUT",
|
||||||
if database is None:
|
config.get("neo4j", "liveness_check_timeout", fallback=30.0),
|
||||||
logger.error(
|
),
|
||||||
f"[{self.workspace}] Failed to create {database} at {URI}"
|
)
|
||||||
)
|
KEEP_ALIVE = os.environ.get(
|
||||||
raise e
|
"NEO4J_KEEP_ALIVE",
|
||||||
|
config.get("neo4j", "keep_alive", fallback="true"),
|
||||||
|
).lower() in ("true", "1", "yes", "on")
|
||||||
|
DATABASE = os.environ.get(
|
||||||
|
"NEO4J_DATABASE", re.sub(r"[^a-zA-Z0-9-]", "-", self.namespace)
|
||||||
|
)
|
||||||
|
"""The default value approach for the DATABASE is only intended to maintain compatibility with legacy practices."""
|
||||||
|
|
||||||
|
self._driver: AsyncDriver = AsyncGraphDatabase.driver(
|
||||||
|
URI,
|
||||||
|
auth=(USERNAME, PASSWORD),
|
||||||
|
max_connection_pool_size=MAX_CONNECTION_POOL_SIZE,
|
||||||
|
connection_timeout=CONNECTION_TIMEOUT,
|
||||||
|
connection_acquisition_timeout=CONNECTION_ACQUISITION_TIMEOUT,
|
||||||
|
max_transaction_retry_time=MAX_TRANSACTION_RETRY_TIME,
|
||||||
|
max_connection_lifetime=MAX_CONNECTION_LIFETIME,
|
||||||
|
liveness_check_timeout=LIVENESS_CHECK_TIMEOUT,
|
||||||
|
keep_alive=KEEP_ALIVE,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Try to connect to the database and create it if it doesn't exist
|
||||||
|
for database in (DATABASE, None):
|
||||||
|
self._DATABASE = database
|
||||||
|
connected = False
|
||||||
|
|
||||||
if connected:
|
|
||||||
# Create index for workspace nodes on entity_id if it doesn't exist
|
|
||||||
workspace_label = self._get_workspace_label()
|
|
||||||
try:
|
try:
|
||||||
async with self._driver.session(database=database) as session:
|
async with self._driver.session(database=database) as session:
|
||||||
# Check if index exists first
|
|
||||||
check_query = f"""
|
|
||||||
CALL db.indexes() YIELD name, labelsOrTypes, properties
|
|
||||||
WHERE labelsOrTypes = ['{workspace_label}'] AND properties = ['entity_id']
|
|
||||||
RETURN count(*) > 0 AS exists
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
check_result = await session.run(check_query)
|
result = await session.run("MATCH (n) RETURN n LIMIT 0")
|
||||||
record = await check_result.single()
|
await result.consume() # Ensure result is consumed
|
||||||
await check_result.consume()
|
logger.info(
|
||||||
|
f"[{self.workspace}] Connected to {database} at {URI}"
|
||||||
index_exists = record and record.get("exists", False)
|
)
|
||||||
|
connected = True
|
||||||
if not index_exists:
|
except neo4jExceptions.ServiceUnavailable as e:
|
||||||
# Create index only if it doesn't exist
|
logger.error(
|
||||||
|
f"[{self.workspace}] "
|
||||||
|
+ f"{database} at {URI} is not available".capitalize()
|
||||||
|
)
|
||||||
|
raise e
|
||||||
|
except neo4jExceptions.AuthError as e:
|
||||||
|
logger.error(
|
||||||
|
f"[{self.workspace}] Authentication failed for {database} at {URI}"
|
||||||
|
)
|
||||||
|
raise e
|
||||||
|
except neo4jExceptions.ClientError as e:
|
||||||
|
if e.code == "Neo.ClientError.Database.DatabaseNotFound":
|
||||||
|
logger.info(
|
||||||
|
f"[{self.workspace}] "
|
||||||
|
+ f"{database} at {URI} not found. Try to create specified database.".capitalize()
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
async with self._driver.session() as session:
|
||||||
result = await session.run(
|
result = await session.run(
|
||||||
f"CREATE INDEX FOR (n:`{workspace_label}`) ON (n.entity_id)"
|
f"CREATE DATABASE `{database}` IF NOT EXISTS"
|
||||||
|
)
|
||||||
|
await result.consume() # Ensure result is consumed
|
||||||
|
logger.info(
|
||||||
|
f"[{self.workspace}] "
|
||||||
|
+ f"{database} at {URI} created".capitalize()
|
||||||
|
)
|
||||||
|
connected = True
|
||||||
|
except (
|
||||||
|
neo4jExceptions.ClientError,
|
||||||
|
neo4jExceptions.DatabaseError,
|
||||||
|
) as e:
|
||||||
|
if (
|
||||||
|
e.code
|
||||||
|
== "Neo.ClientError.Statement.UnsupportedAdministrationCommand"
|
||||||
|
) or (
|
||||||
|
e.code == "Neo.DatabaseError.Statement.ExecutionFailed"
|
||||||
|
):
|
||||||
|
if database is not None:
|
||||||
|
logger.warning(
|
||||||
|
f"[{self.workspace}] This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead. Fallback to use the default database."
|
||||||
|
)
|
||||||
|
if database is None:
|
||||||
|
logger.error(
|
||||||
|
f"[{self.workspace}] Failed to create {database} at {URI}"
|
||||||
|
)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
if connected:
|
||||||
|
# Create index for workspace nodes on entity_id if it doesn't exist
|
||||||
|
workspace_label = self._get_workspace_label()
|
||||||
|
try:
|
||||||
|
async with self._driver.session(database=database) as session:
|
||||||
|
# Check if index exists first
|
||||||
|
check_query = f"""
|
||||||
|
CALL db.indexes() YIELD name, labelsOrTypes, properties
|
||||||
|
WHERE labelsOrTypes = ['{workspace_label}'] AND properties = ['entity_id']
|
||||||
|
RETURN count(*) > 0 AS exists
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
check_result = await session.run(check_query)
|
||||||
|
record = await check_result.single()
|
||||||
|
await check_result.consume()
|
||||||
|
|
||||||
|
index_exists = record and record.get("exists", False)
|
||||||
|
|
||||||
|
if not index_exists:
|
||||||
|
# Create index only if it doesn't exist
|
||||||
|
result = await session.run(
|
||||||
|
f"CREATE INDEX FOR (n:`{workspace_label}`) ON (n.entity_id)"
|
||||||
|
)
|
||||||
|
await result.consume()
|
||||||
|
logger.info(
|
||||||
|
f"[{self.workspace}] Created index for {workspace_label} nodes on entity_id in {database}"
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
# Fallback if db.indexes() is not supported in this Neo4j version
|
||||||
|
result = await session.run(
|
||||||
|
f"CREATE INDEX IF NOT EXISTS FOR (n:`{workspace_label}`) ON (n.entity_id)"
|
||||||
)
|
)
|
||||||
await result.consume()
|
await result.consume()
|
||||||
logger.info(
|
except Exception as e:
|
||||||
f"[{self.workspace}] Created index for {workspace_label} nodes on entity_id in {database}"
|
logger.warning(
|
||||||
)
|
f"[{self.workspace}] Failed to create index: {str(e)}"
|
||||||
except Exception:
|
)
|
||||||
# Fallback if db.indexes() is not supported in this Neo4j version
|
break
|
||||||
result = await session.run(
|
|
||||||
f"CREATE INDEX IF NOT EXISTS FOR (n:`{workspace_label}`) ON (n.entity_id)"
|
|
||||||
)
|
|
||||||
await result.consume()
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(
|
|
||||||
f"[{self.workspace}] Failed to create index: {str(e)}"
|
|
||||||
)
|
|
||||||
break
|
|
||||||
|
|
||||||
async def finalize(self):
|
async def finalize(self):
|
||||||
"""Close the Neo4j driver and release all resources"""
|
"""Close the Neo4j driver and release all resources"""
|
||||||
if self._driver:
|
async with get_graph_db_lock(enable_logging=True):
|
||||||
await self._driver.close()
|
if self._driver:
|
||||||
self._driver = None
|
await self._driver.close()
|
||||||
|
self._driver = None
|
||||||
|
|
||||||
async def __aexit__(self, exc_type, exc, tb):
|
async def __aexit__(self, exc_type, exc, tb):
|
||||||
"""Ensure driver is closed when context manager exits"""
|
"""Ensure driver is closed when context manager exits"""
|
||||||
|
|
@ -1526,23 +1533,24 @@ class Neo4JStorage(BaseGraphStorage):
|
||||||
- On success: {"status": "success", "message": "workspace data dropped"}
|
- On success: {"status": "success", "message": "workspace data dropped"}
|
||||||
- On failure: {"status": "error", "message": "<error details>"}
|
- On failure: {"status": "error", "message": "<error details>"}
|
||||||
"""
|
"""
|
||||||
workspace_label = self._get_workspace_label()
|
async with get_graph_db_lock(enable_logging=True):
|
||||||
try:
|
workspace_label = self._get_workspace_label()
|
||||||
async with self._driver.session(database=self._DATABASE) as session:
|
try:
|
||||||
# Delete all nodes and relationships in current workspace only
|
async with self._driver.session(database=self._DATABASE) as session:
|
||||||
query = f"MATCH (n:`{workspace_label}`) DETACH DELETE n"
|
# Delete all nodes and relationships in current workspace only
|
||||||
result = await session.run(query)
|
query = f"MATCH (n:`{workspace_label}`) DETACH DELETE n"
|
||||||
await result.consume() # Ensure result is fully consumed
|
result = await session.run(query)
|
||||||
|
await result.consume() # Ensure result is fully consumed
|
||||||
|
|
||||||
# logger.debug(
|
# logger.debug(
|
||||||
# f"[{self.workspace}] Process {os.getpid()} drop Neo4j workspace '{workspace_label}' in database {self._DATABASE}"
|
# f"[{self.workspace}] Process {os.getpid()} drop Neo4j workspace '{workspace_label}' in database {self._DATABASE}"
|
||||||
# )
|
# )
|
||||||
return {
|
return {
|
||||||
"status": "success",
|
"status": "success",
|
||||||
"message": f"workspace '{workspace_label}' data dropped",
|
"message": f"workspace '{workspace_label}' data dropped",
|
||||||
}
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"[{self.workspace}] Error dropping Neo4j workspace '{workspace_label}' in database {self._DATABASE}: {e}"
|
f"[{self.workspace}] Error dropping Neo4j workspace '{workspace_label}' in database {self._DATABASE}: {e}"
|
||||||
)
|
)
|
||||||
return {"status": "error", "message": str(e)}
|
return {"status": "error", "message": str(e)}
|
||||||
|
|
|
||||||
|
|
@ -30,7 +30,7 @@ from ..base import (
|
||||||
from ..namespace import NameSpace, is_namespace
|
from ..namespace import NameSpace, is_namespace
|
||||||
from ..utils import logger
|
from ..utils import logger
|
||||||
from ..constants import GRAPH_FIELD_SEP
|
from ..constants import GRAPH_FIELD_SEP
|
||||||
from ..kg.shared_storage import get_graph_db_lock
|
from ..kg.shared_storage import get_graph_db_lock, get_storage_lock
|
||||||
|
|
||||||
import pipmaster as pm
|
import pipmaster as pm
|
||||||
|
|
||||||
|
|
@ -1406,23 +1406,26 @@ class PGKVStorage(BaseKVStorage):
|
||||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
if self.db is None:
|
async with get_storage_lock(enable_logging=True):
|
||||||
self.db = await ClientManager.get_client()
|
if self.db is None:
|
||||||
# Implement workspace priority: PostgreSQLDB.workspace > self.workspace > "default"
|
self.db = await ClientManager.get_client()
|
||||||
if self.db.workspace:
|
|
||||||
# Use PostgreSQLDB's workspace (highest priority)
|
# Implement workspace priority: PostgreSQLDB.workspace > self.workspace > "default"
|
||||||
self.workspace = self.db.workspace
|
if self.db.workspace:
|
||||||
elif hasattr(self, "workspace") and self.workspace:
|
# Use PostgreSQLDB's workspace (highest priority)
|
||||||
# Use storage class's workspace (medium priority)
|
self.workspace = self.db.workspace
|
||||||
pass
|
elif hasattr(self, "workspace") and self.workspace:
|
||||||
else:
|
# Use storage class's workspace (medium priority)
|
||||||
# Use "default" for compatibility (lowest priority)
|
pass
|
||||||
self.workspace = "default"
|
else:
|
||||||
|
# Use "default" for compatibility (lowest priority)
|
||||||
|
self.workspace = "default"
|
||||||
|
|
||||||
async def finalize(self):
|
async def finalize(self):
|
||||||
if self.db is not None:
|
async with get_storage_lock(enable_logging=True):
|
||||||
await ClientManager.release_client(self.db)
|
if self.db is not None:
|
||||||
self.db = None
|
await ClientManager.release_client(self.db)
|
||||||
|
self.db = None
|
||||||
|
|
||||||
################ QUERY METHODS ################
|
################ QUERY METHODS ################
|
||||||
async def get_all(self) -> dict[str, Any]:
|
async def get_all(self) -> dict[str, Any]:
|
||||||
|
|
@ -1834,21 +1837,22 @@ class PGKVStorage(BaseKVStorage):
|
||||||
|
|
||||||
async def drop(self) -> dict[str, str]:
|
async def drop(self) -> dict[str, str]:
|
||||||
"""Drop the storage"""
|
"""Drop the storage"""
|
||||||
try:
|
async with get_storage_lock(enable_logging=True):
|
||||||
table_name = namespace_to_table_name(self.namespace)
|
try:
|
||||||
if not table_name:
|
table_name = namespace_to_table_name(self.namespace)
|
||||||
return {
|
if not table_name:
|
||||||
"status": "error",
|
return {
|
||||||
"message": f"Unknown namespace: {self.namespace}",
|
"status": "error",
|
||||||
}
|
"message": f"Unknown namespace: {self.namespace}",
|
||||||
|
}
|
||||||
|
|
||||||
drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format(
|
drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format(
|
||||||
table_name=table_name
|
table_name=table_name
|
||||||
)
|
)
|
||||||
await self.db.execute(drop_sql, {"workspace": self.workspace})
|
await self.db.execute(drop_sql, {"workspace": self.workspace})
|
||||||
return {"status": "success", "message": "data dropped"}
|
return {"status": "success", "message": "data dropped"}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return {"status": "error", "message": str(e)}
|
return {"status": "error", "message": str(e)}
|
||||||
|
|
||||||
|
|
||||||
@final
|
@final
|
||||||
|
|
@ -1867,23 +1871,26 @@ class PGVectorStorage(BaseVectorStorage):
|
||||||
self.cosine_better_than_threshold = cosine_threshold
|
self.cosine_better_than_threshold = cosine_threshold
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
if self.db is None:
|
async with get_storage_lock(enable_logging=True):
|
||||||
self.db = await ClientManager.get_client()
|
if self.db is None:
|
||||||
# Implement workspace priority: PostgreSQLDB.workspace > self.workspace > "default"
|
self.db = await ClientManager.get_client()
|
||||||
if self.db.workspace:
|
|
||||||
# Use PostgreSQLDB's workspace (highest priority)
|
# Implement workspace priority: PostgreSQLDB.workspace > self.workspace > "default"
|
||||||
self.workspace = self.db.workspace
|
if self.db.workspace:
|
||||||
elif hasattr(self, "workspace") and self.workspace:
|
# Use PostgreSQLDB's workspace (highest priority)
|
||||||
# Use storage class's workspace (medium priority)
|
self.workspace = self.db.workspace
|
||||||
pass
|
elif hasattr(self, "workspace") and self.workspace:
|
||||||
else:
|
# Use storage class's workspace (medium priority)
|
||||||
# Use "default" for compatibility (lowest priority)
|
pass
|
||||||
self.workspace = "default"
|
else:
|
||||||
|
# Use "default" for compatibility (lowest priority)
|
||||||
|
self.workspace = "default"
|
||||||
|
|
||||||
async def finalize(self):
|
async def finalize(self):
|
||||||
if self.db is not None:
|
async with get_storage_lock(enable_logging=True):
|
||||||
await ClientManager.release_client(self.db)
|
if self.db is not None:
|
||||||
self.db = None
|
await ClientManager.release_client(self.db)
|
||||||
|
self.db = None
|
||||||
|
|
||||||
def _upsert_chunks(
|
def _upsert_chunks(
|
||||||
self, item: dict[str, Any], current_time: datetime.datetime
|
self, item: dict[str, Any], current_time: datetime.datetime
|
||||||
|
|
@ -2153,21 +2160,22 @@ class PGVectorStorage(BaseVectorStorage):
|
||||||
|
|
||||||
async def drop(self) -> dict[str, str]:
|
async def drop(self) -> dict[str, str]:
|
||||||
"""Drop the storage"""
|
"""Drop the storage"""
|
||||||
try:
|
async with get_storage_lock(enable_logging=True):
|
||||||
table_name = namespace_to_table_name(self.namespace)
|
try:
|
||||||
if not table_name:
|
table_name = namespace_to_table_name(self.namespace)
|
||||||
return {
|
if not table_name:
|
||||||
"status": "error",
|
return {
|
||||||
"message": f"Unknown namespace: {self.namespace}",
|
"status": "error",
|
||||||
}
|
"message": f"Unknown namespace: {self.namespace}",
|
||||||
|
}
|
||||||
|
|
||||||
drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format(
|
drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format(
|
||||||
table_name=table_name
|
table_name=table_name
|
||||||
)
|
)
|
||||||
await self.db.execute(drop_sql, {"workspace": self.workspace})
|
await self.db.execute(drop_sql, {"workspace": self.workspace})
|
||||||
return {"status": "success", "message": "data dropped"}
|
return {"status": "success", "message": "data dropped"}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return {"status": "error", "message": str(e)}
|
return {"status": "error", "message": str(e)}
|
||||||
|
|
||||||
|
|
||||||
@final
|
@final
|
||||||
|
|
@ -2186,23 +2194,26 @@ class PGDocStatusStorage(DocStatusStorage):
|
||||||
return dt.isoformat()
|
return dt.isoformat()
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
if self.db is None:
|
async with get_storage_lock(enable_logging=True):
|
||||||
self.db = await ClientManager.get_client()
|
if self.db is None:
|
||||||
# Implement workspace priority: PostgreSQLDB.workspace > self.workspace > "default"
|
self.db = await ClientManager.get_client()
|
||||||
if self.db.workspace:
|
|
||||||
# Use PostgreSQLDB's workspace (highest priority)
|
# Implement workspace priority: PostgreSQLDB.workspace > self.workspace > "default"
|
||||||
self.workspace = self.db.workspace
|
if self.db.workspace:
|
||||||
elif hasattr(self, "workspace") and self.workspace:
|
# Use PostgreSQLDB's workspace (highest priority)
|
||||||
# Use storage class's workspace (medium priority)
|
self.workspace = self.db.workspace
|
||||||
pass
|
elif hasattr(self, "workspace") and self.workspace:
|
||||||
else:
|
# Use storage class's workspace (medium priority)
|
||||||
# Use "default" for compatibility (lowest priority)
|
pass
|
||||||
self.workspace = "default"
|
else:
|
||||||
|
# Use "default" for compatibility (lowest priority)
|
||||||
|
self.workspace = "default"
|
||||||
|
|
||||||
async def finalize(self):
|
async def finalize(self):
|
||||||
if self.db is not None:
|
async with get_storage_lock(enable_logging=True):
|
||||||
await ClientManager.release_client(self.db)
|
if self.db is not None:
|
||||||
self.db = None
|
await ClientManager.release_client(self.db)
|
||||||
|
self.db = None
|
||||||
|
|
||||||
async def filter_keys(self, keys: set[str]) -> set[str]:
|
async def filter_keys(self, keys: set[str]) -> set[str]:
|
||||||
"""Filter out duplicated content"""
|
"""Filter out duplicated content"""
|
||||||
|
|
@ -2778,30 +2789,29 @@ class PGGraphStorage(BaseGraphStorage):
|
||||||
return normalized_id
|
return normalized_id
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
if self.db is None:
|
async with get_graph_db_lock(enable_logging=True):
|
||||||
self.db = await ClientManager.get_client()
|
if self.db is None:
|
||||||
# Implement workspace priority: PostgreSQLDB.workspace > self.workspace > "default"
|
self.db = await ClientManager.get_client()
|
||||||
if self.db.workspace:
|
|
||||||
# Use PostgreSQLDB's workspace (highest priority)
|
|
||||||
self.workspace = self.db.workspace
|
|
||||||
elif hasattr(self, "workspace") and self.workspace:
|
|
||||||
# Use storage class's workspace (medium priority)
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
# Use "default" for compatibility (lowest priority)
|
|
||||||
self.workspace = "default"
|
|
||||||
|
|
||||||
# Dynamically generate graph name based on workspace
|
# Implement workspace priority: PostgreSQLDB.workspace > self.workspace > "default"
|
||||||
self.graph_name = self._get_workspace_graph_name()
|
if self.db.workspace:
|
||||||
|
# Use PostgreSQLDB's workspace (highest priority)
|
||||||
|
self.workspace = self.db.workspace
|
||||||
|
elif hasattr(self, "workspace") and self.workspace:
|
||||||
|
# Use storage class's workspace (medium priority)
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
# Use "default" for compatibility (lowest priority)
|
||||||
|
self.workspace = "default"
|
||||||
|
|
||||||
# Log the graph initialization for debugging
|
# Dynamically generate graph name based on workspace
|
||||||
logger.info(
|
self.graph_name = self._get_workspace_graph_name()
|
||||||
f"[{self.workspace}] PostgreSQL Graph initialized: graph_name='{self.graph_name}'"
|
|
||||||
)
|
# Log the graph initialization for debugging
|
||||||
|
logger.info(
|
||||||
|
f"[{self.workspace}] PostgreSQL Graph initialized: graph_name='{self.graph_name}'"
|
||||||
|
)
|
||||||
|
|
||||||
# Use graph database lock to ensure atomic operations and prevent deadlocks
|
|
||||||
graph_db_lock = get_graph_db_lock(enable_logging=False)
|
|
||||||
async with graph_db_lock:
|
|
||||||
# Create AGE extension and configure graph environment once at initialization
|
# Create AGE extension and configure graph environment once at initialization
|
||||||
async with self.db.pool.acquire() as connection:
|
async with self.db.pool.acquire() as connection:
|
||||||
# First ensure AGE extension is created
|
# First ensure AGE extension is created
|
||||||
|
|
@ -2840,9 +2850,10 @@ class PGGraphStorage(BaseGraphStorage):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def finalize(self):
|
async def finalize(self):
|
||||||
if self.db is not None:
|
async with get_graph_db_lock(enable_logging=True):
|
||||||
await ClientManager.release_client(self.db)
|
if self.db is not None:
|
||||||
self.db = None
|
await ClientManager.release_client(self.db)
|
||||||
|
self.db = None
|
||||||
|
|
||||||
async def index_done_callback(self) -> None:
|
async def index_done_callback(self) -> None:
|
||||||
# PG handles persistence automatically
|
# PG handles persistence automatically
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ import hashlib
|
||||||
import uuid
|
import uuid
|
||||||
from ..utils import logger
|
from ..utils import logger
|
||||||
from ..base import BaseVectorStorage
|
from ..base import BaseVectorStorage
|
||||||
|
from ..kg.shared_storage import get_storage_lock
|
||||||
import configparser
|
import configparser
|
||||||
import pipmaster as pm
|
import pipmaster as pm
|
||||||
|
|
||||||
|
|
@ -118,13 +119,33 @@ class QdrantVectorDBStorage(BaseVectorStorage):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self._max_batch_size = self.global_config["embedding_batch_num"]
|
self._max_batch_size = self.global_config["embedding_batch_num"]
|
||||||
QdrantVectorDBStorage.create_collection_if_not_exist(
|
self._initialized = False
|
||||||
self._client,
|
|
||||||
self.final_namespace,
|
async def initialize(self):
|
||||||
vectors_config=models.VectorParams(
|
"""Initialize Qdrant collection"""
|
||||||
size=self.embedding_func.embedding_dim, distance=models.Distance.COSINE
|
async with get_storage_lock(enable_logging=True):
|
||||||
),
|
if self._initialized:
|
||||||
)
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create collection if not exists
|
||||||
|
QdrantVectorDBStorage.create_collection_if_not_exist(
|
||||||
|
self._client,
|
||||||
|
self.final_namespace,
|
||||||
|
vectors_config=models.VectorParams(
|
||||||
|
size=self.embedding_func.embedding_dim,
|
||||||
|
distance=models.Distance.COSINE,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self._initialized = True
|
||||||
|
logger.info(
|
||||||
|
f"[{self.workspace}] Qdrant collection '{self.namespace}' initialized successfully"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"[{self.workspace}] Failed to initialize Qdrant collection '{self.namespace}': {e}"
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
|
||||||
logger.debug(f"[{self.workspace}] Inserting {len(data)} to {self.namespace}")
|
logger.debug(f"[{self.workspace}] Inserting {len(data)} to {self.namespace}")
|
||||||
|
|
@ -382,27 +403,28 @@ class QdrantVectorDBStorage(BaseVectorStorage):
|
||||||
- On success: {"status": "success", "message": "data dropped"}
|
- On success: {"status": "success", "message": "data dropped"}
|
||||||
- On failure: {"status": "error", "message": "<error details>"}
|
- On failure: {"status": "error", "message": "<error details>"}
|
||||||
"""
|
"""
|
||||||
try:
|
async with get_storage_lock(enable_logging=True):
|
||||||
# Delete the collection and recreate it
|
try:
|
||||||
if self._client.collection_exists(self.final_namespace):
|
# Delete the collection and recreate it
|
||||||
self._client.delete_collection(self.final_namespace)
|
if self._client.collection_exists(self.final_namespace):
|
||||||
|
self._client.delete_collection(self.final_namespace)
|
||||||
|
|
||||||
# Recreate the collection
|
# Recreate the collection
|
||||||
QdrantVectorDBStorage.create_collection_if_not_exist(
|
QdrantVectorDBStorage.create_collection_if_not_exist(
|
||||||
self._client,
|
self._client,
|
||||||
self.final_namespace,
|
self.final_namespace,
|
||||||
vectors_config=models.VectorParams(
|
vectors_config=models.VectorParams(
|
||||||
size=self.embedding_func.embedding_dim,
|
size=self.embedding_func.embedding_dim,
|
||||||
distance=models.Distance.COSINE,
|
distance=models.Distance.COSINE,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[{self.workspace}] Process {os.getpid()} drop Qdrant collection {self.namespace}"
|
f"[{self.workspace}] Process {os.getpid()} drop Qdrant collection {self.namespace}"
|
||||||
)
|
)
|
||||||
return {"status": "success", "message": "data dropped"}
|
return {"status": "success", "message": "data dropped"}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"[{self.workspace}] Error dropping Qdrant collection {self.namespace}: {e}"
|
f"[{self.workspace}] Error dropping Qdrant collection {self.namespace}: {e}"
|
||||||
)
|
)
|
||||||
return {"status": "error", "message": str(e)}
|
return {"status": "error", "message": str(e)}
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,7 @@ from lightrag.base import (
|
||||||
DocStatus,
|
DocStatus,
|
||||||
DocProcessingStatus,
|
DocProcessingStatus,
|
||||||
)
|
)
|
||||||
|
from ..kg.shared_storage import get_storage_lock
|
||||||
import json
|
import json
|
||||||
|
|
||||||
# Import tenacity for retry logic
|
# Import tenacity for retry logic
|
||||||
|
|
@ -179,32 +180,33 @@ class RedisKVStorage(BaseKVStorage):
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
"""Initialize Redis connection and migrate legacy cache structure if needed"""
|
"""Initialize Redis connection and migrate legacy cache structure if needed"""
|
||||||
if self._initialized:
|
async with get_storage_lock(enable_logging=True):
|
||||||
return
|
if self._initialized:
|
||||||
|
return
|
||||||
|
|
||||||
# Test connection
|
# Test connection
|
||||||
try:
|
|
||||||
async with self._get_redis_connection() as redis:
|
|
||||||
await redis.ping()
|
|
||||||
logger.info(
|
|
||||||
f"[{self.workspace}] Connected to Redis for namespace {self.namespace}"
|
|
||||||
)
|
|
||||||
self._initialized = True
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[{self.workspace}] Failed to connect to Redis: {e}")
|
|
||||||
# Clean up on connection failure
|
|
||||||
await self.close()
|
|
||||||
raise
|
|
||||||
|
|
||||||
# Migrate legacy cache structure if this is a cache namespace
|
|
||||||
if self.namespace.endswith("_cache"):
|
|
||||||
try:
|
try:
|
||||||
await self._migrate_legacy_cache_structure()
|
async with self._get_redis_connection() as redis:
|
||||||
|
await redis.ping()
|
||||||
|
logger.info(
|
||||||
|
f"[{self.workspace}] Connected to Redis for namespace {self.namespace}"
|
||||||
|
)
|
||||||
|
self._initialized = True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(f"[{self.workspace}] Failed to connect to Redis: {e}")
|
||||||
f"[{self.workspace}] Failed to migrate legacy cache structure: {e}"
|
# Clean up on connection failure
|
||||||
)
|
await self.close()
|
||||||
# Don't fail initialization for migration errors, just log them
|
raise
|
||||||
|
|
||||||
|
# Migrate legacy cache structure if this is a cache namespace
|
||||||
|
if self.namespace.endswith("_cache"):
|
||||||
|
try:
|
||||||
|
await self._migrate_legacy_cache_structure()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"[{self.workspace}] Failed to migrate legacy cache structure: {e}"
|
||||||
|
)
|
||||||
|
# Don't fail initialization for migration errors, just log them
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def _get_redis_connection(self):
|
async def _get_redis_connection(self):
|
||||||
|
|
@ -426,39 +428,42 @@ class RedisKVStorage(BaseKVStorage):
|
||||||
Returns:
|
Returns:
|
||||||
dict[str, str]: Status of the operation with keys 'status' and 'message'
|
dict[str, str]: Status of the operation with keys 'status' and 'message'
|
||||||
"""
|
"""
|
||||||
async with self._get_redis_connection() as redis:
|
async with get_storage_lock(enable_logging=True):
|
||||||
try:
|
async with self._get_redis_connection() as redis:
|
||||||
# Use SCAN to find all keys with the namespace prefix
|
try:
|
||||||
pattern = f"{self.final_namespace}:*"
|
# Use SCAN to find all keys with the namespace prefix
|
||||||
cursor = 0
|
pattern = f"{self.final_namespace}:*"
|
||||||
deleted_count = 0
|
cursor = 0
|
||||||
|
deleted_count = 0
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
cursor, keys = await redis.scan(cursor, match=pattern, count=1000)
|
cursor, keys = await redis.scan(
|
||||||
if keys:
|
cursor, match=pattern, count=1000
|
||||||
# Delete keys in batches
|
)
|
||||||
pipe = redis.pipeline()
|
if keys:
|
||||||
for key in keys:
|
# Delete keys in batches
|
||||||
pipe.delete(key)
|
pipe = redis.pipeline()
|
||||||
results = await pipe.execute()
|
for key in keys:
|
||||||
deleted_count += sum(results)
|
pipe.delete(key)
|
||||||
|
results = await pipe.execute()
|
||||||
|
deleted_count += sum(results)
|
||||||
|
|
||||||
if cursor == 0:
|
if cursor == 0:
|
||||||
break
|
break
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[{self.workspace}] Dropped {deleted_count} keys from {self.namespace}"
|
f"[{self.workspace}] Dropped {deleted_count} keys from {self.namespace}"
|
||||||
)
|
)
|
||||||
return {
|
return {
|
||||||
"status": "success",
|
"status": "success",
|
||||||
"message": f"{deleted_count} keys dropped",
|
"message": f"{deleted_count} keys dropped",
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"[{self.workspace}] Error dropping keys from {self.namespace}: {e}"
|
f"[{self.workspace}] Error dropping keys from {self.namespace}: {e}"
|
||||||
)
|
)
|
||||||
return {"status": "error", "message": str(e)}
|
return {"status": "error", "message": str(e)}
|
||||||
|
|
||||||
async def _migrate_legacy_cache_structure(self):
|
async def _migrate_legacy_cache_structure(self):
|
||||||
"""Migrate legacy nested cache structure to flattened structure for Redis
|
"""Migrate legacy nested cache structure to flattened structure for Redis
|
||||||
|
|
@ -601,23 +606,24 @@ class RedisDocStatusStorage(DocStatusStorage):
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
"""Initialize Redis connection"""
|
"""Initialize Redis connection"""
|
||||||
if self._initialized:
|
async with get_storage_lock(enable_logging=True):
|
||||||
return
|
if self._initialized:
|
||||||
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with self._get_redis_connection() as redis:
|
async with self._get_redis_connection() as redis:
|
||||||
await redis.ping()
|
await redis.ping()
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[{self.workspace}] Connected to Redis for doc status namespace {self.namespace}"
|
f"[{self.workspace}] Connected to Redis for doc status namespace {self.namespace}"
|
||||||
|
)
|
||||||
|
self._initialized = True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"[{self.workspace}] Failed to connect to Redis for doc status: {e}"
|
||||||
)
|
)
|
||||||
self._initialized = True
|
# Clean up on connection failure
|
||||||
except Exception as e:
|
await self.close()
|
||||||
logger.error(
|
raise
|
||||||
f"[{self.workspace}] Failed to connect to Redis for doc status: {e}"
|
|
||||||
)
|
|
||||||
# Clean up on connection failure
|
|
||||||
await self.close()
|
|
||||||
raise
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def _get_redis_connection(self):
|
async def _get_redis_connection(self):
|
||||||
|
|
@ -1043,32 +1049,35 @@ class RedisDocStatusStorage(DocStatusStorage):
|
||||||
|
|
||||||
async def drop(self) -> dict[str, str]:
|
async def drop(self) -> dict[str, str]:
|
||||||
"""Drop all document status data from storage and clean up resources"""
|
"""Drop all document status data from storage and clean up resources"""
|
||||||
try:
|
async with get_storage_lock(enable_logging=True):
|
||||||
async with self._get_redis_connection() as redis:
|
try:
|
||||||
# Use SCAN to find all keys with the namespace prefix
|
async with self._get_redis_connection() as redis:
|
||||||
pattern = f"{self.final_namespace}:*"
|
# Use SCAN to find all keys with the namespace prefix
|
||||||
cursor = 0
|
pattern = f"{self.final_namespace}:*"
|
||||||
deleted_count = 0
|
cursor = 0
|
||||||
|
deleted_count = 0
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
cursor, keys = await redis.scan(cursor, match=pattern, count=1000)
|
cursor, keys = await redis.scan(
|
||||||
if keys:
|
cursor, match=pattern, count=1000
|
||||||
# Delete keys in batches
|
)
|
||||||
pipe = redis.pipeline()
|
if keys:
|
||||||
for key in keys:
|
# Delete keys in batches
|
||||||
pipe.delete(key)
|
pipe = redis.pipeline()
|
||||||
results = await pipe.execute()
|
for key in keys:
|
||||||
deleted_count += sum(results)
|
pipe.delete(key)
|
||||||
|
results = await pipe.execute()
|
||||||
|
deleted_count += sum(results)
|
||||||
|
|
||||||
if cursor == 0:
|
if cursor == 0:
|
||||||
break
|
break
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[{self.workspace}] Dropped {deleted_count} doc status keys from {self.namespace}"
|
f"[{self.workspace}] Dropped {deleted_count} doc status keys from {self.namespace}"
|
||||||
|
)
|
||||||
|
return {"status": "success", "message": "data dropped"}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"[{self.workspace}] Error dropping doc status {self.namespace}: {e}"
|
||||||
)
|
)
|
||||||
return {"status": "success", "message": "data dropped"}
|
return {"status": "error", "message": str(e)}
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"[{self.workspace}] Error dropping doc status {self.namespace}: {e}"
|
|
||||||
)
|
|
||||||
return {"status": "error", "message": str(e)}
|
|
||||||
|
|
|
||||||
|
|
@ -530,10 +530,8 @@ class LightRAG:
|
||||||
self._storages_status = StoragesStatus.CREATED
|
self._storages_status = StoragesStatus.CREATED
|
||||||
|
|
||||||
async def initialize_storages(self):
|
async def initialize_storages(self):
|
||||||
"""Asynchronously initialize the storages"""
|
"""Storage initialization must be called one by one to prevent deadlock"""
|
||||||
if self._storages_status == StoragesStatus.CREATED:
|
if self._storages_status == StoragesStatus.CREATED:
|
||||||
tasks = []
|
|
||||||
|
|
||||||
for storage in (
|
for storage in (
|
||||||
self.full_docs,
|
self.full_docs,
|
||||||
self.text_chunks,
|
self.text_chunks,
|
||||||
|
|
@ -547,9 +545,7 @@ class LightRAG:
|
||||||
self.doc_status,
|
self.doc_status,
|
||||||
):
|
):
|
||||||
if storage:
|
if storage:
|
||||||
tasks.append(storage.initialize())
|
await storage.initialize()
|
||||||
|
|
||||||
await asyncio.gather(*tasks)
|
|
||||||
|
|
||||||
self._storages_status = StoragesStatus.INITIALIZED
|
self._storages_status = StoragesStatus.INITIALIZED
|
||||||
logger.debug("All storage types initialized")
|
logger.debug("All storage types initialized")
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue