From 033098c1bc500b72d94237accdc568ac7dd22854 Mon Sep 17 00:00:00 2001 From: yangdx Date: Mon, 7 Jul 2025 00:57:21 +0800 Subject: [PATCH] Feat: Add WORKSPACE support to all storage types --- env.example | 52 +++---- lightrag/api/config.py | 6 +- lightrag/api/lightrag_server.py | 6 +- lightrag/api/routers/document_routes.py | 16 +- lightrag/base.py | 1 + lightrag/kg/faiss_impl.py | 16 +- lightrag/kg/json_doc_status_impl.py | 13 +- lightrag/kg/json_kv_impl.py | 13 +- lightrag/kg/milvus_impl.py | 27 +++- lightrag/kg/mongo_impl.py | 127 +++++++++++++++- lightrag/kg/nano_vector_db_impl.py | 16 +- lightrag/kg/neo4j_impl.py | 188 ++++++++++++++---------- lightrag/kg/networkx_impl.py | 16 +- lightrag/kg/postgres_impl.py | 102 ++++++++++++- lightrag/kg/qdrant_impl.py | 35 +++++ lightrag/kg/redis_impl.py | 46 ++++++ lightrag/lightrag.py | 23 ++- 17 files changed, 566 insertions(+), 137 deletions(-) diff --git a/env.example b/env.example index 514f8425..d2be92c9 100644 --- a/env.example +++ b/env.example @@ -111,25 +111,37 @@ EMBEDDING_BINDING_HOST=http://localhost:11434 ########################### ### Data storage selection ########################### +### In-memory database with data persistence to local files +# LIGHTRAG_KV_STORAGE=JsonKVStorage +# LIGHTRAG_DOC_STATUS_STORAGE=JsonDocStatusStorage +# LIGHTRAG_GRAPH_STORAGE=NetworkXStorage +# LIGHTRAG_VECTOR_STORAGE=NanoVectorDBStorage +# LIGHTRAG_VECTOR_STORAGE=FaissVectorDBStorage ### PostgreSQL # LIGHTRAG_KV_STORAGE=PGKVStorage # LIGHTRAG_DOC_STATUS_STORAGE=PGDocStatusStorage -# LIGHTRAG_VECTOR_STORAGE=PGVectorStorage # LIGHTRAG_GRAPH_STORAGE=PGGraphStorage -### MongoDB +# LIGHTRAG_VECTOR_STORAGE=PGVectorStorage +### MongoDB (recommended for production deploy) # LIGHTRAG_KV_STORAGE=MongoKVStorage # LIGHTRAG_DOC_STATUS_STORAGE=MongoDocStatusStorage -# LIGHTRAG_VECTOR_STORAGE=MongoVectorDBStorage # LIGHTRAG_GRAPH_STORAGE=MongoGraphStorage -### KV Storage +# LIGHTRAG_VECTOR_STORAGE=MongoVectorDBStorage +### Redis Storage (recommended for production deploy) # LIGHTRAG_KV_STORAGE=RedisKVStorage # LIGHTRAG_DOC_STATUS_STORAGE=RedisDocStatusStorage -### Vector Storage -# LIGHTRAG_VECTOR_STORAGE=FaissVectorDBStorage +### Vector Storage (recommended for production deploy) # LIGHTRAG_VECTOR_STORAGE=MilvusVectorDBStorage -### Graph Storage +# LIGHTRAG_VECTOR_STORAGE=QdrantVectorDBStorage +### Graph Storage (recommended for production deploy) # LIGHTRAG_GRAPH_STORAGE=Neo4JStorage -# LIGHTRAG_GRAPH_STORAGE=MemgraphStorage + +#################################################################### +### Default workspace for all storage types +### For the purpose of isolation of data for each LightRAG instance +### Valid characters: a-z, A-Z, 0-9, and _ +#################################################################### +# WORKSPACE=doc— ### PostgreSQL Configuration POSTGRES_HOST=localhost @@ -138,31 +150,18 @@ POSTGRES_USER=your_username POSTGRES_PASSWORD='your_password' POSTGRES_DATABASE=your_database POSTGRES_MAX_CONNECTIONS=12 -### separating all data from difference Lightrag instances -# POSTGRES_WORKSPACE=default +# POSTGRES_WORKSPACE=forced_workspace_name ### Neo4j Configuration NEO4J_URI=neo4j+s://xxxxxxxx.databases.neo4j.io NEO4J_USERNAME=neo4j NEO4J_PASSWORD='your_password' -### Independent AGM Configuration(not for AMG embedded in PostreSQL) -# AGE_POSTGRES_DB= -# AGE_POSTGRES_USER= -# AGE_POSTGRES_PASSWORD= -# AGE_POSTGRES_HOST= -# AGE_POSTGRES_PORT=8529 - -# AGE Graph Name(apply to PostgreSQL and independent AGM) -### AGE_GRAPH_NAME is deprecated -# AGE_GRAPH_NAME=lightrag - ### MongoDB Configuration MONGO_URI=mongodb://root:root@localhost:27017/ +#MONGO_URI=mongodb+srv://root:rooot@cluster0.xxxx.mongodb.net/?retryWrites=true&w=majority&appName=Cluster0 MONGO_DATABASE=LightRAG -### separating all data from difference Lightrag instances(deprecating) -### separating all data from difference Lightrag instances -# MONGODB_WORKSPACE=default +# MONGODB_WORKSPACE=forced_workspace_name ### Milvus Configuration MILVUS_URI=http://localhost:19530 @@ -170,10 +169,13 @@ MILVUS_DB_NAME=lightrag # MILVUS_USER=root # MILVUS_PASSWORD=your_password # MILVUS_TOKEN=your_token +# MILVUS_WORKSPACE=forced_workspace_name ### Qdrant -QDRANT_URL=http://localhost:16333 +QDRANT_URL=http://localhost:6333 # QDRANT_API_KEY=your-api-key +# QDRANT_WORKSPACE=forced_workspace_name ### Redis REDIS_URI=redis://localhost:6379 +# REDIS_WORKSPACE=forced_workspace_name diff --git a/lightrag/api/config.py b/lightrag/api/config.py index 006ab452..2e15fd3a 100644 --- a/lightrag/api/config.py +++ b/lightrag/api/config.py @@ -184,10 +184,10 @@ def parse_args() -> argparse.Namespace: # Namespace parser.add_argument( - "--namespace-prefix", + "--workspace", type=str, - default=get_env_value("NAMESPACE_PREFIX", ""), - help="Prefix of the namespace", + default=get_env_value("WORKSPACE", ""), + help="Default workspace for all storage", ) parser.add_argument( diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 0930c1cd..eaf2e5f6 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -112,8 +112,8 @@ def create_app(args): # Check if API key is provided either through env var or args api_key = os.getenv("LIGHTRAG_API_KEY") or args.key - # Initialize document manager - doc_manager = DocumentManager(args.input_dir) + # Initialize document manager with workspace support for data isolation + doc_manager = DocumentManager(args.input_dir, workspace=args.workspace) @asynccontextmanager async def lifespan(app: FastAPI): @@ -295,6 +295,7 @@ def create_app(args): if args.llm_binding in ["lollms", "ollama", "openai"]: rag = LightRAG( working_dir=args.working_dir, + workspace=args.workspace, llm_model_func=lollms_model_complete if args.llm_binding == "lollms" else ollama_model_complete @@ -330,6 +331,7 @@ def create_app(args): else: # azure_openai rag = LightRAG( working_dir=args.working_dir, + workspace=args.workspace, llm_model_func=azure_openai_model_complete, chunk_token_size=int(args.chunk_size), chunk_overlap_token_size=int(args.chunk_overlap_size), diff --git a/lightrag/api/routers/document_routes.py b/lightrag/api/routers/document_routes.py index 4f22947c..288a66e1 100644 --- a/lightrag/api/routers/document_routes.py +++ b/lightrag/api/routers/document_routes.py @@ -475,6 +475,7 @@ class DocumentManager: def __init__( self, input_dir: str, + workspace: str = "", # New parameter for workspace isolation supported_extensions: tuple = ( ".txt", ".md", @@ -515,10 +516,19 @@ class DocumentManager: ".less", # LESS CSS ), ): - self.input_dir = Path(input_dir) + # Store the base input directory and workspace + self.base_input_dir = Path(input_dir) + self.workspace = workspace self.supported_extensions = supported_extensions self.indexed_files = set() + # Create workspace-specific input directory + # If workspace is provided, create a subdirectory for data isolation + if workspace: + self.input_dir = self.base_input_dir / workspace + else: + self.input_dir = self.base_input_dir + # Create input directory if it doesn't exist self.input_dir.mkdir(parents=True, exist_ok=True) @@ -716,7 +726,9 @@ async def pipeline_enqueue_file(rag: LightRAG, file_path: Path) -> bool: if content: # Check if content contains only whitespace characters if not content.strip(): - logger.warning(f"File contains only whitespace characters. file_paths={file_path.name}") + logger.warning( + f"File contains only whitespace characters. file_paths={file_path.name}" + ) await rag.apipeline_enqueue_documents(content, file_paths=file_path.name) logger.info(f"Successfully fetched and enqueued file: {file_path.name}") diff --git a/lightrag/base.py b/lightrag/base.py index 7820b4da..57cb2ac6 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -103,6 +103,7 @@ class QueryParam: @dataclass class StorageNameSpace(ABC): namespace: str + workspace: str global_config: dict[str, Any] async def initialize(self): diff --git a/lightrag/kg/faiss_impl.py b/lightrag/kg/faiss_impl.py index af691458..c6ee099d 100644 --- a/lightrag/kg/faiss_impl.py +++ b/lightrag/kg/faiss_impl.py @@ -38,9 +38,19 @@ class FaissVectorDBStorage(BaseVectorStorage): self.cosine_better_than_threshold = cosine_threshold # Where to save index file if you want persistent storage - self._faiss_index_file = os.path.join( - self.global_config["working_dir"], f"faiss_index_{self.namespace}.index" - ) + working_dir = self.global_config["working_dir"] + if self.workspace: + # Include workspace in the file path for data isolation + workspace_dir = os.path.join(working_dir, self.workspace) + os.makedirs(workspace_dir, exist_ok=True) + self._faiss_index_file = os.path.join( + workspace_dir, f"faiss_index_{self.namespace}.index" + ) + else: + # Default behavior when workspace is empty + self._faiss_index_file = os.path.join( + working_dir, f"faiss_index_{self.namespace}.index" + ) self._meta_file = self._faiss_index_file + ".meta.json" self._max_batch_size = self.global_config["embedding_batch_num"] diff --git a/lightrag/kg/json_doc_status_impl.py b/lightrag/kg/json_doc_status_impl.py index ab6ab390..317509b3 100644 --- a/lightrag/kg/json_doc_status_impl.py +++ b/lightrag/kg/json_doc_status_impl.py @@ -30,7 +30,18 @@ class JsonDocStatusStorage(DocStatusStorage): def __post_init__(self): working_dir = self.global_config["working_dir"] - self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") + if self.workspace: + # Include workspace in the file path for data isolation + workspace_dir = os.path.join(working_dir, self.workspace) + os.makedirs(workspace_dir, exist_ok=True) + self._file_name = os.path.join( + workspace_dir, f"kv_store_{self.namespace}.json" + ) + else: + # Default behavior when workspace is empty + self._file_name = os.path.join( + working_dir, f"kv_store_{self.namespace}.json" + ) self._data = None self._storage_lock = None self.storage_updated = None diff --git a/lightrag/kg/json_kv_impl.py b/lightrag/kg/json_kv_impl.py index 98835f8c..f76c9e13 100644 --- a/lightrag/kg/json_kv_impl.py +++ b/lightrag/kg/json_kv_impl.py @@ -26,7 +26,18 @@ from .shared_storage import ( class JsonKVStorage(BaseKVStorage): def __post_init__(self): working_dir = self.global_config["working_dir"] - self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") + if self.workspace: + # Include workspace in the file path for data isolation + workspace_dir = os.path.join(working_dir, self.workspace) + os.makedirs(workspace_dir, exist_ok=True) + self._file_name = os.path.join( + workspace_dir, f"kv_store_{self.namespace}.json" + ) + else: + # Default behavior when workspace is empty + self._file_name = os.path.join( + working_dir, f"kv_store_{self.namespace}.json" + ) self._data = None self._storage_lock = None self.storage_updated = None diff --git a/lightrag/kg/milvus_impl.py b/lightrag/kg/milvus_impl.py index c64239a9..2414787a 100644 --- a/lightrag/kg/milvus_impl.py +++ b/lightrag/kg/milvus_impl.py @@ -7,10 +7,6 @@ from lightrag.utils import logger, compute_mdhash_id from ..base import BaseVectorStorage import pipmaster as pm - -if not pm.is_installed("configparser"): - pm.install("configparser") - if not pm.is_installed("pymilvus"): pm.install("pymilvus") @@ -660,6 +656,29 @@ class MilvusVectorDBStorage(BaseVectorStorage): raise def __post_init__(self): + # Check for MILVUS_WORKSPACE environment variable first (higher priority) + # This allows administrators to force a specific workspace for all Milvus storage instances + milvus_workspace = os.environ.get("MILVUS_WORKSPACE") + if milvus_workspace and milvus_workspace.strip(): + # Use environment variable value, overriding the passed workspace parameter + effective_workspace = milvus_workspace.strip() + logger.info( + f"Using MILVUS_WORKSPACE environment variable: '{effective_workspace}' (overriding passed workspace: '{self.workspace}')" + ) + else: + # Use the workspace parameter passed during initialization + effective_workspace = self.workspace + if effective_workspace: + logger.debug( + f"Using passed workspace parameter: '{effective_workspace}'" + ) + + # Build namespace with workspace prefix for data isolation + if effective_workspace: + self.namespace = f"{effective_workspace}_{self.namespace}" + logger.debug(f"Final namespace with workspace prefix: '{self.namespace}'") + # When workspace is empty, keep the original namespace unchanged + kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {}) cosine_threshold = kwargs.get("cosine_better_than_threshold") if cosine_threshold is None: diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index 2ac3aff2..5e742e02 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -25,6 +25,7 @@ if not pm.is_installed("pymongo"): pm.install("pymongo") from pymongo import AsyncMongoClient # type: ignore +from pymongo import UpdateOne # type: ignore from pymongo.asynchronous.database import AsyncDatabase # type: ignore from pymongo.asynchronous.collection import AsyncCollection # type: ignore from pymongo.operations import SearchIndexModel # type: ignore @@ -81,7 +82,39 @@ class MongoKVStorage(BaseKVStorage): db: AsyncDatabase = field(default=None) _data: AsyncCollection = field(default=None) + def __init__(self, namespace, global_config, embedding_func, workspace=None): + super().__init__( + namespace=namespace, + workspace=workspace or "", + global_config=global_config, + embedding_func=embedding_func, + ) + self.__post_init__() + def __post_init__(self): + # Check for MONGODB_WORKSPACE environment variable first (higher priority) + # This allows administrators to force a specific workspace for all MongoDB storage instances + mongodb_workspace = os.environ.get("MONGODB_WORKSPACE") + if mongodb_workspace and mongodb_workspace.strip(): + # Use environment variable value, overriding the passed workspace parameter + effective_workspace = mongodb_workspace.strip() + logger.info( + f"Using MONGODB_WORKSPACE environment variable: '{effective_workspace}' (overriding passed workspace: '{self.workspace}')" + ) + else: + # Use the workspace parameter passed during initialization + effective_workspace = self.workspace + if effective_workspace: + logger.debug( + f"Using passed workspace parameter: '{effective_workspace}'" + ) + + # Build namespace with workspace prefix for data isolation + if effective_workspace: + self.namespace = f"{effective_workspace}_{self.namespace}" + logger.debug(f"Final namespace with workspace prefix: '{self.namespace}'") + # When workspace is empty, keep the original namespace unchanged + self._collection_name = self.namespace async def initialize(self): @@ -142,7 +175,6 @@ class MongoKVStorage(BaseKVStorage): # Unified handling for all namespaces with flattened keys # Use bulk_write for better performance - from pymongo import UpdateOne operations = [] current_time = int(time.time()) # Get current Unix timestamp @@ -252,7 +284,39 @@ class MongoDocStatusStorage(DocStatusStorage): db: AsyncDatabase = field(default=None) _data: AsyncCollection = field(default=None) + def __init__(self, namespace, global_config, embedding_func, workspace=None): + super().__init__( + namespace=namespace, + workspace=workspace or "", + global_config=global_config, + embedding_func=embedding_func, + ) + self.__post_init__() + def __post_init__(self): + # Check for MONGODB_WORKSPACE environment variable first (higher priority) + # This allows administrators to force a specific workspace for all MongoDB storage instances + mongodb_workspace = os.environ.get("MONGODB_WORKSPACE") + if mongodb_workspace and mongodb_workspace.strip(): + # Use environment variable value, overriding the passed workspace parameter + effective_workspace = mongodb_workspace.strip() + logger.info( + f"Using MONGODB_WORKSPACE environment variable: '{effective_workspace}' (overriding passed workspace: '{self.workspace}')" + ) + else: + # Use the workspace parameter passed during initialization + effective_workspace = self.workspace + if effective_workspace: + logger.debug( + f"Using passed workspace parameter: '{effective_workspace}'" + ) + + # Build namespace with workspace prefix for data isolation + if effective_workspace: + self.namespace = f"{effective_workspace}_{self.namespace}" + logger.debug(f"Final namespace with workspace prefix: '{self.namespace}'") + # When workspace is empty, keep the original namespace unchanged + self._collection_name = self.namespace async def initialize(self): @@ -367,12 +431,36 @@ class MongoGraphStorage(BaseGraphStorage): # edge collection storing source_node_id, target_node_id, and edge_properties edgeCollection: AsyncCollection = field(default=None) - def __init__(self, namespace, global_config, embedding_func): + def __init__(self, namespace, global_config, embedding_func, workspace=None): super().__init__( namespace=namespace, + workspace=workspace or "", global_config=global_config, embedding_func=embedding_func, ) + # Check for MONGODB_WORKSPACE environment variable first (higher priority) + # This allows administrators to force a specific workspace for all MongoDB storage instances + mongodb_workspace = os.environ.get("MONGODB_WORKSPACE") + if mongodb_workspace and mongodb_workspace.strip(): + # Use environment variable value, overriding the passed workspace parameter + effective_workspace = mongodb_workspace.strip() + logger.info( + f"Using MONGODB_WORKSPACE environment variable: '{effective_workspace}' (overriding passed workspace: '{self.workspace}')" + ) + else: + # Use the workspace parameter passed during initialization + effective_workspace = self.workspace + if effective_workspace: + logger.debug( + f"Using passed workspace parameter: '{effective_workspace}'" + ) + + # Build namespace with workspace prefix for data isolation + if effective_workspace: + self.namespace = f"{effective_workspace}_{self.namespace}" + logger.debug(f"Final namespace with workspace prefix: '{self.namespace}'") + # When workspace is empty, keep the original namespace unchanged + self._collection_name = self.namespace self._edge_collection_name = f"{self._collection_name}_edges" @@ -1231,7 +1319,42 @@ class MongoVectorDBStorage(BaseVectorStorage): db: AsyncDatabase | None = field(default=None) _data: AsyncCollection | None = field(default=None) + def __init__( + self, namespace, global_config, embedding_func, workspace=None, meta_fields=None + ): + super().__init__( + namespace=namespace, + workspace=workspace or "", + global_config=global_config, + embedding_func=embedding_func, + meta_fields=meta_fields or set(), + ) + self.__post_init__() + def __post_init__(self): + # Check for MONGODB_WORKSPACE environment variable first (higher priority) + # This allows administrators to force a specific workspace for all MongoDB storage instances + mongodb_workspace = os.environ.get("MONGODB_WORKSPACE") + if mongodb_workspace and mongodb_workspace.strip(): + # Use environment variable value, overriding the passed workspace parameter + effective_workspace = mongodb_workspace.strip() + logger.info( + f"Using MONGODB_WORKSPACE environment variable: '{effective_workspace}' (overriding passed workspace: '{self.workspace}')" + ) + else: + # Use the workspace parameter passed during initialization + effective_workspace = self.workspace + if effective_workspace: + logger.debug( + f"Using passed workspace parameter: '{effective_workspace}'" + ) + + # Build namespace with workspace prefix for data isolation + if effective_workspace: + self.namespace = f"{effective_workspace}_{self.namespace}" + logger.debug(f"Final namespace with workspace prefix: '{self.namespace}'") + # When workspace is empty, keep the original namespace unchanged + kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {}) cosine_threshold = kwargs.get("cosine_better_than_threshold") if cosine_threshold is None: diff --git a/lightrag/kg/nano_vector_db_impl.py b/lightrag/kg/nano_vector_db_impl.py index fa56a214..4eae2db5 100644 --- a/lightrag/kg/nano_vector_db_impl.py +++ b/lightrag/kg/nano_vector_db_impl.py @@ -41,9 +41,19 @@ class NanoVectorDBStorage(BaseVectorStorage): ) self.cosine_better_than_threshold = cosine_threshold - self._client_file_name = os.path.join( - self.global_config["working_dir"], f"vdb_{self.namespace}.json" - ) + working_dir = self.global_config["working_dir"] + if self.workspace: + # Include workspace in the file path for data isolation + workspace_dir = os.path.join(working_dir, self.workspace) + os.makedirs(workspace_dir, exist_ok=True) + self._client_file_name = os.path.join( + workspace_dir, f"vdb_{self.namespace}.json" + ) + else: + # Default behavior when workspace is empty + self._client_file_name = os.path.join( + working_dir, f"vdb_{self.namespace}.json" + ) self._max_batch_size = self.global_config["embedding_batch_num"] self._client = NanoVectorDB( diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index d4fbc59c..756dc7c8 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -50,14 +50,20 @@ logging.getLogger("neo4j").setLevel(logging.ERROR) @final @dataclass class Neo4JStorage(BaseGraphStorage): - def __init__(self, namespace, global_config, embedding_func): + def __init__(self, namespace, global_config, embedding_func, workspace=None): super().__init__( namespace=namespace, + workspace=workspace or "", global_config=global_config, embedding_func=embedding_func, ) self._driver = None + def _get_workspace_label(self) -> str: + """Get workspace label, return 'base' for compatibility when workspace is empty""" + workspace = getattr(self, "workspace", None) + return workspace if workspace else "base" + async def initialize(self): URI = os.environ.get("NEO4J_URI", config.get("neo4j", "uri", fallback=None)) USERNAME = os.environ.get( @@ -153,13 +159,14 @@ class Neo4JStorage(BaseGraphStorage): raise e if connected: - # Create index for base nodes on entity_id if it doesn't exist + # Create index for workspace nodes on entity_id if it doesn't exist + workspace_label = self._get_workspace_label() try: async with self._driver.session(database=database) as session: # Check if index exists first - check_query = """ + check_query = f""" CALL db.indexes() YIELD name, labelsOrTypes, properties - WHERE labelsOrTypes = ['base'] AND properties = ['entity_id'] + WHERE labelsOrTypes = ['{workspace_label}'] AND properties = ['entity_id'] RETURN count(*) > 0 AS exists """ try: @@ -172,16 +179,16 @@ class Neo4JStorage(BaseGraphStorage): if not index_exists: # Create index only if it doesn't exist result = await session.run( - "CREATE INDEX FOR (n:base) ON (n.entity_id)" + f"CREATE INDEX FOR (n:`{workspace_label}`) ON (n.entity_id)" ) await result.consume() logger.info( - f"Created index for base nodes on entity_id in {database}" + f"Created index for {workspace_label} nodes on entity_id in {database}" ) except Exception: # Fallback if db.indexes() is not supported in this Neo4j version result = await session.run( - "CREATE INDEX IF NOT EXISTS FOR (n:base) ON (n.entity_id)" + f"CREATE INDEX IF NOT EXISTS FOR (n:`{workspace_label}`) ON (n.entity_id)" ) await result.consume() except Exception as e: @@ -216,11 +223,12 @@ class Neo4JStorage(BaseGraphStorage): ValueError: If node_id is invalid Exception: If there is an error executing the query """ + workspace_label = self._get_workspace_label() async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: try: - query = "MATCH (n:base {entity_id: $entity_id}) RETURN count(n) > 0 AS node_exists" + query = f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN count(n) > 0 AS node_exists" result = await session.run(query, entity_id=node_id) single_result = await result.single() await result.consume() # Ensure result is fully consumed @@ -245,12 +253,13 @@ class Neo4JStorage(BaseGraphStorage): ValueError: If either node_id is invalid Exception: If there is an error executing the query """ + workspace_label = self._get_workspace_label() async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: try: query = ( - "MATCH (a:base {entity_id: $source_entity_id})-[r]-(b:base {entity_id: $target_entity_id}) " + f"MATCH (a:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(b:`{workspace_label}` {{entity_id: $target_entity_id}}) " "RETURN COUNT(r) > 0 AS edgeExists" ) result = await session.run( @@ -282,11 +291,14 @@ class Neo4JStorage(BaseGraphStorage): ValueError: If node_id is invalid Exception: If there is an error executing the query """ + workspace_label = self._get_workspace_label() async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: try: - query = "MATCH (n:base {entity_id: $entity_id}) RETURN n" + query = ( + f"MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN n" + ) result = await session.run(query, entity_id=node_id) try: records = await result.fetch( @@ -300,12 +312,12 @@ class Neo4JStorage(BaseGraphStorage): if records: node = records[0]["n"] node_dict = dict(node) - # Remove base label from labels list if it exists + # Remove workspace label from labels list if it exists if "labels" in node_dict: node_dict["labels"] = [ label for label in node_dict["labels"] - if label != "base" + if label != workspace_label ] # logger.debug(f"Neo4j query node {query} return: {node_dict}") return node_dict @@ -326,12 +338,13 @@ class Neo4JStorage(BaseGraphStorage): Returns: A dictionary mapping each node_id to its node data (or None if not found). """ + workspace_label = self._get_workspace_label() async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: - query = """ + query = f""" UNWIND $node_ids AS id - MATCH (n:base {entity_id: id}) + MATCH (n:`{workspace_label}` {{entity_id: id}}) RETURN n.entity_id AS entity_id, n """ result = await session.run(query, node_ids=node_ids) @@ -340,10 +353,12 @@ class Neo4JStorage(BaseGraphStorage): entity_id = record["entity_id"] node = record["n"] node_dict = dict(node) - # Remove the 'base' label if present in a 'labels' property + # Remove the workspace label if present in a 'labels' property if "labels" in node_dict: node_dict["labels"] = [ - label for label in node_dict["labels"] if label != "base" + label + for label in node_dict["labels"] + if label != workspace_label ] nodes[entity_id] = node_dict await result.consume() # Make sure to consume the result fully @@ -364,12 +379,13 @@ class Neo4JStorage(BaseGraphStorage): ValueError: If node_id is invalid Exception: If there is an error executing the query """ + workspace_label = self._get_workspace_label() async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: try: - query = """ - MATCH (n:base {entity_id: $entity_id}) + query = f""" + MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) OPTIONAL MATCH (n)-[r]-() RETURN COUNT(r) AS degree """ @@ -403,13 +419,14 @@ class Neo4JStorage(BaseGraphStorage): A dictionary mapping each node_id to its degree (number of relationships). If a node is not found, its degree will be set to 0. """ + workspace_label = self._get_workspace_label() async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: - query = """ + query = f""" UNWIND $node_ids AS id - MATCH (n:base {entity_id: id}) - RETURN n.entity_id AS entity_id, count { (n)--() } AS degree; + MATCH (n:`{workspace_label}` {{entity_id: id}}) + RETURN n.entity_id AS entity_id, count {{ (n)--() }} AS degree; """ result = await session.run(query, node_ids=node_ids) degrees = {} @@ -489,12 +506,13 @@ class Neo4JStorage(BaseGraphStorage): ValueError: If either node_id is invalid Exception: If there is an error executing the query """ + workspace_label = self._get_workspace_label() try: async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: - query = """ - MATCH (start:base {entity_id: $source_entity_id})-[r]-(end:base {entity_id: $target_entity_id}) + query = f""" + MATCH (start:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(end:`{workspace_label}` {{entity_id: $target_entity_id}}) RETURN properties(r) as edge_properties """ result = await session.run( @@ -571,12 +589,13 @@ class Neo4JStorage(BaseGraphStorage): Returns: A dictionary mapping (src, tgt) tuples to their edge properties. """ + workspace_label = self._get_workspace_label() async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: - query = """ + query = f""" UNWIND $pairs AS pair - MATCH (start:base {entity_id: pair.src})-[r:DIRECTED]-(end:base {entity_id: pair.tgt}) + MATCH (start:`{workspace_label}` {{entity_id: pair.src}})-[r:DIRECTED]-(end:`{workspace_label}` {{entity_id: pair.tgt}}) RETURN pair.src AS src_id, pair.tgt AS tgt_id, collect(properties(r)) AS edges """ result = await session.run(query, pairs=pairs) @@ -627,8 +646,9 @@ class Neo4JStorage(BaseGraphStorage): database=self._DATABASE, default_access_mode="READ" ) as session: try: - query = """MATCH (n:base {entity_id: $entity_id}) - OPTIONAL MATCH (n)-[r]-(connected:base) + workspace_label = self._get_workspace_label() + query = f"""MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) + OPTIONAL MATCH (n)-[r]-(connected:`{workspace_label}`) WHERE connected.entity_id IS NOT NULL RETURN n, r, connected""" results = await session.run(query, entity_id=source_node_id) @@ -689,10 +709,11 @@ class Neo4JStorage(BaseGraphStorage): database=self._DATABASE, default_access_mode="READ" ) as session: # Query to get both outgoing and incoming edges - query = """ + workspace_label = self._get_workspace_label() + query = f""" UNWIND $node_ids AS id - MATCH (n:base {entity_id: id}) - OPTIONAL MATCH (n)-[r]-(connected:base) + MATCH (n:`{workspace_label}` {{entity_id: id}}) + OPTIONAL MATCH (n)-[r]-(connected:`{workspace_label}`) RETURN id AS queried_id, n.entity_id AS node_entity_id, connected.entity_id AS connected_entity_id, startNode(r).entity_id AS start_entity_id @@ -727,12 +748,13 @@ class Neo4JStorage(BaseGraphStorage): return edges_dict async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: + workspace_label = self._get_workspace_label() async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: - query = """ + query = f""" UNWIND $chunk_ids AS chunk_id - MATCH (n:base) + MATCH (n:`{workspace_label}`) WHERE n.source_id IS NOT NULL AND chunk_id IN split(n.source_id, $sep) RETURN DISTINCT n """ @@ -748,12 +770,13 @@ class Neo4JStorage(BaseGraphStorage): return nodes async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: + workspace_label = self._get_workspace_label() async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: - query = """ + query = f""" UNWIND $chunk_ids AS chunk_id - MATCH (a:base)-[r]-(b:base) + MATCH (a:`{workspace_label}`)-[r]-(b:`{workspace_label}`) WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, $sep) RETURN DISTINCT a.entity_id AS source, b.entity_id AS target, properties(r) AS properties """ @@ -787,6 +810,7 @@ class Neo4JStorage(BaseGraphStorage): node_id: The unique identifier for the node (used as label) node_data: Dictionary of node properties """ + workspace_label = self._get_workspace_label() properties = node_data entity_type = properties["entity_type"] if "entity_id" not in properties: @@ -796,14 +820,11 @@ class Neo4JStorage(BaseGraphStorage): async with self._driver.session(database=self._DATABASE) as session: async def execute_upsert(tx: AsyncManagedTransaction): - query = ( - """ - MERGE (n:base {entity_id: $entity_id}) + query = f""" + MERGE (n:`{workspace_label}` {{entity_id: $entity_id}}) SET n += $properties - SET n:`%s` + SET n:`{entity_type}` """ - % entity_type - ) result = await tx.run( query, entity_id=node_id, properties=properties ) @@ -847,10 +868,11 @@ class Neo4JStorage(BaseGraphStorage): async with self._driver.session(database=self._DATABASE) as session: async def execute_upsert(tx: AsyncManagedTransaction): - query = """ - MATCH (source:base {entity_id: $source_entity_id}) + workspace_label = self._get_workspace_label() + query = f""" + MATCH (source:`{workspace_label}` {{entity_id: $source_entity_id}}) WITH source - MATCH (target:base {entity_id: $target_entity_id}) + MATCH (target:`{workspace_label}` {{entity_id: $target_entity_id}}) MERGE (source)-[r:DIRECTED]-(target) SET r += $properties RETURN r, source, target @@ -889,6 +911,7 @@ class Neo4JStorage(BaseGraphStorage): KnowledgeGraph object containing nodes and edges, with an is_truncated flag indicating whether the graph was truncated due to max_nodes limit """ + workspace_label = self._get_workspace_label() result = KnowledgeGraph() seen_nodes = set() seen_edges = set() @@ -899,7 +922,9 @@ class Neo4JStorage(BaseGraphStorage): try: if node_label == "*": # First check total node count to determine if graph is truncated - count_query = "MATCH (n) RETURN count(n) as total" + count_query = ( + f"MATCH (n:`{workspace_label}`) RETURN count(n) as total" + ) count_result = None try: count_result = await session.run(count_query) @@ -915,13 +940,13 @@ class Neo4JStorage(BaseGraphStorage): await count_result.consume() # Run main query to get nodes with highest degree - main_query = """ - MATCH (n) + main_query = f""" + MATCH (n:`{workspace_label}`) OPTIONAL MATCH (n)-[r]-() WITH n, COALESCE(count(r), 0) AS degree ORDER BY degree DESC LIMIT $max_nodes - WITH collect({node: n}) AS filtered_nodes + WITH collect({{node: n}}) AS filtered_nodes UNWIND filtered_nodes AS node_info WITH collect(node_info.node) AS kept_nodes, filtered_nodes OPTIONAL MATCH (a)-[r]-(b) @@ -943,20 +968,21 @@ class Neo4JStorage(BaseGraphStorage): else: # return await self._robust_fallback(node_label, max_depth, max_nodes) # First try without limit to check if we need to truncate - full_query = """ - MATCH (start) + full_query = f""" + MATCH (start:`{workspace_label}`) WHERE start.entity_id = $entity_id WITH start - CALL apoc.path.subgraphAll(start, { + CALL apoc.path.subgraphAll(start, {{ relationshipFilter: '', + labelFilter: '{workspace_label}', minLevel: 0, maxLevel: $max_depth, bfs: true - }) + }}) YIELD nodes, relationships WITH nodes, relationships, size(nodes) AS total_nodes UNWIND nodes AS node - WITH collect({node: node}) AS node_info, relationships, total_nodes + WITH collect({{node: node}}) AS node_info, relationships, total_nodes RETURN node_info, relationships, total_nodes """ @@ -994,20 +1020,21 @@ class Neo4JStorage(BaseGraphStorage): ) # Run limited query - limited_query = """ - MATCH (start) + limited_query = f""" + MATCH (start:`{workspace_label}`) WHERE start.entity_id = $entity_id WITH start - CALL apoc.path.subgraphAll(start, { + CALL apoc.path.subgraphAll(start, {{ relationshipFilter: '', + labelFilter: '{workspace_label}', minLevel: 0, maxLevel: $max_depth, limit: $max_nodes, bfs: true - }) + }}) YIELD nodes, relationships UNWIND nodes AS node - WITH collect({node: node}) AS node_info, relationships + WITH collect({{node: node}}) AS node_info, relationships RETURN node_info, relationships """ result_set = None @@ -1094,11 +1121,12 @@ class Neo4JStorage(BaseGraphStorage): visited_edge_pairs = set() # Get the starting node's data + workspace_label = self._get_workspace_label() async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: - query = """ - MATCH (n:base {entity_id: $entity_id}) + query = f""" + MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) RETURN id(n) as node_id, n """ node_result = await session.run(query, entity_id=node_label) @@ -1156,8 +1184,9 @@ class Neo4JStorage(BaseGraphStorage): async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: - query = """ - MATCH (a:base {entity_id: $entity_id})-[r]-(b) + workspace_label = self._get_workspace_label() + query = f""" + MATCH (a:`{workspace_label}` {{entity_id: $entity_id}})-[r]-(b) WITH r, b, id(r) as edge_id, id(b) as target_id RETURN r, b, edge_id, target_id """ @@ -1241,6 +1270,7 @@ class Neo4JStorage(BaseGraphStorage): Returns: ["Person", "Company", ...] # Alphabetically sorted label list """ + workspace_label = self._get_workspace_label() async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: @@ -1248,8 +1278,8 @@ class Neo4JStorage(BaseGraphStorage): # query = "CALL db.labels() YIELD label RETURN label" # Method 2: Query compatible with older versions - query = """ - MATCH (n:base) + query = f""" + MATCH (n:`{workspace_label}`) WHERE n.entity_id IS NOT NULL RETURN DISTINCT n.entity_id AS label ORDER BY label @@ -1285,8 +1315,9 @@ class Neo4JStorage(BaseGraphStorage): """ async def _do_delete(tx: AsyncManagedTransaction): - query = """ - MATCH (n:base {entity_id: $entity_id}) + workspace_label = self._get_workspace_label() + query = f""" + MATCH (n:`{workspace_label}` {{entity_id: $entity_id}}) DETACH DELETE n """ result = await tx.run(query, entity_id=node_id) @@ -1342,8 +1373,9 @@ class Neo4JStorage(BaseGraphStorage): for source, target in edges: async def _do_delete_edge(tx: AsyncManagedTransaction): - query = """ - MATCH (source:base {entity_id: $source_entity_id})-[r]-(target:base {entity_id: $target_entity_id}) + workspace_label = self._get_workspace_label() + query = f""" + MATCH (source:`{workspace_label}` {{entity_id: $source_entity_id}})-[r]-(target:`{workspace_label}` {{entity_id: $target_entity_id}}) DELETE r """ result = await tx.run( @@ -1360,26 +1392,32 @@ class Neo4JStorage(BaseGraphStorage): raise async def drop(self) -> dict[str, str]: - """Drop all data from storage and clean up resources + """Drop all data from current workspace storage and clean up resources - This method will delete all nodes and relationships in the Neo4j database. + This method will delete all nodes and relationships in the current workspace only. Returns: dict[str, str]: Operation status and message - - On success: {"status": "success", "message": "data dropped"} + - On success: {"status": "success", "message": "workspace data dropped"} - On failure: {"status": "error", "message": ""} """ + workspace_label = self._get_workspace_label() try: async with self._driver.session(database=self._DATABASE) as session: - # Delete all nodes and relationships - query = "MATCH (n) DETACH DELETE n" + # Delete all nodes and relationships in current workspace only + query = f"MATCH (n:`{workspace_label}`) DETACH DELETE n" result = await session.run(query) await result.consume() # Ensure result is fully consumed logger.info( - f"Process {os.getpid()} drop Neo4j database {self._DATABASE}" + f"Process {os.getpid()} drop Neo4j workspace '{workspace_label}' in database {self._DATABASE}" ) - return {"status": "success", "message": "data dropped"} + return { + "status": "success", + "message": f"workspace '{workspace_label}' data dropped", + } except Exception as e: - logger.error(f"Error dropping Neo4j database {self._DATABASE}: {e}") + logger.error( + f"Error dropping Neo4j workspace '{workspace_label}' in database {self._DATABASE}: {e}" + ) return {"status": "error", "message": str(e)} diff --git a/lightrag/kg/networkx_impl.py b/lightrag/kg/networkx_impl.py index bb7233b4..faff6a96 100644 --- a/lightrag/kg/networkx_impl.py +++ b/lightrag/kg/networkx_impl.py @@ -46,9 +46,19 @@ class NetworkXStorage(BaseGraphStorage): nx.write_graphml(graph, file_name) def __post_init__(self): - self._graphml_xml_file = os.path.join( - self.global_config["working_dir"], f"graph_{self.namespace}.graphml" - ) + working_dir = self.global_config["working_dir"] + if self.workspace: + # Include workspace in the file path for data isolation + workspace_dir = os.path.join(working_dir, self.workspace) + os.makedirs(workspace_dir, exist_ok=True) + self._graphml_xml_file = os.path.join( + workspace_dir, f"graph_{self.namespace}.graphml" + ) + else: + # Default behavior when workspace is empty + self._graphml_xml_file = os.path.join( + working_dir, f"graph_{self.namespace}.graphml" + ) self._storage_lock = None self.storage_updated = None self._graph = None diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index d8447664..a5f9a46e 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -1,6 +1,7 @@ import asyncio import json import os +import re import datetime from datetime import timezone from dataclasses import dataclass, field @@ -319,7 +320,7 @@ class PostgreSQLDB: # Get all old format data old_data_sql = """ SELECT id, mode, original_prompt, return_value, chunk_id, - create_time, update_time + workspace, create_time, update_time FROM LIGHTRAG_LLM_CACHE WHERE id NOT LIKE '%:%' """ @@ -364,7 +365,9 @@ class PostgreSQLDB: await self.execute( insert_sql, { - "workspace": self.workspace, + "workspace": record[ + "workspace" + ], # Use original record's workspace "id": new_key, "mode": record["mode"], "original_prompt": record["original_prompt"], @@ -384,7 +387,9 @@ class PostgreSQLDB: await self.execute( delete_sql, { - "workspace": self.workspace, + "workspace": record[ + "workspace" + ], # Use original record's workspace "mode": record["mode"], "id": record["id"], # Old id }, @@ -670,7 +675,7 @@ class ClientManager: ), "workspace": os.environ.get( "POSTGRES_WORKSPACE", - config.get("postgres", "workspace", fallback="default"), + config.get("postgres", "workspace", fallback=None), ), "max_connections": os.environ.get( "POSTGRES_MAX_CONNECTIONS", @@ -716,6 +721,18 @@ class PGKVStorage(BaseKVStorage): async def initialize(self): if self.db is None: self.db = await ClientManager.get_client() + # Implement workspace priority: PostgreSQLDB.workspace > self.workspace > "default" + if self.db.workspace: + # Use PostgreSQLDB's workspace (highest priority) + final_workspace = self.db.workspace + elif hasattr(self, "workspace") and self.workspace: + # Use storage class's workspace (medium priority) + final_workspace = self.workspace + self.db.workspace = final_workspace + else: + # Use "default" for compatibility (lowest priority) + final_workspace = "default" + self.db.workspace = final_workspace async def finalize(self): if self.db is not None: @@ -1047,6 +1064,18 @@ class PGVectorStorage(BaseVectorStorage): async def initialize(self): if self.db is None: self.db = await ClientManager.get_client() + # Implement workspace priority: PostgreSQLDB.workspace > self.workspace > "default" + if self.db.workspace: + # Use PostgreSQLDB's workspace (highest priority) + final_workspace = self.db.workspace + elif hasattr(self, "workspace") and self.workspace: + # Use storage class's workspace (medium priority) + final_workspace = self.workspace + self.db.workspace = final_workspace + else: + # Use "default" for compatibility (lowest priority) + final_workspace = "default" + self.db.workspace = final_workspace async def finalize(self): if self.db is not None: @@ -1328,6 +1357,18 @@ class PGDocStatusStorage(DocStatusStorage): async def initialize(self): if self.db is None: self.db = await ClientManager.get_client() + # Implement workspace priority: PostgreSQLDB.workspace > self.workspace > "default" + if self.db.workspace: + # Use PostgreSQLDB's workspace (highest priority) + final_workspace = self.db.workspace + elif hasattr(self, "workspace") and self.workspace: + # Use storage class's workspace (medium priority) + final_workspace = self.workspace + self.db.workspace = final_workspace + else: + # Use "default" for compatibility (lowest priority) + final_workspace = "default" + self.db.workspace = final_workspace async def finalize(self): if self.db is not None: @@ -1606,9 +1647,34 @@ class PGGraphQueryException(Exception): @dataclass class PGGraphStorage(BaseGraphStorage): def __post_init__(self): - self.graph_name = self.namespace or os.environ.get("AGE_GRAPH_NAME", "lightrag") + # Graph name will be dynamically generated in initialize() based on workspace self.db: PostgreSQLDB | None = None + def _get_workspace_graph_name(self) -> str: + """ + Generate graph name based on workspace and namespace for data isolation. + Rules: + - If workspace is empty: graph_name = namespace + - If workspace has value: graph_name = workspace_namespace + + Args: + None + + Returns: + str: The graph name for the current workspace + """ + workspace = getattr(self, "workspace", None) + namespace = self.namespace or os.environ.get("AGE_GRAPH_NAME", "lightrag") + + if workspace and workspace.strip(): + # Ensure names comply with PostgreSQL identifier specifications + safe_workspace = re.sub(r"[^a-zA-Z0-9_]", "_", workspace.strip()) + safe_namespace = re.sub(r"[^a-zA-Z0-9_]", "_", namespace) + return f"{safe_workspace}_{safe_namespace}" + else: + # When workspace is empty, use namespace directly + return re.sub(r"[^a-zA-Z0-9_]", "_", namespace) + @staticmethod def _normalize_node_id(node_id: str) -> str: """ @@ -1629,6 +1695,27 @@ class PGGraphStorage(BaseGraphStorage): async def initialize(self): if self.db is None: self.db = await ClientManager.get_client() + # Implement workspace priority: PostgreSQLDB.workspace > self.workspace > None + if self.db.workspace: + # Use PostgreSQLDB's workspace (highest priority) + final_workspace = self.db.workspace + elif hasattr(self, "workspace") and self.workspace: + # Use storage class's workspace (medium priority) + final_workspace = self.workspace + self.db.workspace = final_workspace + else: + # Use None for compatibility (lowest priority) + final_workspace = None + self.db.workspace = final_workspace + + # Dynamically generate graph name based on workspace + self.workspace = self.db.workspace + self.graph_name = self._get_workspace_graph_name() + + # Log the graph initialization for debugging + logger.info( + f"PostgreSQL Graph initialized: workspace='{self.workspace}', graph_name='{self.graph_name}'" + ) # Execute each statement separately and ignore errors queries = [ @@ -2833,7 +2920,10 @@ class PGGraphStorage(BaseGraphStorage): $$) AS (result agtype)""" await self._query(drop_query, readonly=False) - return {"status": "success", "message": "graph data dropped"} + return { + "status": "success", + "message": f"workspace '{self.workspace}' graph data dropped", + } except Exception as e: logger.error(f"Error dropping graph: {e}") return {"status": "error", "message": str(e)} diff --git a/lightrag/kg/qdrant_impl.py b/lightrag/kg/qdrant_impl.py index dada278a..9cdeff7a 100644 --- a/lightrag/kg/qdrant_impl.py +++ b/lightrag/kg/qdrant_impl.py @@ -50,6 +50,18 @@ def compute_mdhash_id_for_qdrant( @final @dataclass class QdrantVectorDBStorage(BaseVectorStorage): + def __init__( + self, namespace, global_config, embedding_func, workspace=None, meta_fields=None + ): + super().__init__( + namespace=namespace, + workspace=workspace or "", + global_config=global_config, + embedding_func=embedding_func, + meta_fields=meta_fields or set(), + ) + self.__post_init__() + @staticmethod def create_collection_if_not_exist( client: QdrantClient, collection_name: str, **kwargs @@ -59,6 +71,29 @@ class QdrantVectorDBStorage(BaseVectorStorage): client.create_collection(collection_name, **kwargs) def __post_init__(self): + # Check for QDRANT_WORKSPACE environment variable first (higher priority) + # This allows administrators to force a specific workspace for all Qdrant storage instances + qdrant_workspace = os.environ.get("QDRANT_WORKSPACE") + if qdrant_workspace and qdrant_workspace.strip(): + # Use environment variable value, overriding the passed workspace parameter + effective_workspace = qdrant_workspace.strip() + logger.info( + f"Using QDRANT_WORKSPACE environment variable: '{effective_workspace}' (overriding passed workspace: '{self.workspace}')" + ) + else: + # Use the workspace parameter passed during initialization + effective_workspace = self.workspace + if effective_workspace: + logger.debug( + f"Using passed workspace parameter: '{effective_workspace}'" + ) + + # Build namespace with workspace prefix for data isolation + if effective_workspace: + self.namespace = f"{effective_workspace}_{self.namespace}" + logger.debug(f"Final namespace with workspace prefix: '{self.namespace}'") + # When workspace is empty, keep the original namespace unchanged + kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {}) cosine_threshold = kwargs.get("cosine_better_than_threshold") if cosine_threshold is None: diff --git a/lightrag/kg/redis_impl.py b/lightrag/kg/redis_impl.py index dba228ca..06e51384 100644 --- a/lightrag/kg/redis_impl.py +++ b/lightrag/kg/redis_impl.py @@ -71,6 +71,29 @@ class RedisConnectionManager: @dataclass class RedisKVStorage(BaseKVStorage): def __post_init__(self): + # Check for REDIS_WORKSPACE environment variable first (higher priority) + # This allows administrators to force a specific workspace for all Redis storage instances + redis_workspace = os.environ.get("REDIS_WORKSPACE") + if redis_workspace and redis_workspace.strip(): + # Use environment variable value, overriding the passed workspace parameter + effective_workspace = redis_workspace.strip() + logger.info( + f"Using REDIS_WORKSPACE environment variable: '{effective_workspace}' (overriding passed workspace: '{self.workspace}')" + ) + else: + # Use the workspace parameter passed during initialization + effective_workspace = self.workspace + if effective_workspace: + logger.debug( + f"Using passed workspace parameter: '{effective_workspace}'" + ) + + # Build namespace with workspace prefix for data isolation + if effective_workspace: + self.namespace = f"{effective_workspace}_{self.namespace}" + logger.debug(f"Final namespace with workspace prefix: '{self.namespace}'") + # When workspace is empty, keep the original namespace unchanged + redis_url = os.environ.get( "REDIS_URI", config.get("redis", "uri", fallback="redis://localhost:6379") ) @@ -461,6 +484,29 @@ class RedisDocStatusStorage(DocStatusStorage): """Redis implementation of document status storage""" def __post_init__(self): + # Check for REDIS_WORKSPACE environment variable first (higher priority) + # This allows administrators to force a specific workspace for all Redis storage instances + redis_workspace = os.environ.get("REDIS_WORKSPACE") + if redis_workspace and redis_workspace.strip(): + # Use environment variable value, overriding the passed workspace parameter + effective_workspace = redis_workspace.strip() + logger.info( + f"Using REDIS_WORKSPACE environment variable: '{effective_workspace}' (overriding passed workspace: '{self.workspace}')" + ) + else: + # Use the workspace parameter passed during initialization + effective_workspace = self.workspace + if effective_workspace: + logger.debug( + f"Using passed workspace parameter: '{effective_workspace}'" + ) + + # Build namespace with workspace prefix for data isolation + if effective_workspace: + self.namespace = f"{effective_workspace}_{self.namespace}" + logger.debug(f"Final namespace with workspace prefix: '{self.namespace}'") + # When workspace is empty, keep the original namespace unchanged + redis_url = os.environ.get( "REDIS_URI", config.get("redis", "uri", fallback="redis://localhost:6379") ) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 17681f89..5d96aeba 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -97,9 +97,7 @@ class LightRAG: # Directory # --- - working_dir: str = field( - default=f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}" - ) + working_dir: str = field(default="./rag_storage") """Directory where cache and temporary files are stored.""" # Storage @@ -117,6 +115,12 @@ class LightRAG: doc_status_storage: str = field(default="JsonDocStatusStorage") """Storage type for tracking document processing statuses.""" + # Workspace + # --- + + workspace: str = field(default_factory=lambda: os.getenv("WORKSPACE", "")) + """Workspace for data isolation. Defaults to empty string if WORKSPACE environment variable is not set.""" + # Logging (Deprecated, use setup_logger in utils.py instead) # --- log_level: int | None = field(default=None) @@ -242,7 +246,6 @@ class LightRAG: vector_db_storage_cls_kwargs: dict[str, Any] = field(default_factory=dict) """Additional parameters for vector database storage.""" - enable_llm_cache: bool = field(default=True) """Enables caching for LLM responses to avoid redundant computations.""" @@ -380,39 +383,44 @@ class LightRAG: self.llm_response_cache: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore namespace=NameSpace.KV_STORE_LLM_RESPONSE_CACHE, - global_config=asdict( - self - ), # Add global_config to ensure cache works properly + workspace=self.workspace, + global_config=global_config, embedding_func=self.embedding_func, ) self.full_docs: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore namespace=NameSpace.KV_STORE_FULL_DOCS, + workspace=self.workspace, embedding_func=self.embedding_func, ) self.text_chunks: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore namespace=NameSpace.KV_STORE_TEXT_CHUNKS, + workspace=self.workspace, embedding_func=self.embedding_func, ) self.chunk_entity_relation_graph: BaseGraphStorage = self.graph_storage_cls( # type: ignore namespace=NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION, + workspace=self.workspace, embedding_func=self.embedding_func, ) self.entities_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore namespace=NameSpace.VECTOR_STORE_ENTITIES, + workspace=self.workspace, embedding_func=self.embedding_func, meta_fields={"entity_name", "source_id", "content", "file_path"}, ) self.relationships_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore namespace=NameSpace.VECTOR_STORE_RELATIONSHIPS, + workspace=self.workspace, embedding_func=self.embedding_func, meta_fields={"src_id", "tgt_id", "source_id", "content", "file_path"}, ) self.chunks_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore namespace=NameSpace.VECTOR_STORE_CHUNKS, + workspace=self.workspace, embedding_func=self.embedding_func, meta_fields={"full_doc_id", "content", "file_path"}, ) @@ -420,6 +428,7 @@ class LightRAG: # Initialize document status storage self.doc_status: DocStatusStorage = self.doc_status_storage_cls( namespace=NameSpace.DOC_STATUS, + workspace=self.workspace, global_config=global_config, embedding_func=None, )