fix: sync all kg modules from upstream

This commit is contained in:
Raphaël MANSUY 2025-12-04 19:22:13 +08:00
parent 09d0721cab
commit ccd2f82174
7 changed files with 190 additions and 181 deletions

View file

@ -13,7 +13,7 @@ from lightrag.utils import (
from lightrag.exceptions import StorageNotInitializedError from lightrag.exceptions import StorageNotInitializedError
from .shared_storage import ( from .shared_storage import (
get_namespace_data, get_namespace_data,
get_storage_lock, get_namespace_lock,
get_data_init_lock, get_data_init_lock,
get_update_flag, get_update_flag,
set_all_update_flags, set_all_update_flags,
@ -30,12 +30,10 @@ class JsonKVStorage(BaseKVStorage):
if self.workspace: if self.workspace:
# Include workspace in the file path for data isolation # Include workspace in the file path for data isolation
workspace_dir = os.path.join(working_dir, self.workspace) workspace_dir = os.path.join(working_dir, self.workspace)
self.final_namespace = f"{self.workspace}_{self.namespace}"
else: else:
# Default behavior when workspace is empty # Default behavior when workspace is empty
workspace_dir = working_dir workspace_dir = working_dir
self.final_namespace = self.namespace self.workspace = ""
self.workspace = "_"
os.makedirs(workspace_dir, exist_ok=True) os.makedirs(workspace_dir, exist_ok=True)
self._file_name = os.path.join(workspace_dir, f"kv_store_{self.namespace}.json") self._file_name = os.path.join(workspace_dir, f"kv_store_{self.namespace}.json")
@ -46,12 +44,20 @@ class JsonKVStorage(BaseKVStorage):
async def initialize(self): async def initialize(self):
"""Initialize storage data""" """Initialize storage data"""
self._storage_lock = get_storage_lock() self._storage_lock = get_namespace_lock(
self.storage_updated = await get_update_flag(self.final_namespace) self.namespace, workspace=self.workspace
)
self.storage_updated = await get_update_flag(
self.namespace, workspace=self.workspace
)
async with get_data_init_lock(): async with get_data_init_lock():
# check need_init must before get_namespace_data # check need_init must before get_namespace_data
need_init = await try_initialize_namespace(self.final_namespace) need_init = await try_initialize_namespace(
self._data = await get_namespace_data(self.final_namespace) self.namespace, workspace=self.workspace
)
self._data = await get_namespace_data(
self.namespace, workspace=self.workspace
)
if need_init: if need_init:
loaded_data = load_json(self._file_name) or {} loaded_data = load_json(self._file_name) or {}
async with self._storage_lock: async with self._storage_lock:
@ -91,11 +97,11 @@ class JsonKVStorage(BaseKVStorage):
f"[{self.workspace}] Reloading sanitized data into shared memory for {self.namespace}" f"[{self.workspace}] Reloading sanitized data into shared memory for {self.namespace}"
) )
cleaned_data = load_json(self._file_name) cleaned_data = load_json(self._file_name)
if cleaned_data: if cleaned_data is not None:
self._data.clear() self._data.clear()
self._data.update(cleaned_data) self._data.update(cleaned_data)
await clear_all_update_flags(self.final_namespace) await clear_all_update_flags(self.namespace, workspace=self.workspace)
async def get_by_id(self, id: str) -> dict[str, Any] | None: async def get_by_id(self, id: str) -> dict[str, Any] | None:
async with self._storage_lock: async with self._storage_lock:
@ -168,7 +174,7 @@ class JsonKVStorage(BaseKVStorage):
v["_id"] = k v["_id"] = k
self._data.update(data) self._data.update(data)
await set_all_update_flags(self.final_namespace) await set_all_update_flags(self.namespace, workspace=self.workspace)
async def delete(self, ids: list[str]) -> None: async def delete(self, ids: list[str]) -> None:
"""Delete specific records from storage by their IDs """Delete specific records from storage by their IDs
@ -191,7 +197,7 @@ class JsonKVStorage(BaseKVStorage):
any_deleted = True any_deleted = True
if any_deleted: if any_deleted:
await set_all_update_flags(self.final_namespace) await set_all_update_flags(self.namespace, workspace=self.workspace)
async def is_empty(self) -> bool: async def is_empty(self) -> bool:
"""Check if the storage is empty """Check if the storage is empty
@ -219,7 +225,7 @@ class JsonKVStorage(BaseKVStorage):
try: try:
async with self._storage_lock: async with self._storage_lock:
self._data.clear() self._data.clear()
await set_all_update_flags(self.final_namespace) await set_all_update_flags(self.namespace, workspace=self.workspace)
await self.index_done_callback() await self.index_done_callback()
logger.info( logger.info(
@ -283,7 +289,7 @@ class JsonKVStorage(BaseKVStorage):
f"[{self.workspace}] Reloading sanitized migration data for {self.namespace}" f"[{self.workspace}] Reloading sanitized migration data for {self.namespace}"
) )
cleaned_data = load_json(self._file_name) cleaned_data = load_json(self._file_name)
if cleaned_data: if cleaned_data is not None:
return cleaned_data # Return cleaned data to update shared memory return cleaned_data # Return cleaned data to update shared memory
return migrated_data return migrated_data

View file

@ -8,7 +8,7 @@ import configparser
from ..utils import logger from ..utils import logger
from ..base import BaseGraphStorage from ..base import BaseGraphStorage
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
from ..kg.shared_storage import get_data_init_lock, get_graph_db_lock from ..kg.shared_storage import get_data_init_lock
import pipmaster as pm import pipmaster as pm
if not pm.is_installed("neo4j"): if not pm.is_installed("neo4j"):
@ -101,10 +101,9 @@ class MemgraphStorage(BaseGraphStorage):
raise raise
async def finalize(self): async def finalize(self):
async with get_graph_db_lock(): if self._driver is not None:
if self._driver is not None: await self._driver.close()
await self._driver.close() self._driver = None
self._driver = None
async def __aexit__(self, exc_type, exc, tb): async def __aexit__(self, exc_type, exc, tb):
await self.finalize() await self.finalize()
@ -762,22 +761,21 @@ class MemgraphStorage(BaseGraphStorage):
raise RuntimeError( raise RuntimeError(
"Memgraph driver is not initialized. Call 'await initialize()' first." "Memgraph driver is not initialized. Call 'await initialize()' first."
) )
async with get_graph_db_lock(): try:
try: async with self._driver.session(database=self._DATABASE) as session:
async with self._driver.session(database=self._DATABASE) as session: workspace_label = self._get_workspace_label()
workspace_label = self._get_workspace_label() query = f"MATCH (n:`{workspace_label}`) DETACH DELETE n"
query = f"MATCH (n:`{workspace_label}`) DETACH DELETE n" result = await session.run(query)
result = await session.run(query) await result.consume()
await result.consume() logger.info(
logger.info( f"[{self.workspace}] Dropped workspace {workspace_label} from Memgraph database {self._DATABASE}"
f"[{self.workspace}] Dropped workspace {workspace_label} from Memgraph database {self._DATABASE}"
)
return {"status": "success", "message": "workspace data dropped"}
except Exception as e:
logger.error(
f"[{self.workspace}] Error dropping workspace {workspace_label} from Memgraph database {self._DATABASE}: {e}"
) )
return {"status": "error", "message": str(e)} return {"status": "success", "message": "workspace data dropped"}
except Exception as e:
logger.error(
f"[{self.workspace}] Error dropping workspace {workspace_label} from Memgraph database {self._DATABASE}: {e}"
)
return {"status": "error", "message": str(e)}
async def edge_degree(self, src_id: str, tgt_id: str) -> int: async def edge_degree(self, src_id: str, tgt_id: str) -> int:
"""Get the total degree (sum of relationships) of two nodes. """Get the total degree (sum of relationships) of two nodes.
@ -1050,12 +1048,12 @@ class MemgraphStorage(BaseGraphStorage):
"Memgraph driver is not initialized. Call 'await initialize()' first." "Memgraph driver is not initialized. Call 'await initialize()' first."
) )
workspace_label = self._get_workspace_label() result = None
async with self._driver.session( try:
database=self._DATABASE, default_access_mode="READ" workspace_label = self._get_workspace_label()
) as session: async with self._driver.session(
result = None database=self._DATABASE, default_access_mode="READ"
try: ) as session:
query = f""" query = f"""
MATCH (n:`{workspace_label}`) MATCH (n:`{workspace_label}`)
WHERE n.entity_id IS NOT NULL WHERE n.entity_id IS NOT NULL
@ -1075,13 +1073,11 @@ class MemgraphStorage(BaseGraphStorage):
f"[{self.workspace}] Retrieved {len(labels)} popular labels (limit: {limit})" f"[{self.workspace}] Retrieved {len(labels)} popular labels (limit: {limit})"
) )
return labels return labels
except Exception as e: except Exception as e:
logger.error( logger.error(f"[{self.workspace}] Error getting popular labels: {str(e)}")
f"[{self.workspace}] Error getting popular labels: {str(e)}" if result is not None:
) await result.consume()
if result is not None: return []
await result.consume()
return []
async def search_labels(self, query: str, limit: int = 50) -> list[str]: async def search_labels(self, query: str, limit: int = 50) -> list[str]:
"""Search labels with fuzzy matching """Search labels with fuzzy matching
@ -1103,12 +1099,12 @@ class MemgraphStorage(BaseGraphStorage):
if not query_lower: if not query_lower:
return [] return []
workspace_label = self._get_workspace_label() result = None
async with self._driver.session( try:
database=self._DATABASE, default_access_mode="READ" workspace_label = self._get_workspace_label()
) as session: async with self._driver.session(
result = None database=self._DATABASE, default_access_mode="READ"
try: ) as session:
cypher_query = f""" cypher_query = f"""
MATCH (n:`{workspace_label}`) MATCH (n:`{workspace_label}`)
WHERE n.entity_id IS NOT NULL WHERE n.entity_id IS NOT NULL
@ -1135,8 +1131,8 @@ class MemgraphStorage(BaseGraphStorage):
f"[{self.workspace}] Search query '{query}' returned {len(labels)} results (limit: {limit})" f"[{self.workspace}] Search query '{query}' returned {len(labels)} results (limit: {limit})"
) )
return labels return labels
except Exception as e: except Exception as e:
logger.error(f"[{self.workspace}] Error searching labels: {str(e)}") logger.error(f"[{self.workspace}] Error searching labels: {str(e)}")
if result is not None: if result is not None:
await result.consume() await result.consume()
return [] return []

View file

@ -939,28 +939,22 @@ class MilvusVectorDBStorage(BaseVectorStorage):
milvus_workspace = os.environ.get("MILVUS_WORKSPACE") milvus_workspace = os.environ.get("MILVUS_WORKSPACE")
if milvus_workspace and milvus_workspace.strip(): if milvus_workspace and milvus_workspace.strip():
# Use environment variable value, overriding the passed workspace parameter # Use environment variable value, overriding the passed workspace parameter
self.workspace = milvus_workspace.strip() effective_workspace = milvus_workspace.strip()
logger.info( logger.info(
f"Using MILVUS_WORKSPACE environment variable: '{self.workspace}' (overriding passed workspace)" f"Using MILVUS_WORKSPACE environment variable: '{effective_workspace}' (overriding passed workspace: '{self.workspace}')"
) )
else: else:
# Use the workspace parameter passed during initialization # Use the workspace parameter passed during initialization
if self.workspace: effective_workspace = self.workspace
if effective_workspace:
logger.debug( logger.debug(
f"Using passed workspace parameter: '{self.workspace}'" f"Using passed workspace parameter: '{effective_workspace}'"
) )
# Get composite workspace (supports multi-tenant isolation)
composite_workspace = self._get_composite_workspace()
# Sanitize for Milvus (replace colons with underscores)
# Milvus collection names must start with a letter or underscore, and can only contain letters, numbers, and underscores
safe_composite_workspace = composite_workspace.replace(":", "_")
# Build final_namespace with workspace prefix for data isolation # Build final_namespace with workspace prefix for data isolation
# Keep original namespace unchanged for type detection logic # Keep original namespace unchanged for type detection logic
if safe_composite_workspace and safe_composite_workspace != "_": if effective_workspace:
self.final_namespace = f"{safe_composite_workspace}_{self.namespace}" self.final_namespace = f"{effective_workspace}_{self.namespace}"
logger.debug( logger.debug(
f"Final namespace with workspace prefix: '{self.final_namespace}'" f"Final namespace with workspace prefix: '{self.final_namespace}'"
) )

View file

@ -47,7 +47,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
else: else:
# Default behavior when workspace is empty # Default behavior when workspace is empty
self.final_namespace = self.namespace self.final_namespace = self.namespace
self.workspace = "_" self.workspace = ""
workspace_dir = working_dir workspace_dir = working_dir
os.makedirs(workspace_dir, exist_ok=True) os.makedirs(workspace_dir, exist_ok=True)
@ -66,11 +66,11 @@ class NanoVectorDBStorage(BaseVectorStorage):
"""Initialize storage data""" """Initialize storage data"""
# Get the update flag for cross-process update notification # Get the update flag for cross-process update notification
self.storage_updated = await get_update_flag( self.storage_updated = await get_update_flag(
self.final_namespace, workspace=self.workspace self.namespace, workspace=self.workspace
) )
# Get the storage lock for use in other methods # Get the storage lock for use in other methods
self._storage_lock = get_namespace_lock( self._storage_lock = get_namespace_lock(
self.final_namespace, workspace=self.workspace self.namespace, workspace=self.workspace
) )
async def _get_client(self): async def _get_client(self):
@ -292,9 +292,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
# Save data to disk # Save data to disk
self._client.save() self._client.save()
# Notify other processes that data has been updated # Notify other processes that data has been updated
await set_all_update_flags( await set_all_update_flags(self.namespace, workspace=self.workspace)
self.final_namespace, workspace=self.workspace
)
# Reset own update flag to avoid self-reloading # Reset own update flag to avoid self-reloading
self.storage_updated.value = False self.storage_updated.value = False
return True # Return success return True # Return success
@ -416,9 +414,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
) )
# Notify other processes that data has been updated # Notify other processes that data has been updated
await set_all_update_flags( await set_all_update_flags(self.namespace, workspace=self.workspace)
self.final_namespace, workspace=self.workspace
)
# Reset own update flag to avoid self-reloading # Reset own update flag to avoid self-reloading
self.storage_updated.value = False self.storage_updated.value = False

View file

@ -16,7 +16,7 @@ import logging
from ..utils import logger from ..utils import logger
from ..base import BaseGraphStorage from ..base import BaseGraphStorage
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
from ..kg.shared_storage import get_data_init_lock, get_graph_db_lock from ..kg.shared_storage import get_data_init_lock
import pipmaster as pm import pipmaster as pm
if not pm.is_installed("neo4j"): if not pm.is_installed("neo4j"):
@ -44,6 +44,23 @@ config.read("config.ini", "utf-8")
logging.getLogger("neo4j").setLevel(logging.ERROR) logging.getLogger("neo4j").setLevel(logging.ERROR)
READ_RETRY_EXCEPTIONS = (
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.SessionExpired,
ConnectionResetError,
OSError,
AttributeError,
)
READ_RETRY = retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(READ_RETRY_EXCEPTIONS),
reraise=True,
)
@final @final
@dataclass @dataclass
class Neo4JStorage(BaseGraphStorage): class Neo4JStorage(BaseGraphStorage):
@ -340,10 +357,9 @@ class Neo4JStorage(BaseGraphStorage):
async def finalize(self): async def finalize(self):
"""Close the Neo4j driver and release all resources""" """Close the Neo4j driver and release all resources"""
async with get_graph_db_lock(): if self._driver:
if self._driver: await self._driver.close()
await self._driver.close() self._driver = None
self._driver = None
async def __aexit__(self, exc_type, exc, tb): async def __aexit__(self, exc_type, exc, tb):
"""Ensure driver is closed when context manager exits""" """Ensure driver is closed when context manager exits"""
@ -353,6 +369,7 @@ class Neo4JStorage(BaseGraphStorage):
# Neo4J handles persistence automatically # Neo4J handles persistence automatically
pass pass
@READ_RETRY
async def has_node(self, node_id: str) -> bool: async def has_node(self, node_id: str) -> bool:
""" """
Check if a node with the given label exists in the database Check if a node with the given label exists in the database
@ -386,6 +403,7 @@ class Neo4JStorage(BaseGraphStorage):
await result.consume() # Ensure results are consumed even on error await result.consume() # Ensure results are consumed even on error
raise raise
@READ_RETRY
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
""" """
Check if an edge exists between two nodes Check if an edge exists between two nodes
@ -427,6 +445,7 @@ class Neo4JStorage(BaseGraphStorage):
await result.consume() # Ensure results are consumed even on error await result.consume() # Ensure results are consumed even on error
raise raise
@READ_RETRY
async def get_node(self, node_id: str) -> dict[str, str] | None: async def get_node(self, node_id: str) -> dict[str, str] | None:
"""Get node by its label identifier, return only node properties """Get node by its label identifier, return only node properties
@ -480,6 +499,7 @@ class Neo4JStorage(BaseGraphStorage):
) )
raise raise
@READ_RETRY
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]: async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
""" """
Retrieve multiple nodes in one query using UNWIND. Retrieve multiple nodes in one query using UNWIND.
@ -516,6 +536,7 @@ class Neo4JStorage(BaseGraphStorage):
await result.consume() # Make sure to consume the result fully await result.consume() # Make sure to consume the result fully
return nodes return nodes
@READ_RETRY
async def node_degree(self, node_id: str) -> int: async def node_degree(self, node_id: str) -> int:
"""Get the degree (number of relationships) of a node with the given label. """Get the degree (number of relationships) of a node with the given label.
If multiple nodes have the same label, returns the degree of the first node. If multiple nodes have the same label, returns the degree of the first node.
@ -564,6 +585,7 @@ class Neo4JStorage(BaseGraphStorage):
) )
raise raise
@READ_RETRY
async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]: async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
""" """
Retrieve the degree for multiple nodes in a single query using UNWIND. Retrieve the degree for multiple nodes in a single query using UNWIND.
@ -622,6 +644,7 @@ class Neo4JStorage(BaseGraphStorage):
degrees = int(src_degree) + int(trg_degree) degrees = int(src_degree) + int(trg_degree)
return degrees return degrees
@READ_RETRY
async def edge_degrees_batch( async def edge_degrees_batch(
self, edge_pairs: list[tuple[str, str]] self, edge_pairs: list[tuple[str, str]]
) -> dict[tuple[str, str], int]: ) -> dict[tuple[str, str], int]:
@ -648,6 +671,7 @@ class Neo4JStorage(BaseGraphStorage):
edge_degrees[(src, tgt)] = degrees.get(src, 0) + degrees.get(tgt, 0) edge_degrees[(src, tgt)] = degrees.get(src, 0) + degrees.get(tgt, 0)
return edge_degrees return edge_degrees
@READ_RETRY
async def get_edge( async def get_edge(
self, source_node_id: str, target_node_id: str self, source_node_id: str, target_node_id: str
) -> dict[str, str] | None: ) -> dict[str, str] | None:
@ -735,6 +759,7 @@ class Neo4JStorage(BaseGraphStorage):
) )
raise raise
@READ_RETRY
async def get_edges_batch( async def get_edges_batch(
self, pairs: list[dict[str, str]] self, pairs: list[dict[str, str]]
) -> dict[tuple[str, str], dict]: ) -> dict[tuple[str, str], dict]:
@ -785,6 +810,7 @@ class Neo4JStorage(BaseGraphStorage):
await result.consume() await result.consume()
return edges_dict return edges_dict
@READ_RETRY
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
"""Retrieves all edges (relationships) for a particular node identified by its label. """Retrieves all edges (relationships) for a particular node identified by its label.
@ -852,6 +878,7 @@ class Neo4JStorage(BaseGraphStorage):
) )
raise raise
@READ_RETRY
async def get_nodes_edges_batch( async def get_nodes_edges_batch(
self, node_ids: list[str] self, node_ids: list[str]
) -> dict[str, list[tuple[str, str]]]: ) -> dict[str, list[tuple[str, str]]]:
@ -1773,24 +1800,23 @@ class Neo4JStorage(BaseGraphStorage):
- On success: {"status": "success", "message": "workspace data dropped"} - On success: {"status": "success", "message": "workspace data dropped"}
- On failure: {"status": "error", "message": "<error details>"} - On failure: {"status": "error", "message": "<error details>"}
""" """
async with get_graph_db_lock(): workspace_label = self._get_workspace_label()
workspace_label = self._get_workspace_label() try:
try: async with self._driver.session(database=self._DATABASE) as session:
async with self._driver.session(database=self._DATABASE) as session: # Delete all nodes and relationships in current workspace only
# Delete all nodes and relationships in current workspace only query = f"MATCH (n:`{workspace_label}`) DETACH DELETE n"
query = f"MATCH (n:`{workspace_label}`) DETACH DELETE n" result = await session.run(query)
result = await session.run(query) await result.consume() # Ensure result is fully consumed
await result.consume() # Ensure result is fully consumed
# logger.debug( # logger.debug(
# f"[{self.workspace}] Process {os.getpid()} drop Neo4j workspace '{workspace_label}' in database {self._DATABASE}" # f"[{self.workspace}] Process {os.getpid()} drop Neo4j workspace '{workspace_label}' in database {self._DATABASE}"
# ) # )
return { return {
"status": "success", "status": "success",
"message": f"workspace '{workspace_label}' data dropped", "message": f"workspace '{workspace_label}' data dropped",
} }
except Exception as e: except Exception as e:
logger.error( logger.error(
f"[{self.workspace}] Error dropping Neo4j workspace '{workspace_label}' in database {self._DATABASE}: {e}" f"[{self.workspace}] Error dropping Neo4j workspace '{workspace_label}' in database {self._DATABASE}: {e}"
) )
return {"status": "error", "message": str(e)} return {"status": "error", "message": str(e)}

View file

@ -44,7 +44,7 @@ class NetworkXStorage(BaseGraphStorage):
else: else:
# Default behavior when workspace is empty # Default behavior when workspace is empty
workspace_dir = working_dir workspace_dir = working_dir
self.workspace = "_" self.workspace = ""
os.makedirs(workspace_dir, exist_ok=True) os.makedirs(workspace_dir, exist_ok=True)
self._graphml_xml_file = os.path.join( self._graphml_xml_file = os.path.join(
@ -70,11 +70,11 @@ class NetworkXStorage(BaseGraphStorage):
"""Initialize storage data""" """Initialize storage data"""
# Get the update flag for cross-process update notification # Get the update flag for cross-process update notification
self.storage_updated = await get_update_flag( self.storage_updated = await get_update_flag(
self.final_namespace, workspace=self.workspace self.namespace, workspace=self.workspace
) )
# Get the storage lock for use in other methods # Get the storage lock for use in other methods
self._storage_lock = get_namespace_lock( self._storage_lock = get_namespace_lock(
self.final_namespace, workspace=self.workspace self.namespace, workspace=self.workspace
) )
async def _get_graph(self): async def _get_graph(self):
@ -524,9 +524,7 @@ class NetworkXStorage(BaseGraphStorage):
self._graph, self._graphml_xml_file, self.workspace self._graph, self._graphml_xml_file, self.workspace
) )
# Notify other processes that data has been updated # Notify other processes that data has been updated
await set_all_update_flags( await set_all_update_flags(self.namespace, workspace=self.workspace)
self.final_namespace, workspace=self.workspace
)
# Reset own update flag to avoid self-reloading # Reset own update flag to avoid self-reloading
self.storage_updated.value = False self.storage_updated.value = False
return True # Return success return True # Return success
@ -557,9 +555,7 @@ class NetworkXStorage(BaseGraphStorage):
os.remove(self._graphml_xml_file) os.remove(self._graphml_xml_file)
self._graph = nx.Graph() self._graph = nx.Graph()
# Notify other processes that data has been updated # Notify other processes that data has been updated
await set_all_update_flags( await set_all_update_flags(self.namespace, workspace=self.workspace)
self.final_namespace, workspace=self.workspace
)
# Reset own update flag to avoid self-reloading # Reset own update flag to avoid self-reloading
self.storage_updated.value = False self.storage_updated.value = False
logger.info( logger.info(

View file

@ -21,7 +21,7 @@ from lightrag.base import (
DocStatus, DocStatus,
DocProcessingStatus, DocProcessingStatus,
) )
from ..kg.shared_storage import get_data_init_lock, get_storage_lock from ..kg.shared_storage import get_data_init_lock
import json import json
# Import tenacity for retry logic # Import tenacity for retry logic
@ -153,7 +153,7 @@ class RedisKVStorage(BaseKVStorage):
else: else:
# When workspace is empty, final_namespace equals original namespace # When workspace is empty, final_namespace equals original namespace
self.final_namespace = self.namespace self.final_namespace = self.namespace
self.workspace = "_" self.workspace = ""
logger.debug(f"Final namespace (no workspace): '{self.final_namespace}'") logger.debug(f"Final namespace (no workspace): '{self.final_namespace}'")
self._redis_url = os.environ.get( self._redis_url = os.environ.get(
@ -368,12 +368,13 @@ class RedisKVStorage(BaseKVStorage):
Returns: Returns:
bool: True if storage is empty, False otherwise bool: True if storage is empty, False otherwise
""" """
pattern = f"{self.namespace}:{self.workspace}:*" pattern = f"{self.final_namespace}:*"
try: try:
# Use scan to check if any keys exist async with self._get_redis_connection() as redis:
async for key in self.redis.scan_iter(match=pattern, count=1): # Use scan to check if any keys exist
return False # Found at least one key async for key in redis.scan_iter(match=pattern, count=1):
return True # No keys found return False # Found at least one key
return True # No keys found
except Exception as e: except Exception as e:
logger.error(f"[{self.workspace}] Error checking if storage is empty: {e}") logger.error(f"[{self.workspace}] Error checking if storage is empty: {e}")
return True return True
@ -400,42 +401,39 @@ class RedisKVStorage(BaseKVStorage):
Returns: Returns:
dict[str, str]: Status of the operation with keys 'status' and 'message' dict[str, str]: Status of the operation with keys 'status' and 'message'
""" """
async with get_storage_lock(): async with self._get_redis_connection() as redis:
async with self._get_redis_connection() as redis: try:
try: # Use SCAN to find all keys with the namespace prefix
# Use SCAN to find all keys with the namespace prefix pattern = f"{self.final_namespace}:*"
pattern = f"{self.final_namespace}:*" cursor = 0
cursor = 0 deleted_count = 0
deleted_count = 0
while True: while True:
cursor, keys = await redis.scan( cursor, keys = await redis.scan(cursor, match=pattern, count=1000)
cursor, match=pattern, count=1000 if keys:
) # Delete keys in batches
if keys: pipe = redis.pipeline()
# Delete keys in batches for key in keys:
pipe = redis.pipeline() pipe.delete(key)
for key in keys: results = await pipe.execute()
pipe.delete(key) deleted_count += sum(results)
results = await pipe.execute()
deleted_count += sum(results)
if cursor == 0: if cursor == 0:
break break
logger.info( logger.info(
f"[{self.workspace}] Dropped {deleted_count} keys from {self.namespace}" f"[{self.workspace}] Dropped {deleted_count} keys from {self.namespace}"
) )
return { return {
"status": "success", "status": "success",
"message": f"{deleted_count} keys dropped", "message": f"{deleted_count} keys dropped",
} }
except Exception as e: except Exception as e:
logger.error( logger.error(
f"[{self.workspace}] Error dropping keys from {self.namespace}: {e}" f"[{self.workspace}] Error dropping keys from {self.namespace}: {e}"
) )
return {"status": "error", "message": str(e)} return {"status": "error", "message": str(e)}
async def _migrate_legacy_cache_structure(self): async def _migrate_legacy_cache_structure(self):
"""Migrate legacy nested cache structure to flattened structure for Redis """Migrate legacy nested cache structure to flattened structure for Redis
@ -1090,35 +1088,32 @@ class RedisDocStatusStorage(DocStatusStorage):
async def drop(self) -> dict[str, str]: async def drop(self) -> dict[str, str]:
"""Drop all document status data from storage and clean up resources""" """Drop all document status data from storage and clean up resources"""
async with get_storage_lock(): try:
try: async with self._get_redis_connection() as redis:
async with self._get_redis_connection() as redis: # Use SCAN to find all keys with the namespace prefix
# Use SCAN to find all keys with the namespace prefix pattern = f"{self.final_namespace}:*"
pattern = f"{self.final_namespace}:*" cursor = 0
cursor = 0 deleted_count = 0
deleted_count = 0
while True: while True:
cursor, keys = await redis.scan( cursor, keys = await redis.scan(cursor, match=pattern, count=1000)
cursor, match=pattern, count=1000 if keys:
) # Delete keys in batches
if keys: pipe = redis.pipeline()
# Delete keys in batches for key in keys:
pipe = redis.pipeline() pipe.delete(key)
for key in keys: results = await pipe.execute()
pipe.delete(key) deleted_count += sum(results)
results = await pipe.execute()
deleted_count += sum(results)
if cursor == 0: if cursor == 0:
break break
logger.info( logger.info(
f"[{self.workspace}] Dropped {deleted_count} doc status keys from {self.namespace}" f"[{self.workspace}] Dropped {deleted_count} doc status keys from {self.namespace}"
)
return {"status": "success", "message": "data dropped"}
except Exception as e:
logger.error(
f"[{self.workspace}] Error dropping doc status {self.namespace}: {e}"
) )
return {"status": "error", "message": str(e)} return {"status": "success", "message": "data dropped"}
except Exception as e:
logger.error(
f"[{self.workspace}] Error dropping doc status {self.namespace}: {e}"
)
return {"status": "error", "message": str(e)}