Refactor storage classes to use namespace instead of final_namespace

(cherry picked from commit fd486bc922)
This commit is contained in:
yangdx 2025-11-17 05:07:53 +08:00 committed by Raphaël MANSUY
parent ed46d375fb
commit 5bd1320a1d
2 changed files with 37 additions and 61 deletions

View file

@ -15,7 +15,7 @@ from lightrag.utils import (
from lightrag.base import BaseVectorStorage from lightrag.base import BaseVectorStorage
from nano_vectordb import NanoVectorDB from nano_vectordb import NanoVectorDB
from .shared_storage import ( from .shared_storage import (
get_storage_lock, get_namespace_lock,
get_update_flag, get_update_flag,
set_all_update_flags, set_all_update_flags,
) )
@ -40,20 +40,14 @@ class NanoVectorDBStorage(BaseVectorStorage):
self.cosine_better_than_threshold = cosine_threshold self.cosine_better_than_threshold = cosine_threshold
working_dir = self.global_config["working_dir"] working_dir = self.global_config["working_dir"]
if self.workspace:
# Get composite workspace (supports multi-tenant isolation) # Include workspace in the file path for data isolation
composite_workspace = self._get_composite_workspace() workspace_dir = os.path.join(working_dir, self.workspace)
self.final_namespace = f"{self.workspace}_{self.namespace}"
if composite_workspace and composite_workspace != "_":
# Include composite workspace in the file path for data isolation
# For multi-tenant: tenant_id:kb_id:workspace
# For single-tenant: just workspace
workspace_dir = os.path.join(working_dir, composite_workspace)
self.final_namespace = f"{composite_workspace}_{self.namespace}"
else: else:
# Default behavior when workspace is empty # Default behavior when workspace is empty
self.final_namespace = self.namespace self.final_namespace = self.namespace
self.workspace = "" self.workspace = "_"
workspace_dir = working_dir workspace_dir = working_dir
os.makedirs(workspace_dir, exist_ok=True) os.makedirs(workspace_dir, exist_ok=True)
@ -71,9 +65,13 @@ class NanoVectorDBStorage(BaseVectorStorage):
async def initialize(self): async def initialize(self):
"""Initialize storage data""" """Initialize storage data"""
# Get the update flag for cross-process update notification # Get the update flag for cross-process update notification
self.storage_updated = await get_update_flag(self.final_namespace) self.storage_updated = await get_update_flag(
self.namespace, workspace=self.workspace
)
# Get the storage lock for use in other methods # Get the storage lock for use in other methods
self._storage_lock = get_storage_lock(enable_logging=False) self._storage_lock = get_namespace_lock(
self.namespace, workspace=self.workspace
)
async def _get_client(self): async def _get_client(self):
"""Check if the storage should be reloaded""" """Check if the storage should be reloaded"""
@ -190,9 +188,17 @@ class NanoVectorDBStorage(BaseVectorStorage):
""" """
try: try:
client = await self._get_client() client = await self._get_client()
# Record count before deletion
before_count = len(client)
client.delete(ids) client.delete(ids)
# Calculate actual deleted count
after_count = len(client)
deleted_count = before_count - after_count
logger.debug( logger.debug(
f"[{self.workspace}] Successfully deleted {len(ids)} vectors from {self.namespace}" f"[{self.workspace}] Successfully deleted {deleted_count} vectors from {self.namespace}"
) )
except Exception as e: except Exception as e:
logger.error( logger.error(
@ -286,7 +292,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
# Save data to disk # Save data to disk
self._client.save() self._client.save()
# Notify other processes that data has been updated # Notify other processes that data has been updated
await set_all_update_flags(self.final_namespace) await set_all_update_flags(self.namespace, workspace=self.workspace)
# Reset own update flag to avoid self-reloading # Reset own update flag to avoid self-reloading
self.storage_updated.value = False self.storage_updated.value = False
return True # Return success return True # Return success
@ -408,7 +414,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
) )
# Notify other processes that data has been updated # Notify other processes that data has been updated
await set_all_update_flags(self.final_namespace) await set_all_update_flags(self.namespace, workspace=self.workspace)
# Reset own update flag to avoid self-reloading # Reset own update flag to avoid self-reloading
self.storage_updated.value = False self.storage_updated.value = False

View file

@ -5,10 +5,9 @@ from typing import final
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
from lightrag.utils import logger from lightrag.utils import logger
from lightrag.base import BaseGraphStorage from lightrag.base import BaseGraphStorage
from lightrag.constants import GRAPH_FIELD_SEP
import networkx as nx import networkx as nx
from .shared_storage import ( from .shared_storage import (
get_storage_lock, get_namespace_lock,
get_update_flag, get_update_flag,
set_all_update_flags, set_all_update_flags,
) )
@ -39,21 +38,15 @@ class NetworkXStorage(BaseGraphStorage):
def __post_init__(self): def __post_init__(self):
working_dir = self.global_config["working_dir"] working_dir = self.global_config["working_dir"]
if self.workspace:
# Get composite workspace (supports multi-tenant isolation) # Include workspace in the file path for data isolation
composite_workspace = self._get_composite_workspace() workspace_dir = os.path.join(working_dir, self.workspace)
self.final_namespace = f"{self.workspace}_{self.namespace}"
if composite_workspace and composite_workspace != "_":
# Include composite workspace in the file path for data isolation
# For multi-tenant: tenant_id:kb_id:workspace
# For single-tenant: just workspace
workspace_dir = os.path.join(working_dir, composite_workspace)
self.final_namespace = f"{composite_workspace}_{self.namespace}"
else: else:
# Default behavior when workspace is empty # Default behavior when workspace is empty
self.final_namespace = self.namespace self.final_namespace = self.namespace
workspace_dir = working_dir workspace_dir = working_dir
self.workspace = "" self.workspace = "_"
os.makedirs(workspace_dir, exist_ok=True) os.makedirs(workspace_dir, exist_ok=True)
self._graphml_xml_file = os.path.join( self._graphml_xml_file = os.path.join(
@ -78,9 +71,13 @@ class NetworkXStorage(BaseGraphStorage):
async def initialize(self): async def initialize(self):
"""Initialize storage data""" """Initialize storage data"""
# Get the update flag for cross-process update notification # Get the update flag for cross-process update notification
self.storage_updated = await get_update_flag(self.final_namespace) self.storage_updated = await get_update_flag(
self.namespace, workspace=self.workspace
)
# Get the storage lock for use in other methods # Get the storage lock for use in other methods
self._storage_lock = get_storage_lock() self._storage_lock = get_namespace_lock(
self.namespace, workspace=self.workspace
)
async def _get_graph(self): async def _get_graph(self):
"""Check if the storage should be reloaded""" """Check if the storage should be reloaded"""
@ -476,33 +473,6 @@ class NetworkXStorage(BaseGraphStorage):
) )
return result return result
async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
chunk_ids_set = set(chunk_ids)
graph = await self._get_graph()
matching_nodes = []
for node_id, node_data in graph.nodes(data=True):
if "source_id" in node_data:
node_source_ids = set(node_data["source_id"].split(GRAPH_FIELD_SEP))
if not node_source_ids.isdisjoint(chunk_ids_set):
node_data_with_id = node_data.copy()
node_data_with_id["id"] = node_id
matching_nodes.append(node_data_with_id)
return matching_nodes
async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
chunk_ids_set = set(chunk_ids)
graph = await self._get_graph()
matching_edges = []
for u, v, edge_data in graph.edges(data=True):
if "source_id" in edge_data:
edge_source_ids = set(edge_data["source_id"].split(GRAPH_FIELD_SEP))
if not edge_source_ids.isdisjoint(chunk_ids_set):
edge_data_with_nodes = edge_data.copy()
edge_data_with_nodes["source"] = u
edge_data_with_nodes["target"] = v
matching_edges.append(edge_data_with_nodes)
return matching_edges
async def get_all_nodes(self) -> list[dict]: async def get_all_nodes(self) -> list[dict]:
"""Get all nodes in the graph. """Get all nodes in the graph.
@ -556,7 +526,7 @@ class NetworkXStorage(BaseGraphStorage):
self._graph, self._graphml_xml_file, self.workspace self._graph, self._graphml_xml_file, self.workspace
) )
# Notify other processes that data has been updated # Notify other processes that data has been updated
await set_all_update_flags(self.final_namespace) await set_all_update_flags(self.namespace, workspace=self.workspace)
# Reset own update flag to avoid self-reloading # Reset own update flag to avoid self-reloading
self.storage_updated.value = False self.storage_updated.value = False
return True # Return success return True # Return success
@ -587,7 +557,7 @@ class NetworkXStorage(BaseGraphStorage):
os.remove(self._graphml_xml_file) os.remove(self._graphml_xml_file)
self._graph = nx.Graph() self._graph = nx.Graph()
# Notify other processes that data has been updated # Notify other processes that data has been updated
await set_all_update_flags(self.final_namespace) await set_all_update_flags(self.namespace, workspace=self.workspace)
# Reset own update flag to avoid self-reloading # Reset own update flag to avoid self-reloading
self.storage_updated.value = False self.storage_updated.value = False
logger.info( logger.info(