diff --git a/lightrag/kg/faiss_impl.py b/lightrag/kg/faiss_impl.py index a46db3ec..6de640b7 100644 --- a/lightrag/kg/faiss_impl.py +++ b/lightrag/kg/faiss_impl.py @@ -10,7 +10,7 @@ from lightrag.utils import logger, compute_mdhash_id from lightrag.base import BaseVectorStorage from .shared_storage import ( - get_storage_lock, + get_namespace_lock, get_update_flag, set_all_update_flags, ) @@ -39,21 +39,16 @@ class FaissVectorDBStorage(BaseVectorStorage): # Where to save index file if you want persistent storage working_dir = self.global_config["working_dir"] - - # Get composite workspace (supports multi-tenant isolation) - composite_workspace = self._get_composite_workspace() - - if composite_workspace and composite_workspace != "_": - # Include composite workspace in the file path for data isolation - # For multi-tenant: tenant_id:kb_id:workspace - # For single-tenant: just workspace - workspace_dir = os.path.join(working_dir, composite_workspace) - self.final_namespace = f"{composite_workspace}_{self.namespace}" + if self.workspace: + # Include workspace in the file path for data isolation + workspace_dir = os.path.join(working_dir, self.workspace) + self.final_namespace = f"{self.workspace}_{self.namespace}" + else: # Default behavior when workspace is empty - workspace_dir = working_dir self.final_namespace = self.namespace - composite_workspace = "_" + workspace_dir = working_dir + self.workspace = "" os.makedirs(workspace_dir, exist_ok=True) self._faiss_index_file = os.path.join( @@ -78,9 +73,13 @@ class FaissVectorDBStorage(BaseVectorStorage): async def initialize(self): """Initialize storage data""" # Get the update flag for cross-process update notification - self.storage_updated = await get_update_flag(self.final_namespace) + self.storage_updated = await get_update_flag( + self.final_namespace, workspace=self.workspace + ) # Get the storage lock for use in other methods - self._storage_lock = get_storage_lock() + self._storage_lock = get_namespace_lock( + self.final_namespace, workspace=self.workspace + ) async def _get_index(self): """Check if the shtorage should be reloaded""" @@ -405,7 +404,9 @@ class FaissVectorDBStorage(BaseVectorStorage): # Save data to disk self._save_faiss_index() # Notify other processes that data has been updated - await set_all_update_flags(self.final_namespace) + await set_all_update_flags( + self.final_namespace, workspace=self.workspace + ) # Reset own update flag to avoid self-reloading self.storage_updated.value = False except Exception as e: @@ -455,23 +456,23 @@ class FaissVectorDBStorage(BaseVectorStorage): if not ids: return [] - results = [] + results: list[dict[str, Any] | None] = [] for id in ids: + record = None fid = self._find_faiss_id_by_custom_id(id) if fid is not None: - metadata = self._id_to_meta.get(fid, {}) + metadata = self._id_to_meta.get(fid) if metadata: # Filter out __vector__ from metadata to avoid returning large vector data filtered_metadata = { k: v for k, v in metadata.items() if k != "__vector__" } - results.append( - { - **filtered_metadata, - "id": metadata.get("__id__"), - "created_at": metadata.get("__created_at__"), - } - ) + record = { + **filtered_metadata, + "id": metadata.get("__id__"), + "created_at": metadata.get("__created_at__"), + } + results.append(record) return results @@ -532,7 +533,9 @@ class FaissVectorDBStorage(BaseVectorStorage): self._load_faiss_index() # Notify other processes - await set_all_update_flags(self.final_namespace) + await set_all_update_flags( + self.final_namespace, workspace=self.workspace + ) self.storage_updated.value = False logger.info( diff --git a/lightrag/kg/json_doc_status_impl.py b/lightrag/kg/json_doc_status_impl.py index a5648e28..a4ac792b 100644 --- a/lightrag/kg/json_doc_status_impl.py +++ b/lightrag/kg/json_doc_status_impl.py @@ -16,7 +16,7 @@ from lightrag.utils import ( from lightrag.exceptions import StorageNotInitializedError from .shared_storage import ( get_namespace_data, - get_storage_lock, + get_namespace_lock, get_data_init_lock, get_update_flag, set_all_update_flags, @@ -32,21 +32,15 @@ class JsonDocStatusStorage(DocStatusStorage): def __post_init__(self): working_dir = self.global_config["working_dir"] - - # Get composite workspace (supports multi-tenant isolation) - composite_workspace = self._get_composite_workspace() - - if composite_workspace and composite_workspace != "_": - # Include composite workspace in the file path for data isolation - # For multi-tenant: tenant_id:kb_id:workspace - # For single-tenant: just workspace - workspace_dir = os.path.join(working_dir, composite_workspace) - self.final_namespace = f"{composite_workspace}_{self.namespace}" + if self.workspace: + # Include workspace in the file path for data isolation + workspace_dir = os.path.join(working_dir, self.workspace) + self.final_namespace = f"{self.workspace}_{self.namespace}" else: # Default behavior when workspace is empty workspace_dir = working_dir self.final_namespace = self.namespace - composite_workspace = "_" + self.workspace = "" os.makedirs(workspace_dir, exist_ok=True) self._file_name = os.path.join(workspace_dir, f"kv_store_{self.namespace}.json") @@ -56,12 +50,20 @@ class JsonDocStatusStorage(DocStatusStorage): async def initialize(self): """Initialize storage data""" - self._storage_lock = get_storage_lock() - self.storage_updated = await get_update_flag(self.final_namespace) + self._storage_lock = get_namespace_lock( + self.namespace, workspace=self.workspace + ) + self.storage_updated = await get_update_flag( + self.namespace, workspace=self.workspace + ) async with get_data_init_lock(): # check need_init must before get_namespace_data - need_init = await try_initialize_namespace(self.final_namespace) - self._data = await get_namespace_data(self.final_namespace) + need_init = await try_initialize_namespace( + self.namespace, workspace=self.workspace + ) + self._data = await get_namespace_data( + self.namespace, workspace=self.workspace + ) if need_init: loaded_data = load_json(self._file_name) or {} async with self._storage_lock: @@ -96,21 +98,8 @@ class JsonDocStatusStorage(DocStatusStorage): if self._storage_lock is None: raise StorageNotInitializedError("JsonDocStatusStorage") async with self._storage_lock: - for doc_id, doc in self._data.items(): - try: - status = doc.get("status") - if status in counts: - counts[status] += 1 - else: - # Log warning for unknown status but don't fail - logger.warning( - f"[{self.workspace}] Unknown status '{status}' for document {doc_id}" - ) - except Exception as e: - logger.error( - f"[{self.workspace}] Error counting status for document {doc_id}: {e}" - ) - continue + for doc in self._data.values(): + counts[doc["status"]] += 1 return counts async def get_docs_by_status( @@ -190,11 +179,11 @@ class JsonDocStatusStorage(DocStatusStorage): f"[{self.workspace}] Reloading sanitized data into shared memory for {self.namespace}" ) cleaned_data = load_json(self._file_name) - if cleaned_data: + if cleaned_data is not None: self._data.clear() self._data.update(cleaned_data) - await clear_all_update_flags(self.final_namespace) + await clear_all_update_flags(self.namespace, workspace=self.workspace) async def upsert(self, data: dict[str, dict[str, Any]]) -> None: """ @@ -215,10 +204,24 @@ class JsonDocStatusStorage(DocStatusStorage): if "chunks_list" not in doc_data: doc_data["chunks_list"] = [] self._data.update(data) - await set_all_update_flags(self.final_namespace) + await set_all_update_flags(self.namespace, workspace=self.workspace) await self.index_done_callback() + async def is_empty(self) -> bool: + """Check if the storage is empty + + Returns: + bool: True if storage is empty, False otherwise + + Raises: + StorageNotInitializedError: If storage is not initialized + """ + if self._storage_lock is None: + raise StorageNotInitializedError("JsonDocStatusStorage") + async with self._storage_lock: + return len(self._data) == 0 + async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: async with self._storage_lock: return self._data.get(id) @@ -260,9 +263,6 @@ class JsonDocStatusStorage(DocStatusStorage): # For JSON storage, we load all data and sort/filter in memory all_docs = [] - if self._storage_lock is None: - raise StorageNotInitializedError("JsonDocStatusStorage") - async with self._storage_lock: for doc_id, doc_data in self._data.items(): # Apply status filter @@ -283,12 +283,7 @@ class JsonDocStatusStorage(DocStatusStorage): if "error_msg" not in data: data["error_msg"] = None - # Filter data to only include valid fields for DocProcessingStatus - # This prevents TypeError if extra fields are present in the JSON - valid_fields = DocProcessingStatus.__dataclass_fields__.keys() - filtered_data = {k: v for k, v in data.items() if k in valid_fields} - - doc_status = DocProcessingStatus(**filtered_data) + doc_status = DocProcessingStatus(**data) # Add sort key for sorting if sort_field == "id": @@ -302,16 +297,11 @@ class JsonDocStatusStorage(DocStatusStorage): all_docs.append((doc_id, doc_status)) - except (KeyError, TypeError, ValueError) as e: + except KeyError as e: logger.error( f"[{self.workspace}] Error processing document {doc_id}: {e}" ) continue - except Exception as e: - logger.error( - f"[{self.workspace}] Unexpected error processing document {doc_id}: {e}" - ) - continue # Sort documents reverse_sort = sort_direction.lower() == "desc" @@ -368,7 +358,7 @@ class JsonDocStatusStorage(DocStatusStorage): any_deleted = True if any_deleted: - await set_all_update_flags(self.final_namespace) + await set_all_update_flags(self.namespace, workspace=self.workspace) async def get_doc_by_file_path(self, file_path: str) -> Union[dict[str, Any], None]: """Get document by file path @@ -391,28 +381,6 @@ class JsonDocStatusStorage(DocStatusStorage): return None - async def get_doc_by_external_id( - self, external_id: str - ) -> Union[dict[str, Any], None]: - """Get document by external ID for idempotency checks. - - Args: - external_id: The external ID to search for (client-provided unique identifier) - - Returns: - Union[dict[str, Any], None]: Document data if found, None otherwise - Returns the same format as get_by_id method - """ - if self._storage_lock is None: - raise StorageNotInitializedError("JsonDocStatusStorage") - - async with self._storage_lock: - for doc_id, doc_data in self._data.items(): - if doc_data.get("external_id") == external_id: - return doc_data - - return None - async def drop(self) -> dict[str, str]: """Drop all document status data from storage and clean up resources @@ -429,7 +397,7 @@ class JsonDocStatusStorage(DocStatusStorage): try: async with self._storage_lock: self._data.clear() - await set_all_update_flags(self.final_namespace) + await set_all_update_flags(self.namespace, workspace=self.workspace) await self.index_done_callback() logger.info( diff --git a/lightrag/kg/json_kv_impl.py b/lightrag/kg/json_kv_impl.py index 88df469a..b1151e73 100644 --- a/lightrag/kg/json_kv_impl.py +++ b/lightrag/kg/json_kv_impl.py @@ -13,7 +13,7 @@ from lightrag.utils import ( from lightrag.exceptions import StorageNotInitializedError from .shared_storage import ( get_namespace_data, - get_storage_lock, + get_namespace_lock, get_data_init_lock, get_update_flag, set_all_update_flags, @@ -27,21 +27,15 @@ from .shared_storage import ( class JsonKVStorage(BaseKVStorage): def __post_init__(self): working_dir = self.global_config["working_dir"] - - # Get composite workspace (supports multi-tenant isolation) - composite_workspace = self._get_composite_workspace() - - if composite_workspace and composite_workspace != "_": - # Include composite workspace in the file path for data isolation - # For multi-tenant: tenant_id:kb_id:workspace - # For single-tenant: just workspace - workspace_dir = os.path.join(working_dir, composite_workspace) - self.final_namespace = f"{composite_workspace}_{self.namespace}" + if self.workspace: + # Include workspace in the file path for data isolation + workspace_dir = os.path.join(working_dir, self.workspace) + self.final_namespace = f"{self.workspace}_{self.namespace}" else: # Default behavior when workspace is empty workspace_dir = working_dir self.final_namespace = self.namespace - composite_workspace = "_" + self.workspace = "" os.makedirs(workspace_dir, exist_ok=True) self._file_name = os.path.join(workspace_dir, f"kv_store_{self.namespace}.json") @@ -52,12 +46,20 @@ class JsonKVStorage(BaseKVStorage): async def initialize(self): """Initialize storage data""" - self._storage_lock = get_storage_lock() - self.storage_updated = await get_update_flag(self.final_namespace) + self._storage_lock = get_namespace_lock( + self.namespace, workspace=self.workspace + ) + self.storage_updated = await get_update_flag( + self.namespace, workspace=self.workspace + ) async with get_data_init_lock(): # check need_init must before get_namespace_data - need_init = await try_initialize_namespace(self.final_namespace) - self._data = await get_namespace_data(self.final_namespace) + need_init = await try_initialize_namespace( + self.namespace, workspace=self.workspace + ) + self._data = await get_namespace_data( + self.namespace, workspace=self.workspace + ) if need_init: loaded_data = load_json(self._file_name) or {} async with self._storage_lock: @@ -97,31 +99,11 @@ class JsonKVStorage(BaseKVStorage): f"[{self.workspace}] Reloading sanitized data into shared memory for {self.namespace}" ) cleaned_data = load_json(self._file_name) - if cleaned_data: + if cleaned_data is not None: self._data.clear() self._data.update(cleaned_data) - await clear_all_update_flags(self.final_namespace) - - async def get_all(self) -> dict[str, Any]: - """Get all data from storage - - Returns: - Dictionary containing all stored data - """ - async with self._storage_lock: - result = {} - for key, value in self._data.items(): - if value: - # Create a copy to avoid modifying the original data - data = dict(value) - # Ensure time fields are present, provide default values for old data - data.setdefault("create_time", 0) - data.setdefault("update_time", 0) - result[key] = data - else: - result[key] = value - return result + await clear_all_update_flags(self.namespace, workspace=self.workspace) async def get_by_id(self, id: str) -> dict[str, Any] | None: async with self._storage_lock: @@ -194,7 +176,7 @@ class JsonKVStorage(BaseKVStorage): v["_id"] = k self._data.update(data) - await set_all_update_flags(self.final_namespace) + await set_all_update_flags(self.namespace, workspace=self.workspace) async def delete(self, ids: list[str]) -> None: """Delete specific records from storage by their IDs @@ -217,7 +199,16 @@ class JsonKVStorage(BaseKVStorage): any_deleted = True if any_deleted: - await set_all_update_flags(self.final_namespace) + await set_all_update_flags(self.namespace, workspace=self.workspace) + + async def is_empty(self) -> bool: + """Check if the storage is empty + + Returns: + bool: True if storage contains no data, False otherwise + """ + async with self._storage_lock: + return len(self._data) == 0 async def drop(self) -> dict[str, str]: """Drop all data from storage and clean up resources @@ -236,7 +227,7 @@ class JsonKVStorage(BaseKVStorage): try: async with self._storage_lock: self._data.clear() - await set_all_update_flags(self.final_namespace) + await set_all_update_flags(self.namespace, workspace=self.workspace) await self.index_done_callback() logger.info( @@ -254,7 +245,7 @@ class JsonKVStorage(BaseKVStorage): data: Original data dictionary that may contain legacy structure Returns: - Migrated data dictionary with flattened cache keys + Migrated data dictionary with flattened cache keys (sanitized if needed) """ from lightrag.utils import generate_cache_key @@ -291,8 +282,17 @@ class JsonKVStorage(BaseKVStorage): logger.info( f"[{self.workspace}] Migrated {migration_count} legacy cache entries to flattened structure" ) - # Persist migrated data immediately - write_json(migrated_data, self._file_name) + # Persist migrated data immediately and check if sanitization was applied + needs_reload = write_json(migrated_data, self._file_name) + + # If data was sanitized during write, reload cleaned data + if needs_reload: + logger.info( + f"[{self.workspace}] Reloading sanitized migration data for {self.namespace}" + ) + cleaned_data = load_json(self._file_name) + if cleaned_data is not None: + return cleaned_data # Return cleaned data to update shared memory return migrated_data diff --git a/lightrag/kg/milvus_impl.py b/lightrag/kg/milvus_impl.py index 3649aeb2..ff53b0c7 100644 --- a/lightrag/kg/milvus_impl.py +++ b/lightrag/kg/milvus_impl.py @@ -944,8 +944,8 @@ class MilvusVectorDBStorage(BaseVectorStorage): else: # When workspace is empty, final_namespace equals original namespace self.final_namespace = self.namespace + self.workspace = "" logger.debug(f"Final namespace (no workspace): '{self.final_namespace}'") - self.workspace = "_" kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {}) cosine_threshold = kwargs.get("cosine_better_than_threshold") diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index a62c3031..fdc00e1a 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -120,7 +120,7 @@ class MongoKVStorage(BaseKVStorage): else: # When workspace is empty, final_namespace equals original namespace self.final_namespace = self.namespace - self.workspace = "_" + self.workspace = "" logger.debug( f"[{self.workspace}] Final namespace (no workspace): '{self.namespace}'" ) @@ -352,7 +352,7 @@ class MongoDocStatusStorage(DocStatusStorage): else: # When workspace is empty, final_namespace equals original namespace self.final_namespace = self.namespace - self.workspace = "_" + self.workspace = "" logger.debug(f"Final namespace (no workspace): '{self.final_namespace}'") self._collection_name = self.final_namespace @@ -505,7 +505,7 @@ class MongoDocStatusStorage(DocStatusStorage): collation_config = {"locale": "zh", "numericOrdering": True} # Use workspace-specific index names to avoid cross-workspace conflicts - workspace_prefix = f"{self.workspace}_" if self.workspace != "_" else "" + workspace_prefix = f"{self.workspace}_" if self.workspace != "" else "" # 1. Define all indexes needed with workspace-specific names all_indexes = [ @@ -763,7 +763,7 @@ class MongoGraphStorage(BaseGraphStorage): else: # When workspace is empty, final_namespace equals original namespace self.final_namespace = self.namespace - self.workspace = "_" + self.workspace = "" logger.debug(f"Final namespace (no workspace): '{self.final_namespace}'") self._collection_name = self.final_namespace @@ -2116,7 +2116,7 @@ class MongoVectorDBStorage(BaseVectorStorage): else: # When workspace is empty, final_namespace equals original namespace self.final_namespace = self.namespace - self.workspace = "_" + self.workspace = "" logger.debug(f"Final namespace (no workspace): '{self.final_namespace}'") # Set index name based on workspace for backward compatibility diff --git a/lightrag/kg/nano_vector_db_impl.py b/lightrag/kg/nano_vector_db_impl.py index d54bb56f..cc7a411d 100644 --- a/lightrag/kg/nano_vector_db_impl.py +++ b/lightrag/kg/nano_vector_db_impl.py @@ -53,7 +53,7 @@ class NanoVectorDBStorage(BaseVectorStorage): else: # Default behavior when workspace is empty self.final_namespace = self.namespace - self.workspace = "_" + self.workspace = "" workspace_dir = working_dir os.makedirs(workspace_dir, exist_ok=True) diff --git a/lightrag/kg/networkx_impl.py b/lightrag/kg/networkx_impl.py index bb4049b9..74ee1690 100644 --- a/lightrag/kg/networkx_impl.py +++ b/lightrag/kg/networkx_impl.py @@ -53,7 +53,7 @@ class NetworkXStorage(BaseGraphStorage): # Default behavior when workspace is empty self.final_namespace = self.namespace workspace_dir = working_dir - self.workspace = "_" + self.workspace = "" os.makedirs(workspace_dir, exist_ok=True) self._graphml_xml_file = os.path.join( diff --git a/lightrag/kg/redis_impl.py b/lightrag/kg/redis_impl.py index 13768385..a254d4ee 100644 --- a/lightrag/kg/redis_impl.py +++ b/lightrag/kg/redis_impl.py @@ -14,7 +14,6 @@ if not pm.is_installed("redis"): from redis.asyncio import Redis, ConnectionPool # type: ignore from redis.exceptions import RedisError, ConnectionError, TimeoutError # type: ignore from lightrag.utils import logger, get_pinyin_sort_key -from lightrag.utils_context import get_current_tenant_id from lightrag.base import ( BaseKVStorage, @@ -22,7 +21,7 @@ from lightrag.base import ( DocStatus, DocProcessingStatus, ) -from ..kg.shared_storage import get_data_init_lock, get_storage_lock +from ..kg.shared_storage import get_data_init_lock import json # Import tenacity for retry logic @@ -147,21 +146,14 @@ class RedisKVStorage(BaseKVStorage): # Build final_namespace with workspace prefix for data isolation # Keep original namespace unchanged for type detection logic if effective_workspace: - self.workspace = effective_workspace - else: - self.workspace = "_" - - # Get composite workspace (supports multi-tenant isolation) - composite_workspace = self._get_composite_workspace() - - if composite_workspace and composite_workspace != "_": - self.final_namespace = f"{composite_workspace}_{self.namespace}" + self.final_namespace = f"{effective_workspace}_{self.namespace}" logger.debug( - f"Final namespace with composite workspace: '{self.final_namespace}'" + f"Final namespace with workspace prefix: '{self.final_namespace}'" ) else: # When workspace is empty, final_namespace equals original namespace self.final_namespace = self.namespace + self.workspace = "" logger.debug(f"Final namespace (no workspace): '{self.final_namespace}'") self._redis_url = os.environ.get( @@ -271,20 +263,11 @@ class RedisKVStorage(BaseKVStorage): """Ensure Redis resources are cleaned up when exiting context.""" await self.close() - def _get_key_prefix(self) -> str: - tenant_id = get_current_tenant_id() - if tenant_id: - return f"{self.final_namespace}:{tenant_id}" - return self.final_namespace - - def _get_full_key(self, id: str) -> str: - return f"{self._get_key_prefix()}:{id}" - @redis_retry async def get_by_id(self, id: str) -> dict[str, Any] | None: async with self._get_redis_connection() as redis: try: - data = await redis.get(self._get_full_key(id)) + data = await redis.get(f"{self.final_namespace}:{id}") if data: result = json.loads(data) # Ensure time fields are present, provide default values for old data @@ -302,7 +285,7 @@ class RedisKVStorage(BaseKVStorage): try: pipe = redis.pipeline() for id in ids: - pipe.get(self._get_full_key(id)) + pipe.get(f"{self.final_namespace}:{id}") results = await pipe.execute() processed_results = [] @@ -321,62 +304,12 @@ class RedisKVStorage(BaseKVStorage): logger.error(f"[{self.workspace}] JSON decode error in batch get: {e}") return [None] * len(ids) - async def get_all(self) -> dict[str, Any]: - """Get all data from storage - - Returns: - Dictionary containing all stored data - """ - async with self._get_redis_connection() as redis: - try: - # Get all keys for this namespace - prefix = self._get_key_prefix() - keys = await redis.keys(f"{prefix}:*") - - if not keys: - return {} - - # Get all values in batch - pipe = redis.pipeline() - for key in keys: - pipe.get(key) - values = await pipe.execute() - - # Build result dictionary - result = {} - for key, value in zip(keys, values): - if value: - # Extract the ID part (after prefix:) - # key is prefix:id - # prefix might contain colons, so we need to be careful - # key_id = key[len(prefix) + 1:] - key_id = key[len(prefix) + 1 :] - try: - data = json.loads(value) - # Ensure time fields are present for all documents - data.setdefault("create_time", 0) - data.setdefault("update_time", 0) - result[key_id] = data - except json.JSONDecodeError as e: - logger.error( - f"[{self.workspace}] JSON decode error for key {key}: {e}" - ) - continue - - return result - except Exception as e: - logger.error( - f"[{self.workspace}] Error getting all data from Redis: {e}" - ) - return {} - async def filter_keys(self, keys: set[str]) -> set[str]: - """Return keys that should be processed (not in storage or not successfully processed)""" async with self._get_redis_connection() as redis: pipe = redis.pipeline() keys_list = list(keys) # Convert set to list for indexing for key in keys_list: - pipe.exists(self._get_full_key(key)) + pipe.exists(f"{self.final_namespace}:{key}") results = await pipe.execute() existing_ids = {keys_list[i] for i, exists in enumerate(results) if exists} @@ -390,14 +323,13 @@ class RedisKVStorage(BaseKVStorage): import time current_time = int(time.time()) # Get current Unix timestamp - tenant_id = get_current_tenant_id() async with self._get_redis_connection() as redis: try: # Check which keys already exist to determine create vs update pipe = redis.pipeline() for k in data.keys(): - pipe.exists(self._get_full_key(k)) + pipe.exists(f"{self.final_namespace}:{k}") exists_results = await pipe.execute() # Add timestamps to data @@ -415,13 +347,11 @@ class RedisKVStorage(BaseKVStorage): v["update_time"] = current_time v["_id"] = k - if tenant_id: - v["tenant_id"] = tenant_id # Store the data pipe = redis.pipeline() for k, v in data.items(): - pipe.set(self._get_full_key(k), json.dumps(v)) + pipe.set(f"{self.final_namespace}:{k}", json.dumps(v)) await pipe.execute() except json.JSONDecodeError as e: @@ -432,15 +362,32 @@ class RedisKVStorage(BaseKVStorage): # Redis handles persistence automatically pass + async def is_empty(self) -> bool: + """Check if the storage is empty for the current workspace and namespace + + Returns: + bool: True if storage is empty, False otherwise + """ + pattern = f"{self.final_namespace}:*" + try: + async with self._get_redis_connection() as redis: + # Use scan to check if any keys exist + async for key in redis.scan_iter(match=pattern, count=1): + return False # Found at least one key + return True # No keys found + except Exception as e: + logger.error(f"[{self.workspace}] Error checking if storage is empty: {e}") + return True + async def delete(self, ids: list[str]) -> None: - """Delete entries with specified IDs""" + """Delete specific records from storage by their IDs""" if not ids: return async with self._get_redis_connection() as redis: pipe = redis.pipeline() for id in ids: - pipe.delete(self._get_full_key(id)) + pipe.delete(f"{self.final_namespace}:{id}") results = await pipe.execute() deleted_count = sum(results) @@ -454,43 +401,39 @@ class RedisKVStorage(BaseKVStorage): Returns: dict[str, str]: Status of the operation with keys 'status' and 'message' """ - async with get_storage_lock(): - async with self._get_redis_connection() as redis: - try: - # Use SCAN to find all keys with the namespace prefix - prefix = self._get_key_prefix() - pattern = f"{prefix}:*" - cursor = 0 - deleted_count = 0 + async with self._get_redis_connection() as redis: + try: + # Use SCAN to find all keys with the namespace prefix + pattern = f"{self.final_namespace}:*" + cursor = 0 + deleted_count = 0 - while True: - cursor, keys = await redis.scan( - cursor, match=pattern, count=1000 - ) - if keys: - # Delete keys in batches - pipe = redis.pipeline() - for key in keys: - pipe.delete(key) - results = await pipe.execute() - deleted_count += sum(results) + while True: + cursor, keys = await redis.scan(cursor, match=pattern, count=1000) + if keys: + # Delete keys in batches + pipe = redis.pipeline() + for key in keys: + pipe.delete(key) + results = await pipe.execute() + deleted_count += sum(results) - if cursor == 0: - break + if cursor == 0: + break - logger.info( - f"[{self.workspace}] Dropped {deleted_count} keys from {self.namespace}" - ) - return { - "status": "success", - "message": f"{deleted_count} keys dropped", - } + logger.info( + f"[{self.workspace}] Dropped {deleted_count} keys from {self.namespace}" + ) + return { + "status": "success", + "message": f"{deleted_count} keys dropped", + } - except Exception as e: - logger.error( - f"[{self.workspace}] Error dropping keys from {self.namespace}: {e}" - ) - return {"status": "error", "message": str(e)} + except Exception as e: + logger.error( + f"[{self.workspace}] Error dropping keys from {self.namespace}: {e}" + ) + return {"status": "error", "message": str(e)} async def _migrate_legacy_cache_structure(self): """Migrate legacy nested cache structure to flattened structure for Redis @@ -504,8 +447,7 @@ class RedisKVStorage(BaseKVStorage): async with self._get_redis_connection() as redis: # Get all keys for this namespace - prefix = self._get_key_prefix() - keys = await redis.keys(f"{prefix}:*") + keys = await redis.keys(f"{self.final_namespace}:*") if not keys: return @@ -515,8 +457,8 @@ class RedisKVStorage(BaseKVStorage): keys_to_migrate = [] for key in keys: - # Extract the ID part (after prefix:) - key_id = key[len(prefix) + 1 :] + # Extract the ID part (after namespace:) + key_id = key.split(":", 1)[1] # Check if already in flattened format (contains exactly 2 colons for mode:cache_type:hash) if ":" in key_id and len(key_id.split(":")) == 3: @@ -559,7 +501,7 @@ class RedisKVStorage(BaseKVStorage): for cache_hash, cache_entry in nested_data.items(): cache_type = cache_entry.get("cache_type", "extract") flattened_key = generate_cache_key(mode, cache_type, cache_hash) - full_key = f"{prefix}:{flattened_key}" + full_key = f"{self.final_namespace}:{flattened_key}" pipe.set(full_key, json.dumps(cache_entry)) migration_count += 1 @@ -597,21 +539,14 @@ class RedisDocStatusStorage(DocStatusStorage): # Build final_namespace with workspace prefix for data isolation # Keep original namespace unchanged for type detection logic if effective_workspace: - self.workspace = effective_workspace - else: - self.workspace = "_" - - # Get composite workspace (supports multi-tenant isolation) - composite_workspace = self._get_composite_workspace() - - if composite_workspace and composite_workspace != "_": - self.final_namespace = f"{composite_workspace}_{self.namespace}" + self.final_namespace = f"{effective_workspace}_{self.namespace}" logger.debug( - f"[{self.workspace}] Final namespace with composite workspace: '{self.namespace}'" + f"[{self.workspace}] Final namespace with workspace prefix: '{self.namespace}'" ) else: # When workspace is empty, final_namespace equals original namespace self.final_namespace = self.namespace + self.workspace = "_" logger.debug( f"[{self.workspace}] Final namespace (no workspace): '{self.namespace}'" ) @@ -714,22 +649,13 @@ class RedisDocStatusStorage(DocStatusStorage): """Ensure Redis resources are cleaned up when exiting context.""" await self.close() - def _get_key_prefix(self) -> str: - tenant_id = get_current_tenant_id() - if tenant_id: - return f"{self.final_namespace}:{tenant_id}" - return self.final_namespace - - def _get_full_key(self, id: str) -> str: - return f"{self._get_key_prefix()}:{id}" - async def filter_keys(self, keys: set[str]) -> set[str]: """Return keys that should be processed (not in storage or not successfully processed)""" async with self._get_redis_connection() as redis: pipe = redis.pipeline() keys_list = list(keys) for key in keys_list: - pipe.exists(self._get_full_key(key)) + pipe.exists(f"{self.final_namespace}:{key}") results = await pipe.execute() existing_ids = {keys_list[i] for i, exists in enumerate(results) if exists} @@ -741,7 +667,7 @@ class RedisDocStatusStorage(DocStatusStorage): try: pipe = redis.pipeline() for id in ids: - pipe.get(self._get_full_key(id)) + pipe.get(f"{self.final_namespace}:{id}") results = await pipe.execute() for result_data in results: @@ -765,11 +691,10 @@ class RedisDocStatusStorage(DocStatusStorage): async with self._get_redis_connection() as redis: try: # Use SCAN to iterate through all keys in the namespace - prefix = self._get_key_prefix() cursor = 0 while True: cursor, keys = await redis.scan( - cursor, match=f"{prefix}:*", count=1000 + cursor, match=f"{self.final_namespace}:*", count=1000 ) if keys: # Get all values in batch @@ -804,11 +729,10 @@ class RedisDocStatusStorage(DocStatusStorage): async with self._get_redis_connection() as redis: try: # Use SCAN to iterate through all keys in the namespace - prefix = self._get_key_prefix() cursor = 0 while True: cursor, keys = await redis.scan( - cursor, match=f"{prefix}:*", count=1000 + cursor, match=f"{self.final_namespace}:*", count=1000 ) if keys: # Get all values in batch @@ -824,8 +748,7 @@ class RedisDocStatusStorage(DocStatusStorage): doc_data = json.loads(value) if doc_data.get("status") == status.value: # Extract document ID from key - # key is prefix:id - doc_id = key[len(prefix) + 1 :] + doc_id = key.split(":", 1)[1] # Make a copy of the data to avoid modifying the original data = doc_data.copy() @@ -862,11 +785,10 @@ class RedisDocStatusStorage(DocStatusStorage): async with self._get_redis_connection() as redis: try: # Use SCAN to iterate through all keys in the namespace - prefix = self._get_key_prefix() cursor = 0 while True: cursor, keys = await redis.scan( - cursor, match=f"{prefix}:*", count=1000 + cursor, match=f"{self.final_namespace}:*", count=1000 ) if keys: # Get all values in batch @@ -882,7 +804,7 @@ class RedisDocStatusStorage(DocStatusStorage): doc_data = json.loads(value) if doc_data.get("track_id") == track_id: # Extract document ID from key - doc_id = key[len(prefix) + 1 :] + doc_id = key.split(":", 1)[1] # Make a copy of the data to avoid modifying the original data = doc_data.copy() @@ -915,6 +837,23 @@ class RedisDocStatusStorage(DocStatusStorage): """Redis handles persistence automatically""" pass + async def is_empty(self) -> bool: + """Check if the storage is empty for the current workspace and namespace + + Returns: + bool: True if storage is empty, False otherwise + """ + pattern = f"{self.final_namespace}:*" + try: + async with self._get_redis_connection() as redis: + # Use scan to check if any keys exist + async for key in redis.scan_iter(match=pattern, count=1): + return False # Found at least one key + return True # No keys found + except Exception as e: + logger.error(f"[{self.workspace}] Error checking if storage is empty: {e}") + return True + @redis_retry async def upsert(self, data: dict[str, dict[str, Any]]) -> None: """Insert or update document status data""" @@ -933,7 +872,7 @@ class RedisDocStatusStorage(DocStatusStorage): pipe = redis.pipeline() for k, v in data.items(): - pipe.set(self._get_full_key(k), json.dumps(v)) + pipe.set(f"{self.final_namespace}:{k}", json.dumps(v)) await pipe.execute() except json.JSONDecodeError as e: logger.error(f"[{self.workspace}] JSON decode error during upsert: {e}") @@ -943,7 +882,7 @@ class RedisDocStatusStorage(DocStatusStorage): async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: async with self._get_redis_connection() as redis: try: - data = await redis.get(self._get_full_key(id)) + data = await redis.get(f"{self.final_namespace}:{id}") return json.loads(data) if data else None except json.JSONDecodeError as e: logger.error(f"[{self.workspace}] JSON decode error for id {id}: {e}") @@ -957,7 +896,7 @@ class RedisDocStatusStorage(DocStatusStorage): async with self._get_redis_connection() as redis: pipe = redis.pipeline() for doc_id in doc_ids: - pipe.delete(self._get_full_key(doc_id)) + pipe.delete(f"{self.final_namespace}:{doc_id}") results = await pipe.execute() deleted_count = sum(results) @@ -1006,11 +945,10 @@ class RedisDocStatusStorage(DocStatusStorage): async with self._get_redis_connection() as redis: try: # Use SCAN to iterate through all keys in the namespace - prefix = self._get_key_prefix() cursor = 0 while True: cursor, keys = await redis.scan( - cursor, match=f"{prefix}:*", count=1000 + cursor, match=f"{self.final_namespace}:*", count=1000 ) if keys: # Get all values in batch @@ -1034,7 +972,7 @@ class RedisDocStatusStorage(DocStatusStorage): continue # Extract document ID from key - doc_id = key[len(prefix) + 1 :] + doc_id = key.split(":", 1)[1] # Prepare document data data = doc_data.copy() @@ -1115,11 +1053,10 @@ class RedisDocStatusStorage(DocStatusStorage): async with self._get_redis_connection() as redis: try: # Use SCAN to iterate through all keys in the namespace - prefix = self._get_key_prefix() cursor = 0 while True: cursor, keys = await redis.scan( - cursor, match=f"{prefix}:*", count=1000 + cursor, match=f"{self.final_namespace}:*", count=1000 ) if keys: # Get all values in batch @@ -1149,94 +1086,34 @@ class RedisDocStatusStorage(DocStatusStorage): logger.error(f"[{self.workspace}] Error in get_doc_by_file_path: {e}") return None - async def get_doc_by_external_id( - self, external_id: str - ) -> Union[dict[str, Any], None]: - """Get document by external ID for idempotency checks. - - Args: - external_id: The external ID to search for (client-provided unique identifier) - - Returns: - Union[dict[str, Any], None]: Document data if found, None otherwise - Returns the same format as get_by_id method - - Note: - This method scans all documents in the namespace since Redis doesn't - support secondary indexes. For high-volume workloads, consider using - a separate hash index or switching to PostgreSQL storage. - """ - async with self._get_redis_connection() as redis: - try: - # Use SCAN to iterate through all keys in the namespace - prefix = self._get_key_prefix() + async def drop(self) -> dict[str, str]: + """Drop all document status data from storage and clean up resources""" + try: + async with self._get_redis_connection() as redis: + # Use SCAN to find all keys with the namespace prefix + pattern = f"{self.final_namespace}:*" cursor = 0 + deleted_count = 0 + while True: - cursor, keys = await redis.scan( - cursor, match=f"{prefix}:*", count=1000 - ) + cursor, keys = await redis.scan(cursor, match=pattern, count=1000) if keys: - # Get all values in batch + # Delete keys in batches pipe = redis.pipeline() for key in keys: - pipe.get(key) - values = await pipe.execute() - - # Check each document for matching external_id - for value in values: - if value: - try: - doc_data = json.loads(value) - if doc_data.get("external_id") == external_id: - return doc_data - except json.JSONDecodeError as e: - logger.error( - f"[{self.workspace}] JSON decode error in get_doc_by_external_id: {e}" - ) - continue + pipe.delete(key) + results = await pipe.execute() + deleted_count += sum(results) if cursor == 0: break - return None - except Exception as e: - logger.error( - f"[{self.workspace}] Error in get_doc_by_external_id: {e}" + logger.info( + f"[{self.workspace}] Dropped {deleted_count} doc status keys from {self.namespace}" ) - return None - - async def drop(self) -> dict[str, str]: - """Drop all document status data from storage and clean up resources""" - async with get_storage_lock(): - try: - async with self._get_redis_connection() as redis: - # Use SCAN to find all keys with the namespace prefix - prefix = self._get_key_prefix() - pattern = f"{prefix}:*" - cursor = 0 - deleted_count = 0 - - while True: - cursor, keys = await redis.scan( - cursor, match=pattern, count=1000 - ) - if keys: - # Delete keys in batches - pipe = redis.pipeline() - for key in keys: - pipe.delete(key) - results = await pipe.execute() - deleted_count += sum(results) - - if cursor == 0: - break - - logger.info( - f"[{self.workspace}] Dropped {deleted_count} doc status keys from {self.namespace}" - ) - return {"status": "success", "message": "data dropped"} - except Exception as e: - logger.error( - f"[{self.workspace}] Error dropping doc status {self.namespace}: {e}" - ) - return {"status": "error", "message": str(e)} + return {"status": "success", "message": "data dropped"} + except Exception as e: + logger.error( + f"[{self.workspace}] Error dropping doc status {self.namespace}: {e}" + ) + return {"status": "error", "message": str(e)}