fix: sync all kg modules from upstream
This commit is contained in:
parent
09d0721cab
commit
ccd2f82174
7 changed files with 190 additions and 181 deletions
|
|
@ -13,7 +13,7 @@ from lightrag.utils import (
|
||||||
from lightrag.exceptions import StorageNotInitializedError
|
from lightrag.exceptions import StorageNotInitializedError
|
||||||
from .shared_storage import (
|
from .shared_storage import (
|
||||||
get_namespace_data,
|
get_namespace_data,
|
||||||
get_storage_lock,
|
get_namespace_lock,
|
||||||
get_data_init_lock,
|
get_data_init_lock,
|
||||||
get_update_flag,
|
get_update_flag,
|
||||||
set_all_update_flags,
|
set_all_update_flags,
|
||||||
|
|
@ -30,12 +30,10 @@ class JsonKVStorage(BaseKVStorage):
|
||||||
if self.workspace:
|
if self.workspace:
|
||||||
# Include workspace in the file path for data isolation
|
# Include workspace in the file path for data isolation
|
||||||
workspace_dir = os.path.join(working_dir, self.workspace)
|
workspace_dir = os.path.join(working_dir, self.workspace)
|
||||||
self.final_namespace = f"{self.workspace}_{self.namespace}"
|
|
||||||
else:
|
else:
|
||||||
# Default behavior when workspace is empty
|
# Default behavior when workspace is empty
|
||||||
workspace_dir = working_dir
|
workspace_dir = working_dir
|
||||||
self.final_namespace = self.namespace
|
self.workspace = ""
|
||||||
self.workspace = "_"
|
|
||||||
|
|
||||||
os.makedirs(workspace_dir, exist_ok=True)
|
os.makedirs(workspace_dir, exist_ok=True)
|
||||||
self._file_name = os.path.join(workspace_dir, f"kv_store_{self.namespace}.json")
|
self._file_name = os.path.join(workspace_dir, f"kv_store_{self.namespace}.json")
|
||||||
|
|
@ -46,12 +44,20 @@ class JsonKVStorage(BaseKVStorage):
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
"""Initialize storage data"""
|
"""Initialize storage data"""
|
||||||
self._storage_lock = get_storage_lock()
|
self._storage_lock = get_namespace_lock(
|
||||||
self.storage_updated = await get_update_flag(self.final_namespace)
|
self.namespace, workspace=self.workspace
|
||||||
|
)
|
||||||
|
self.storage_updated = await get_update_flag(
|
||||||
|
self.namespace, workspace=self.workspace
|
||||||
|
)
|
||||||
async with get_data_init_lock():
|
async with get_data_init_lock():
|
||||||
# check need_init must before get_namespace_data
|
# check need_init must before get_namespace_data
|
||||||
need_init = await try_initialize_namespace(self.final_namespace)
|
need_init = await try_initialize_namespace(
|
||||||
self._data = await get_namespace_data(self.final_namespace)
|
self.namespace, workspace=self.workspace
|
||||||
|
)
|
||||||
|
self._data = await get_namespace_data(
|
||||||
|
self.namespace, workspace=self.workspace
|
||||||
|
)
|
||||||
if need_init:
|
if need_init:
|
||||||
loaded_data = load_json(self._file_name) or {}
|
loaded_data = load_json(self._file_name) or {}
|
||||||
async with self._storage_lock:
|
async with self._storage_lock:
|
||||||
|
|
@ -91,11 +97,11 @@ class JsonKVStorage(BaseKVStorage):
|
||||||
f"[{self.workspace}] Reloading sanitized data into shared memory for {self.namespace}"
|
f"[{self.workspace}] Reloading sanitized data into shared memory for {self.namespace}"
|
||||||
)
|
)
|
||||||
cleaned_data = load_json(self._file_name)
|
cleaned_data = load_json(self._file_name)
|
||||||
if cleaned_data:
|
if cleaned_data is not None:
|
||||||
self._data.clear()
|
self._data.clear()
|
||||||
self._data.update(cleaned_data)
|
self._data.update(cleaned_data)
|
||||||
|
|
||||||
await clear_all_update_flags(self.final_namespace)
|
await clear_all_update_flags(self.namespace, workspace=self.workspace)
|
||||||
|
|
||||||
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
async def get_by_id(self, id: str) -> dict[str, Any] | None:
|
||||||
async with self._storage_lock:
|
async with self._storage_lock:
|
||||||
|
|
@ -168,7 +174,7 @@ class JsonKVStorage(BaseKVStorage):
|
||||||
v["_id"] = k
|
v["_id"] = k
|
||||||
|
|
||||||
self._data.update(data)
|
self._data.update(data)
|
||||||
await set_all_update_flags(self.final_namespace)
|
await set_all_update_flags(self.namespace, workspace=self.workspace)
|
||||||
|
|
||||||
async def delete(self, ids: list[str]) -> None:
|
async def delete(self, ids: list[str]) -> None:
|
||||||
"""Delete specific records from storage by their IDs
|
"""Delete specific records from storage by their IDs
|
||||||
|
|
@ -191,7 +197,7 @@ class JsonKVStorage(BaseKVStorage):
|
||||||
any_deleted = True
|
any_deleted = True
|
||||||
|
|
||||||
if any_deleted:
|
if any_deleted:
|
||||||
await set_all_update_flags(self.final_namespace)
|
await set_all_update_flags(self.namespace, workspace=self.workspace)
|
||||||
|
|
||||||
async def is_empty(self) -> bool:
|
async def is_empty(self) -> bool:
|
||||||
"""Check if the storage is empty
|
"""Check if the storage is empty
|
||||||
|
|
@ -219,7 +225,7 @@ class JsonKVStorage(BaseKVStorage):
|
||||||
try:
|
try:
|
||||||
async with self._storage_lock:
|
async with self._storage_lock:
|
||||||
self._data.clear()
|
self._data.clear()
|
||||||
await set_all_update_flags(self.final_namespace)
|
await set_all_update_flags(self.namespace, workspace=self.workspace)
|
||||||
|
|
||||||
await self.index_done_callback()
|
await self.index_done_callback()
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|
@ -283,7 +289,7 @@ class JsonKVStorage(BaseKVStorage):
|
||||||
f"[{self.workspace}] Reloading sanitized migration data for {self.namespace}"
|
f"[{self.workspace}] Reloading sanitized migration data for {self.namespace}"
|
||||||
)
|
)
|
||||||
cleaned_data = load_json(self._file_name)
|
cleaned_data = load_json(self._file_name)
|
||||||
if cleaned_data:
|
if cleaned_data is not None:
|
||||||
return cleaned_data # Return cleaned data to update shared memory
|
return cleaned_data # Return cleaned data to update shared memory
|
||||||
|
|
||||||
return migrated_data
|
return migrated_data
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ import configparser
|
||||||
from ..utils import logger
|
from ..utils import logger
|
||||||
from ..base import BaseGraphStorage
|
from ..base import BaseGraphStorage
|
||||||
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
||||||
from ..kg.shared_storage import get_data_init_lock, get_graph_db_lock
|
from ..kg.shared_storage import get_data_init_lock
|
||||||
import pipmaster as pm
|
import pipmaster as pm
|
||||||
|
|
||||||
if not pm.is_installed("neo4j"):
|
if not pm.is_installed("neo4j"):
|
||||||
|
|
@ -101,10 +101,9 @@ class MemgraphStorage(BaseGraphStorage):
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def finalize(self):
|
async def finalize(self):
|
||||||
async with get_graph_db_lock():
|
if self._driver is not None:
|
||||||
if self._driver is not None:
|
await self._driver.close()
|
||||||
await self._driver.close()
|
self._driver = None
|
||||||
self._driver = None
|
|
||||||
|
|
||||||
async def __aexit__(self, exc_type, exc, tb):
|
async def __aexit__(self, exc_type, exc, tb):
|
||||||
await self.finalize()
|
await self.finalize()
|
||||||
|
|
@ -762,22 +761,21 @@ class MemgraphStorage(BaseGraphStorage):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Memgraph driver is not initialized. Call 'await initialize()' first."
|
"Memgraph driver is not initialized. Call 'await initialize()' first."
|
||||||
)
|
)
|
||||||
async with get_graph_db_lock():
|
try:
|
||||||
try:
|
async with self._driver.session(database=self._DATABASE) as session:
|
||||||
async with self._driver.session(database=self._DATABASE) as session:
|
workspace_label = self._get_workspace_label()
|
||||||
workspace_label = self._get_workspace_label()
|
query = f"MATCH (n:`{workspace_label}`) DETACH DELETE n"
|
||||||
query = f"MATCH (n:`{workspace_label}`) DETACH DELETE n"
|
result = await session.run(query)
|
||||||
result = await session.run(query)
|
await result.consume()
|
||||||
await result.consume()
|
logger.info(
|
||||||
logger.info(
|
f"[{self.workspace}] Dropped workspace {workspace_label} from Memgraph database {self._DATABASE}"
|
||||||
f"[{self.workspace}] Dropped workspace {workspace_label} from Memgraph database {self._DATABASE}"
|
|
||||||
)
|
|
||||||
return {"status": "success", "message": "workspace data dropped"}
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"[{self.workspace}] Error dropping workspace {workspace_label} from Memgraph database {self._DATABASE}: {e}"
|
|
||||||
)
|
)
|
||||||
return {"status": "error", "message": str(e)}
|
return {"status": "success", "message": "workspace data dropped"}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"[{self.workspace}] Error dropping workspace {workspace_label} from Memgraph database {self._DATABASE}: {e}"
|
||||||
|
)
|
||||||
|
return {"status": "error", "message": str(e)}
|
||||||
|
|
||||||
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
||||||
"""Get the total degree (sum of relationships) of two nodes.
|
"""Get the total degree (sum of relationships) of two nodes.
|
||||||
|
|
@ -1050,12 +1048,12 @@ class MemgraphStorage(BaseGraphStorage):
|
||||||
"Memgraph driver is not initialized. Call 'await initialize()' first."
|
"Memgraph driver is not initialized. Call 'await initialize()' first."
|
||||||
)
|
)
|
||||||
|
|
||||||
workspace_label = self._get_workspace_label()
|
result = None
|
||||||
async with self._driver.session(
|
try:
|
||||||
database=self._DATABASE, default_access_mode="READ"
|
workspace_label = self._get_workspace_label()
|
||||||
) as session:
|
async with self._driver.session(
|
||||||
result = None
|
database=self._DATABASE, default_access_mode="READ"
|
||||||
try:
|
) as session:
|
||||||
query = f"""
|
query = f"""
|
||||||
MATCH (n:`{workspace_label}`)
|
MATCH (n:`{workspace_label}`)
|
||||||
WHERE n.entity_id IS NOT NULL
|
WHERE n.entity_id IS NOT NULL
|
||||||
|
|
@ -1075,13 +1073,11 @@ class MemgraphStorage(BaseGraphStorage):
|
||||||
f"[{self.workspace}] Retrieved {len(labels)} popular labels (limit: {limit})"
|
f"[{self.workspace}] Retrieved {len(labels)} popular labels (limit: {limit})"
|
||||||
)
|
)
|
||||||
return labels
|
return labels
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(f"[{self.workspace}] Error getting popular labels: {str(e)}")
|
||||||
f"[{self.workspace}] Error getting popular labels: {str(e)}"
|
if result is not None:
|
||||||
)
|
await result.consume()
|
||||||
if result is not None:
|
return []
|
||||||
await result.consume()
|
|
||||||
return []
|
|
||||||
|
|
||||||
async def search_labels(self, query: str, limit: int = 50) -> list[str]:
|
async def search_labels(self, query: str, limit: int = 50) -> list[str]:
|
||||||
"""Search labels with fuzzy matching
|
"""Search labels with fuzzy matching
|
||||||
|
|
@ -1103,12 +1099,12 @@ class MemgraphStorage(BaseGraphStorage):
|
||||||
if not query_lower:
|
if not query_lower:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
workspace_label = self._get_workspace_label()
|
result = None
|
||||||
async with self._driver.session(
|
try:
|
||||||
database=self._DATABASE, default_access_mode="READ"
|
workspace_label = self._get_workspace_label()
|
||||||
) as session:
|
async with self._driver.session(
|
||||||
result = None
|
database=self._DATABASE, default_access_mode="READ"
|
||||||
try:
|
) as session:
|
||||||
cypher_query = f"""
|
cypher_query = f"""
|
||||||
MATCH (n:`{workspace_label}`)
|
MATCH (n:`{workspace_label}`)
|
||||||
WHERE n.entity_id IS NOT NULL
|
WHERE n.entity_id IS NOT NULL
|
||||||
|
|
@ -1135,8 +1131,8 @@ class MemgraphStorage(BaseGraphStorage):
|
||||||
f"[{self.workspace}] Search query '{query}' returned {len(labels)} results (limit: {limit})"
|
f"[{self.workspace}] Search query '{query}' returned {len(labels)} results (limit: {limit})"
|
||||||
)
|
)
|
||||||
return labels
|
return labels
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[{self.workspace}] Error searching labels: {str(e)}")
|
logger.error(f"[{self.workspace}] Error searching labels: {str(e)}")
|
||||||
if result is not None:
|
if result is not None:
|
||||||
await result.consume()
|
await result.consume()
|
||||||
return []
|
return []
|
||||||
|
|
|
||||||
|
|
@ -939,28 +939,22 @@ class MilvusVectorDBStorage(BaseVectorStorage):
|
||||||
milvus_workspace = os.environ.get("MILVUS_WORKSPACE")
|
milvus_workspace = os.environ.get("MILVUS_WORKSPACE")
|
||||||
if milvus_workspace and milvus_workspace.strip():
|
if milvus_workspace and milvus_workspace.strip():
|
||||||
# Use environment variable value, overriding the passed workspace parameter
|
# Use environment variable value, overriding the passed workspace parameter
|
||||||
self.workspace = milvus_workspace.strip()
|
effective_workspace = milvus_workspace.strip()
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Using MILVUS_WORKSPACE environment variable: '{self.workspace}' (overriding passed workspace)"
|
f"Using MILVUS_WORKSPACE environment variable: '{effective_workspace}' (overriding passed workspace: '{self.workspace}')"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Use the workspace parameter passed during initialization
|
# Use the workspace parameter passed during initialization
|
||||||
if self.workspace:
|
effective_workspace = self.workspace
|
||||||
|
if effective_workspace:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Using passed workspace parameter: '{self.workspace}'"
|
f"Using passed workspace parameter: '{effective_workspace}'"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get composite workspace (supports multi-tenant isolation)
|
|
||||||
composite_workspace = self._get_composite_workspace()
|
|
||||||
|
|
||||||
# Sanitize for Milvus (replace colons with underscores)
|
|
||||||
# Milvus collection names must start with a letter or underscore, and can only contain letters, numbers, and underscores
|
|
||||||
safe_composite_workspace = composite_workspace.replace(":", "_")
|
|
||||||
|
|
||||||
# Build final_namespace with workspace prefix for data isolation
|
# Build final_namespace with workspace prefix for data isolation
|
||||||
# Keep original namespace unchanged for type detection logic
|
# Keep original namespace unchanged for type detection logic
|
||||||
if safe_composite_workspace and safe_composite_workspace != "_":
|
if effective_workspace:
|
||||||
self.final_namespace = f"{safe_composite_workspace}_{self.namespace}"
|
self.final_namespace = f"{effective_workspace}_{self.namespace}"
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Final namespace with workspace prefix: '{self.final_namespace}'"
|
f"Final namespace with workspace prefix: '{self.final_namespace}'"
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -47,7 +47,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||||
else:
|
else:
|
||||||
# Default behavior when workspace is empty
|
# Default behavior when workspace is empty
|
||||||
self.final_namespace = self.namespace
|
self.final_namespace = self.namespace
|
||||||
self.workspace = "_"
|
self.workspace = ""
|
||||||
workspace_dir = working_dir
|
workspace_dir = working_dir
|
||||||
|
|
||||||
os.makedirs(workspace_dir, exist_ok=True)
|
os.makedirs(workspace_dir, exist_ok=True)
|
||||||
|
|
@ -66,11 +66,11 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||||
"""Initialize storage data"""
|
"""Initialize storage data"""
|
||||||
# Get the update flag for cross-process update notification
|
# Get the update flag for cross-process update notification
|
||||||
self.storage_updated = await get_update_flag(
|
self.storage_updated = await get_update_flag(
|
||||||
self.final_namespace, workspace=self.workspace
|
self.namespace, workspace=self.workspace
|
||||||
)
|
)
|
||||||
# Get the storage lock for use in other methods
|
# Get the storage lock for use in other methods
|
||||||
self._storage_lock = get_namespace_lock(
|
self._storage_lock = get_namespace_lock(
|
||||||
self.final_namespace, workspace=self.workspace
|
self.namespace, workspace=self.workspace
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _get_client(self):
|
async def _get_client(self):
|
||||||
|
|
@ -292,9 +292,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||||
# Save data to disk
|
# Save data to disk
|
||||||
self._client.save()
|
self._client.save()
|
||||||
# Notify other processes that data has been updated
|
# Notify other processes that data has been updated
|
||||||
await set_all_update_flags(
|
await set_all_update_flags(self.namespace, workspace=self.workspace)
|
||||||
self.final_namespace, workspace=self.workspace
|
|
||||||
)
|
|
||||||
# Reset own update flag to avoid self-reloading
|
# Reset own update flag to avoid self-reloading
|
||||||
self.storage_updated.value = False
|
self.storage_updated.value = False
|
||||||
return True # Return success
|
return True # Return success
|
||||||
|
|
@ -416,9 +414,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Notify other processes that data has been updated
|
# Notify other processes that data has been updated
|
||||||
await set_all_update_flags(
|
await set_all_update_flags(self.namespace, workspace=self.workspace)
|
||||||
self.final_namespace, workspace=self.workspace
|
|
||||||
)
|
|
||||||
# Reset own update flag to avoid self-reloading
|
# Reset own update flag to avoid self-reloading
|
||||||
self.storage_updated.value = False
|
self.storage_updated.value = False
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@ import logging
|
||||||
from ..utils import logger
|
from ..utils import logger
|
||||||
from ..base import BaseGraphStorage
|
from ..base import BaseGraphStorage
|
||||||
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
||||||
from ..kg.shared_storage import get_data_init_lock, get_graph_db_lock
|
from ..kg.shared_storage import get_data_init_lock
|
||||||
import pipmaster as pm
|
import pipmaster as pm
|
||||||
|
|
||||||
if not pm.is_installed("neo4j"):
|
if not pm.is_installed("neo4j"):
|
||||||
|
|
@ -44,6 +44,23 @@ config.read("config.ini", "utf-8")
|
||||||
logging.getLogger("neo4j").setLevel(logging.ERROR)
|
logging.getLogger("neo4j").setLevel(logging.ERROR)
|
||||||
|
|
||||||
|
|
||||||
|
READ_RETRY_EXCEPTIONS = (
|
||||||
|
neo4jExceptions.ServiceUnavailable,
|
||||||
|
neo4jExceptions.TransientError,
|
||||||
|
neo4jExceptions.SessionExpired,
|
||||||
|
ConnectionResetError,
|
||||||
|
OSError,
|
||||||
|
AttributeError,
|
||||||
|
)
|
||||||
|
|
||||||
|
READ_RETRY = retry(
|
||||||
|
stop=stop_after_attempt(3),
|
||||||
|
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||||
|
retry=retry_if_exception_type(READ_RETRY_EXCEPTIONS),
|
||||||
|
reraise=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@final
|
@final
|
||||||
@dataclass
|
@dataclass
|
||||||
class Neo4JStorage(BaseGraphStorage):
|
class Neo4JStorage(BaseGraphStorage):
|
||||||
|
|
@ -340,10 +357,9 @@ class Neo4JStorage(BaseGraphStorage):
|
||||||
|
|
||||||
async def finalize(self):
|
async def finalize(self):
|
||||||
"""Close the Neo4j driver and release all resources"""
|
"""Close the Neo4j driver and release all resources"""
|
||||||
async with get_graph_db_lock():
|
if self._driver:
|
||||||
if self._driver:
|
await self._driver.close()
|
||||||
await self._driver.close()
|
self._driver = None
|
||||||
self._driver = None
|
|
||||||
|
|
||||||
async def __aexit__(self, exc_type, exc, tb):
|
async def __aexit__(self, exc_type, exc, tb):
|
||||||
"""Ensure driver is closed when context manager exits"""
|
"""Ensure driver is closed when context manager exits"""
|
||||||
|
|
@ -353,6 +369,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||||
# Neo4J handles persistence automatically
|
# Neo4J handles persistence automatically
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@READ_RETRY
|
||||||
async def has_node(self, node_id: str) -> bool:
|
async def has_node(self, node_id: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if a node with the given label exists in the database
|
Check if a node with the given label exists in the database
|
||||||
|
|
@ -386,6 +403,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||||
await result.consume() # Ensure results are consumed even on error
|
await result.consume() # Ensure results are consumed even on error
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
@READ_RETRY
|
||||||
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if an edge exists between two nodes
|
Check if an edge exists between two nodes
|
||||||
|
|
@ -427,6 +445,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||||
await result.consume() # Ensure results are consumed even on error
|
await result.consume() # Ensure results are consumed even on error
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
@READ_RETRY
|
||||||
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
||||||
"""Get node by its label identifier, return only node properties
|
"""Get node by its label identifier, return only node properties
|
||||||
|
|
||||||
|
|
@ -480,6 +499,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
@READ_RETRY
|
||||||
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
|
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
|
||||||
"""
|
"""
|
||||||
Retrieve multiple nodes in one query using UNWIND.
|
Retrieve multiple nodes in one query using UNWIND.
|
||||||
|
|
@ -516,6 +536,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||||
await result.consume() # Make sure to consume the result fully
|
await result.consume() # Make sure to consume the result fully
|
||||||
return nodes
|
return nodes
|
||||||
|
|
||||||
|
@READ_RETRY
|
||||||
async def node_degree(self, node_id: str) -> int:
|
async def node_degree(self, node_id: str) -> int:
|
||||||
"""Get the degree (number of relationships) of a node with the given label.
|
"""Get the degree (number of relationships) of a node with the given label.
|
||||||
If multiple nodes have the same label, returns the degree of the first node.
|
If multiple nodes have the same label, returns the degree of the first node.
|
||||||
|
|
@ -564,6 +585,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
@READ_RETRY
|
||||||
async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
|
async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
|
||||||
"""
|
"""
|
||||||
Retrieve the degree for multiple nodes in a single query using UNWIND.
|
Retrieve the degree for multiple nodes in a single query using UNWIND.
|
||||||
|
|
@ -622,6 +644,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||||
degrees = int(src_degree) + int(trg_degree)
|
degrees = int(src_degree) + int(trg_degree)
|
||||||
return degrees
|
return degrees
|
||||||
|
|
||||||
|
@READ_RETRY
|
||||||
async def edge_degrees_batch(
|
async def edge_degrees_batch(
|
||||||
self, edge_pairs: list[tuple[str, str]]
|
self, edge_pairs: list[tuple[str, str]]
|
||||||
) -> dict[tuple[str, str], int]:
|
) -> dict[tuple[str, str], int]:
|
||||||
|
|
@ -648,6 +671,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||||
edge_degrees[(src, tgt)] = degrees.get(src, 0) + degrees.get(tgt, 0)
|
edge_degrees[(src, tgt)] = degrees.get(src, 0) + degrees.get(tgt, 0)
|
||||||
return edge_degrees
|
return edge_degrees
|
||||||
|
|
||||||
|
@READ_RETRY
|
||||||
async def get_edge(
|
async def get_edge(
|
||||||
self, source_node_id: str, target_node_id: str
|
self, source_node_id: str, target_node_id: str
|
||||||
) -> dict[str, str] | None:
|
) -> dict[str, str] | None:
|
||||||
|
|
@ -735,6 +759,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
@READ_RETRY
|
||||||
async def get_edges_batch(
|
async def get_edges_batch(
|
||||||
self, pairs: list[dict[str, str]]
|
self, pairs: list[dict[str, str]]
|
||||||
) -> dict[tuple[str, str], dict]:
|
) -> dict[tuple[str, str], dict]:
|
||||||
|
|
@ -785,6 +810,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||||
await result.consume()
|
await result.consume()
|
||||||
return edges_dict
|
return edges_dict
|
||||||
|
|
||||||
|
@READ_RETRY
|
||||||
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
|
||||||
"""Retrieves all edges (relationships) for a particular node identified by its label.
|
"""Retrieves all edges (relationships) for a particular node identified by its label.
|
||||||
|
|
||||||
|
|
@ -852,6 +878,7 @@ class Neo4JStorage(BaseGraphStorage):
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
@READ_RETRY
|
||||||
async def get_nodes_edges_batch(
|
async def get_nodes_edges_batch(
|
||||||
self, node_ids: list[str]
|
self, node_ids: list[str]
|
||||||
) -> dict[str, list[tuple[str, str]]]:
|
) -> dict[str, list[tuple[str, str]]]:
|
||||||
|
|
@ -1773,24 +1800,23 @@ class Neo4JStorage(BaseGraphStorage):
|
||||||
- On success: {"status": "success", "message": "workspace data dropped"}
|
- On success: {"status": "success", "message": "workspace data dropped"}
|
||||||
- On failure: {"status": "error", "message": "<error details>"}
|
- On failure: {"status": "error", "message": "<error details>"}
|
||||||
"""
|
"""
|
||||||
async with get_graph_db_lock():
|
workspace_label = self._get_workspace_label()
|
||||||
workspace_label = self._get_workspace_label()
|
try:
|
||||||
try:
|
async with self._driver.session(database=self._DATABASE) as session:
|
||||||
async with self._driver.session(database=self._DATABASE) as session:
|
# Delete all nodes and relationships in current workspace only
|
||||||
# Delete all nodes and relationships in current workspace only
|
query = f"MATCH (n:`{workspace_label}`) DETACH DELETE n"
|
||||||
query = f"MATCH (n:`{workspace_label}`) DETACH DELETE n"
|
result = await session.run(query)
|
||||||
result = await session.run(query)
|
await result.consume() # Ensure result is fully consumed
|
||||||
await result.consume() # Ensure result is fully consumed
|
|
||||||
|
|
||||||
# logger.debug(
|
# logger.debug(
|
||||||
# f"[{self.workspace}] Process {os.getpid()} drop Neo4j workspace '{workspace_label}' in database {self._DATABASE}"
|
# f"[{self.workspace}] Process {os.getpid()} drop Neo4j workspace '{workspace_label}' in database {self._DATABASE}"
|
||||||
# )
|
# )
|
||||||
return {
|
return {
|
||||||
"status": "success",
|
"status": "success",
|
||||||
"message": f"workspace '{workspace_label}' data dropped",
|
"message": f"workspace '{workspace_label}' data dropped",
|
||||||
}
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"[{self.workspace}] Error dropping Neo4j workspace '{workspace_label}' in database {self._DATABASE}: {e}"
|
f"[{self.workspace}] Error dropping Neo4j workspace '{workspace_label}' in database {self._DATABASE}: {e}"
|
||||||
)
|
)
|
||||||
return {"status": "error", "message": str(e)}
|
return {"status": "error", "message": str(e)}
|
||||||
|
|
|
||||||
|
|
@ -44,7 +44,7 @@ class NetworkXStorage(BaseGraphStorage):
|
||||||
else:
|
else:
|
||||||
# Default behavior when workspace is empty
|
# Default behavior when workspace is empty
|
||||||
workspace_dir = working_dir
|
workspace_dir = working_dir
|
||||||
self.workspace = "_"
|
self.workspace = ""
|
||||||
|
|
||||||
os.makedirs(workspace_dir, exist_ok=True)
|
os.makedirs(workspace_dir, exist_ok=True)
|
||||||
self._graphml_xml_file = os.path.join(
|
self._graphml_xml_file = os.path.join(
|
||||||
|
|
@ -70,11 +70,11 @@ class NetworkXStorage(BaseGraphStorage):
|
||||||
"""Initialize storage data"""
|
"""Initialize storage data"""
|
||||||
# Get the update flag for cross-process update notification
|
# Get the update flag for cross-process update notification
|
||||||
self.storage_updated = await get_update_flag(
|
self.storage_updated = await get_update_flag(
|
||||||
self.final_namespace, workspace=self.workspace
|
self.namespace, workspace=self.workspace
|
||||||
)
|
)
|
||||||
# Get the storage lock for use in other methods
|
# Get the storage lock for use in other methods
|
||||||
self._storage_lock = get_namespace_lock(
|
self._storage_lock = get_namespace_lock(
|
||||||
self.final_namespace, workspace=self.workspace
|
self.namespace, workspace=self.workspace
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _get_graph(self):
|
async def _get_graph(self):
|
||||||
|
|
@ -524,9 +524,7 @@ class NetworkXStorage(BaseGraphStorage):
|
||||||
self._graph, self._graphml_xml_file, self.workspace
|
self._graph, self._graphml_xml_file, self.workspace
|
||||||
)
|
)
|
||||||
# Notify other processes that data has been updated
|
# Notify other processes that data has been updated
|
||||||
await set_all_update_flags(
|
await set_all_update_flags(self.namespace, workspace=self.workspace)
|
||||||
self.final_namespace, workspace=self.workspace
|
|
||||||
)
|
|
||||||
# Reset own update flag to avoid self-reloading
|
# Reset own update flag to avoid self-reloading
|
||||||
self.storage_updated.value = False
|
self.storage_updated.value = False
|
||||||
return True # Return success
|
return True # Return success
|
||||||
|
|
@ -557,9 +555,7 @@ class NetworkXStorage(BaseGraphStorage):
|
||||||
os.remove(self._graphml_xml_file)
|
os.remove(self._graphml_xml_file)
|
||||||
self._graph = nx.Graph()
|
self._graph = nx.Graph()
|
||||||
# Notify other processes that data has been updated
|
# Notify other processes that data has been updated
|
||||||
await set_all_update_flags(
|
await set_all_update_flags(self.namespace, workspace=self.workspace)
|
||||||
self.final_namespace, workspace=self.workspace
|
|
||||||
)
|
|
||||||
# Reset own update flag to avoid self-reloading
|
# Reset own update flag to avoid self-reloading
|
||||||
self.storage_updated.value = False
|
self.storage_updated.value = False
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|
|
||||||
|
|
@ -21,7 +21,7 @@ from lightrag.base import (
|
||||||
DocStatus,
|
DocStatus,
|
||||||
DocProcessingStatus,
|
DocProcessingStatus,
|
||||||
)
|
)
|
||||||
from ..kg.shared_storage import get_data_init_lock, get_storage_lock
|
from ..kg.shared_storage import get_data_init_lock
|
||||||
import json
|
import json
|
||||||
|
|
||||||
# Import tenacity for retry logic
|
# Import tenacity for retry logic
|
||||||
|
|
@ -153,7 +153,7 @@ class RedisKVStorage(BaseKVStorage):
|
||||||
else:
|
else:
|
||||||
# When workspace is empty, final_namespace equals original namespace
|
# When workspace is empty, final_namespace equals original namespace
|
||||||
self.final_namespace = self.namespace
|
self.final_namespace = self.namespace
|
||||||
self.workspace = "_"
|
self.workspace = ""
|
||||||
logger.debug(f"Final namespace (no workspace): '{self.final_namespace}'")
|
logger.debug(f"Final namespace (no workspace): '{self.final_namespace}'")
|
||||||
|
|
||||||
self._redis_url = os.environ.get(
|
self._redis_url = os.environ.get(
|
||||||
|
|
@ -368,12 +368,13 @@ class RedisKVStorage(BaseKVStorage):
|
||||||
Returns:
|
Returns:
|
||||||
bool: True if storage is empty, False otherwise
|
bool: True if storage is empty, False otherwise
|
||||||
"""
|
"""
|
||||||
pattern = f"{self.namespace}:{self.workspace}:*"
|
pattern = f"{self.final_namespace}:*"
|
||||||
try:
|
try:
|
||||||
# Use scan to check if any keys exist
|
async with self._get_redis_connection() as redis:
|
||||||
async for key in self.redis.scan_iter(match=pattern, count=1):
|
# Use scan to check if any keys exist
|
||||||
return False # Found at least one key
|
async for key in redis.scan_iter(match=pattern, count=1):
|
||||||
return True # No keys found
|
return False # Found at least one key
|
||||||
|
return True # No keys found
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[{self.workspace}] Error checking if storage is empty: {e}")
|
logger.error(f"[{self.workspace}] Error checking if storage is empty: {e}")
|
||||||
return True
|
return True
|
||||||
|
|
@ -400,42 +401,39 @@ class RedisKVStorage(BaseKVStorage):
|
||||||
Returns:
|
Returns:
|
||||||
dict[str, str]: Status of the operation with keys 'status' and 'message'
|
dict[str, str]: Status of the operation with keys 'status' and 'message'
|
||||||
"""
|
"""
|
||||||
async with get_storage_lock():
|
async with self._get_redis_connection() as redis:
|
||||||
async with self._get_redis_connection() as redis:
|
try:
|
||||||
try:
|
# Use SCAN to find all keys with the namespace prefix
|
||||||
# Use SCAN to find all keys with the namespace prefix
|
pattern = f"{self.final_namespace}:*"
|
||||||
pattern = f"{self.final_namespace}:*"
|
cursor = 0
|
||||||
cursor = 0
|
deleted_count = 0
|
||||||
deleted_count = 0
|
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
cursor, keys = await redis.scan(
|
cursor, keys = await redis.scan(cursor, match=pattern, count=1000)
|
||||||
cursor, match=pattern, count=1000
|
if keys:
|
||||||
)
|
# Delete keys in batches
|
||||||
if keys:
|
pipe = redis.pipeline()
|
||||||
# Delete keys in batches
|
for key in keys:
|
||||||
pipe = redis.pipeline()
|
pipe.delete(key)
|
||||||
for key in keys:
|
results = await pipe.execute()
|
||||||
pipe.delete(key)
|
deleted_count += sum(results)
|
||||||
results = await pipe.execute()
|
|
||||||
deleted_count += sum(results)
|
|
||||||
|
|
||||||
if cursor == 0:
|
if cursor == 0:
|
||||||
break
|
break
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[{self.workspace}] Dropped {deleted_count} keys from {self.namespace}"
|
f"[{self.workspace}] Dropped {deleted_count} keys from {self.namespace}"
|
||||||
)
|
)
|
||||||
return {
|
return {
|
||||||
"status": "success",
|
"status": "success",
|
||||||
"message": f"{deleted_count} keys dropped",
|
"message": f"{deleted_count} keys dropped",
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"[{self.workspace}] Error dropping keys from {self.namespace}: {e}"
|
f"[{self.workspace}] Error dropping keys from {self.namespace}: {e}"
|
||||||
)
|
)
|
||||||
return {"status": "error", "message": str(e)}
|
return {"status": "error", "message": str(e)}
|
||||||
|
|
||||||
async def _migrate_legacy_cache_structure(self):
|
async def _migrate_legacy_cache_structure(self):
|
||||||
"""Migrate legacy nested cache structure to flattened structure for Redis
|
"""Migrate legacy nested cache structure to flattened structure for Redis
|
||||||
|
|
@ -1090,35 +1088,32 @@ class RedisDocStatusStorage(DocStatusStorage):
|
||||||
|
|
||||||
async def drop(self) -> dict[str, str]:
|
async def drop(self) -> dict[str, str]:
|
||||||
"""Drop all document status data from storage and clean up resources"""
|
"""Drop all document status data from storage and clean up resources"""
|
||||||
async with get_storage_lock():
|
try:
|
||||||
try:
|
async with self._get_redis_connection() as redis:
|
||||||
async with self._get_redis_connection() as redis:
|
# Use SCAN to find all keys with the namespace prefix
|
||||||
# Use SCAN to find all keys with the namespace prefix
|
pattern = f"{self.final_namespace}:*"
|
||||||
pattern = f"{self.final_namespace}:*"
|
cursor = 0
|
||||||
cursor = 0
|
deleted_count = 0
|
||||||
deleted_count = 0
|
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
cursor, keys = await redis.scan(
|
cursor, keys = await redis.scan(cursor, match=pattern, count=1000)
|
||||||
cursor, match=pattern, count=1000
|
if keys:
|
||||||
)
|
# Delete keys in batches
|
||||||
if keys:
|
pipe = redis.pipeline()
|
||||||
# Delete keys in batches
|
for key in keys:
|
||||||
pipe = redis.pipeline()
|
pipe.delete(key)
|
||||||
for key in keys:
|
results = await pipe.execute()
|
||||||
pipe.delete(key)
|
deleted_count += sum(results)
|
||||||
results = await pipe.execute()
|
|
||||||
deleted_count += sum(results)
|
|
||||||
|
|
||||||
if cursor == 0:
|
if cursor == 0:
|
||||||
break
|
break
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[{self.workspace}] Dropped {deleted_count} doc status keys from {self.namespace}"
|
f"[{self.workspace}] Dropped {deleted_count} doc status keys from {self.namespace}"
|
||||||
)
|
|
||||||
return {"status": "success", "message": "data dropped"}
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"[{self.workspace}] Error dropping doc status {self.namespace}: {e}"
|
|
||||||
)
|
)
|
||||||
return {"status": "error", "message": str(e)}
|
return {"status": "success", "message": "data dropped"}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"[{self.workspace}] Error dropping doc status {self.namespace}: {e}"
|
||||||
|
)
|
||||||
|
return {"status": "error", "message": str(e)}
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue