Fix: Resolve workspace isolation issues in in-memory database with multiple LightRAG instances

This commit is contained in:
yangdx 2025-08-12 01:26:09 +08:00
parent 095e0cbfa2
commit d9c1f935f5
3 changed files with 135 additions and 81 deletions

View file

@ -42,15 +42,18 @@ class FaissVectorDBStorage(BaseVectorStorage):
if self.workspace:
# Include workspace in the file path for data isolation
workspace_dir = os.path.join(working_dir, self.workspace)
os.makedirs(workspace_dir, exist_ok=True)
self._faiss_index_file = os.path.join(
workspace_dir, f"faiss_index_{self.namespace}.index"
)
self.final_namespace = f"{self.workspace}_{self.namespace}"
else:
# Default behavior when workspace is empty
self._faiss_index_file = os.path.join(
working_dir, f"faiss_index_{self.namespace}.index"
)
self.final_namespace = self.namespace
self.workspace = "_"
workspace_dir = working_dir
os.makedirs(workspace_dir, exist_ok=True)
self._faiss_index_file = os.path.join(
workspace_dir, f"faiss_index_{self.namespace}.index"
)
self._meta_file = self._faiss_index_file + ".meta.json"
self._max_batch_size = self.global_config["embedding_batch_num"]
@ -70,7 +73,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
async def initialize(self):
"""Initialize storage data"""
# Get the update flag for cross-process update notification
self.storage_updated = await get_update_flag(self.namespace)
self.storage_updated = await get_update_flag(self.final_namespace)
# Get the storage lock for use in other methods
self._storage_lock = get_storage_lock()
@ -81,7 +84,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
# Check if storage was updated by another process
if self.storage_updated.value:
logger.info(
f"Process {os.getpid()} FAISS reloading {self.namespace} due to update by another process"
f"[{self.workspace}] Process {os.getpid()} FAISS reloading {self.namespace} due to update by another process"
)
# Reload data
self._index = faiss.IndexFlatIP(self._dim)
@ -106,7 +109,9 @@ class FaissVectorDBStorage(BaseVectorStorage):
...
}
"""
logger.debug(f"FAISS: Inserting {len(data)} to {self.namespace}")
logger.debug(
f"[{self.workspace}] FAISS: Inserting {len(data)} to {self.namespace}"
)
if not data:
return
@ -136,7 +141,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
embeddings = np.concatenate(embeddings_list, axis=0)
if len(embeddings) != len(list_data):
logger.error(
f"Embedding size mismatch. Embeddings: {len(embeddings)}, Data: {len(list_data)}"
f"[{self.workspace}] Embedding size mismatch. Embeddings: {len(embeddings)}, Data: {len(list_data)}"
)
return []
@ -169,7 +174,9 @@ class FaissVectorDBStorage(BaseVectorStorage):
meta["__vector__"] = embeddings[i].tolist()
self._id_to_meta.update({fid: meta})
logger.debug(f"Upserted {len(list_data)} vectors into Faiss index.")
logger.debug(
f"[{self.workspace}] Upserted {len(list_data)} vectors into Faiss index."
)
return [m["__id__"] for m in list_data]
async def query(
@ -228,7 +235,9 @@ class FaissVectorDBStorage(BaseVectorStorage):
2. Only one process should updating the storage at a time before index_done_callback,
KG-storage-log should be used to avoid data corruption
"""
logger.debug(f"Deleting {len(ids)} vectors from {self.namespace}")
logger.debug(
f"[{self.workspace}] Deleting {len(ids)} vectors from {self.namespace}"
)
to_remove = []
for cid in ids:
fid = self._find_faiss_id_by_custom_id(cid)
@ -238,7 +247,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
if to_remove:
await self._remove_faiss_ids(to_remove)
logger.debug(
f"Successfully deleted {len(to_remove)} vectors from {self.namespace}"
f"[{self.workspace}] Successfully deleted {len(to_remove)} vectors from {self.namespace}"
)
async def delete_entity(self, entity_name: str) -> None:
@ -249,7 +258,9 @@ class FaissVectorDBStorage(BaseVectorStorage):
KG-storage-log should be used to avoid data corruption
"""
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
logger.debug(f"Attempting to delete entity {entity_name} with ID {entity_id}")
logger.debug(
f"[{self.workspace}] Attempting to delete entity {entity_name} with ID {entity_id}"
)
await self.delete([entity_id])
async def delete_entity_relation(self, entity_name: str) -> None:
@ -259,16 +270,20 @@ class FaissVectorDBStorage(BaseVectorStorage):
2. Only one process should updating the storage at a time before index_done_callback,
KG-storage-log should be used to avoid data corruption
"""
logger.debug(f"Searching relations for entity {entity_name}")
logger.debug(f"[{self.workspace}] Searching relations for entity {entity_name}")
relations = []
for fid, meta in self._id_to_meta.items():
if meta.get("src_id") == entity_name or meta.get("tgt_id") == entity_name:
relations.append(fid)
logger.debug(f"Found {len(relations)} relations for {entity_name}")
logger.debug(
f"[{self.workspace}] Found {len(relations)} relations for {entity_name}"
)
if relations:
await self._remove_faiss_ids(relations)
logger.debug(f"Deleted {len(relations)} relations for {entity_name}")
logger.debug(
f"[{self.workspace}] Deleted {len(relations)} relations for {entity_name}"
)
# --------------------------------------------------------------------------------
# Internal helper methods
@ -330,7 +345,9 @@ class FaissVectorDBStorage(BaseVectorStorage):
and rebuild in-memory structures so we can query.
"""
if not os.path.exists(self._faiss_index_file):
logger.warning(f"No existing Faiss index file found for {self.namespace}")
logger.warning(
f"[{self.workspace}] No existing Faiss index file found for {self.namespace}"
)
return
try:
@ -347,11 +364,13 @@ class FaissVectorDBStorage(BaseVectorStorage):
self._id_to_meta[fid] = meta
logger.info(
f"Faiss index loaded with {self._index.ntotal} vectors from {self._faiss_index_file}"
f"[{self.workspace}] Faiss index loaded with {self._index.ntotal} vectors from {self._faiss_index_file}"
)
except Exception as e:
logger.error(f"Failed to load Faiss index or metadata: {e}")
logger.warning("Starting with an empty Faiss index.")
logger.error(
f"[{self.workspace}] Failed to load Faiss index or metadata: {e}"
)
logger.warning(f"[{self.workspace}] Starting with an empty Faiss index.")
self._index = faiss.IndexFlatIP(self._dim)
self._id_to_meta = {}
@ -361,7 +380,7 @@ class FaissVectorDBStorage(BaseVectorStorage):
if self.storage_updated.value:
# Storage was updated by another process, reload data instead of saving
logger.warning(
f"Storage for FAISS {self.namespace} was updated by another process, reloading..."
f"[{self.workspace}] Storage for FAISS {self.namespace} was updated by another process, reloading..."
)
self._index = faiss.IndexFlatIP(self._dim)
self._id_to_meta = {}
@ -375,11 +394,13 @@ class FaissVectorDBStorage(BaseVectorStorage):
# Save data to disk
self._save_faiss_index()
# Notify other processes that data has been updated
await set_all_update_flags(self.namespace)
await set_all_update_flags(self.final_namespace)
# Reset own update flag to avoid self-reloading
self.storage_updated.value = False
except Exception as e:
logger.error(f"Error saving FAISS index for {self.namespace}: {e}")
logger.error(
f"[{self.workspace}] Error saving FAISS index for {self.namespace}: {e}"
)
return False # Return error
return True # Return success
@ -469,11 +490,15 @@ class FaissVectorDBStorage(BaseVectorStorage):
self._load_faiss_index()
# Notify other processes
await set_all_update_flags(self.namespace)
await set_all_update_flags(self.final_namespace)
self.storage_updated.value = False
logger.info(f"Process {os.getpid()} drop FAISS index {self.namespace}")
logger.info(
f"[{self.workspace}] Process {os.getpid()} drop FAISS index {self.namespace}"
)
return {"status": "success", "message": "data dropped"}
except Exception as e:
logger.error(f"Error dropping FAISS index {self.namespace}: {e}")
logger.error(
f"[{self.workspace}] Error dropping FAISS index {self.namespace}: {e}"
)
return {"status": "error", "message": str(e)}

View file

@ -41,15 +41,18 @@ class NanoVectorDBStorage(BaseVectorStorage):
if self.workspace:
# Include workspace in the file path for data isolation
workspace_dir = os.path.join(working_dir, self.workspace)
os.makedirs(workspace_dir, exist_ok=True)
self._client_file_name = os.path.join(
workspace_dir, f"vdb_{self.namespace}.json"
)
self.final_namespace = f"{self.workspace}_{self.namespace}"
else:
# Default behavior when workspace is empty
self._client_file_name = os.path.join(
working_dir, f"vdb_{self.namespace}.json"
)
self.final_namespace = self.namespace
self.workspace = "_"
workspace_dir = working_dir
os.makedirs(workspace_dir, exist_ok=True)
self._client_file_name = os.path.join(
workspace_dir, f"vdb_{self.namespace}.json"
)
self._max_batch_size = self.global_config["embedding_batch_num"]
self._client = NanoVectorDB(
@ -60,7 +63,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
async def initialize(self):
"""Initialize storage data"""
# Get the update flag for cross-process update notification
self.storage_updated = await get_update_flag(self.namespace)
self.storage_updated = await get_update_flag(self.final_namespace)
# Get the storage lock for use in other methods
self._storage_lock = get_storage_lock(enable_logging=False)
@ -71,7 +74,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
# Check if data needs to be reloaded
if self.storage_updated.value:
logger.info(
f"Process {os.getpid()} reloading {self.namespace} due to update by another process"
f"[{self.workspace}] Process {os.getpid()} reloading {self.namespace} due to update by another process"
)
# Reload data
self._client = NanoVectorDB(
@ -91,7 +94,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
KG-storage-log should be used to avoid data corruption
"""
logger.debug(f"Inserting {len(data)} to {self.namespace}")
logger.debug(f"[{self.workspace}] Inserting {len(data)} to {self.namespace}")
if not data:
return
@ -124,7 +127,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
else:
# sometimes the embedding is not returned correctly. just log it.
logger.error(
f"embedding is not 1-1 with data, {len(embeddings)} != {len(list_data)}"
f"[{self.workspace}] embedding is not 1-1 with data, {len(embeddings)} != {len(list_data)}"
)
async def query(
@ -173,10 +176,12 @@ class NanoVectorDBStorage(BaseVectorStorage):
client = await self._get_client()
client.delete(ids)
logger.debug(
f"Successfully deleted {len(ids)} vectors from {self.namespace}"
f"[{self.workspace}] Successfully deleted {len(ids)} vectors from {self.namespace}"
)
except Exception as e:
logger.error(f"Error while deleting vectors from {self.namespace}: {e}")
logger.error(
f"[{self.workspace}] Error while deleting vectors from {self.namespace}: {e}"
)
async def delete_entity(self, entity_name: str) -> None:
"""
@ -189,18 +194,22 @@ class NanoVectorDBStorage(BaseVectorStorage):
try:
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
logger.debug(
f"Attempting to delete entity {entity_name} with ID {entity_id}"
f"[{self.workspace}] Attempting to delete entity {entity_name} with ID {entity_id}"
)
# Check if the entity exists
client = await self._get_client()
if client.get([entity_id]):
client.delete([entity_id])
logger.debug(f"Successfully deleted entity {entity_name}")
logger.debug(
f"[{self.workspace}] Successfully deleted entity {entity_name}"
)
else:
logger.debug(f"Entity {entity_name} not found in storage")
logger.debug(
f"[{self.workspace}] Entity {entity_name} not found in storage"
)
except Exception as e:
logger.error(f"Error deleting entity {entity_name}: {e}")
logger.error(f"[{self.workspace}] Error deleting entity {entity_name}: {e}")
async def delete_entity_relation(self, entity_name: str) -> None:
"""
@ -218,19 +227,25 @@ class NanoVectorDBStorage(BaseVectorStorage):
for dp in storage["data"]
if dp["src_id"] == entity_name or dp["tgt_id"] == entity_name
]
logger.debug(f"Found {len(relations)} relations for entity {entity_name}")
logger.debug(
f"[{self.workspace}] Found {len(relations)} relations for entity {entity_name}"
)
ids_to_delete = [relation["__id__"] for relation in relations]
if ids_to_delete:
client = await self._get_client()
client.delete(ids_to_delete)
logger.debug(
f"Deleted {len(ids_to_delete)} relations for {entity_name}"
f"[{self.workspace}] Deleted {len(ids_to_delete)} relations for {entity_name}"
)
else:
logger.debug(f"No relations found for entity {entity_name}")
logger.debug(
f"[{self.workspace}] No relations found for entity {entity_name}"
)
except Exception as e:
logger.error(f"Error deleting relations for {entity_name}: {e}")
logger.error(
f"[{self.workspace}] Error deleting relations for {entity_name}: {e}"
)
async def index_done_callback(self) -> bool:
"""Save data to disk"""
@ -239,7 +254,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
if self.storage_updated.value:
# Storage was updated by another process, reload data instead of saving
logger.warning(
f"Storage for {self.namespace} was updated by another process, reloading..."
f"[{self.workspace}] Storage for {self.namespace} was updated by another process, reloading..."
)
self._client = NanoVectorDB(
self.embedding_func.embedding_dim,
@ -255,12 +270,14 @@ class NanoVectorDBStorage(BaseVectorStorage):
# Save data to disk
self._client.save()
# Notify other processes that data has been updated
await set_all_update_flags(self.namespace)
await set_all_update_flags(self.final_namespace)
# Reset own update flag to avoid self-reloading
self.storage_updated.value = False
return True # Return success
except Exception as e:
logger.error(f"Error saving data for {self.namespace}: {e}")
logger.error(
f"[{self.workspace}] Error saving data for {self.namespace}: {e}"
)
return False # Return error
return True # Return success
@ -336,14 +353,14 @@ class NanoVectorDBStorage(BaseVectorStorage):
)
# Notify other processes that data has been updated
await set_all_update_flags(self.namespace)
await set_all_update_flags(self.final_namespace)
# Reset own update flag to avoid self-reloading
self.storage_updated.value = False
logger.info(
f"Process {os.getpid()} drop {self.namespace}(file:{self._client_file_name})"
f"[{self.workspace}] Process {os.getpid()} drop {self.namespace}(file:{self._client_file_name})"
)
return {"status": "success", "message": "data dropped"}
except Exception as e:
logger.error(f"Error dropping {self.namespace}: {e}")
logger.error(f"[{self.workspace}] Error dropping {self.namespace}: {e}")
return {"status": "error", "message": str(e)}

View file

@ -31,9 +31,9 @@ class NetworkXStorage(BaseGraphStorage):
return None
@staticmethod
def write_nx_graph(graph: nx.Graph, file_name):
def write_nx_graph(graph: nx.Graph, file_name, workspace="_"):
logger.info(
f"Writing graph with {graph.number_of_nodes()} nodes, {graph.number_of_edges()} edges"
f"[{workspace}] Writing graph with {graph.number_of_nodes()} nodes, {graph.number_of_edges()} edges"
)
nx.write_graphml(graph, file_name)
@ -42,15 +42,17 @@ class NetworkXStorage(BaseGraphStorage):
if self.workspace:
# Include workspace in the file path for data isolation
workspace_dir = os.path.join(working_dir, self.workspace)
os.makedirs(workspace_dir, exist_ok=True)
self._graphml_xml_file = os.path.join(
workspace_dir, f"graph_{self.namespace}.graphml"
)
self.final_namespace = f"{self.workspace}_{self.namespace}"
else:
# Default behavior when workspace is empty
self._graphml_xml_file = os.path.join(
working_dir, f"graph_{self.namespace}.graphml"
)
self.final_namespace = self.namespace
workspace_dir = working_dir
self.workspace = "_"
os.makedirs(workspace_dir, exist_ok=True)
self._graphml_xml_file = os.path.join(
working_dir, f"graph_{self.namespace}.graphml"
)
self._storage_lock = None
self.storage_updated = None
self._graph = None
@ -59,16 +61,18 @@ class NetworkXStorage(BaseGraphStorage):
preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
if preloaded_graph is not None:
logger.info(
f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
f"[{self.workspace}] Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
)
else:
logger.info("Created new empty graph")
logger.info(
f"[{self.workspace}] Created new empty graph fiel: {self._graphml_xml_file}"
)
self._graph = preloaded_graph or nx.Graph()
async def initialize(self):
"""Initialize storage data"""
# Get the update flag for cross-process update notification
self.storage_updated = await get_update_flag(self.namespace)
self.storage_updated = await get_update_flag(self.final_namespace)
# Get the storage lock for use in other methods
self._storage_lock = get_storage_lock()
@ -79,7 +83,7 @@ class NetworkXStorage(BaseGraphStorage):
# Check if data needs to be reloaded
if self.storage_updated.value:
logger.info(
f"Process {os.getpid()} reloading graph {self.namespace} due to update by another process"
f"[{self.workspace}] Process {os.getpid()} reloading graph {self._graphml_xml_file} due to modifications by another process"
)
# Reload data
self._graph = (
@ -156,9 +160,11 @@ class NetworkXStorage(BaseGraphStorage):
graph = await self._get_graph()
if graph.has_node(node_id):
graph.remove_node(node_id)
logger.debug(f"Node {node_id} deleted from the graph.")
logger.debug(f"[{self.workspace}] Node {node_id} deleted from the graph")
else:
logger.warning(f"Node {node_id} not found in the graph for deletion.")
logger.warning(
f"[{self.workspace}] Node {node_id} not found in the graph for deletion"
)
async def remove_nodes(self, nodes: list[str]):
"""Delete multiple nodes
@ -246,7 +252,7 @@ class NetworkXStorage(BaseGraphStorage):
if len(sorted_nodes) > max_nodes:
result.is_truncated = True
logger.info(
f"Graph truncated: {len(sorted_nodes)} nodes found, limited to {max_nodes}"
f"[{self.workspace}] Graph truncated: {len(sorted_nodes)} nodes found, limited to {max_nodes}"
)
limited_nodes = [node for node, _ in sorted_nodes[:max_nodes]]
@ -255,7 +261,9 @@ class NetworkXStorage(BaseGraphStorage):
else:
# Check if node exists
if node_label not in graph:
logger.warning(f"Node {node_label} not found in the graph")
logger.warning(
f"[{self.workspace}] Node {node_label} not found in the graph"
)
return KnowledgeGraph() # Return empty graph
# Use modified BFS to get nodes, prioritizing high-degree nodes at the same depth
@ -305,7 +313,7 @@ class NetworkXStorage(BaseGraphStorage):
if queue and len(bfs_nodes) >= max_nodes:
result.is_truncated = True
logger.info(
f"Graph truncated: breadth-first search limited to {max_nodes} nodes"
f"[{self.workspace}] Graph truncated: breadth-first search limited to {max_nodes} nodes"
)
# Create subgraph with BFS discovered nodes
@ -362,7 +370,7 @@ class NetworkXStorage(BaseGraphStorage):
seen_edges.add(edge_id)
logger.info(
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
f"[{self.workspace}] Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
)
return result
@ -429,7 +437,7 @@ class NetworkXStorage(BaseGraphStorage):
if self.storage_updated.value:
# Storage was updated by another process, reload data instead of saving
logger.info(
f"Graph for {self.namespace} was updated by another process, reloading..."
f"[{self.workspace}] Graph was updated by another process, reloading..."
)
self._graph = (
NetworkXStorage.load_nx_graph(self._graphml_xml_file) or nx.Graph()
@ -442,14 +450,16 @@ class NetworkXStorage(BaseGraphStorage):
async with self._storage_lock:
try:
# Save data to disk
NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file)
NetworkXStorage.write_nx_graph(
self._graph, self._graphml_xml_file, self.workspace
)
# Notify other processes that data has been updated
await set_all_update_flags(self.namespace)
await set_all_update_flags(self.final_namespace)
# Reset own update flag to avoid self-reloading
self.storage_updated.value = False
return True # Return success
except Exception as e:
logger.error(f"Error saving graph for {self.namespace}: {e}")
logger.error(f"[{self.workspace}] Error saving graph: {e}")
return False # Return error
return True
@ -475,13 +485,15 @@ class NetworkXStorage(BaseGraphStorage):
os.remove(self._graphml_xml_file)
self._graph = nx.Graph()
# Notify other processes that data has been updated
await set_all_update_flags(self.namespace)
await set_all_update_flags(self.final_namespace)
# Reset own update flag to avoid self-reloading
self.storage_updated.value = False
logger.info(
f"Process {os.getpid()} drop graph {self.namespace} (file:{self._graphml_xml_file})"
f"[{self.workspace}] Process {os.getpid()} drop graph file:{self._graphml_xml_file}"
)
return {"status": "success", "message": "data dropped"}
except Exception as e:
logger.error(f"Error dropping graph {self.namespace}: {e}")
logger.error(
f"[{self.workspace}] Error dropping graph file:{self._graphml_xml_file}: {e}"
)
return {"status": "error", "message": str(e)}