Refactor storage classes to use namespace instead of final_namespace

(cherry picked from commit fd486bc922)
This commit is contained in:
yangdx 2025-11-17 05:07:53 +08:00 committed by Raphaël MANSUY
parent ed46d375fb
commit 5bd1320a1d
2 changed files with 37 additions and 61 deletions

View file

@ -15,7 +15,7 @@ from lightrag.utils import (
from lightrag.base import BaseVectorStorage
from nano_vectordb import NanoVectorDB
from .shared_storage import (
get_storage_lock,
get_namespace_lock,
get_update_flag,
set_all_update_flags,
)
@ -40,20 +40,14 @@ class NanoVectorDBStorage(BaseVectorStorage):
self.cosine_better_than_threshold = cosine_threshold
working_dir = self.global_config["working_dir"]
# Get composite workspace (supports multi-tenant isolation)
composite_workspace = self._get_composite_workspace()
if composite_workspace and composite_workspace != "_":
# Include composite workspace in the file path for data isolation
# For multi-tenant: tenant_id:kb_id:workspace
# For single-tenant: just workspace
workspace_dir = os.path.join(working_dir, composite_workspace)
self.final_namespace = f"{composite_workspace}_{self.namespace}"
if self.workspace:
# Include workspace in the file path for data isolation
workspace_dir = os.path.join(working_dir, self.workspace)
self.final_namespace = f"{self.workspace}_{self.namespace}"
else:
# Default behavior when workspace is empty
self.final_namespace = self.namespace
self.workspace = ""
self.workspace = "_"
workspace_dir = working_dir
os.makedirs(workspace_dir, exist_ok=True)
@ -71,9 +65,13 @@ class NanoVectorDBStorage(BaseVectorStorage):
async def initialize(self):
"""Initialize storage data"""
# Get the update flag for cross-process update notification
self.storage_updated = await get_update_flag(self.final_namespace)
self.storage_updated = await get_update_flag(
self.namespace, workspace=self.workspace
)
# Get the storage lock for use in other methods
self._storage_lock = get_storage_lock(enable_logging=False)
self._storage_lock = get_namespace_lock(
self.namespace, workspace=self.workspace
)
async def _get_client(self):
"""Check if the storage should be reloaded"""
@ -190,9 +188,17 @@ class NanoVectorDBStorage(BaseVectorStorage):
"""
try:
client = await self._get_client()
# Record count before deletion
before_count = len(client)
client.delete(ids)
# Calculate actual deleted count
after_count = len(client)
deleted_count = before_count - after_count
logger.debug(
f"[{self.workspace}] Successfully deleted {len(ids)} vectors from {self.namespace}"
f"[{self.workspace}] Successfully deleted {deleted_count} vectors from {self.namespace}"
)
except Exception as e:
logger.error(
@ -286,7 +292,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
# Save data to disk
self._client.save()
# Notify other processes that data has been updated
await set_all_update_flags(self.final_namespace)
await set_all_update_flags(self.namespace, workspace=self.workspace)
# Reset own update flag to avoid self-reloading
self.storage_updated.value = False
return True # Return success
@ -408,7 +414,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
)
# Notify other processes that data has been updated
await set_all_update_flags(self.final_namespace)
await set_all_update_flags(self.namespace, workspace=self.workspace)
# Reset own update flag to avoid self-reloading
self.storage_updated.value = False

View file

@ -5,10 +5,9 @@ from typing import final
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
from lightrag.utils import logger
from lightrag.base import BaseGraphStorage
from lightrag.constants import GRAPH_FIELD_SEP
import networkx as nx
from .shared_storage import (
get_storage_lock,
get_namespace_lock,
get_update_flag,
set_all_update_flags,
)
@ -39,21 +38,15 @@ class NetworkXStorage(BaseGraphStorage):
def __post_init__(self):
working_dir = self.global_config["working_dir"]
# Get composite workspace (supports multi-tenant isolation)
composite_workspace = self._get_composite_workspace()
if composite_workspace and composite_workspace != "_":
# Include composite workspace in the file path for data isolation
# For multi-tenant: tenant_id:kb_id:workspace
# For single-tenant: just workspace
workspace_dir = os.path.join(working_dir, composite_workspace)
self.final_namespace = f"{composite_workspace}_{self.namespace}"
if self.workspace:
# Include workspace in the file path for data isolation
workspace_dir = os.path.join(working_dir, self.workspace)
self.final_namespace = f"{self.workspace}_{self.namespace}"
else:
# Default behavior when workspace is empty
self.final_namespace = self.namespace
workspace_dir = working_dir
self.workspace = ""
self.workspace = "_"
os.makedirs(workspace_dir, exist_ok=True)
self._graphml_xml_file = os.path.join(
@ -78,9 +71,13 @@ class NetworkXStorage(BaseGraphStorage):
async def initialize(self):
"""Initialize storage data"""
# Get the update flag for cross-process update notification
self.storage_updated = await get_update_flag(self.final_namespace)
self.storage_updated = await get_update_flag(
self.namespace, workspace=self.workspace
)
# Get the storage lock for use in other methods
self._storage_lock = get_storage_lock()
self._storage_lock = get_namespace_lock(
self.namespace, workspace=self.workspace
)
async def _get_graph(self):
"""Check if the storage should be reloaded"""
@ -476,33 +473,6 @@ class NetworkXStorage(BaseGraphStorage):
)
return result
async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
chunk_ids_set = set(chunk_ids)
graph = await self._get_graph()
matching_nodes = []
for node_id, node_data in graph.nodes(data=True):
if "source_id" in node_data:
node_source_ids = set(node_data["source_id"].split(GRAPH_FIELD_SEP))
if not node_source_ids.isdisjoint(chunk_ids_set):
node_data_with_id = node_data.copy()
node_data_with_id["id"] = node_id
matching_nodes.append(node_data_with_id)
return matching_nodes
async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
chunk_ids_set = set(chunk_ids)
graph = await self._get_graph()
matching_edges = []
for u, v, edge_data in graph.edges(data=True):
if "source_id" in edge_data:
edge_source_ids = set(edge_data["source_id"].split(GRAPH_FIELD_SEP))
if not edge_source_ids.isdisjoint(chunk_ids_set):
edge_data_with_nodes = edge_data.copy()
edge_data_with_nodes["source"] = u
edge_data_with_nodes["target"] = v
matching_edges.append(edge_data_with_nodes)
return matching_edges
async def get_all_nodes(self) -> list[dict]:
"""Get all nodes in the graph.
@ -556,7 +526,7 @@ class NetworkXStorage(BaseGraphStorage):
self._graph, self._graphml_xml_file, self.workspace
)
# Notify other processes that data has been updated
await set_all_update_flags(self.final_namespace)
await set_all_update_flags(self.namespace, workspace=self.workspace)
# Reset own update flag to avoid self-reloading
self.storage_updated.value = False
return True # Return success
@ -587,7 +557,7 @@ class NetworkXStorage(BaseGraphStorage):
os.remove(self._graphml_xml_file)
self._graph = nx.Graph()
# Notify other processes that data has been updated
await set_all_update_flags(self.final_namespace)
await set_all_update_flags(self.namespace, workspace=self.workspace)
# Reset own update flag to avoid self-reloading
self.storage_updated.value = False
logger.info(