Refactor storage classes to use namespace instead of final_namespace
(cherry picked from commit fd486bc922)
This commit is contained in:
parent
ed46d375fb
commit
5bd1320a1d
2 changed files with 37 additions and 61 deletions
|
|
@ -15,7 +15,7 @@ from lightrag.utils import (
|
||||||
from lightrag.base import BaseVectorStorage
|
from lightrag.base import BaseVectorStorage
|
||||||
from nano_vectordb import NanoVectorDB
|
from nano_vectordb import NanoVectorDB
|
||||||
from .shared_storage import (
|
from .shared_storage import (
|
||||||
get_storage_lock,
|
get_namespace_lock,
|
||||||
get_update_flag,
|
get_update_flag,
|
||||||
set_all_update_flags,
|
set_all_update_flags,
|
||||||
)
|
)
|
||||||
|
|
@ -40,20 +40,14 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||||
self.cosine_better_than_threshold = cosine_threshold
|
self.cosine_better_than_threshold = cosine_threshold
|
||||||
|
|
||||||
working_dir = self.global_config["working_dir"]
|
working_dir = self.global_config["working_dir"]
|
||||||
|
if self.workspace:
|
||||||
# Get composite workspace (supports multi-tenant isolation)
|
# Include workspace in the file path for data isolation
|
||||||
composite_workspace = self._get_composite_workspace()
|
workspace_dir = os.path.join(working_dir, self.workspace)
|
||||||
|
self.final_namespace = f"{self.workspace}_{self.namespace}"
|
||||||
if composite_workspace and composite_workspace != "_":
|
|
||||||
# Include composite workspace in the file path for data isolation
|
|
||||||
# For multi-tenant: tenant_id:kb_id:workspace
|
|
||||||
# For single-tenant: just workspace
|
|
||||||
workspace_dir = os.path.join(working_dir, composite_workspace)
|
|
||||||
self.final_namespace = f"{composite_workspace}_{self.namespace}"
|
|
||||||
else:
|
else:
|
||||||
# Default behavior when workspace is empty
|
# Default behavior when workspace is empty
|
||||||
self.final_namespace = self.namespace
|
self.final_namespace = self.namespace
|
||||||
self.workspace = ""
|
self.workspace = "_"
|
||||||
workspace_dir = working_dir
|
workspace_dir = working_dir
|
||||||
|
|
||||||
os.makedirs(workspace_dir, exist_ok=True)
|
os.makedirs(workspace_dir, exist_ok=True)
|
||||||
|
|
@ -71,9 +65,13 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
"""Initialize storage data"""
|
"""Initialize storage data"""
|
||||||
# Get the update flag for cross-process update notification
|
# Get the update flag for cross-process update notification
|
||||||
self.storage_updated = await get_update_flag(self.final_namespace)
|
self.storage_updated = await get_update_flag(
|
||||||
|
self.namespace, workspace=self.workspace
|
||||||
|
)
|
||||||
# Get the storage lock for use in other methods
|
# Get the storage lock for use in other methods
|
||||||
self._storage_lock = get_storage_lock(enable_logging=False)
|
self._storage_lock = get_namespace_lock(
|
||||||
|
self.namespace, workspace=self.workspace
|
||||||
|
)
|
||||||
|
|
||||||
async def _get_client(self):
|
async def _get_client(self):
|
||||||
"""Check if the storage should be reloaded"""
|
"""Check if the storage should be reloaded"""
|
||||||
|
|
@ -190,9 +188,17 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
client = await self._get_client()
|
client = await self._get_client()
|
||||||
|
# Record count before deletion
|
||||||
|
before_count = len(client)
|
||||||
|
|
||||||
client.delete(ids)
|
client.delete(ids)
|
||||||
|
|
||||||
|
# Calculate actual deleted count
|
||||||
|
after_count = len(client)
|
||||||
|
deleted_count = before_count - after_count
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"[{self.workspace}] Successfully deleted {len(ids)} vectors from {self.namespace}"
|
f"[{self.workspace}] Successfully deleted {deleted_count} vectors from {self.namespace}"
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
|
|
@ -286,7 +292,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||||
# Save data to disk
|
# Save data to disk
|
||||||
self._client.save()
|
self._client.save()
|
||||||
# Notify other processes that data has been updated
|
# Notify other processes that data has been updated
|
||||||
await set_all_update_flags(self.final_namespace)
|
await set_all_update_flags(self.namespace, workspace=self.workspace)
|
||||||
# Reset own update flag to avoid self-reloading
|
# Reset own update flag to avoid self-reloading
|
||||||
self.storage_updated.value = False
|
self.storage_updated.value = False
|
||||||
return True # Return success
|
return True # Return success
|
||||||
|
|
@ -408,7 +414,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Notify other processes that data has been updated
|
# Notify other processes that data has been updated
|
||||||
await set_all_update_flags(self.final_namespace)
|
await set_all_update_flags(self.namespace, workspace=self.workspace)
|
||||||
# Reset own update flag to avoid self-reloading
|
# Reset own update flag to avoid self-reloading
|
||||||
self.storage_updated.value = False
|
self.storage_updated.value = False
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,10 +5,9 @@ from typing import final
|
||||||
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
||||||
from lightrag.utils import logger
|
from lightrag.utils import logger
|
||||||
from lightrag.base import BaseGraphStorage
|
from lightrag.base import BaseGraphStorage
|
||||||
from lightrag.constants import GRAPH_FIELD_SEP
|
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
from .shared_storage import (
|
from .shared_storage import (
|
||||||
get_storage_lock,
|
get_namespace_lock,
|
||||||
get_update_flag,
|
get_update_flag,
|
||||||
set_all_update_flags,
|
set_all_update_flags,
|
||||||
)
|
)
|
||||||
|
|
@ -39,21 +38,15 @@ class NetworkXStorage(BaseGraphStorage):
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
working_dir = self.global_config["working_dir"]
|
working_dir = self.global_config["working_dir"]
|
||||||
|
if self.workspace:
|
||||||
# Get composite workspace (supports multi-tenant isolation)
|
# Include workspace in the file path for data isolation
|
||||||
composite_workspace = self._get_composite_workspace()
|
workspace_dir = os.path.join(working_dir, self.workspace)
|
||||||
|
self.final_namespace = f"{self.workspace}_{self.namespace}"
|
||||||
if composite_workspace and composite_workspace != "_":
|
|
||||||
# Include composite workspace in the file path for data isolation
|
|
||||||
# For multi-tenant: tenant_id:kb_id:workspace
|
|
||||||
# For single-tenant: just workspace
|
|
||||||
workspace_dir = os.path.join(working_dir, composite_workspace)
|
|
||||||
self.final_namespace = f"{composite_workspace}_{self.namespace}"
|
|
||||||
else:
|
else:
|
||||||
# Default behavior when workspace is empty
|
# Default behavior when workspace is empty
|
||||||
self.final_namespace = self.namespace
|
self.final_namespace = self.namespace
|
||||||
workspace_dir = working_dir
|
workspace_dir = working_dir
|
||||||
self.workspace = ""
|
self.workspace = "_"
|
||||||
|
|
||||||
os.makedirs(workspace_dir, exist_ok=True)
|
os.makedirs(workspace_dir, exist_ok=True)
|
||||||
self._graphml_xml_file = os.path.join(
|
self._graphml_xml_file = os.path.join(
|
||||||
|
|
@ -78,9 +71,13 @@ class NetworkXStorage(BaseGraphStorage):
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
"""Initialize storage data"""
|
"""Initialize storage data"""
|
||||||
# Get the update flag for cross-process update notification
|
# Get the update flag for cross-process update notification
|
||||||
self.storage_updated = await get_update_flag(self.final_namespace)
|
self.storage_updated = await get_update_flag(
|
||||||
|
self.namespace, workspace=self.workspace
|
||||||
|
)
|
||||||
# Get the storage lock for use in other methods
|
# Get the storage lock for use in other methods
|
||||||
self._storage_lock = get_storage_lock()
|
self._storage_lock = get_namespace_lock(
|
||||||
|
self.namespace, workspace=self.workspace
|
||||||
|
)
|
||||||
|
|
||||||
async def _get_graph(self):
|
async def _get_graph(self):
|
||||||
"""Check if the storage should be reloaded"""
|
"""Check if the storage should be reloaded"""
|
||||||
|
|
@ -476,33 +473,6 @@ class NetworkXStorage(BaseGraphStorage):
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
|
|
||||||
chunk_ids_set = set(chunk_ids)
|
|
||||||
graph = await self._get_graph()
|
|
||||||
matching_nodes = []
|
|
||||||
for node_id, node_data in graph.nodes(data=True):
|
|
||||||
if "source_id" in node_data:
|
|
||||||
node_source_ids = set(node_data["source_id"].split(GRAPH_FIELD_SEP))
|
|
||||||
if not node_source_ids.isdisjoint(chunk_ids_set):
|
|
||||||
node_data_with_id = node_data.copy()
|
|
||||||
node_data_with_id["id"] = node_id
|
|
||||||
matching_nodes.append(node_data_with_id)
|
|
||||||
return matching_nodes
|
|
||||||
|
|
||||||
async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
|
|
||||||
chunk_ids_set = set(chunk_ids)
|
|
||||||
graph = await self._get_graph()
|
|
||||||
matching_edges = []
|
|
||||||
for u, v, edge_data in graph.edges(data=True):
|
|
||||||
if "source_id" in edge_data:
|
|
||||||
edge_source_ids = set(edge_data["source_id"].split(GRAPH_FIELD_SEP))
|
|
||||||
if not edge_source_ids.isdisjoint(chunk_ids_set):
|
|
||||||
edge_data_with_nodes = edge_data.copy()
|
|
||||||
edge_data_with_nodes["source"] = u
|
|
||||||
edge_data_with_nodes["target"] = v
|
|
||||||
matching_edges.append(edge_data_with_nodes)
|
|
||||||
return matching_edges
|
|
||||||
|
|
||||||
async def get_all_nodes(self) -> list[dict]:
|
async def get_all_nodes(self) -> list[dict]:
|
||||||
"""Get all nodes in the graph.
|
"""Get all nodes in the graph.
|
||||||
|
|
||||||
|
|
@ -556,7 +526,7 @@ class NetworkXStorage(BaseGraphStorage):
|
||||||
self._graph, self._graphml_xml_file, self.workspace
|
self._graph, self._graphml_xml_file, self.workspace
|
||||||
)
|
)
|
||||||
# Notify other processes that data has been updated
|
# Notify other processes that data has been updated
|
||||||
await set_all_update_flags(self.final_namespace)
|
await set_all_update_flags(self.namespace, workspace=self.workspace)
|
||||||
# Reset own update flag to avoid self-reloading
|
# Reset own update flag to avoid self-reloading
|
||||||
self.storage_updated.value = False
|
self.storage_updated.value = False
|
||||||
return True # Return success
|
return True # Return success
|
||||||
|
|
@ -587,7 +557,7 @@ class NetworkXStorage(BaseGraphStorage):
|
||||||
os.remove(self._graphml_xml_file)
|
os.remove(self._graphml_xml_file)
|
||||||
self._graph = nx.Graph()
|
self._graph = nx.Graph()
|
||||||
# Notify other processes that data has been updated
|
# Notify other processes that data has been updated
|
||||||
await set_all_update_flags(self.final_namespace)
|
await set_all_update_flags(self.namespace, workspace=self.workspace)
|
||||||
# Reset own update flag to avoid self-reloading
|
# Reset own update flag to avoid self-reloading
|
||||||
self.storage_updated.value = False
|
self.storage_updated.value = False
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue