diff --git a/env.example b/env.example index 537e5c81..be054576 100644 --- a/env.example +++ b/env.example @@ -252,6 +252,7 @@ POSTGRES_IVFFLAT_LISTS=100 NEO4J_URI=neo4j+s://xxxxxxxx.databases.neo4j.io NEO4J_USERNAME=neo4j NEO4J_PASSWORD='your_password' +# NEO4J_DATABASE=chunk_entity_relation NEO4J_MAX_CONNECTION_POOL_SIZE=100 NEO4J_CONNECTION_TIMEOUT=30 NEO4J_CONNECTION_ACQUISITION_TIMEOUT=30 diff --git a/lightrag/kg/json_doc_status_impl.py b/lightrag/kg/json_doc_status_impl.py index 09af3ef1..9fb114f2 100644 --- a/lightrag/kg/json_doc_status_impl.py +++ b/lightrag/kg/json_doc_status_impl.py @@ -33,19 +33,18 @@ class JsonDocStatusStorage(DocStatusStorage): if self.workspace: # Include workspace in the file path for data isolation workspace_dir = os.path.join(working_dir, self.workspace) - os.makedirs(workspace_dir, exist_ok=True) - self._file_name = os.path.join( - workspace_dir, f"kv_store_{self.namespace}.json" - ) + self.final_namespace = f"{self.workspace}_{self.namespace}" else: # Default behavior when workspace is empty - self._file_name = os.path.join( - working_dir, f"kv_store_{self.namespace}.json" - ) + self.final_namespace = self.namespace + self.workspace = "_" + workspace_dir = working_dir + + os.makedirs(workspace_dir, exist_ok=True) + self._file_name = os.path.join(workspace_dir, f"kv_store_{self.namespace}.json") self._data = None self._storage_lock = None self.storage_updated = None - self.final_namespace = f"{self.workspace}_{self.namespace}" async def initialize(self): """Initialize storage data""" @@ -60,7 +59,7 @@ class JsonDocStatusStorage(DocStatusStorage): async with self._storage_lock: self._data.update(loaded_data) logger.info( - f"Process {os.getpid()} doc status load {self.final_namespace} with {len(loaded_data)} records" + f"[{self.workspace}] Process {os.getpid()} doc status load {self.namespace} with {len(loaded_data)} records" ) async def filter_keys(self, keys: set[str]) -> set[str]: @@ -108,7 +107,9 @@ class JsonDocStatusStorage(DocStatusStorage): data["error_msg"] = None result[k] = DocProcessingStatus(**data) except KeyError as e: - logger.error(f"Missing required field for document {k}: {e}") + logger.error( + f"[{self.workspace}] Missing required field for document {k}: {e}" + ) continue return result @@ -135,7 +136,9 @@ class JsonDocStatusStorage(DocStatusStorage): data["error_msg"] = None result[k] = DocProcessingStatus(**data) except KeyError as e: - logger.error(f"Missing required field for document {k}: {e}") + logger.error( + f"[{self.workspace}] Missing required field for document {k}: {e}" + ) continue return result @@ -146,7 +149,7 @@ class JsonDocStatusStorage(DocStatusStorage): dict(self._data) if hasattr(self._data, "_getvalue") else self._data ) logger.debug( - f"Process {os.getpid()} doc status writting {len(data_dict)} records to {self.final_namespace}" + f"[{self.workspace}] Process {os.getpid()} doc status writting {len(data_dict)} records to {self.namespace}" ) write_json(data_dict, self._file_name) await clear_all_update_flags(self.final_namespace) @@ -159,7 +162,9 @@ class JsonDocStatusStorage(DocStatusStorage): """ if not data: return - logger.debug(f"Inserting {len(data)} records to {self.final_namespace}") + logger.debug( + f"[{self.workspace}] Inserting {len(data)} records to {self.namespace}" + ) async with self._storage_lock: # Ensure chunks_list field exists for new documents for doc_id, doc_data in data.items(): @@ -242,7 +247,9 @@ class JsonDocStatusStorage(DocStatusStorage): all_docs.append((doc_id, doc_status)) except KeyError as e: - logger.error(f"Error processing document {doc_id}: {e}") + logger.error( + f"[{self.workspace}] Error processing document {doc_id}: {e}" + ) continue # Sort documents @@ -321,8 +328,10 @@ class JsonDocStatusStorage(DocStatusStorage): await set_all_update_flags(self.final_namespace) await self.index_done_callback() - logger.info(f"Process {os.getpid()} drop {self.final_namespace}") + logger.info( + f"[{self.workspace}] Process {os.getpid()} drop {self.namespace}" + ) return {"status": "success", "message": "data dropped"} except Exception as e: - logger.error(f"Error dropping {self.final_namespace}: {e}") + logger.error(f"[{self.workspace}] Error dropping {self.namespace}: {e}") return {"status": "error", "message": str(e)} diff --git a/lightrag/kg/json_kv_impl.py b/lightrag/kg/json_kv_impl.py index d6d80079..ca3aa453 100644 --- a/lightrag/kg/json_kv_impl.py +++ b/lightrag/kg/json_kv_impl.py @@ -29,19 +29,19 @@ class JsonKVStorage(BaseKVStorage): if self.workspace: # Include workspace in the file path for data isolation workspace_dir = os.path.join(working_dir, self.workspace) - os.makedirs(workspace_dir, exist_ok=True) - self._file_name = os.path.join( - workspace_dir, f"kv_store_{self.namespace}.json" - ) + self.final_namespace = f"{self.workspace}_{self.namespace}" else: # Default behavior when workspace is empty - self._file_name = os.path.join( - working_dir, f"kv_store_{self.namespace}.json" - ) + workspace_dir = working_dir + self.final_namespace = self.namespace + self.workspace = "_" + + os.makedirs(workspace_dir, exist_ok=True) + self._file_name = os.path.join(workspace_dir, f"kv_store_{self.namespace}.json") + self._data = None self._storage_lock = None self.storage_updated = None - self.final_namespace = f"{self.workspace}_{self.namespace}" async def initialize(self): """Initialize storage data""" @@ -64,7 +64,7 @@ class JsonKVStorage(BaseKVStorage): data_count = len(loaded_data) logger.info( - f"Process {os.getpid()} KV load {self.final_namespace} with {data_count} records" + f"[{self.workspace}] Process {os.getpid()} KV load {self.namespace} with {data_count} records" ) async def index_done_callback(self) -> None: @@ -78,7 +78,7 @@ class JsonKVStorage(BaseKVStorage): data_count = len(data_dict) logger.debug( - f"Process {os.getpid()} KV writting {data_count} records to {self.final_namespace}" + f"[{self.workspace}] Process {os.getpid()} KV writting {data_count} records to {self.namespace}" ) write_json(data_dict, self._file_name) await clear_all_update_flags(self.final_namespace) @@ -151,12 +151,14 @@ class JsonKVStorage(BaseKVStorage): current_time = int(time.time()) # Get current Unix timestamp - logger.debug(f"Inserting {len(data)} records to {self.final_namespace}") + logger.debug( + f"[{self.workspace}] Inserting {len(data)} records to {self.namespace}" + ) async with self._storage_lock: # Add timestamps to data based on whether key exists for k, v in data.items(): # For text_chunks namespace, ensure llm_cache_list field exists - if "text_chunks" in self.namespace: + if self.namespace.endswith("text_chunks"): if "llm_cache_list" not in v: v["llm_cache_list"] = [] @@ -215,10 +217,12 @@ class JsonKVStorage(BaseKVStorage): await set_all_update_flags(self.final_namespace) await self.index_done_callback() - logger.info(f"Process {os.getpid()} drop {self.final_namespace}") + logger.info( + f"[{self.workspace}] Process {os.getpid()} drop {self.namespace}" + ) return {"status": "success", "message": "data dropped"} except Exception as e: - logger.error(f"Error dropping {self.final_namespace}: {e}") + logger.error(f"[{self.workspace}] Error dropping {self.namespace}: {e}") return {"status": "error", "message": str(e)} async def _migrate_legacy_cache_structure(self, data: dict) -> dict: @@ -263,7 +267,7 @@ class JsonKVStorage(BaseKVStorage): if migration_count > 0: logger.info( - f"Migrated {migration_count} legacy cache entries to flattened structure" + f"[{self.workspace}] Migrated {migration_count} legacy cache entries to flattened structure" ) # Persist migrated data immediately write_json(migrated_data, self._file_name) diff --git a/lightrag/kg/memgraph_impl.py b/lightrag/kg/memgraph_impl.py index af26b961..bd69678a 100644 --- a/lightrag/kg/memgraph_impl.py +++ b/lightrag/kg/memgraph_impl.py @@ -34,21 +34,25 @@ config.read("config.ini", "utf-8") @dataclass class MemgraphStorage(BaseGraphStorage): def __init__(self, namespace, global_config, embedding_func, workspace=None): + # Priority: 1) MEMGRAPH_WORKSPACE env 2) user arg 3) default 'base' memgraph_workspace = os.environ.get("MEMGRAPH_WORKSPACE") if memgraph_workspace and memgraph_workspace.strip(): workspace = memgraph_workspace + + if not workspace or not str(workspace).strip(): + workspace = "base" + super().__init__( namespace=namespace, - workspace=workspace or "", + workspace=workspace, global_config=global_config, embedding_func=embedding_func, ) self._driver = None def _get_workspace_label(self) -> str: - """Get workspace label, return 'base' for compatibility when workspace is empty""" - workspace = getattr(self, "workspace", None) - return workspace if workspace else "base" + """Return workspace label (guaranteed non-empty during initialization)""" + return self.workspace async def initialize(self): URI = os.environ.get( @@ -79,17 +83,19 @@ class MemgraphStorage(BaseGraphStorage): f"""CREATE INDEX ON :{workspace_label}(entity_id)""" ) logger.info( - f"Created index on :{workspace_label}(entity_id) in Memgraph." + f"[{self.workspace}] Created index on :{workspace_label}(entity_id) in Memgraph." ) except Exception as e: # Index may already exist, which is not an error logger.warning( - f"Index creation on :{workspace_label}(entity_id) may have failed or already exists: {e}" + f"[{self.workspace}] Index creation on :{workspace_label}(entity_id) may have failed or already exists: {e}" ) await session.run("RETURN 1") - logger.info(f"Connected to Memgraph at {URI}") + logger.info(f"[{self.workspace}] Connected to Memgraph at {URI}") except Exception as e: - logger.error(f"Failed to connect to Memgraph at {URI}: {e}") + logger.error( + f"[{self.workspace}] Failed to connect to Memgraph at {URI}: {e}" + ) raise async def finalize(self): @@ -134,7 +140,9 @@ class MemgraphStorage(BaseGraphStorage): single_result["node_exists"] if single_result is not None else False ) except Exception as e: - logger.error(f"Error checking node existence for {node_id}: {str(e)}") + logger.error( + f"[{self.workspace}] Error checking node existence for {node_id}: {str(e)}" + ) await result.consume() # Ensure the result is consumed even on error raise @@ -177,7 +185,7 @@ class MemgraphStorage(BaseGraphStorage): ) except Exception as e: logger.error( - f"Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}" + f"[{self.workspace}] Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}" ) await result.consume() # Ensure the result is consumed even on error raise @@ -215,7 +223,7 @@ class MemgraphStorage(BaseGraphStorage): if len(records) > 1: logger.warning( - f"Multiple nodes found with label '{node_id}'. Using first node." + f"[{self.workspace}] Multiple nodes found with label '{node_id}'. Using first node." ) if records: node = records[0]["n"] @@ -232,7 +240,9 @@ class MemgraphStorage(BaseGraphStorage): finally: await result.consume() # Ensure result is fully consumed except Exception as e: - logger.error(f"Error getting node for {node_id}: {str(e)}") + logger.error( + f"[{self.workspace}] Error getting node for {node_id}: {str(e)}" + ) raise async def node_degree(self, node_id: str) -> int: @@ -268,7 +278,9 @@ class MemgraphStorage(BaseGraphStorage): record = await result.single() if not record: - logger.warning(f"No node found with label '{node_id}'") + logger.warning( + f"[{self.workspace}] No node found with label '{node_id}'" + ) return 0 degree = record["degree"] @@ -276,7 +288,9 @@ class MemgraphStorage(BaseGraphStorage): finally: await result.consume() # Ensure result is fully consumed except Exception as e: - logger.error(f"Error getting node degree for {node_id}: {str(e)}") + logger.error( + f"[{self.workspace}] Error getting node degree for {node_id}: {str(e)}" + ) raise async def get_all_labels(self) -> list[str]: @@ -310,7 +324,7 @@ class MemgraphStorage(BaseGraphStorage): await result.consume() return labels except Exception as e: - logger.error(f"Error getting all labels: {str(e)}") + logger.error(f"[{self.workspace}] Error getting all labels: {str(e)}") await result.consume() # Ensure the result is consumed even on error raise @@ -370,12 +384,14 @@ class MemgraphStorage(BaseGraphStorage): return edges except Exception as e: logger.error( - f"Error getting edges for node {source_node_id}: {str(e)}" + f"[{self.workspace}] Error getting edges for node {source_node_id}: {str(e)}" ) await results.consume() # Ensure results are consumed even on error raise except Exception as e: - logger.error(f"Error in get_node_edges for {source_node_id}: {str(e)}") + logger.error( + f"[{self.workspace}] Error in get_node_edges for {source_node_id}: {str(e)}" + ) raise async def get_edge( @@ -424,13 +440,13 @@ class MemgraphStorage(BaseGraphStorage): if key not in edge_result: edge_result[key] = default_value logger.warning( - f"Edge between {source_node_id} and {target_node_id} is missing property: {key}. Using default value: {default_value}" + f"[{self.workspace}] Edge between {source_node_id} and {target_node_id} is missing property: {key}. Using default value: {default_value}" ) return edge_result return None except Exception as e: logger.error( - f"Error getting edge between {source_node_id} and {target_node_id}: {str(e)}" + f"[{self.workspace}] Error getting edge between {source_node_id} and {target_node_id}: {str(e)}" ) await result.consume() # Ensure the result is consumed even on error raise @@ -463,7 +479,7 @@ class MemgraphStorage(BaseGraphStorage): for attempt in range(max_retries): try: logger.debug( - f"Attempting node upsert, attempt {attempt + 1}/{max_retries}" + f"[{self.workspace}] Attempting node upsert, attempt {attempt + 1}/{max_retries}" ) async with self._driver.session(database=self._DATABASE) as session: workspace_label = self._get_workspace_label() @@ -504,20 +520,24 @@ class MemgraphStorage(BaseGraphStorage): initial_wait_time * (backoff_factor**attempt) + jitter ) logger.warning( - f"Node upsert failed. Attempt #{attempt + 1} retrying in {wait_time:.3f} seconds... Error: {str(e)}" + f"[{self.workspace}] Node upsert failed. Attempt #{attempt + 1} retrying in {wait_time:.3f} seconds... Error: {str(e)}" ) await asyncio.sleep(wait_time) else: logger.error( - f"Memgraph transient error during node upsert after {max_retries} retries: {str(e)}" + f"[{self.workspace}] Memgraph transient error during node upsert after {max_retries} retries: {str(e)}" ) raise else: # Non-transient error, don't retry - logger.error(f"Non-transient error during node upsert: {str(e)}") + logger.error( + f"[{self.workspace}] Non-transient error during node upsert: {str(e)}" + ) raise except Exception as e: - logger.error(f"Unexpected error during node upsert: {str(e)}") + logger.error( + f"[{self.workspace}] Unexpected error during node upsert: {str(e)}" + ) raise async def upsert_edge( @@ -552,7 +572,7 @@ class MemgraphStorage(BaseGraphStorage): for attempt in range(max_retries): try: logger.debug( - f"Attempting edge upsert, attempt {attempt + 1}/{max_retries}" + f"[{self.workspace}] Attempting edge upsert, attempt {attempt + 1}/{max_retries}" ) async with self._driver.session(database=self._DATABASE) as session: @@ -602,20 +622,24 @@ class MemgraphStorage(BaseGraphStorage): initial_wait_time * (backoff_factor**attempt) + jitter ) logger.warning( - f"Edge upsert failed. Attempt #{attempt + 1} retrying in {wait_time:.3f} seconds... Error: {str(e)}" + f"[{self.workspace}] Edge upsert failed. Attempt #{attempt + 1} retrying in {wait_time:.3f} seconds... Error: {str(e)}" ) await asyncio.sleep(wait_time) else: logger.error( - f"Memgraph transient error during edge upsert after {max_retries} retries: {str(e)}" + f"[{self.workspace}] Memgraph transient error during edge upsert after {max_retries} retries: {str(e)}" ) raise else: # Non-transient error, don't retry - logger.error(f"Non-transient error during edge upsert: {str(e)}") + logger.error( + f"[{self.workspace}] Non-transient error during edge upsert: {str(e)}" + ) raise except Exception as e: - logger.error(f"Unexpected error during edge upsert: {str(e)}") + logger.error( + f"[{self.workspace}] Unexpected error during edge upsert: {str(e)}" + ) raise async def delete_node(self, node_id: str) -> None: @@ -639,14 +663,14 @@ class MemgraphStorage(BaseGraphStorage): DETACH DELETE n """ result = await tx.run(query, entity_id=node_id) - logger.debug(f"Deleted node with label {node_id}") + logger.debug(f"[{self.workspace}] Deleted node with label {node_id}") await result.consume() try: async with self._driver.session(database=self._DATABASE) as session: await session.execute_write(_do_delete) except Exception as e: - logger.error(f"Error during node deletion: {str(e)}") + logger.error(f"[{self.workspace}] Error during node deletion: {str(e)}") raise async def remove_nodes(self, nodes: list[str]): @@ -686,14 +710,16 @@ class MemgraphStorage(BaseGraphStorage): result = await tx.run( query, source_entity_id=source, target_entity_id=target ) - logger.debug(f"Deleted edge from '{source}' to '{target}'") + logger.debug( + f"[{self.workspace}] Deleted edge from '{source}' to '{target}'" + ) await result.consume() # Ensure result is fully consumed try: async with self._driver.session(database=self._DATABASE) as session: await session.execute_write(_do_delete_edge) except Exception as e: - logger.error(f"Error during edge deletion: {str(e)}") + logger.error(f"[{self.workspace}] Error during edge deletion: {str(e)}") raise async def drop(self) -> dict[str, str]: @@ -720,12 +746,12 @@ class MemgraphStorage(BaseGraphStorage): result = await session.run(query) await result.consume() logger.info( - f"Dropped workspace {workspace_label} from Memgraph database {self._DATABASE}" + f"[{self.workspace}] Dropped workspace {workspace_label} from Memgraph database {self._DATABASE}" ) return {"status": "success", "message": "workspace data dropped"} except Exception as e: logger.error( - f"Error dropping workspace {workspace_label} from Memgraph database {self._DATABASE}: {e}" + f"[{self.workspace}] Error dropping workspace {workspace_label} from Memgraph database {self._DATABASE}: {e}" ) return {"status": "error", "message": str(e)} @@ -945,14 +971,16 @@ class MemgraphStorage(BaseGraphStorage): # If no record found, return empty KnowledgeGraph if not record: - logger.debug(f"No nodes found for entity_id: {node_label}") + logger.debug( + f"[{self.workspace}] No nodes found for entity_id: {node_label}" + ) return result # Check if the result was truncated if record.get("is_truncated"): result.is_truncated = True logger.info( - f"Graph truncated: breadth-first search limited to {max_nodes} nodes" + f"[{self.workspace}] Graph truncated: breadth-first search limited to {max_nodes} nodes" ) finally: @@ -990,11 +1018,13 @@ class MemgraphStorage(BaseGraphStorage): seen_edges.add(edge_id) logger.info( - f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" + f"[{self.workspace}] Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" ) except Exception as e: - logger.warning(f"Memgraph error during subgraph query: {str(e)}") + logger.warning( + f"[{self.workspace}] Memgraph error during subgraph query: {str(e)}" + ) return result diff --git a/lightrag/kg/milvus_impl.py b/lightrag/kg/milvus_impl.py index 55e5f9eb..f3333bb7 100644 --- a/lightrag/kg/milvus_impl.py +++ b/lightrag/kg/milvus_impl.py @@ -37,7 +37,7 @@ class MilvusVectorDBStorage(BaseVectorStorage): ] # Determine specific fields based on namespace - if "entities" in self.namespace.lower(): + if self.namespace.endswith("entities"): specific_fields = [ FieldSchema( name="entity_name", @@ -54,7 +54,7 @@ class MilvusVectorDBStorage(BaseVectorStorage): ] description = "LightRAG entities vector storage" - elif "relationships" in self.namespace.lower(): + elif self.namespace.endswith("relationships"): specific_fields = [ FieldSchema( name="src_id", dtype=DataType.VARCHAR, max_length=512, nullable=True @@ -71,7 +71,7 @@ class MilvusVectorDBStorage(BaseVectorStorage): ] description = "LightRAG relationships vector storage" - elif "chunks" in self.namespace.lower(): + elif self.namespace.endswith("chunks"): specific_fields = [ FieldSchema( name="full_doc_id", @@ -147,7 +147,7 @@ class MilvusVectorDBStorage(BaseVectorStorage): """Fallback method to create vector index using direct API""" try: self._client.create_index( - collection_name=self.namespace, + collection_name=self.final_namespace, field_name="vector", index_params={ "index_type": "HNSW", @@ -155,29 +155,35 @@ class MilvusVectorDBStorage(BaseVectorStorage): "params": {"M": 16, "efConstruction": 256}, }, ) - logger.debug("Created vector index using fallback method") + logger.debug( + f"[{self.workspace}] Created vector index using fallback method" + ) except Exception as e: - logger.warning(f"Failed to create vector index using fallback method: {e}") + logger.warning( + f"[{self.workspace}] Failed to create vector index using fallback method: {e}" + ) def _create_scalar_index_fallback(self, field_name: str, index_type: str): """Fallback method to create scalar index using direct API""" # Skip unsupported index types if index_type == "SORTED": logger.info( - f"Skipping SORTED index for {field_name} (not supported in this Milvus version)" + f"[{self.workspace}] Skipping SORTED index for {field_name} (not supported in this Milvus version)" ) return try: self._client.create_index( - collection_name=self.namespace, + collection_name=self.final_namespace, field_name=field_name, index_params={"index_type": index_type}, ) - logger.debug(f"Created {field_name} index using fallback method") + logger.debug( + f"[{self.workspace}] Created {field_name} index using fallback method" + ) except Exception as e: logger.info( - f"Could not create {field_name} index using fallback method: {e}" + f"[{self.workspace}] Could not create {field_name} index using fallback method: {e}" ) def _create_indexes_after_collection(self): @@ -198,15 +204,19 @@ class MilvusVectorDBStorage(BaseVectorStorage): params={"M": 16, "efConstruction": 256}, ) self._client.create_index( - collection_name=self.namespace, index_params=vector_index + collection_name=self.final_namespace, index_params=vector_index + ) + logger.debug( + f"[{self.workspace}] Created vector index using IndexParams" ) - logger.debug("Created vector index using IndexParams") except Exception as e: - logger.debug(f"IndexParams method failed for vector index: {e}") + logger.debug( + f"[{self.workspace}] IndexParams method failed for vector index: {e}" + ) self._create_vector_index_fallback() # Create scalar indexes based on namespace - if "entities" in self.namespace.lower(): + if self.namespace.endswith("entities"): # Create indexes for entity fields try: entity_name_index = self._get_index_params() @@ -214,14 +224,16 @@ class MilvusVectorDBStorage(BaseVectorStorage): field_name="entity_name", index_type="INVERTED" ) self._client.create_index( - collection_name=self.namespace, + collection_name=self.final_namespace, index_params=entity_name_index, ) except Exception as e: - logger.debug(f"IndexParams method failed for entity_name: {e}") + logger.debug( + f"[{self.workspace}] IndexParams method failed for entity_name: {e}" + ) self._create_scalar_index_fallback("entity_name", "INVERTED") - elif "relationships" in self.namespace.lower(): + elif self.namespace.endswith("relationships"): # Create indexes for relationship fields try: src_id_index = self._get_index_params() @@ -229,10 +241,13 @@ class MilvusVectorDBStorage(BaseVectorStorage): field_name="src_id", index_type="INVERTED" ) self._client.create_index( - collection_name=self.namespace, index_params=src_id_index + collection_name=self.final_namespace, + index_params=src_id_index, ) except Exception as e: - logger.debug(f"IndexParams method failed for src_id: {e}") + logger.debug( + f"[{self.workspace}] IndexParams method failed for src_id: {e}" + ) self._create_scalar_index_fallback("src_id", "INVERTED") try: @@ -241,13 +256,16 @@ class MilvusVectorDBStorage(BaseVectorStorage): field_name="tgt_id", index_type="INVERTED" ) self._client.create_index( - collection_name=self.namespace, index_params=tgt_id_index + collection_name=self.final_namespace, + index_params=tgt_id_index, ) except Exception as e: - logger.debug(f"IndexParams method failed for tgt_id: {e}") + logger.debug( + f"[{self.workspace}] IndexParams method failed for tgt_id: {e}" + ) self._create_scalar_index_fallback("tgt_id", "INVERTED") - elif "chunks" in self.namespace.lower(): + elif self.namespace.endswith("chunks"): # Create indexes for chunk fields try: doc_id_index = self._get_index_params() @@ -255,10 +273,13 @@ class MilvusVectorDBStorage(BaseVectorStorage): field_name="full_doc_id", index_type="INVERTED" ) self._client.create_index( - collection_name=self.namespace, index_params=doc_id_index + collection_name=self.final_namespace, + index_params=doc_id_index, ) except Exception as e: - logger.debug(f"IndexParams method failed for full_doc_id: {e}") + logger.debug( + f"[{self.workspace}] IndexParams method failed for full_doc_id: {e}" + ) self._create_scalar_index_fallback("full_doc_id", "INVERTED") # No common indexes needed @@ -266,25 +287,29 @@ class MilvusVectorDBStorage(BaseVectorStorage): else: # Fallback to direct API calls if IndexParams is not available logger.info( - f"IndexParams not available, using fallback methods for {self.namespace}" + f"[{self.workspace}] IndexParams not available, using fallback methods for {self.namespace}" ) # Create vector index using fallback self._create_vector_index_fallback() # Create scalar indexes using fallback - if "entities" in self.namespace.lower(): + if self.namespace.endswith("entities"): self._create_scalar_index_fallback("entity_name", "INVERTED") - elif "relationships" in self.namespace.lower(): + elif self.namespace.endswith("relationships"): self._create_scalar_index_fallback("src_id", "INVERTED") self._create_scalar_index_fallback("tgt_id", "INVERTED") - elif "chunks" in self.namespace.lower(): + elif self.namespace.endswith("chunks"): self._create_scalar_index_fallback("full_doc_id", "INVERTED") - logger.info(f"Created indexes for collection: {self.namespace}") + logger.info( + f"[{self.workspace}] Created indexes for collection: {self.namespace}" + ) except Exception as e: - logger.warning(f"Failed to create some indexes for {self.namespace}: {e}") + logger.warning( + f"[{self.workspace}] Failed to create some indexes for {self.namespace}: {e}" + ) def _get_required_fields_for_namespace(self) -> dict: """Get required core field definitions for current namespace""" @@ -297,18 +322,18 @@ class MilvusVectorDBStorage(BaseVectorStorage): } # Add specific fields based on namespace - if "entities" in self.namespace.lower(): + if self.namespace.endswith("entities"): specific_fields = { "entity_name": {"type": "VarChar"}, "file_path": {"type": "VarChar"}, } - elif "relationships" in self.namespace.lower(): + elif self.namespace.endswith("relationships"): specific_fields = { "src_id": {"type": "VarChar"}, "tgt_id": {"type": "VarChar"}, "file_path": {"type": "VarChar"}, } - elif "chunks" in self.namespace.lower(): + elif self.namespace.endswith("chunks"): specific_fields = { "full_doc_id": {"type": "VarChar"}, "file_path": {"type": "VarChar"}, @@ -327,7 +352,7 @@ class MilvusVectorDBStorage(BaseVectorStorage): expected_type = expected_config.get("type") logger.debug( - f"Checking field '{field_name}': existing_type={existing_type} (type={type(existing_type)}), expected_type={expected_type}" + f"[{self.workspace}] Checking field '{field_name}': existing_type={existing_type} (type={type(existing_type)}), expected_type={expected_type}" ) # Convert DataType enum values to string names if needed @@ -335,7 +360,7 @@ class MilvusVectorDBStorage(BaseVectorStorage): if hasattr(existing_type, "name"): existing_type = existing_type.name logger.debug( - f"Converted enum to name: {original_existing_type} -> {existing_type}" + f"[{self.workspace}] Converted enum to name: {original_existing_type} -> {existing_type}" ) elif isinstance(existing_type, int): # Map common Milvus internal type codes to type names for backward compatibility @@ -346,7 +371,9 @@ class MilvusVectorDBStorage(BaseVectorStorage): 9: "Double", } mapped_type = type_mapping.get(existing_type, str(existing_type)) - logger.debug(f"Mapped numeric type: {existing_type} -> {mapped_type}") + logger.debug( + f"[{self.workspace}] Mapped numeric type: {existing_type} -> {mapped_type}" + ) existing_type = mapped_type # Normalize type names for comparison @@ -367,18 +394,18 @@ class MilvusVectorDBStorage(BaseVectorStorage): if original_existing != existing_type or original_expected != expected_type: logger.debug( - f"Applied aliases: {original_existing} -> {existing_type}, {original_expected} -> {expected_type}" + f"[{self.workspace}] Applied aliases: {original_existing} -> {existing_type}, {original_expected} -> {expected_type}" ) # Basic type compatibility check type_compatible = existing_type == expected_type logger.debug( - f"Type compatibility for '{field_name}': {existing_type} == {expected_type} -> {type_compatible}" + f"[{self.workspace}] Type compatibility for '{field_name}': {existing_type} == {expected_type} -> {type_compatible}" ) if not type_compatible: logger.warning( - f"Type mismatch for field '{field_name}': expected {expected_type}, got {existing_type}" + f"[{self.workspace}] Type mismatch for field '{field_name}': expected {expected_type}, got {existing_type}" ) return False @@ -391,23 +418,25 @@ class MilvusVectorDBStorage(BaseVectorStorage): or existing_field.get("primary_key", False) ) logger.debug( - f"Primary key check for '{field_name}': expected=True, actual={is_primary}" + f"[{self.workspace}] Primary key check for '{field_name}': expected=True, actual={is_primary}" + ) + logger.debug( + f"[{self.workspace}] Raw field data for '{field_name}': {existing_field}" ) - logger.debug(f"Raw field data for '{field_name}': {existing_field}") # For ID field, be more lenient - if it's the ID field, assume it should be primary if field_name == "id" and not is_primary: logger.info( - f"ID field '{field_name}' not marked as primary in existing collection, but treating as compatible" + f"[{self.workspace}] ID field '{field_name}' not marked as primary in existing collection, but treating as compatible" ) # Don't fail for ID field primary key mismatch elif not is_primary: logger.warning( - f"Primary key mismatch for field '{field_name}': expected primary key, but field is not primary" + f"[{self.workspace}] Primary key mismatch for field '{field_name}': expected primary key, but field is not primary" ) return False - logger.debug(f"Field '{field_name}' is compatible") + logger.debug(f"[{self.workspace}] Field '{field_name}' is compatible") return True def _check_vector_dimension(self, collection_info: dict): @@ -434,18 +463,22 @@ class MilvusVectorDBStorage(BaseVectorStorage): if existing_dimension != current_dimension: raise ValueError( - f"Vector dimension mismatch for collection '{self.namespace}': " + f"Vector dimension mismatch for collection '{self.final_namespace}': " f"existing={existing_dimension}, current={current_dimension}" ) - logger.debug(f"Vector dimension check passed: {current_dimension}") + logger.debug( + f"[{self.workspace}] Vector dimension check passed: {current_dimension}" + ) return # If no vector field found, this might be an old collection created with simple schema logger.warning( - f"Vector field not found in collection '{self.namespace}'. This might be an old collection created with simple schema." + f"[{self.workspace}] Vector field not found in collection '{self.namespace}'. This might be an old collection created with simple schema." + ) + logger.warning( + f"[{self.workspace}] Consider recreating the collection for optimal performance." ) - logger.warning("Consider recreating the collection for optimal performance.") return def _check_schema_compatibility(self, collection_info: dict): @@ -461,12 +494,14 @@ class MilvusVectorDBStorage(BaseVectorStorage): if not has_vector_field: logger.warning( - f"Collection {self.namespace} appears to be created with old simple schema (no vector field)" + f"[{self.workspace}] Collection {self.namespace} appears to be created with old simple schema (no vector field)" ) logger.warning( - "This collection will work but may have suboptimal performance" + f"[{self.workspace}] This collection will work but may have suboptimal performance" + ) + logger.warning( + f"[{self.workspace}] Consider recreating the collection for optimal performance" ) - logger.warning("Consider recreating the collection for optimal performance") return # For collections with vector field, check basic compatibility @@ -486,7 +521,7 @@ class MilvusVectorDBStorage(BaseVectorStorage): if incompatible_fields: raise ValueError( - f"Critical schema incompatibility in collection '{self.namespace}': {incompatible_fields}" + f"Critical schema incompatibility in collection '{self.final_namespace}': {incompatible_fields}" ) # Get all expected fields for informational purposes @@ -497,18 +532,20 @@ class MilvusVectorDBStorage(BaseVectorStorage): if missing_fields: logger.info( - f"Collection {self.namespace} missing optional fields: {missing_fields}" + f"[{self.workspace}] Collection {self.namespace} missing optional fields: {missing_fields}" ) logger.info( "These fields would be available in a newly created collection for better performance" ) - logger.debug(f"Schema compatibility check passed for {self.namespace}") + logger.debug( + f"[{self.workspace}] Schema compatibility check passed for {self.namespace}" + ) def _validate_collection_compatibility(self): """Validate existing collection's dimension and schema compatibility""" try: - collection_info = self._client.describe_collection(self.namespace) + collection_info = self._client.describe_collection(self.final_namespace) # 1. Check vector dimension self._check_vector_dimension(collection_info) @@ -517,12 +554,12 @@ class MilvusVectorDBStorage(BaseVectorStorage): self._check_schema_compatibility(collection_info) logger.info( - f"VectorDB Collection '{self.namespace}' compatibility validation passed" + f"[{self.workspace}] VectorDB Collection '{self.namespace}' compatibility validation passed" ) except Exception as e: logger.error( - f"Collection compatibility validation failed for {self.namespace}: {e}" + f"[{self.workspace}] Collection compatibility validation failed for {self.namespace}: {e}" ) raise @@ -530,17 +567,21 @@ class MilvusVectorDBStorage(BaseVectorStorage): """Ensure the collection is loaded into memory for search operations""" try: # Check if collection exists first - if not self._client.has_collection(self.namespace): - logger.error(f"Collection {self.namespace} does not exist") - raise ValueError(f"Collection {self.namespace} does not exist") + if not self._client.has_collection(self.final_namespace): + logger.error( + f"[{self.workspace}] Collection {self.namespace} does not exist" + ) + raise ValueError(f"Collection {self.final_namespace} does not exist") # Load the collection if it's not already loaded # In Milvus, collections need to be loaded before they can be searched - self._client.load_collection(self.namespace) - # logger.debug(f"Collection {self.namespace} loaded successfully") + self._client.load_collection(self.final_namespace) + # logger.debug(f"[{self.workspace}] Collection {self.namespace} loaded successfully") except Exception as e: - logger.error(f"Failed to load collection {self.namespace}: {e}") + logger.error( + f"[{self.workspace}] Failed to load collection {self.namespace}: {e}" + ) raise def _create_collection_if_not_exist(self): @@ -550,41 +591,45 @@ class MilvusVectorDBStorage(BaseVectorStorage): # First, list all collections to see what actually exists try: all_collections = self._client.list_collections() - logger.debug(f"All collections in database: {all_collections}") + logger.debug( + f"[{self.workspace}] All collections in database: {all_collections}" + ) except Exception as list_error: - logger.warning(f"Could not list collections: {list_error}") + logger.warning( + f"[{self.workspace}] Could not list collections: {list_error}" + ) all_collections = [] # Check if our specific collection exists - collection_exists = self._client.has_collection(self.namespace) + collection_exists = self._client.has_collection(self.final_namespace) logger.info( - f"VectorDB collection '{self.namespace}' exists check: {collection_exists}" + f"[{self.workspace}] VectorDB collection '{self.namespace}' exists check: {collection_exists}" ) if collection_exists: # Double-check by trying to describe the collection try: - self._client.describe_collection(self.namespace) + self._client.describe_collection(self.final_namespace) self._validate_collection_compatibility() # Ensure the collection is loaded after validation self._ensure_collection_loaded() return except Exception as describe_error: logger.warning( - f"Collection '{self.namespace}' exists but cannot be described: {describe_error}" + f"[{self.workspace}] Collection '{self.namespace}' exists but cannot be described: {describe_error}" ) logger.info( - "Treating as if collection doesn't exist and creating new one..." + f"[{self.workspace}] Treating as if collection doesn't exist and creating new one..." ) # Fall through to creation logic # Collection doesn't exist, create new collection - logger.info(f"Creating new collection: {self.namespace}") + logger.info(f"[{self.workspace}] Creating new collection: {self.namespace}") schema = self._create_schema_for_namespace() # Create collection with schema only first self._client.create_collection( - collection_name=self.namespace, schema=schema + collection_name=self.final_namespace, schema=schema ) # Then create indexes @@ -593,43 +638,49 @@ class MilvusVectorDBStorage(BaseVectorStorage): # Load the newly created collection self._ensure_collection_loaded() - logger.info(f"Successfully created Milvus collection: {self.namespace}") + logger.info( + f"[{self.workspace}] Successfully created Milvus collection: {self.namespace}" + ) except Exception as e: logger.error( - f"Error in _create_collection_if_not_exist for {self.namespace}: {e}" + f"[{self.workspace}] Error in _create_collection_if_not_exist for {self.namespace}: {e}" ) # If there's any error, try to force create the collection - logger.info(f"Attempting to force create collection {self.namespace}...") + logger.info( + f"[{self.workspace}] Attempting to force create collection {self.namespace}..." + ) try: # Try to drop the collection first if it exists in a bad state try: - if self._client.has_collection(self.namespace): + if self._client.has_collection(self.final_namespace): logger.info( - f"Dropping potentially corrupted collection {self.namespace}" + f"[{self.workspace}] Dropping potentially corrupted collection {self.namespace}" ) - self._client.drop_collection(self.namespace) + self._client.drop_collection(self.final_namespace) except Exception as drop_error: logger.warning( - f"Could not drop collection {self.namespace}: {drop_error}" + f"[{self.workspace}] Could not drop collection {self.namespace}: {drop_error}" ) # Create fresh collection schema = self._create_schema_for_namespace() self._client.create_collection( - collection_name=self.namespace, schema=schema + collection_name=self.final_namespace, schema=schema ) self._create_indexes_after_collection() # Load the newly created collection self._ensure_collection_loaded() - logger.info(f"Successfully force-created collection {self.namespace}") + logger.info( + f"[{self.workspace}] Successfully force-created collection {self.namespace}" + ) except Exception as create_error: logger.error( - f"Failed to force-create collection {self.namespace}: {create_error}" + f"[{self.workspace}] Failed to force-create collection {self.namespace}: {create_error}" ) raise @@ -651,11 +702,18 @@ class MilvusVectorDBStorage(BaseVectorStorage): f"Using passed workspace parameter: '{effective_workspace}'" ) - # Build namespace with workspace prefix for data isolation + # Build final_namespace with workspace prefix for data isolation + # Keep original namespace unchanged for type detection logic if effective_workspace: - self.namespace = f"{effective_workspace}_{self.namespace}" - logger.debug(f"Final namespace with workspace prefix: '{self.namespace}'") - # When workspace is empty, keep the original namespace unchanged + self.final_namespace = f"{effective_workspace}_{self.namespace}" + logger.debug( + f"Final namespace with workspace prefix: '{self.final_namespace}'" + ) + else: + # When workspace is empty, final_namespace equals original namespace + self.final_namespace = self.namespace + logger.debug(f"Final namespace (no workspace): '{self.final_namespace}'") + self.workspace = "_" kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {}) cosine_threshold = kwargs.get("cosine_better_than_threshold") @@ -699,7 +757,7 @@ class MilvusVectorDBStorage(BaseVectorStorage): self._create_collection_if_not_exist() async def upsert(self, data: dict[str, dict[str, Any]]) -> None: - logger.debug(f"Inserting {len(data)} to {self.namespace}") + logger.debug(f"[{self.workspace}] Inserting {len(data)} to {self.namespace}") if not data: return @@ -730,7 +788,9 @@ class MilvusVectorDBStorage(BaseVectorStorage): embeddings = np.concatenate(embeddings_list) for i, d in enumerate(list_data): d["vector"] = embeddings[i] - results = self._client.upsert(collection_name=self.namespace, data=list_data) + results = self._client.upsert( + collection_name=self.final_namespace, data=list_data + ) return results async def query( @@ -747,7 +807,7 @@ class MilvusVectorDBStorage(BaseVectorStorage): output_fields = list(self.meta_fields) results = self._client.search( - collection_name=self.namespace, + collection_name=self.final_namespace, data=embedding, limit=top_k, output_fields=output_fields, @@ -780,21 +840,25 @@ class MilvusVectorDBStorage(BaseVectorStorage): # Compute entity ID from name entity_id = compute_mdhash_id(entity_name, prefix="ent-") logger.debug( - f"Attempting to delete entity {entity_name} with ID {entity_id}" + f"[{self.workspace}] Attempting to delete entity {entity_name} with ID {entity_id}" ) # Delete the entity from Milvus collection result = self._client.delete( - collection_name=self.namespace, pks=[entity_id] + collection_name=self.final_namespace, pks=[entity_id] ) if result and result.get("delete_count", 0) > 0: - logger.debug(f"Successfully deleted entity {entity_name}") + logger.debug( + f"[{self.workspace}] Successfully deleted entity {entity_name}" + ) else: - logger.debug(f"Entity {entity_name} not found in storage") + logger.debug( + f"[{self.workspace}] Entity {entity_name} not found in storage" + ) except Exception as e: - logger.error(f"Error deleting entity {entity_name}: {e}") + logger.error(f"[{self.workspace}] Error deleting entity {entity_name}: {e}") async def delete_entity_relation(self, entity_name: str) -> None: """Delete all relations associated with an entity @@ -811,31 +875,35 @@ class MilvusVectorDBStorage(BaseVectorStorage): # Find all relations involving this entity results = self._client.query( - collection_name=self.namespace, filter=expr, output_fields=["id"] + collection_name=self.final_namespace, filter=expr, output_fields=["id"] ) if not results or len(results) == 0: - logger.debug(f"No relations found for entity {entity_name}") + logger.debug( + f"[{self.workspace}] No relations found for entity {entity_name}" + ) return # Extract IDs of relations to delete relation_ids = [item["id"] for item in results] logger.debug( - f"Found {len(relation_ids)} relations for entity {entity_name}" + f"[{self.workspace}] Found {len(relation_ids)} relations for entity {entity_name}" ) # Delete the relations if relation_ids: delete_result = self._client.delete( - collection_name=self.namespace, pks=relation_ids + collection_name=self.final_namespace, pks=relation_ids ) logger.debug( - f"Deleted {delete_result.get('delete_count', 0)} relations for {entity_name}" + f"[{self.workspace}] Deleted {delete_result.get('delete_count', 0)} relations for {entity_name}" ) except Exception as e: - logger.error(f"Error deleting relations for {entity_name}: {e}") + logger.error( + f"[{self.workspace}] Error deleting relations for {entity_name}: {e}" + ) async def delete(self, ids: list[str]) -> None: """Delete vectors with specified IDs @@ -848,17 +916,21 @@ class MilvusVectorDBStorage(BaseVectorStorage): self._ensure_collection_loaded() # Delete vectors by IDs - result = self._client.delete(collection_name=self.namespace, pks=ids) + result = self._client.delete(collection_name=self.final_namespace, pks=ids) if result and result.get("delete_count", 0) > 0: logger.debug( - f"Successfully deleted {result.get('delete_count', 0)} vectors from {self.namespace}" + f"[{self.workspace}] Successfully deleted {result.get('delete_count', 0)} vectors from {self.namespace}" ) else: - logger.debug(f"No vectors were deleted from {self.namespace}") + logger.debug( + f"[{self.workspace}] No vectors were deleted from {self.namespace}" + ) except Exception as e: - logger.error(f"Error while deleting vectors from {self.namespace}: {e}") + logger.error( + f"[{self.workspace}] Error while deleting vectors from {self.namespace}: {e}" + ) async def get_by_id(self, id: str) -> dict[str, Any] | None: """Get vector data by its ID @@ -878,7 +950,7 @@ class MilvusVectorDBStorage(BaseVectorStorage): # Query Milvus for a specific ID result = self._client.query( - collection_name=self.namespace, + collection_name=self.final_namespace, filter=f'id == "{id}"', output_fields=output_fields, ) @@ -888,7 +960,9 @@ class MilvusVectorDBStorage(BaseVectorStorage): return result[0] except Exception as e: - logger.error(f"Error retrieving vector data for ID {id}: {e}") + logger.error( + f"[{self.workspace}] Error retrieving vector data for ID {id}: {e}" + ) return None async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: @@ -916,14 +990,16 @@ class MilvusVectorDBStorage(BaseVectorStorage): # Query Milvus with the filter result = self._client.query( - collection_name=self.namespace, + collection_name=self.final_namespace, filter=filter_expr, output_fields=output_fields, ) return result or [] except Exception as e: - logger.error(f"Error retrieving vector data for IDs {ids}: {e}") + logger.error( + f"[{self.workspace}] Error retrieving vector data for IDs {ids}: {e}" + ) return [] async def drop(self) -> dict[str, str]: @@ -938,16 +1014,18 @@ class MilvusVectorDBStorage(BaseVectorStorage): """ try: # Drop the collection and recreate it - if self._client.has_collection(self.namespace): - self._client.drop_collection(self.namespace) + if self._client.has_collection(self.final_namespace): + self._client.drop_collection(self.final_namespace) # Recreate the collection self._create_collection_if_not_exist() logger.info( - f"Process {os.getpid()} drop Milvus collection {self.namespace}" + f"[{self.workspace}] Process {os.getpid()} drop Milvus collection {self.namespace}" ) return {"status": "success", "message": "data dropped"} except Exception as e: - logger.error(f"Error dropping Milvus collection {self.namespace}: {e}") + logger.error( + f"[{self.workspace}] Error dropping Milvus collection {self.namespace}: {e}" + ) return {"status": "error", "message": str(e)} diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index 64622127..b4550c1b 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -107,19 +107,30 @@ class MongoKVStorage(BaseKVStorage): f"Using passed workspace parameter: '{effective_workspace}'" ) - # Build namespace with workspace prefix for data isolation + # Build final_namespace with workspace prefix for data isolation + # Keep original namespace unchanged for type detection logic if effective_workspace: - self.namespace = f"{effective_workspace}_{self.namespace}" - logger.debug(f"Final namespace with workspace prefix: '{self.namespace}'") - # When workspace is empty, keep the original namespace unchanged + self.final_namespace = f"{effective_workspace}_{self.namespace}" + logger.debug( + f"Final namespace with workspace prefix: '{self.final_namespace}'" + ) + else: + # When workspace is empty, final_namespace equals original namespace + self.final_namespace = self.namespace + self.workspace = "_" + logger.debug( + f"[{self.workspace}] Final namespace (no workspace): '{self.namespace}'" + ) - self._collection_name = self.namespace + self._collection_name = self.final_namespace async def initialize(self): if self.db is None: self.db = await ClientManager.get_client() self._data = await get_or_create_collection(self.db, self._collection_name) - logger.debug(f"Use MongoDB as KV {self._collection_name}") + logger.debug( + f"[{self.workspace}] Use MongoDB as KV {self._collection_name}" + ) async def finalize(self): if self.db is not None: @@ -167,7 +178,7 @@ class MongoKVStorage(BaseKVStorage): return result async def upsert(self, data: dict[str, dict[str, Any]]) -> None: - logger.debug(f"Inserting {len(data)} to {self.namespace}") + logger.debug(f"[{self.workspace}] Inserting {len(data)} to {self.namespace}") if not data: return @@ -227,10 +238,12 @@ class MongoKVStorage(BaseKVStorage): try: result = await self._data.delete_many({"_id": {"$in": ids}}) logger.info( - f"Deleted {result.deleted_count} documents from {self.namespace}" + f"[{self.workspace}] Deleted {result.deleted_count} documents from {self.namespace}" ) except PyMongoError as e: - logger.error(f"Error deleting documents from {self.namespace}: {e}") + logger.error( + f"[{self.workspace}] Error deleting documents from {self.namespace}: {e}" + ) async def drop(self) -> dict[str, str]: """Drop the storage by removing all documents in the collection. @@ -243,14 +256,16 @@ class MongoKVStorage(BaseKVStorage): deleted_count = result.deleted_count logger.info( - f"Dropped {deleted_count} documents from doc status {self._collection_name}" + f"[{self.workspace}] Dropped {deleted_count} documents from doc status {self._collection_name}" ) return { "status": "success", "message": f"{deleted_count} documents dropped", } except PyMongoError as e: - logger.error(f"Error dropping doc status {self._collection_name}: {e}") + logger.error( + f"[{self.workspace}] Error dropping doc status {self._collection_name}: {e}" + ) return {"status": "error", "message": str(e)} @@ -287,13 +302,20 @@ class MongoDocStatusStorage(DocStatusStorage): f"Using passed workspace parameter: '{effective_workspace}'" ) - # Build namespace with workspace prefix for data isolation + # Build final_namespace with workspace prefix for data isolation + # Keep original namespace unchanged for type detection logic if effective_workspace: - self.namespace = f"{effective_workspace}_{self.namespace}" - logger.debug(f"Final namespace with workspace prefix: '{self.namespace}'") - # When workspace is empty, keep the original namespace unchanged + self.final_namespace = f"{effective_workspace}_{self.namespace}" + logger.debug( + f"Final namespace with workspace prefix: '{self.final_namespace}'" + ) + else: + # When workspace is empty, final_namespace equals original namespace + self.final_namespace = self.namespace + self.workspace = "_" + logger.debug(f"Final namespace (no workspace): '{self.final_namespace}'") - self._collection_name = self.namespace + self._collection_name = self.final_namespace async def initialize(self): if self.db is None: @@ -306,7 +328,9 @@ class MongoDocStatusStorage(DocStatusStorage): # Create pagination indexes for better query performance await self.create_pagination_indexes_if_not_exists() - logger.debug(f"Use MongoDB as DocStatus {self._collection_name}") + logger.debug( + f"[{self.workspace}] Use MongoDB as DocStatus {self._collection_name}" + ) async def finalize(self): if self.db is not None: @@ -327,7 +351,7 @@ class MongoDocStatusStorage(DocStatusStorage): return data - existing_ids async def upsert(self, data: dict[str, dict[str, Any]]) -> None: - logger.debug(f"Inserting {len(data)} to {self.namespace}") + logger.debug(f"[{self.workspace}] Inserting {len(data)} to {self.namespace}") if not data: return update_tasks: list[Any] = [] @@ -376,7 +400,9 @@ class MongoDocStatusStorage(DocStatusStorage): data["error_msg"] = None processed_result[doc["_id"]] = DocProcessingStatus(**data) except KeyError as e: - logger.error(f"Missing required field for document {doc['_id']}: {e}") + logger.error( + f"[{self.workspace}] Missing required field for document {doc['_id']}: {e}" + ) continue return processed_result @@ -405,7 +431,9 @@ class MongoDocStatusStorage(DocStatusStorage): data["error_msg"] = None processed_result[doc["_id"]] = DocProcessingStatus(**data) except KeyError as e: - logger.error(f"Missing required field for document {doc['_id']}: {e}") + logger.error( + f"[{self.workspace}] Missing required field for document {doc['_id']}: {e}" + ) continue return processed_result @@ -424,14 +452,16 @@ class MongoDocStatusStorage(DocStatusStorage): deleted_count = result.deleted_count logger.info( - f"Dropped {deleted_count} documents from doc status {self._collection_name}" + f"[{self.workspace}] Dropped {deleted_count} documents from doc status {self._collection_name}" ) return { "status": "success", "message": f"{deleted_count} documents dropped", } except PyMongoError as e: - logger.error(f"Error dropping doc status {self._collection_name}: {e}") + logger.error( + f"[{self.workspace}] Error dropping doc status {self._collection_name}: {e}" + ) return {"status": "error", "message": str(e)} async def delete(self, ids: list[str]) -> None: @@ -450,16 +480,16 @@ class MongoDocStatusStorage(DocStatusStorage): if not track_id_index_exists: await self._data.create_index("track_id") logger.info( - f"Created track_id index for collection {self._collection_name}" + f"[{self.workspace}] Created track_id index for collection {self._collection_name}" ) else: logger.debug( - f"track_id index already exists for collection {self._collection_name}" + f"[{self.workspace}] track_id index already exists for collection {self._collection_name}" ) except PyMongoError as e: logger.error( - f"Error creating track_id index for {self._collection_name}: {e}" + f"[{self.workspace}] Error creating track_id index for {self._collection_name}: {e}" ) async def create_pagination_indexes_if_not_exists(self): @@ -492,16 +522,16 @@ class MongoDocStatusStorage(DocStatusStorage): if index_name not in existing_index_names: await self._data.create_index(index_info["keys"], name=index_name) logger.info( - f"Created pagination index '{index_name}' for collection {self._collection_name}" + f"[{self.workspace}] Created pagination index '{index_name}' for collection {self._collection_name}" ) else: logger.debug( - f"Pagination index '{index_name}' already exists for collection {self._collection_name}" + f"[{self.workspace}] Pagination index '{index_name}' already exists for collection {self._collection_name}" ) except PyMongoError as e: logger.error( - f"Error creating pagination indexes for {self._collection_name}: {e}" + f"[{self.workspace}] Error creating pagination indexes for {self._collection_name}: {e}" ) async def get_docs_paginated( @@ -586,7 +616,9 @@ class MongoDocStatusStorage(DocStatusStorage): doc_status = DocProcessingStatus(**data) documents.append((doc_id, doc_status)) except KeyError as e: - logger.error(f"Missing required field for document {doc['_id']}: {e}") + logger.error( + f"[{self.workspace}] Missing required field for document {doc['_id']}: {e}" + ) continue return documents, total_count @@ -650,13 +682,20 @@ class MongoGraphStorage(BaseGraphStorage): f"Using passed workspace parameter: '{effective_workspace}'" ) - # Build namespace with workspace prefix for data isolation + # Build final_namespace with workspace prefix for data isolation + # Keep original namespace unchanged for type detection logic if effective_workspace: - self.namespace = f"{effective_workspace}_{self.namespace}" - logger.debug(f"Final namespace with workspace prefix: '{self.namespace}'") - # When workspace is empty, keep the original namespace unchanged + self.final_namespace = f"{effective_workspace}_{self.namespace}" + logger.debug( + f"Final namespace with workspace prefix: '{self.final_namespace}'" + ) + else: + # When workspace is empty, final_namespace equals original namespace + self.final_namespace = self.namespace + self.workspace = "_" + logger.debug(f"Final namespace (no workspace): '{self.final_namespace}'") - self._collection_name = self.namespace + self._collection_name = self.final_namespace self._edge_collection_name = f"{self._collection_name}_edges" async def initialize(self): @@ -668,7 +707,9 @@ class MongoGraphStorage(BaseGraphStorage): self.edge_collection = await get_or_create_collection( self.db, self._edge_collection_name ) - logger.debug(f"Use MongoDB as KG {self._collection_name}") + logger.debug( + f"[{self.workspace}] Use MongoDB as KG {self._collection_name}" + ) async def finalize(self): if self.db is not None: @@ -1248,7 +1289,9 @@ class MongoGraphStorage(BaseGraphStorage): # Verify if starting node exists start_node = await self.collection.find_one({"_id": node_label}) if not start_node: - logger.warning(f"Starting node with label {node_label} does not exist!") + logger.warning( + f"[{self.workspace}] Starting node with label {node_label} does not exist!" + ) return result seen_nodes.add(node_label) @@ -1407,14 +1450,14 @@ class MongoGraphStorage(BaseGraphStorage): duration = time.perf_counter() - start logger.info( - f"Subgraph query successful in {duration:.4f} seconds | Node count: {len(result.nodes)} | Edge count: {len(result.edges)} | Truncated: {result.is_truncated}" + f"[{self.workspace}] Subgraph query successful in {duration:.4f} seconds | Node count: {len(result.nodes)} | Edge count: {len(result.edges)} | Truncated: {result.is_truncated}" ) except PyMongoError as e: # Handle memory limit errors specifically if "memory limit" in str(e).lower() or "sort exceeded" in str(e).lower(): logger.warning( - f"MongoDB memory limit exceeded, falling back to simple query: {str(e)}" + f"[{self.workspace}] MongoDB memory limit exceeded, falling back to simple query: {str(e)}" ) # Fallback to a simple query without complex aggregation try: @@ -1425,12 +1468,14 @@ class MongoGraphStorage(BaseGraphStorage): ) result.is_truncated = True logger.info( - f"Fallback query completed | Node count: {len(result.nodes)}" + f"[{self.workspace}] Fallback query completed | Node count: {len(result.nodes)}" ) except PyMongoError as fallback_error: - logger.error(f"Fallback query also failed: {str(fallback_error)}") + logger.error( + f"[{self.workspace}] Fallback query also failed: {str(fallback_error)}" + ) else: - logger.error(f"MongoDB query failed: {str(e)}") + logger.error(f"[{self.workspace}] MongoDB query failed: {str(e)}") return result @@ -1444,7 +1489,7 @@ class MongoGraphStorage(BaseGraphStorage): Args: nodes: List of node IDs to be deleted """ - logger.info(f"Deleting {len(nodes)} nodes") + logger.info(f"[{self.workspace}] Deleting {len(nodes)} nodes") if not nodes: return @@ -1461,7 +1506,7 @@ class MongoGraphStorage(BaseGraphStorage): # 2. Delete the node documents await self.collection.delete_many({"_id": {"$in": nodes}}) - logger.debug(f"Successfully deleted nodes: {nodes}") + logger.debug(f"[{self.workspace}] Successfully deleted nodes: {nodes}") async def remove_edges(self, edges: list[tuple[str, str]]) -> None: """Delete multiple edges @@ -1469,7 +1514,7 @@ class MongoGraphStorage(BaseGraphStorage): Args: edges: List of edges to be deleted, each edge is a (source, target) tuple """ - logger.info(f"Deleting {len(edges)} edges") + logger.info(f"[{self.workspace}] Deleting {len(edges)} edges") if not edges: return @@ -1484,7 +1529,7 @@ class MongoGraphStorage(BaseGraphStorage): await self.edge_collection.delete_many({"$or": all_edge_pairs}) - logger.debug(f"Successfully deleted edges: {edges}") + logger.debug(f"[{self.workspace}] Successfully deleted edges: {edges}") async def get_all_nodes(self) -> list[dict]: """Get all nodes in the graph. @@ -1527,13 +1572,13 @@ class MongoGraphStorage(BaseGraphStorage): deleted_count = result.deleted_count logger.info( - f"Dropped {deleted_count} documents from graph {self._collection_name}" + f"[{self.workspace}] Dropped {deleted_count} documents from graph {self._collection_name}" ) result = await self.edge_collection.delete_many({}) edge_count = result.deleted_count logger.info( - f"Dropped {edge_count} edges from graph {self._edge_collection_name}" + f"[{self.workspace}] Dropped {edge_count} edges from graph {self._edge_collection_name}" ) return { @@ -1541,7 +1586,9 @@ class MongoGraphStorage(BaseGraphStorage): "message": f"{deleted_count} documents and {edge_count} edges dropped", } except PyMongoError as e: - logger.error(f"Error dropping graph {self._collection_name}: {e}") + logger.error( + f"[{self.workspace}] Error dropping graph {self._collection_name}: {e}" + ) return {"status": "error", "message": str(e)} @@ -1582,16 +1629,23 @@ class MongoVectorDBStorage(BaseVectorStorage): f"Using passed workspace parameter: '{effective_workspace}'" ) - # Build namespace with workspace prefix for data isolation + # Build final_namespace with workspace prefix for data isolation + # Keep original namespace unchanged for type detection logic if effective_workspace: - self.namespace = f"{effective_workspace}_{self.namespace}" - logger.debug(f"Final namespace with workspace prefix: '{self.namespace}'") - # When workspace is empty, keep the original namespace unchanged + self.final_namespace = f"{effective_workspace}_{self.namespace}" + logger.debug( + f"Final namespace with workspace prefix: '{self.final_namespace}'" + ) + else: + # When workspace is empty, final_namespace equals original namespace + self.final_namespace = self.namespace + self.workspace = "_" + logger.debug(f"Final namespace (no workspace): '{self.final_namespace}'") # Set index name based on workspace for backward compatibility if effective_workspace: # Use collection-specific index name for workspaced collections to avoid conflicts - self._index_name = f"vector_knn_index_{self.namespace}" + self._index_name = f"vector_knn_index_{self.final_namespace}" else: # Keep original index name for backward compatibility with existing deployments self._index_name = "vector_knn_index" @@ -1603,7 +1657,7 @@ class MongoVectorDBStorage(BaseVectorStorage): "cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs" ) self.cosine_better_than_threshold = cosine_threshold - self._collection_name = self.namespace + self._collection_name = self.final_namespace self._max_batch_size = self.global_config["embedding_batch_num"] async def initialize(self): @@ -1614,7 +1668,9 @@ class MongoVectorDBStorage(BaseVectorStorage): # Ensure vector index exists await self.create_vector_index_if_not_exists() - logger.debug(f"Use MongoDB as VDB {self._collection_name}") + logger.debug( + f"[{self.workspace}] Use MongoDB as VDB {self._collection_name}" + ) async def finalize(self): if self.db is not None: @@ -1629,7 +1685,9 @@ class MongoVectorDBStorage(BaseVectorStorage): indexes = await indexes_cursor.to_list(length=None) for index in indexes: if index["name"] == self._index_name: - logger.info(f"vector index {self._index_name} already exist") + logger.info( + f"[{self.workspace}] vector index {self._index_name} already exist" + ) return search_index_model = SearchIndexModel( @@ -1648,17 +1706,19 @@ class MongoVectorDBStorage(BaseVectorStorage): ) await self._data.create_search_index(search_index_model) - logger.info(f"Vector index {self._index_name} created successfully.") + logger.info( + f"[{self.workspace}] Vector index {self._index_name} created successfully." + ) except PyMongoError as e: - error_msg = f"Error creating vector index {self._index_name}: {e}" + error_msg = f"[{self.workspace}] Error creating vector index {self._index_name}: {e}" logger.error(error_msg) raise SystemExit( f"Failed to create MongoDB vector index. Program cannot continue. {error_msg}" ) async def upsert(self, data: dict[str, dict[str, Any]]) -> None: - logger.debug(f"Inserting {len(data)} to {self.namespace}") + logger.debug(f"[{self.workspace}] Inserting {len(data)} to {self.namespace}") if not data: return @@ -1747,7 +1807,9 @@ class MongoVectorDBStorage(BaseVectorStorage): Args: ids: List of vector IDs to be deleted """ - logger.debug(f"Deleting {len(ids)} vectors from {self.namespace}") + logger.debug( + f"[{self.workspace}] Deleting {len(ids)} vectors from {self.namespace}" + ) if not ids: return @@ -1758,11 +1820,11 @@ class MongoVectorDBStorage(BaseVectorStorage): try: result = await self._data.delete_many({"_id": {"$in": ids}}) logger.debug( - f"Successfully deleted {result.deleted_count} vectors from {self.namespace}" + f"[{self.workspace}] Successfully deleted {result.deleted_count} vectors from {self.namespace}" ) except PyMongoError as e: logger.error( - f"Error while deleting vectors from {self.namespace}: {str(e)}" + f"[{self.workspace}] Error while deleting vectors from {self.namespace}: {str(e)}" ) async def delete_entity(self, entity_name: str) -> None: @@ -1774,16 +1836,22 @@ class MongoVectorDBStorage(BaseVectorStorage): try: entity_id = compute_mdhash_id(entity_name, prefix="ent-") logger.debug( - f"Attempting to delete entity {entity_name} with ID {entity_id}" + f"[{self.workspace}] Attempting to delete entity {entity_name} with ID {entity_id}" ) result = await self._data.delete_one({"_id": entity_id}) if result.deleted_count > 0: - logger.debug(f"Successfully deleted entity {entity_name}") + logger.debug( + f"[{self.workspace}] Successfully deleted entity {entity_name}" + ) else: - logger.debug(f"Entity {entity_name} not found in storage") + logger.debug( + f"[{self.workspace}] Entity {entity_name} not found in storage" + ) except PyMongoError as e: - logger.error(f"Error deleting entity {entity_name}: {str(e)}") + logger.error( + f"[{self.workspace}] Error deleting entity {entity_name}: {str(e)}" + ) async def delete_entity_relation(self, entity_name: str) -> None: """Delete all relations associated with an entity @@ -1799,23 +1867,31 @@ class MongoVectorDBStorage(BaseVectorStorage): relations = await relations_cursor.to_list(length=None) if not relations: - logger.debug(f"No relations found for entity {entity_name}") + logger.debug( + f"[{self.workspace}] No relations found for entity {entity_name}" + ) return # Extract IDs of relations to delete relation_ids = [relation["_id"] for relation in relations] logger.debug( - f"Found {len(relation_ids)} relations for entity {entity_name}" + f"[{self.workspace}] Found {len(relation_ids)} relations for entity {entity_name}" ) # Delete the relations result = await self._data.delete_many({"_id": {"$in": relation_ids}}) - logger.debug(f"Deleted {result.deleted_count} relations for {entity_name}") + logger.debug( + f"[{self.workspace}] Deleted {result.deleted_count} relations for {entity_name}" + ) except PyMongoError as e: - logger.error(f"Error deleting relations for {entity_name}: {str(e)}") + logger.error( + f"[{self.workspace}] Error deleting relations for {entity_name}: {str(e)}" + ) except PyMongoError as e: - logger.error(f"Error searching by prefix in {self.namespace}: {str(e)}") + logger.error( + f"[{self.workspace}] Error searching by prefix in {self.namespace}: {str(e)}" + ) return [] async def get_by_id(self, id: str) -> dict[str, Any] | None: @@ -1838,7 +1914,9 @@ class MongoVectorDBStorage(BaseVectorStorage): return result_dict return None except Exception as e: - logger.error(f"Error retrieving vector data for ID {id}: {e}") + logger.error( + f"[{self.workspace}] Error retrieving vector data for ID {id}: {e}" + ) return None async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: @@ -1868,7 +1946,9 @@ class MongoVectorDBStorage(BaseVectorStorage): return formatted_results except Exception as e: - logger.error(f"Error retrieving vector data for IDs {ids}: {e}") + logger.error( + f"[{self.workspace}] Error retrieving vector data for IDs {ids}: {e}" + ) return [] async def drop(self) -> dict[str, str]: @@ -1886,14 +1966,16 @@ class MongoVectorDBStorage(BaseVectorStorage): await self.create_vector_index_if_not_exists() logger.info( - f"Dropped {deleted_count} documents from vector storage {self._collection_name} and recreated vector index" + f"[{self.workspace}] Dropped {deleted_count} documents from vector storage {self._collection_name} and recreated vector index" ) return { "status": "success", "message": f"{deleted_count} documents dropped and vector index recreated", } except PyMongoError as e: - logger.error(f"Error dropping vector storage {self._collection_name}: {e}") + logger.error( + f"[{self.workspace}] Error dropping vector storage {self._collection_name}: {e}" + ) return {"status": "error", "message": str(e)} diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index 6532fe6f..953946a1 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -48,23 +48,26 @@ logging.getLogger("neo4j").setLevel(logging.ERROR) @dataclass class Neo4JStorage(BaseGraphStorage): def __init__(self, namespace, global_config, embedding_func, workspace=None): - # Check NEO4J_WORKSPACE environment variable and override workspace if set + # Read env and override the arg if present neo4j_workspace = os.environ.get("NEO4J_WORKSPACE") if neo4j_workspace and neo4j_workspace.strip(): workspace = neo4j_workspace + # Default to 'base' when both arg and env are empty + if not workspace or not str(workspace).strip(): + workspace = "base" + super().__init__( namespace=namespace, - workspace=workspace or "", + workspace=workspace, global_config=global_config, embedding_func=embedding_func, ) self._driver = None def _get_workspace_label(self) -> str: - """Get workspace label, return 'base' for compatibility when workspace is empty""" - workspace = getattr(self, "workspace", None) - return workspace if workspace else "base" + """Return workspace label (guaranteed non-empty during initialization)""" + return self.workspace async def initialize(self): URI = os.environ.get("NEO4J_URI", config.get("neo4j", "uri", fallback=None)) @@ -117,6 +120,7 @@ class Neo4JStorage(BaseGraphStorage): DATABASE = os.environ.get( "NEO4J_DATABASE", re.sub(r"[^a-zA-Z0-9-]", "-", self.namespace) ) + """The default value approach for the DATABASE is only intended to maintain compatibility with legacy practices.""" self._driver: AsyncDriver = AsyncGraphDatabase.driver( URI, @@ -140,20 +144,26 @@ class Neo4JStorage(BaseGraphStorage): try: result = await session.run("MATCH (n) RETURN n LIMIT 0") await result.consume() # Ensure result is consumed - logger.info(f"Connected to {database} at {URI}") + logger.info( + f"[{self.workspace}] Connected to {database} at {URI}" + ) connected = True except neo4jExceptions.ServiceUnavailable as e: logger.error( - f"{database} at {URI} is not available".capitalize() + f"[{self.workspace}] " + + f"{database} at {URI} is not available".capitalize() ) raise e except neo4jExceptions.AuthError as e: - logger.error(f"Authentication failed for {database} at {URI}") + logger.error( + f"[{self.workspace}] Authentication failed for {database} at {URI}" + ) raise e except neo4jExceptions.ClientError as e: if e.code == "Neo.ClientError.Database.DatabaseNotFound": logger.info( - f"{database} at {URI} not found. Try to create specified database.".capitalize() + f"[{self.workspace}] " + + f"{database} at {URI} not found. Try to create specified database.".capitalize() ) try: async with self._driver.session() as session: @@ -161,7 +171,10 @@ class Neo4JStorage(BaseGraphStorage): f"CREATE DATABASE `{database}` IF NOT EXISTS" ) await result.consume() # Ensure result is consumed - logger.info(f"{database} at {URI} created".capitalize()) + logger.info( + f"[{self.workspace}] " + + f"{database} at {URI} created".capitalize() + ) connected = True except ( neo4jExceptions.ClientError, @@ -173,10 +186,12 @@ class Neo4JStorage(BaseGraphStorage): ) or (e.code == "Neo.DatabaseError.Statement.ExecutionFailed"): if database is not None: logger.warning( - "This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead. Fallback to use the default database." + f"[{self.workspace}] This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead. Fallback to use the default database." ) if database is None: - logger.error(f"Failed to create {database} at {URI}") + logger.error( + f"[{self.workspace}] Failed to create {database} at {URI}" + ) raise e if connected: @@ -204,7 +219,7 @@ class Neo4JStorage(BaseGraphStorage): ) await result.consume() logger.info( - f"Created index for {workspace_label} nodes on entity_id in {database}" + f"[{self.workspace}] Created index for {workspace_label} nodes on entity_id in {database}" ) except Exception: # Fallback if db.indexes() is not supported in this Neo4j version @@ -213,7 +228,9 @@ class Neo4JStorage(BaseGraphStorage): ) await result.consume() except Exception as e: - logger.warning(f"Failed to create index: {str(e)}") + logger.warning( + f"[{self.workspace}] Failed to create index: {str(e)}" + ) break async def finalize(self): @@ -255,7 +272,9 @@ class Neo4JStorage(BaseGraphStorage): await result.consume() # Ensure result is fully consumed return single_result["node_exists"] except Exception as e: - logger.error(f"Error checking node existence for {node_id}: {str(e)}") + logger.error( + f"[{self.workspace}] Error checking node existence for {node_id}: {str(e)}" + ) await result.consume() # Ensure results are consumed even on error raise @@ -293,7 +312,7 @@ class Neo4JStorage(BaseGraphStorage): return single_result["edgeExists"] except Exception as e: logger.error( - f"Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}" + f"[{self.workspace}] Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}" ) await result.consume() # Ensure results are consumed even on error raise @@ -328,7 +347,7 @@ class Neo4JStorage(BaseGraphStorage): if len(records) > 1: logger.warning( - f"Multiple nodes found with label '{node_id}'. Using first node." + f"[{self.workspace}] Multiple nodes found with label '{node_id}'. Using first node." ) if records: node = records[0]["n"] @@ -346,7 +365,9 @@ class Neo4JStorage(BaseGraphStorage): finally: await result.consume() # Ensure result is fully consumed except Exception as e: - logger.error(f"Error getting node for {node_id}: {str(e)}") + logger.error( + f"[{self.workspace}] Error getting node for {node_id}: {str(e)}" + ) raise async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]: @@ -415,18 +436,22 @@ class Neo4JStorage(BaseGraphStorage): record = await result.single() if not record: - logger.warning(f"No node found with label '{node_id}'") + logger.warning( + f"[{self.workspace}] No node found with label '{node_id}'" + ) return 0 degree = record["degree"] # logger.debug( - # f"Neo4j query node degree for {node_id} return: {degree}" + # f"[{self.workspace}] Neo4j query node degree for {node_id} return: {degree}" # ) return degree finally: await result.consume() # Ensure result is fully consumed except Exception as e: - logger.error(f"Error getting node degree for {node_id}: {str(e)}") + logger.error( + f"[{self.workspace}] Error getting node degree for {node_id}: {str(e)}" + ) raise async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]: @@ -459,10 +484,12 @@ class Neo4JStorage(BaseGraphStorage): # For any node_id that did not return a record, set degree to 0. for nid in node_ids: if nid not in degrees: - logger.warning(f"No node found with label '{nid}'") + logger.warning( + f"[{self.workspace}] No node found with label '{nid}'" + ) degrees[nid] = 0 - # logger.debug(f"Neo4j batch node degree query returned: {degrees}") + # logger.debug(f"[{self.workspace}] Neo4j batch node degree query returned: {degrees}") return degrees async def edge_degree(self, src_id: str, tgt_id: str) -> int: @@ -546,7 +573,7 @@ class Neo4JStorage(BaseGraphStorage): if len(records) > 1: logger.warning( - f"Multiple edges found between '{source_node_id}' and '{target_node_id}'. Using first edge." + f"[{self.workspace}] Multiple edges found between '{source_node_id}' and '{target_node_id}'. Using first edge." ) if records: try: @@ -563,7 +590,7 @@ class Neo4JStorage(BaseGraphStorage): if key not in edge_result: edge_result[key] = default_value logger.warning( - f"Edge between {source_node_id} and {target_node_id} " + f"[{self.workspace}] Edge between {source_node_id} and {target_node_id} " f"missing {key}, using default: {default_value}" ) @@ -573,7 +600,7 @@ class Neo4JStorage(BaseGraphStorage): return edge_result except (KeyError, TypeError, ValueError) as e: logger.error( - f"Error processing edge properties between {source_node_id} " + f"[{self.workspace}] Error processing edge properties between {source_node_id} " f"and {target_node_id}: {str(e)}" ) # Return default edge properties on error @@ -594,7 +621,7 @@ class Neo4JStorage(BaseGraphStorage): except Exception as e: logger.error( - f"Error in get_edge between {source_node_id} and {target_node_id}: {str(e)}" + f"[{self.workspace}] Error in get_edge between {source_node_id} and {target_node_id}: {str(e)}" ) raise @@ -701,12 +728,14 @@ class Neo4JStorage(BaseGraphStorage): return edges except Exception as e: logger.error( - f"Error getting edges for node {source_node_id}: {str(e)}" + f"[{self.workspace}] Error getting edges for node {source_node_id}: {str(e)}" ) await results.consume() # Ensure results are consumed even on error raise except Exception as e: - logger.error(f"Error in get_node_edges for {source_node_id}: {str(e)}") + logger.error( + f"[{self.workspace}] Error in get_node_edges for {source_node_id}: {str(e)}" + ) raise async def get_nodes_edges_batch( @@ -856,7 +885,7 @@ class Neo4JStorage(BaseGraphStorage): await session.execute_write(execute_upsert) except Exception as e: - logger.error(f"Error during upsert: {str(e)}") + logger.error(f"[{self.workspace}] Error during upsert: {str(e)}") raise @retry( @@ -917,7 +946,7 @@ class Neo4JStorage(BaseGraphStorage): await session.execute_write(execute_upsert) except Exception as e: - logger.error(f"Error during edge upsert: {str(e)}") + logger.error(f"[{self.workspace}] Error during edge upsert: {str(e)}") raise async def get_knowledge_graph( @@ -967,7 +996,7 @@ class Neo4JStorage(BaseGraphStorage): if count_record and count_record["total"] > max_nodes: result.is_truncated = True logger.info( - f"Graph truncated: {count_record['total']} nodes found, limited to {max_nodes}" + f"[{self.workspace}] Graph truncated: {count_record['total']} nodes found, limited to {max_nodes}" ) finally: if count_result: @@ -1034,7 +1063,9 @@ class Neo4JStorage(BaseGraphStorage): # If no record found, return empty KnowledgeGraph if not full_record: - logger.debug(f"No nodes found for entity_id: {node_label}") + logger.debug( + f"[{self.workspace}] No nodes found for entity_id: {node_label}" + ) return result # If record found, check node count @@ -1043,14 +1074,14 @@ class Neo4JStorage(BaseGraphStorage): if total_nodes <= max_nodes: # If node count is within limit, use full result directly logger.debug( - f"Using full result with {total_nodes} nodes (no truncation needed)" + f"[{self.workspace}] Using full result with {total_nodes} nodes (no truncation needed)" ) record = full_record else: # If node count exceeds limit, set truncated flag and run limited query result.is_truncated = True logger.info( - f"Graph truncated: {total_nodes} nodes found, breadth-first search limited to {max_nodes}" + f"[{self.workspace}] Graph truncated: {total_nodes} nodes found, breadth-first search limited to {max_nodes}" ) # Run limited query @@ -1122,19 +1153,19 @@ class Neo4JStorage(BaseGraphStorage): seen_edges.add(edge_id) logger.info( - f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" + f"[{self.workspace}] Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" ) except neo4jExceptions.ClientError as e: - logger.warning(f"APOC plugin error: {str(e)}") + logger.warning(f"[{self.workspace}] APOC plugin error: {str(e)}") if node_label != "*": logger.warning( - "Neo4j: falling back to basic Cypher recursive search..." + f"[{self.workspace}] Neo4j: falling back to basic Cypher recursive search..." ) return await self._robust_fallback(node_label, max_depth, max_nodes) else: logger.warning( - "Neo4j: APOC plugin error with wildcard query, returning empty result" + f"[{self.workspace}] Neo4j: APOC plugin error with wildcard query, returning empty result" ) return result @@ -1193,7 +1224,7 @@ class Neo4JStorage(BaseGraphStorage): if current_depth > max_depth: logger.debug( - f"Skipping node at depth {current_depth} (max_depth: {max_depth})" + f"[{self.workspace}] Skipping node at depth {current_depth} (max_depth: {max_depth})" ) continue @@ -1210,7 +1241,7 @@ class Neo4JStorage(BaseGraphStorage): if len(visited_nodes) >= max_nodes: result.is_truncated = True logger.info( - f"Graph truncated: breadth-first search limited to: {max_nodes} nodes" + f"[{self.workspace}] Graph truncated: breadth-first search limited to: {max_nodes} nodes" ) break @@ -1281,20 +1312,20 @@ class Neo4JStorage(BaseGraphStorage): # At max depth, we've already added the edge but we don't add the node # This prevents adding nodes beyond max_depth to the result logger.debug( - f"Node {target_id} beyond max depth {max_depth}, edge added but node not included" + f"[{self.workspace}] Node {target_id} beyond max depth {max_depth}, edge added but node not included" ) else: # If target node already exists in result, we don't need to add it again logger.debug( - f"Node {target_id} already visited, edge added but node not queued" + f"[{self.workspace}] Node {target_id} already visited, edge added but node not queued" ) else: logger.warning( - f"Skipping edge {edge_id} due to missing entity_id on target node" + f"[{self.workspace}] Skipping edge {edge_id} due to missing entity_id on target node" ) logger.info( - f"BFS subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" + f"[{self.workspace}] BFS subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" ) return result @@ -1358,14 +1389,14 @@ class Neo4JStorage(BaseGraphStorage): DETACH DELETE n """ result = await tx.run(query, entity_id=node_id) - logger.debug(f"Deleted node with label '{node_id}'") + logger.debug(f"[{self.workspace}] Deleted node with label '{node_id}'") await result.consume() # Ensure result is fully consumed try: async with self._driver.session(database=self._DATABASE) as session: await session.execute_write(_do_delete) except Exception as e: - logger.error(f"Error during node deletion: {str(e)}") + logger.error(f"[{self.workspace}] Error during node deletion: {str(e)}") raise @retry( @@ -1424,14 +1455,16 @@ class Neo4JStorage(BaseGraphStorage): result = await tx.run( query, source_entity_id=source, target_entity_id=target ) - logger.debug(f"Deleted edge from '{source}' to '{target}'") + logger.debug( + f"[{self.workspace}] Deleted edge from '{source}' to '{target}'" + ) await result.consume() # Ensure result is fully consumed try: async with self._driver.session(database=self._DATABASE) as session: await session.execute_write(_do_delete_edge) except Exception as e: - logger.error(f"Error during edge deletion: {str(e)}") + logger.error(f"[{self.workspace}] Error during edge deletion: {str(e)}") raise async def get_all_nodes(self) -> list[dict]: @@ -1501,15 +1534,15 @@ class Neo4JStorage(BaseGraphStorage): result = await session.run(query) await result.consume() # Ensure result is fully consumed - logger.info( - f"Process {os.getpid()} drop Neo4j workspace '{workspace_label}' in database {self._DATABASE}" - ) + # logger.debug( + # f"[{self.workspace}] Process {os.getpid()} drop Neo4j workspace '{workspace_label}' in database {self._DATABASE}" + # ) return { "status": "success", "message": f"workspace '{workspace_label}' data dropped", } except Exception as e: logger.error( - f"Error dropping Neo4j workspace '{workspace_label}' in database {self._DATABASE}: {e}" + f"[{self.workspace}] Error dropping Neo4j workspace '{workspace_label}' in database {self._DATABASE}: {e}" ) return {"status": "error", "message": str(e)} diff --git a/lightrag/kg/qdrant_impl.py b/lightrag/kg/qdrant_impl.py index 9cdeff7a..f5e4d1c2 100644 --- a/lightrag/kg/qdrant_impl.py +++ b/lightrag/kg/qdrant_impl.py @@ -88,11 +88,18 @@ class QdrantVectorDBStorage(BaseVectorStorage): f"Using passed workspace parameter: '{effective_workspace}'" ) - # Build namespace with workspace prefix for data isolation + # Build final_namespace with workspace prefix for data isolation + # Keep original namespace unchanged for type detection logic if effective_workspace: - self.namespace = f"{effective_workspace}_{self.namespace}" - logger.debug(f"Final namespace with workspace prefix: '{self.namespace}'") - # When workspace is empty, keep the original namespace unchanged + self.final_namespace = f"{effective_workspace}_{self.namespace}" + logger.debug( + f"Final namespace with workspace prefix: '{self.final_namespace}'" + ) + else: + # When workspace is empty, final_namespace equals original namespace + self.final_namespace = self.namespace + self.workspace = "_" + logger.debug(f"Final namespace (no workspace): '{self.final_namespace}'") kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {}) cosine_threshold = kwargs.get("cosine_better_than_threshold") @@ -113,14 +120,14 @@ class QdrantVectorDBStorage(BaseVectorStorage): self._max_batch_size = self.global_config["embedding_batch_num"] QdrantVectorDBStorage.create_collection_if_not_exist( self._client, - self.namespace, + self.final_namespace, vectors_config=models.VectorParams( size=self.embedding_func.embedding_dim, distance=models.Distance.COSINE ), ) async def upsert(self, data: dict[str, dict[str, Any]]) -> None: - logger.debug(f"Inserting {len(data)} to {self.namespace}") + logger.debug(f"[{self.workspace}] Inserting {len(data)} to {self.namespace}") if not data: return @@ -158,7 +165,7 @@ class QdrantVectorDBStorage(BaseVectorStorage): ) results = self._client.upsert( - collection_name=self.namespace, points=list_points, wait=True + collection_name=self.final_namespace, points=list_points, wait=True ) return results @@ -169,14 +176,14 @@ class QdrantVectorDBStorage(BaseVectorStorage): [query], _priority=5 ) # higher priority for query results = self._client.search( - collection_name=self.namespace, + collection_name=self.final_namespace, query_vector=embedding[0], limit=top_k, with_payload=True, score_threshold=self.cosine_better_than_threshold, ) - logger.debug(f"query result: {results}") + logger.debug(f"[{self.workspace}] query result: {results}") return [ { @@ -202,17 +209,19 @@ class QdrantVectorDBStorage(BaseVectorStorage): qdrant_ids = [compute_mdhash_id_for_qdrant(id) for id in ids] # Delete points from the collection self._client.delete( - collection_name=self.namespace, + collection_name=self.final_namespace, points_selector=models.PointIdsList( points=qdrant_ids, ), wait=True, ) logger.debug( - f"Successfully deleted {len(ids)} vectors from {self.namespace}" + f"[{self.workspace}] Successfully deleted {len(ids)} vectors from {self.namespace}" ) except Exception as e: - logger.error(f"Error while deleting vectors from {self.namespace}: {e}") + logger.error( + f"[{self.workspace}] Error while deleting vectors from {self.namespace}: {e}" + ) async def delete_entity(self, entity_name: str) -> None: """Delete an entity by name @@ -224,20 +233,22 @@ class QdrantVectorDBStorage(BaseVectorStorage): # Generate the entity ID entity_id = compute_mdhash_id_for_qdrant(entity_name, prefix="ent-") logger.debug( - f"Attempting to delete entity {entity_name} with ID {entity_id}" + f"[{self.workspace}] Attempting to delete entity {entity_name} with ID {entity_id}" ) # Delete the entity point from the collection self._client.delete( - collection_name=self.namespace, + collection_name=self.final_namespace, points_selector=models.PointIdsList( points=[entity_id], ), wait=True, ) - logger.debug(f"Successfully deleted entity {entity_name}") + logger.debug( + f"[{self.workspace}] Successfully deleted entity {entity_name}" + ) except Exception as e: - logger.error(f"Error deleting entity {entity_name}: {e}") + logger.error(f"[{self.workspace}] Error deleting entity {entity_name}: {e}") async def delete_entity_relation(self, entity_name: str) -> None: """Delete all relations associated with an entity @@ -248,7 +259,7 @@ class QdrantVectorDBStorage(BaseVectorStorage): try: # Find relations where the entity is either source or target results = self._client.scroll( - collection_name=self.namespace, + collection_name=self.final_namespace, scroll_filter=models.Filter( should=[ models.FieldCondition( @@ -270,19 +281,23 @@ class QdrantVectorDBStorage(BaseVectorStorage): if ids_to_delete: # Delete the relations self._client.delete( - collection_name=self.namespace, + collection_name=self.final_namespace, points_selector=models.PointIdsList( points=ids_to_delete, ), wait=True, ) logger.debug( - f"Deleted {len(ids_to_delete)} relations for {entity_name}" + f"[{self.workspace}] Deleted {len(ids_to_delete)} relations for {entity_name}" ) else: - logger.debug(f"No relations found for entity {entity_name}") + logger.debug( + f"[{self.workspace}] No relations found for entity {entity_name}" + ) except Exception as e: - logger.error(f"Error deleting relations for {entity_name}: {e}") + logger.error( + f"[{self.workspace}] Error deleting relations for {entity_name}: {e}" + ) async def get_by_id(self, id: str) -> dict[str, Any] | None: """Get vector data by its ID @@ -299,7 +314,7 @@ class QdrantVectorDBStorage(BaseVectorStorage): # Retrieve the point by ID result = self._client.retrieve( - collection_name=self.namespace, + collection_name=self.final_namespace, ids=[qdrant_id], with_payload=True, ) @@ -314,7 +329,9 @@ class QdrantVectorDBStorage(BaseVectorStorage): return payload except Exception as e: - logger.error(f"Error retrieving vector data for ID {id}: {e}") + logger.error( + f"[{self.workspace}] Error retrieving vector data for ID {id}: {e}" + ) return None async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: @@ -335,7 +352,7 @@ class QdrantVectorDBStorage(BaseVectorStorage): # Retrieve the points by IDs results = self._client.retrieve( - collection_name=self.namespace, + collection_name=self.final_namespace, ids=qdrant_ids, with_payload=True, ) @@ -350,7 +367,9 @@ class QdrantVectorDBStorage(BaseVectorStorage): return payloads except Exception as e: - logger.error(f"Error retrieving vector data for IDs {ids}: {e}") + logger.error( + f"[{self.workspace}] Error retrieving vector data for IDs {ids}: {e}" + ) return [] async def drop(self) -> dict[str, str]: @@ -365,13 +384,13 @@ class QdrantVectorDBStorage(BaseVectorStorage): """ try: # Delete the collection and recreate it - if self._client.collection_exists(self.namespace): - self._client.delete_collection(self.namespace) + if self._client.collection_exists(self.final_namespace): + self._client.delete_collection(self.final_namespace) # Recreate the collection QdrantVectorDBStorage.create_collection_if_not_exist( self._client, - self.namespace, + self.final_namespace, vectors_config=models.VectorParams( size=self.embedding_func.embedding_dim, distance=models.Distance.COSINE, @@ -379,9 +398,11 @@ class QdrantVectorDBStorage(BaseVectorStorage): ) logger.info( - f"Process {os.getpid()} drop Qdrant collection {self.namespace}" + f"[{self.workspace}] Process {os.getpid()} drop Qdrant collection {self.namespace}" ) return {"status": "success", "message": "data dropped"} except Exception as e: - logger.error(f"Error dropping Qdrant collection {self.namespace}: {e}") + logger.error( + f"[{self.workspace}] Error dropping Qdrant collection {self.namespace}: {e}" + ) return {"status": "error", "message": str(e)} diff --git a/lightrag/kg/redis_impl.py b/lightrag/kg/redis_impl.py index 1c8d3c68..cdb4793c 100644 --- a/lightrag/kg/redis_impl.py +++ b/lightrag/kg/redis_impl.py @@ -141,11 +141,18 @@ class RedisKVStorage(BaseKVStorage): f"Using passed workspace parameter: '{effective_workspace}'" ) - # Build namespace with workspace prefix for data isolation + # Build final_namespace with workspace prefix for data isolation + # Keep original namespace unchanged for type detection logic if effective_workspace: - self.namespace = f"{effective_workspace}_{self.namespace}" - logger.debug(f"Final namespace with workspace prefix: '{self.namespace}'") - # When workspace is empty, keep the original namespace unchanged + self.final_namespace = f"{effective_workspace}_{self.namespace}" + logger.debug( + f"Final namespace with workspace prefix: '{self.final_namespace}'" + ) + else: + # When workspace is empty, final_namespace equals original namespace + self.final_namespace = self.namespace + self.workspace = "_" + logger.debug(f"Final namespace (no workspace): '{self.final_namespace}'") self._redis_url = os.environ.get( "REDIS_URI", config.get("redis", "uri", fallback="redis://localhost:6379") @@ -159,13 +166,15 @@ class RedisKVStorage(BaseKVStorage): self._pool = RedisConnectionManager.get_pool(self._redis_url) self._redis = Redis(connection_pool=self._pool) logger.info( - f"Initialized Redis KV storage for {self.namespace} using shared connection pool" + f"[{self.workspace}] Initialized Redis KV storage for {self.namespace} using shared connection pool" ) except Exception as e: # Clean up on initialization failure if self._redis_url: RedisConnectionManager.release_pool(self._redis_url) - logger.error(f"Failed to initialize Redis KV storage: {e}") + logger.error( + f"[{self.workspace}] Failed to initialize Redis KV storage: {e}" + ) raise async def initialize(self): @@ -177,10 +186,12 @@ class RedisKVStorage(BaseKVStorage): try: async with self._get_redis_connection() as redis: await redis.ping() - logger.info(f"Connected to Redis for namespace {self.namespace}") + logger.info( + f"[{self.workspace}] Connected to Redis for namespace {self.namespace}" + ) self._initialized = True except Exception as e: - logger.error(f"Failed to connect to Redis: {e}") + logger.error(f"[{self.workspace}] Failed to connect to Redis: {e}") # Clean up on connection failure await self.close() raise @@ -190,7 +201,9 @@ class RedisKVStorage(BaseKVStorage): try: await self._migrate_legacy_cache_structure() except Exception as e: - logger.error(f"Failed to migrate legacy cache structure: {e}") + logger.error( + f"[{self.workspace}] Failed to migrate legacy cache structure: {e}" + ) # Don't fail initialization for migration errors, just log them @asynccontextmanager @@ -203,14 +216,18 @@ class RedisKVStorage(BaseKVStorage): # Use the existing Redis instance with shared pool yield self._redis except ConnectionError as e: - logger.error(f"Redis connection error in {self.namespace}: {e}") + logger.error( + f"[{self.workspace}] Redis connection error in {self.namespace}: {e}" + ) raise except RedisError as e: - logger.error(f"Redis operation error in {self.namespace}: {e}") + logger.error( + f"[{self.workspace}] Redis operation error in {self.namespace}: {e}" + ) raise except Exception as e: logger.error( - f"Unexpected error in Redis operation for {self.namespace}: {e}" + f"[{self.workspace}] Unexpected error in Redis operation for {self.namespace}: {e}" ) raise @@ -219,9 +236,11 @@ class RedisKVStorage(BaseKVStorage): if hasattr(self, "_redis") and self._redis: try: await self._redis.close() - logger.debug(f"Closed Redis connection for {self.namespace}") + logger.debug( + f"[{self.workspace}] Closed Redis connection for {self.namespace}" + ) except Exception as e: - logger.error(f"Error closing Redis connection: {e}") + logger.error(f"[{self.workspace}] Error closing Redis connection: {e}") finally: self._redis = None @@ -230,7 +249,7 @@ class RedisKVStorage(BaseKVStorage): RedisConnectionManager.release_pool(self._redis_url) self._pool = None logger.debug( - f"Released Redis connection pool reference for {self.namespace}" + f"[{self.workspace}] Released Redis connection pool reference for {self.namespace}" ) async def __aenter__(self): @@ -245,7 +264,7 @@ class RedisKVStorage(BaseKVStorage): async def get_by_id(self, id: str) -> dict[str, Any] | None: async with self._get_redis_connection() as redis: try: - data = await redis.get(f"{self.namespace}:{id}") + data = await redis.get(f"{self.final_namespace}:{id}") if data: result = json.loads(data) # Ensure time fields are present, provide default values for old data @@ -254,7 +273,7 @@ class RedisKVStorage(BaseKVStorage): return result return None except json.JSONDecodeError as e: - logger.error(f"JSON decode error for id {id}: {e}") + logger.error(f"[{self.workspace}] JSON decode error for id {id}: {e}") return None @redis_retry @@ -263,7 +282,7 @@ class RedisKVStorage(BaseKVStorage): try: pipe = redis.pipeline() for id in ids: - pipe.get(f"{self.namespace}:{id}") + pipe.get(f"{self.final_namespace}:{id}") results = await pipe.execute() processed_results = [] @@ -279,7 +298,7 @@ class RedisKVStorage(BaseKVStorage): return processed_results except json.JSONDecodeError as e: - logger.error(f"JSON decode error in batch get: {e}") + logger.error(f"[{self.workspace}] JSON decode error in batch get: {e}") return [None] * len(ids) async def get_all(self) -> dict[str, Any]: @@ -291,7 +310,7 @@ class RedisKVStorage(BaseKVStorage): async with self._get_redis_connection() as redis: try: # Get all keys for this namespace - keys = await redis.keys(f"{self.namespace}:*") + keys = await redis.keys(f"{self.final_namespace}:*") if not keys: return {} @@ -315,12 +334,16 @@ class RedisKVStorage(BaseKVStorage): data.setdefault("update_time", 0) result[key_id] = data except json.JSONDecodeError as e: - logger.error(f"JSON decode error for key {key}: {e}") + logger.error( + f"[{self.workspace}] JSON decode error for key {key}: {e}" + ) continue return result except Exception as e: - logger.error(f"Error getting all data from Redis: {e}") + logger.error( + f"[{self.workspace}] Error getting all data from Redis: {e}" + ) return {} async def filter_keys(self, keys: set[str]) -> set[str]: @@ -328,7 +351,7 @@ class RedisKVStorage(BaseKVStorage): pipe = redis.pipeline() keys_list = list(keys) # Convert set to list for indexing for key in keys_list: - pipe.exists(f"{self.namespace}:{key}") + pipe.exists(f"{self.final_namespace}:{key}") results = await pipe.execute() existing_ids = {keys_list[i] for i, exists in enumerate(results) if exists} @@ -348,13 +371,13 @@ class RedisKVStorage(BaseKVStorage): # Check which keys already exist to determine create vs update pipe = redis.pipeline() for k in data.keys(): - pipe.exists(f"{self.namespace}:{k}") + pipe.exists(f"{self.final_namespace}:{k}") exists_results = await pipe.execute() # Add timestamps to data for i, (k, v) in enumerate(data.items()): # For text_chunks namespace, ensure llm_cache_list field exists - if "text_chunks" in self.namespace: + if self.namespace.endswith("text_chunks"): if "llm_cache_list" not in v: v["llm_cache_list"] = [] @@ -370,11 +393,11 @@ class RedisKVStorage(BaseKVStorage): # Store the data pipe = redis.pipeline() for k, v in data.items(): - pipe.set(f"{self.namespace}:{k}", json.dumps(v)) + pipe.set(f"{self.final_namespace}:{k}", json.dumps(v)) await pipe.execute() except json.JSONDecodeError as e: - logger.error(f"JSON decode error during upsert: {e}") + logger.error(f"[{self.workspace}] JSON decode error during upsert: {e}") raise async def index_done_callback(self) -> None: @@ -389,12 +412,12 @@ class RedisKVStorage(BaseKVStorage): async with self._get_redis_connection() as redis: pipe = redis.pipeline() for id in ids: - pipe.delete(f"{self.namespace}:{id}") + pipe.delete(f"{self.final_namespace}:{id}") results = await pipe.execute() deleted_count = sum(results) logger.info( - f"Deleted {deleted_count} of {len(ids)} entries from {self.namespace}" + f"[{self.workspace}] Deleted {deleted_count} of {len(ids)} entries from {self.namespace}" ) async def drop(self) -> dict[str, str]: @@ -406,7 +429,7 @@ class RedisKVStorage(BaseKVStorage): async with self._get_redis_connection() as redis: try: # Use SCAN to find all keys with the namespace prefix - pattern = f"{self.namespace}:*" + pattern = f"{self.final_namespace}:*" cursor = 0 deleted_count = 0 @@ -423,14 +446,18 @@ class RedisKVStorage(BaseKVStorage): if cursor == 0: break - logger.info(f"Dropped {deleted_count} keys from {self.namespace}") + logger.info( + f"[{self.workspace}] Dropped {deleted_count} keys from {self.namespace}" + ) return { "status": "success", "message": f"{deleted_count} keys dropped", } except Exception as e: - logger.error(f"Error dropping keys from {self.namespace}: {e}") + logger.error( + f"[{self.workspace}] Error dropping keys from {self.namespace}: {e}" + ) return {"status": "error", "message": str(e)} async def _migrate_legacy_cache_structure(self): @@ -445,7 +472,7 @@ class RedisKVStorage(BaseKVStorage): async with self._get_redis_connection() as redis: # Get all keys for this namespace - keys = await redis.keys(f"{self.namespace}:*") + keys = await redis.keys(f"{self.final_namespace}:*") if not keys: return @@ -480,7 +507,7 @@ class RedisKVStorage(BaseKVStorage): # If we found any flattened keys, assume migration is already done if has_flattened_keys: logger.debug( - f"Found flattened cache keys in {self.namespace}, skipping migration" + f"[{self.workspace}] Found flattened cache keys in {self.namespace}, skipping migration" ) return @@ -499,7 +526,7 @@ class RedisKVStorage(BaseKVStorage): for cache_hash, cache_entry in nested_data.items(): cache_type = cache_entry.get("cache_type", "extract") flattened_key = generate_cache_key(mode, cache_type, cache_hash) - full_key = f"{self.namespace}:{flattened_key}" + full_key = f"{self.final_namespace}:{flattened_key}" pipe.set(full_key, json.dumps(cache_entry)) migration_count += 1 @@ -507,7 +534,7 @@ class RedisKVStorage(BaseKVStorage): if migration_count > 0: logger.info( - f"Migrated {migration_count} legacy cache entries to flattened structure in Redis" + f"[{self.workspace}] Migrated {migration_count} legacy cache entries to flattened structure in Redis" ) @@ -534,11 +561,20 @@ class RedisDocStatusStorage(DocStatusStorage): f"Using passed workspace parameter: '{effective_workspace}'" ) - # Build namespace with workspace prefix for data isolation + # Build final_namespace with workspace prefix for data isolation + # Keep original namespace unchanged for type detection logic if effective_workspace: - self.namespace = f"{effective_workspace}_{self.namespace}" - logger.debug(f"Final namespace with workspace prefix: '{self.namespace}'") - # When workspace is empty, keep the original namespace unchanged + self.final_namespace = f"{effective_workspace}_{self.namespace}" + logger.debug( + f"[{self.workspace}] Final namespace with workspace prefix: '{self.namespace}'" + ) + else: + # When workspace is empty, final_namespace equals original namespace + self.final_namespace = self.namespace + self.workspace = "_" + logger.debug( + f"[{self.workspace}] Final namespace (no workspace): '{self.namespace}'" + ) self._redis_url = os.environ.get( "REDIS_URI", config.get("redis", "uri", fallback="redis://localhost:6379") @@ -552,13 +588,15 @@ class RedisDocStatusStorage(DocStatusStorage): self._pool = RedisConnectionManager.get_pool(self._redis_url) self._redis = Redis(connection_pool=self._pool) logger.info( - f"Initialized Redis doc status storage for {self.namespace} using shared connection pool" + f"[{self.workspace}] Initialized Redis doc status storage for {self.namespace} using shared connection pool" ) except Exception as e: # Clean up on initialization failure if self._redis_url: RedisConnectionManager.release_pool(self._redis_url) - logger.error(f"Failed to initialize Redis doc status storage: {e}") + logger.error( + f"[{self.workspace}] Failed to initialize Redis doc status storage: {e}" + ) raise async def initialize(self): @@ -570,11 +608,13 @@ class RedisDocStatusStorage(DocStatusStorage): async with self._get_redis_connection() as redis: await redis.ping() logger.info( - f"Connected to Redis for doc status namespace {self.namespace}" + f"[{self.workspace}] Connected to Redis for doc status namespace {self.namespace}" ) self._initialized = True except Exception as e: - logger.error(f"Failed to connect to Redis for doc status: {e}") + logger.error( + f"[{self.workspace}] Failed to connect to Redis for doc status: {e}" + ) # Clean up on connection failure await self.close() raise @@ -589,14 +629,18 @@ class RedisDocStatusStorage(DocStatusStorage): # Use the existing Redis instance with shared pool yield self._redis except ConnectionError as e: - logger.error(f"Redis connection error in doc status {self.namespace}: {e}") + logger.error( + f"[{self.workspace}] Redis connection error in doc status {self.namespace}: {e}" + ) raise except RedisError as e: - logger.error(f"Redis operation error in doc status {self.namespace}: {e}") + logger.error( + f"[{self.workspace}] Redis operation error in doc status {self.namespace}: {e}" + ) raise except Exception as e: logger.error( - f"Unexpected error in Redis doc status operation for {self.namespace}: {e}" + f"[{self.workspace}] Unexpected error in Redis doc status operation for {self.namespace}: {e}" ) raise @@ -605,9 +649,11 @@ class RedisDocStatusStorage(DocStatusStorage): if hasattr(self, "_redis") and self._redis: try: await self._redis.close() - logger.debug(f"Closed Redis connection for doc status {self.namespace}") + logger.debug( + f"[{self.workspace}] Closed Redis connection for doc status {self.namespace}" + ) except Exception as e: - logger.error(f"Error closing Redis connection: {e}") + logger.error(f"[{self.workspace}] Error closing Redis connection: {e}") finally: self._redis = None @@ -616,7 +662,7 @@ class RedisDocStatusStorage(DocStatusStorage): RedisConnectionManager.release_pool(self._redis_url) self._pool = None logger.debug( - f"Released Redis connection pool reference for doc status {self.namespace}" + f"[{self.workspace}] Released Redis connection pool reference for doc status {self.namespace}" ) async def __aenter__(self): @@ -633,7 +679,7 @@ class RedisDocStatusStorage(DocStatusStorage): pipe = redis.pipeline() keys_list = list(keys) for key in keys_list: - pipe.exists(f"{self.namespace}:{key}") + pipe.exists(f"{self.final_namespace}:{key}") results = await pipe.execute() existing_ids = {keys_list[i] for i, exists in enumerate(results) if exists} @@ -645,7 +691,7 @@ class RedisDocStatusStorage(DocStatusStorage): try: pipe = redis.pipeline() for id in ids: - pipe.get(f"{self.namespace}:{id}") + pipe.get(f"{self.final_namespace}:{id}") results = await pipe.execute() for result_data in results: @@ -653,10 +699,12 @@ class RedisDocStatusStorage(DocStatusStorage): try: result.append(json.loads(result_data)) except json.JSONDecodeError as e: - logger.error(f"JSON decode error in get_by_ids: {e}") + logger.error( + f"[{self.workspace}] JSON decode error in get_by_ids: {e}" + ) continue except Exception as e: - logger.error(f"Error in get_by_ids: {e}") + logger.error(f"[{self.workspace}] Error in get_by_ids: {e}") return result async def get_status_counts(self) -> dict[str, int]: @@ -668,7 +716,7 @@ class RedisDocStatusStorage(DocStatusStorage): cursor = 0 while True: cursor, keys = await redis.scan( - cursor, match=f"{self.namespace}:*", count=1000 + cursor, match=f"{self.final_namespace}:*", count=1000 ) if keys: # Get all values in batch @@ -691,7 +739,7 @@ class RedisDocStatusStorage(DocStatusStorage): if cursor == 0: break except Exception as e: - logger.error(f"Error getting status counts: {e}") + logger.error(f"[{self.workspace}] Error getting status counts: {e}") return counts @@ -706,7 +754,7 @@ class RedisDocStatusStorage(DocStatusStorage): cursor = 0 while True: cursor, keys = await redis.scan( - cursor, match=f"{self.namespace}:*", count=1000 + cursor, match=f"{self.final_namespace}:*", count=1000 ) if keys: # Get all values in batch @@ -740,14 +788,14 @@ class RedisDocStatusStorage(DocStatusStorage): result[doc_id] = DocProcessingStatus(**data) except (json.JSONDecodeError, KeyError) as e: logger.error( - f"Error processing document {key}: {e}" + f"[{self.workspace}] Error processing document {key}: {e}" ) continue if cursor == 0: break except Exception as e: - logger.error(f"Error getting docs by status: {e}") + logger.error(f"[{self.workspace}] Error getting docs by status: {e}") return result @@ -762,7 +810,7 @@ class RedisDocStatusStorage(DocStatusStorage): cursor = 0 while True: cursor, keys = await redis.scan( - cursor, match=f"{self.namespace}:*", count=1000 + cursor, match=f"{self.final_namespace}:*", count=1000 ) if keys: # Get all values in batch @@ -796,14 +844,14 @@ class RedisDocStatusStorage(DocStatusStorage): result[doc_id] = DocProcessingStatus(**data) except (json.JSONDecodeError, KeyError) as e: logger.error( - f"Error processing document {key}: {e}" + f"[{self.workspace}] Error processing document {key}: {e}" ) continue if cursor == 0: break except Exception as e: - logger.error(f"Error getting docs by track_id: {e}") + logger.error(f"[{self.workspace}] Error getting docs by track_id: {e}") return result @@ -817,7 +865,9 @@ class RedisDocStatusStorage(DocStatusStorage): if not data: return - logger.debug(f"Inserting {len(data)} records to {self.namespace}") + logger.debug( + f"[{self.workspace}] Inserting {len(data)} records to {self.namespace}" + ) async with self._get_redis_connection() as redis: try: # Ensure chunks_list field exists for new documents @@ -827,20 +877,20 @@ class RedisDocStatusStorage(DocStatusStorage): pipe = redis.pipeline() for k, v in data.items(): - pipe.set(f"{self.namespace}:{k}", json.dumps(v)) + pipe.set(f"{self.final_namespace}:{k}", json.dumps(v)) await pipe.execute() except json.JSONDecodeError as e: - logger.error(f"JSON decode error during upsert: {e}") + logger.error(f"[{self.workspace}] JSON decode error during upsert: {e}") raise @redis_retry async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: async with self._get_redis_connection() as redis: try: - data = await redis.get(f"{self.namespace}:{id}") + data = await redis.get(f"{self.final_namespace}:{id}") return json.loads(data) if data else None except json.JSONDecodeError as e: - logger.error(f"JSON decode error for id {id}: {e}") + logger.error(f"[{self.workspace}] JSON decode error for id {id}: {e}") return None async def delete(self, doc_ids: list[str]) -> None: @@ -851,12 +901,12 @@ class RedisDocStatusStorage(DocStatusStorage): async with self._get_redis_connection() as redis: pipe = redis.pipeline() for doc_id in doc_ids: - pipe.delete(f"{self.namespace}:{doc_id}") + pipe.delete(f"{self.final_namespace}:{doc_id}") results = await pipe.execute() deleted_count = sum(results) logger.info( - f"Deleted {deleted_count} of {len(doc_ids)} doc status entries from {self.namespace}" + f"[{self.workspace}] Deleted {deleted_count} of {len(doc_ids)} doc status entries from {self.namespace}" ) async def get_docs_paginated( @@ -903,7 +953,7 @@ class RedisDocStatusStorage(DocStatusStorage): cursor = 0 while True: cursor, keys = await redis.scan( - cursor, match=f"{self.namespace}:*", count=1000 + cursor, match=f"{self.final_namespace}:*", count=1000 ) if keys: # Get all values in batch @@ -950,7 +1000,7 @@ class RedisDocStatusStorage(DocStatusStorage): except (json.JSONDecodeError, KeyError) as e: logger.error( - f"Error processing document {key}: {e}" + f"[{self.workspace}] Error processing document {key}: {e}" ) continue @@ -958,7 +1008,7 @@ class RedisDocStatusStorage(DocStatusStorage): break except Exception as e: - logger.error(f"Error getting paginated docs: {e}") + logger.error(f"[{self.workspace}] Error getting paginated docs: {e}") return [], 0 # Sort documents using the separate sort key @@ -996,7 +1046,7 @@ class RedisDocStatusStorage(DocStatusStorage): try: async with self._get_redis_connection() as redis: # Use SCAN to find all keys with the namespace prefix - pattern = f"{self.namespace}:*" + pattern = f"{self.final_namespace}:*" cursor = 0 deleted_count = 0 @@ -1014,9 +1064,11 @@ class RedisDocStatusStorage(DocStatusStorage): break logger.info( - f"Dropped {deleted_count} doc status keys from {self.namespace}" + f"[{self.workspace}] Dropped {deleted_count} doc status keys from {self.namespace}" ) return {"status": "success", "message": "data dropped"} except Exception as e: - logger.error(f"Error dropping doc status {self.namespace}: {e}") + logger.error( + f"[{self.workspace}] Error dropping doc status {self.namespace}: {e}" + ) return {"status": "error", "message": str(e)} diff --git a/lightrag/namespace.py b/lightrag/namespace.py index 5c042713..2acfe9a4 100644 --- a/lightrag/namespace.py +++ b/lightrag/namespace.py @@ -3,6 +3,7 @@ from __future__ import annotations from typing import Iterable +# All namespace should not be changed class NameSpace: KV_STORE_FULL_DOCS = "full_docs" KV_STORE_TEXT_CHUNKS = "text_chunks" diff --git a/lightrag/operate.py b/lightrag/operate.py index 7725caca..dd8d54be 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -1248,7 +1248,7 @@ async def merge_nodes_and_edges( semaphore = asyncio.Semaphore(graph_max_async) # ===== Phase 1: Process all entities concurrently ===== - log_message = f"Phase 1: Processing {total_entities_count} entities (async: {graph_max_async})" + log_message = f"Phase 1: Processing {total_entities_count} entities from {doc_id} (async: {graph_max_async})" logger.info(log_message) async with pipeline_status_lock: pipeline_status["latest_message"] = log_message @@ -1312,7 +1312,7 @@ async def merge_nodes_and_edges( processed_entities = [task.result() for task in entity_tasks] # ===== Phase 2: Process all relationships concurrently ===== - log_message = f"Phase 2: Processing {total_relations_count} relations (async: {graph_max_async})" + log_message = f"Phase 2: Processing {total_relations_count} relations from {doc_id} (async: {graph_max_async})" logger.info(log_message) async with pipeline_status_lock: pipeline_status["latest_message"] = log_message @@ -1422,6 +1422,12 @@ async def merge_nodes_and_edges( relation_pair = tuple(sorted([src_id, tgt_id])) final_relation_pairs.add(relation_pair) + log_message = f"Phase 3: Updating final {len(final_entity_names)}({len(processed_entities)}+{len(all_added_entities)}) entities and {len(final_relation_pairs)} relations from {doc_id}" + logger.info(log_message) + async with pipeline_status_lock: + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) + # Update storage if final_entity_names: await full_entities_storage.upsert(