diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 8de03283..376dec5d 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -56,6 +56,8 @@ from lightrag.api.routers.ollama_api import OllamaAPI from lightrag.utils import logger, set_verbose_debug from lightrag.kg.shared_storage import ( get_namespace_data, + get_default_workspace, + # set_default_workspace, initialize_pipeline_status, cleanup_keyed_lock, finalize_share_data, @@ -350,8 +352,9 @@ def create_app(args): try: # Initialize database connections + # set_default_workspace(rag.workspace) # comment this line to test auto default workspace setting in initialize_storages await rag.initialize_storages() - await initialize_pipeline_status() + await initialize_pipeline_status() # with default workspace # Data migration regardless of storage implementation await rag.check_and_migrate_data() @@ -452,7 +455,7 @@ def create_app(args): # Create combined auth dependency for all endpoints combined_auth = get_combined_auth_dependency(api_key) - def get_workspace_from_request(request: Request) -> str: + def get_workspace_from_request(request: Request) -> str | None: """ Extract workspace from HTTP request header or use default. @@ -469,9 +472,8 @@ def create_app(args): # Check custom header first workspace = request.headers.get("LIGHTRAG-WORKSPACE", "").strip() - # Fall back to server default if header not provided if not workspace: - workspace = args.workspace + workspace = None return workspace @@ -641,33 +643,108 @@ def create_app(args): def create_optimized_embedding_function( config_cache: LLMConfigCache, binding, model, host, api_key, args - ): + ) -> EmbeddingFunc: """ - Create optimized embedding function with pre-processed configuration for applicable bindings. - Uses lazy imports for all bindings and avoids repeated configuration parsing. + Create optimized embedding function and return an EmbeddingFunc instance + with proper max_token_size inheritance from provider defaults. + + This function: + 1. Imports the provider embedding function + 2. Extracts max_token_size and embedding_dim from provider if it's an EmbeddingFunc + 3. Creates an optimized wrapper that calls the underlying function directly (avoiding double-wrapping) + 4. Returns a properly configured EmbeddingFunc instance """ + # Step 1: Import provider function and extract default attributes + provider_func = None + provider_max_token_size = None + provider_embedding_dim = None + + try: + if binding == "openai": + from lightrag.llm.openai import openai_embed + + provider_func = openai_embed + elif binding == "ollama": + from lightrag.llm.ollama import ollama_embed + + provider_func = ollama_embed + elif binding == "gemini": + from lightrag.llm.gemini import gemini_embed + + provider_func = gemini_embed + elif binding == "jina": + from lightrag.llm.jina import jina_embed + + provider_func = jina_embed + elif binding == "azure_openai": + from lightrag.llm.azure_openai import azure_openai_embed + + provider_func = azure_openai_embed + elif binding == "aws_bedrock": + from lightrag.llm.bedrock import bedrock_embed + + provider_func = bedrock_embed + elif binding == "lollms": + from lightrag.llm.lollms import lollms_embed + + provider_func = lollms_embed + + # Extract attributes if provider is an EmbeddingFunc + if provider_func and isinstance(provider_func, EmbeddingFunc): + provider_max_token_size = provider_func.max_token_size + provider_embedding_dim = provider_func.embedding_dim + logger.debug( + f"Extracted from {binding} provider: " + f"max_token_size={provider_max_token_size}, " + f"embedding_dim={provider_embedding_dim}" + ) + except ImportError as e: + logger.warning(f"Could not import provider function for {binding}: {e}") + + # Step 2: Apply priority (user config > provider default) + # For max_token_size: explicit env var > provider default > None + final_max_token_size = args.embedding_token_limit or provider_max_token_size + # For embedding_dim: user config (always has value) takes priority + # Only use provider default if user config is explicitly None (which shouldn't happen) + final_embedding_dim = ( + args.embedding_dim if args.embedding_dim else provider_embedding_dim + ) + + # Step 3: Create optimized embedding function (calls underlying function directly) async def optimized_embedding_function(texts, embedding_dim=None): try: if binding == "lollms": from lightrag.llm.lollms import lollms_embed - return await lollms_embed( + # Get real function, skip EmbeddingFunc wrapper if present + actual_func = ( + lollms_embed.func + if isinstance(lollms_embed, EmbeddingFunc) + else lollms_embed + ) + return await actual_func( texts, embed_model=model, host=host, api_key=api_key ) elif binding == "ollama": from lightrag.llm.ollama import ollama_embed - # Use pre-processed configuration if available, otherwise fallback to dynamic parsing + # Get real function, skip EmbeddingFunc wrapper if present + actual_func = ( + ollama_embed.func + if isinstance(ollama_embed, EmbeddingFunc) + else ollama_embed + ) + + # Use pre-processed configuration if available if config_cache.ollama_embedding_options is not None: ollama_options = config_cache.ollama_embedding_options else: - # Fallback for cases where config cache wasn't initialized properly from lightrag.llm.binding_options import OllamaEmbeddingOptions ollama_options = OllamaEmbeddingOptions.options_dict(args) - return await ollama_embed( + return await actual_func( texts, embed_model=model, host=host, @@ -677,15 +754,30 @@ def create_app(args): elif binding == "azure_openai": from lightrag.llm.azure_openai import azure_openai_embed - return await azure_openai_embed(texts, model=model, api_key=api_key) + actual_func = ( + azure_openai_embed.func + if isinstance(azure_openai_embed, EmbeddingFunc) + else azure_openai_embed + ) + return await actual_func(texts, model=model, api_key=api_key) elif binding == "aws_bedrock": from lightrag.llm.bedrock import bedrock_embed - return await bedrock_embed(texts, model=model) + actual_func = ( + bedrock_embed.func + if isinstance(bedrock_embed, EmbeddingFunc) + else bedrock_embed + ) + return await actual_func(texts, model=model) elif binding == "jina": from lightrag.llm.jina import jina_embed - return await jina_embed( + actual_func = ( + jina_embed.func + if isinstance(jina_embed, EmbeddingFunc) + else jina_embed + ) + return await actual_func( texts, embedding_dim=embedding_dim, base_url=host, @@ -694,16 +786,21 @@ def create_app(args): elif binding == "gemini": from lightrag.llm.gemini import gemini_embed - # Use pre-processed configuration if available, otherwise fallback to dynamic parsing + actual_func = ( + gemini_embed.func + if isinstance(gemini_embed, EmbeddingFunc) + else gemini_embed + ) + + # Use pre-processed configuration if available if config_cache.gemini_embedding_options is not None: gemini_options = config_cache.gemini_embedding_options else: - # Fallback for cases where config cache wasn't initialized properly from lightrag.llm.binding_options import GeminiEmbeddingOptions gemini_options = GeminiEmbeddingOptions.options_dict(args) - return await gemini_embed( + return await actual_func( texts, model=model, base_url=host, @@ -714,7 +811,12 @@ def create_app(args): else: # openai and compatible from lightrag.llm.openai import openai_embed - return await openai_embed( + actual_func = ( + openai_embed.func + if isinstance(openai_embed, EmbeddingFunc) + else openai_embed + ) + return await actual_func( texts, model=model, base_url=host, @@ -724,7 +826,21 @@ def create_app(args): except ImportError as e: raise Exception(f"Failed to import {binding} embedding: {e}") - return optimized_embedding_function + # Step 4: Wrap in EmbeddingFunc and return + embedding_func_instance = EmbeddingFunc( + embedding_dim=final_embedding_dim, + func=optimized_embedding_function, + max_token_size=final_max_token_size, + send_dimensions=False, # Will be set later based on binding requirements + ) + + # Log final embedding configuration + logger.info( + f"Embedding config: binding={binding} model={model} " + f"embedding_dim={final_embedding_dim} max_token_size={final_max_token_size}" + ) + + return embedding_func_instance llm_timeout = get_env_value("LLM_TIMEOUT", DEFAULT_LLM_TIMEOUT, int) embedding_timeout = get_env_value( @@ -758,25 +874,24 @@ def create_app(args): **kwargs, ) - # Create embedding function with optimized configuration + # Create embedding function with optimized configuration and max_token_size inheritance import inspect - # Create the optimized embedding function - optimized_embedding_func = create_optimized_embedding_function( + # Create the EmbeddingFunc instance (now returns complete EmbeddingFunc with max_token_size) + embedding_func = create_optimized_embedding_function( config_cache=config_cache, binding=args.embedding_binding, model=args.embedding_model, host=args.embedding_binding_host, api_key=args.embedding_binding_api_key, - args=args, # Pass args object for fallback option generation + args=args, ) # Get embedding_send_dim from centralized configuration embedding_send_dim = args.embedding_send_dim - # Check if the function signature has embedding_dim parameter - # Note: Since optimized_embedding_func is an async function, inspect its signature - sig = inspect.signature(optimized_embedding_func) + # Check if the underlying function signature has embedding_dim parameter + sig = inspect.signature(embedding_func.func) has_embedding_dim_param = "embedding_dim" in sig.parameters # Determine send_dimensions value based on binding type @@ -794,18 +909,27 @@ def create_app(args): else: dimension_control = "by not hasparam" + # Set send_dimensions on the EmbeddingFunc instance + embedding_func.send_dimensions = send_dimensions + logger.info( f"Send embedding dimension: {send_dimensions} {dimension_control} " - f"(dimensions={args.embedding_dim}, has_param={has_embedding_dim_param}, " + f"(dimensions={embedding_func.embedding_dim}, has_param={has_embedding_dim_param}, " f"binding={args.embedding_binding})" ) - # Create EmbeddingFunc with send_dimensions attribute - embedding_func = EmbeddingFunc( - embedding_dim=args.embedding_dim, - func=optimized_embedding_func, - send_dimensions=send_dimensions, - ) + # Log max_token_size source + if embedding_func.max_token_size: + source = ( + "env variable" + if args.embedding_token_limit + else f"{args.embedding_binding} provider default" + ) + logger.info( + f"Embedding max_token_size: {embedding_func.max_token_size} (from {source})" + ) + else: + logger.info("Embedding max_token_size: not set (90% token warning disabled)") # Configure rerank function based on args.rerank_bindingparameter rerank_model_func = None @@ -1017,14 +1141,13 @@ def create_app(args): async def get_status(request: Request): """Get current system status""" try: - # Extract workspace from request header or use default workspace = get_workspace_from_request(request) - - # Construct namespace (following GraphDB pattern) - namespace = f"{workspace}:pipeline" if workspace else "pipeline_status" - - # Get workspace-specific pipeline status - pipeline_status = await get_namespace_data(namespace) + default_workspace = get_default_workspace() + if workspace is None: + workspace = default_workspace + pipeline_status = await get_namespace_data( + "pipeline_status", workspace=workspace + ) if not auth_configured: auth_mode = "disabled" @@ -1055,8 +1178,7 @@ def create_app(args): "vector_storage": args.vector_storage, "enable_llm_cache_for_extract": args.enable_llm_cache_for_extract, "enable_llm_cache": args.enable_llm_cache, - "workspace": workspace, - "default_workspace": args.workspace, + "workspace": default_workspace, "max_graph_nodes": args.max_graph_nodes, # Rerank configuration "enable_rerank": rerank_model_func is not None, @@ -1245,6 +1367,12 @@ def check_and_install_dependencies(): def main(): + # Explicitly initialize configuration for clarity + # (The proxy will auto-initialize anyway, but this makes intent clear) + from .config import initialize_config + + initialize_config() + # Check if running under Gunicorn if "GUNICORN_CMD_ARGS" in os.environ: # If started with Gunicorn, return directly as Gunicorn will call get_application diff --git a/lightrag/api/routers/document_routes.py b/lightrag/api/routers/document_routes.py index fda7a70b..8925c2db 100644 --- a/lightrag/api/routers/document_routes.py +++ b/lightrag/api/routers/document_routes.py @@ -1641,26 +1641,15 @@ async def background_delete_documents( """Background task to delete multiple documents""" from lightrag.kg.shared_storage import ( get_namespace_data, - get_storage_keyed_lock, - initialize_pipeline_status, + get_namespace_lock, ) - # Step 1: Get workspace - workspace = rag.workspace - - # Step 2: Construct namespace - namespace = f"{workspace}:pipeline" if workspace else "pipeline_status" - - # Step 3: Ensure initialization - await initialize_pipeline_status(workspace) - - # Step 4: Get lock - pipeline_status_lock = get_storage_keyed_lock( - keys="status", namespace=namespace, enable_logging=False + pipeline_status = await get_namespace_data( + "pipeline_status", workspace=rag.workspace + ) + pipeline_status_lock = get_namespace_lock( + "pipeline_status", workspace=rag.workspace ) - - # Step 5: Get data - pipeline_status = await get_namespace_data(namespace) total_docs = len(doc_ids) successful_deletions = [] @@ -2149,27 +2138,16 @@ def create_document_routes( """ from lightrag.kg.shared_storage import ( get_namespace_data, - get_storage_keyed_lock, - initialize_pipeline_status, + get_namespace_lock, ) # Get pipeline status and lock - # Step 1: Get workspace - workspace = rag.workspace - - # Step 2: Construct namespace - namespace = f"{workspace}:pipeline" if workspace else "pipeline_status" - - # Step 3: Ensure initialization - await initialize_pipeline_status(workspace) - - # Step 4: Get lock - pipeline_status_lock = get_storage_keyed_lock( - keys="status", namespace=namespace, enable_logging=False + pipeline_status = await get_namespace_data( + "pipeline_status", workspace=rag.workspace + ) + pipeline_status_lock = get_namespace_lock( + "pipeline_status", workspace=rag.workspace ) - - # Step 5: Get data - pipeline_status = await get_namespace_data(namespace) # Check and set status with lock async with pipeline_status_lock: @@ -2360,15 +2338,16 @@ def create_document_routes( try: from lightrag.kg.shared_storage import ( get_namespace_data, + get_namespace_lock, get_all_update_flags_status, - initialize_pipeline_status, ) - # Get workspace-specific pipeline status - workspace = rag.workspace - namespace = f"{workspace}:pipeline" if workspace else "pipeline_status" - await initialize_pipeline_status(workspace) - pipeline_status = await get_namespace_data(namespace) + pipeline_status = await get_namespace_data( + "pipeline_status", workspace=rag.workspace + ) + pipeline_status_lock = get_namespace_lock( + "pipeline_status", workspace=rag.workspace + ) # Get update flags status for all namespaces update_status = await get_all_update_flags_status() @@ -2385,8 +2364,9 @@ def create_document_routes( processed_flags.append(bool(flag)) processed_update_status[namespace] = processed_flags - # Convert to regular dict if it's a Manager.dict - status_dict = dict(pipeline_status) + async with pipeline_status_lock: + # Convert to regular dict if it's a Manager.dict + status_dict = dict(pipeline_status) # Add processed update_status to the status dictionary status_dict["update_status"] = processed_update_status @@ -2575,20 +2555,15 @@ def create_document_routes( try: from lightrag.kg.shared_storage import ( get_namespace_data, - get_storage_keyed_lock, - initialize_pipeline_status, + get_namespace_lock, ) - # Get workspace-specific pipeline status - workspace = rag.workspace - namespace = f"{workspace}:pipeline" if workspace else "pipeline_status" - await initialize_pipeline_status(workspace) - - # Use workspace-aware lock to check busy flag - pipeline_status_lock = get_storage_keyed_lock( - keys="status", namespace=namespace, enable_logging=False + pipeline_status = await get_namespace_data( + "pipeline_status", workspace=rag.workspace + ) + pipeline_status_lock = get_namespace_lock( + "pipeline_status", workspace=rag.workspace ) - pipeline_status = await get_namespace_data(namespace) # Check if pipeline is busy with proper lock async with pipeline_status_lock: @@ -2993,26 +2968,15 @@ def create_document_routes( try: from lightrag.kg.shared_storage import ( get_namespace_data, - get_storage_keyed_lock, - initialize_pipeline_status, + get_namespace_lock, ) - # Step 1: Get workspace - workspace = rag.workspace - - # Step 2: Construct namespace - namespace = f"{workspace}:pipeline" if workspace else "pipeline_status" - - # Step 3: Ensure initialization - await initialize_pipeline_status(workspace) - - # Step 4: Get lock - pipeline_status_lock = get_storage_keyed_lock( - keys="status", namespace=namespace, enable_logging=False + pipeline_status = await get_namespace_data( + "pipeline_status", workspace=rag.workspace + ) + pipeline_status_lock = get_namespace_lock( + "pipeline_status", workspace=rag.workspace ) - - # Step 5: Get data - pipeline_status = await get_namespace_data(namespace) async with pipeline_status_lock: if not pipeline_status.get("busy", False): diff --git a/lightrag/kg/shared_storage.py b/lightrag/kg/shared_storage.py index 0d55db3d..113bda1c 100644 --- a/lightrag/kg/shared_storage.py +++ b/lightrag/kg/shared_storage.py @@ -84,10 +84,7 @@ _init_flags: Optional[Dict[str, bool]] = None # namespace -> initialized _update_flags: Optional[Dict[str, bool]] = None # namespace -> updated # locks for mutex access -_storage_lock: Optional[LockType] = None _internal_lock: Optional[LockType] = None -_pipeline_status_lock: Optional[LockType] = None -_graph_db_lock: Optional[LockType] = None _data_init_lock: Optional[LockType] = None # Manager for all keyed locks _storage_keyed_lock: Optional["KeyedUnifiedLock"] = None @@ -98,6 +95,22 @@ _async_locks: Optional[Dict[str, asyncio.Lock]] = None _debug_n_locks_acquired: int = 0 +def get_final_namespace(namespace: str, workspace: str | None = None): + global _default_workspace + if workspace is None: + workspace = _default_workspace + + if workspace is None: + direct_log( + f"Error: Invoke namespace operation without workspace, pid={os.getpid()}", + level="ERROR", + ) + raise ValueError("Invoke namespace operation without workspace") + + final_namespace = f"{workspace}:{namespace}" if workspace else f"{namespace}" + return final_namespace + + def inc_debug_n_locks_acquired(): global _debug_n_locks_acquired if DEBUG_LOCKS: @@ -1056,40 +1069,10 @@ def get_internal_lock(enable_logging: bool = False) -> UnifiedLock: ) -def get_storage_lock(enable_logging: bool = False) -> UnifiedLock: - """return unified storage lock for data consistency""" - async_lock = _async_locks.get("storage_lock") if _is_multiprocess else None - return UnifiedLock( - lock=_storage_lock, - is_async=not _is_multiprocess, - name="storage_lock", - enable_logging=enable_logging, - async_lock=async_lock, - ) - - -def get_pipeline_status_lock(enable_logging: bool = False) -> UnifiedLock: - """return unified storage lock for data consistency""" - async_lock = _async_locks.get("pipeline_status_lock") if _is_multiprocess else None - return UnifiedLock( - lock=_pipeline_status_lock, - is_async=not _is_multiprocess, - name="pipeline_status_lock", - enable_logging=enable_logging, - async_lock=async_lock, - ) - - -def get_graph_db_lock(enable_logging: bool = False) -> UnifiedLock: - """return unified graph database lock for ensuring atomic operations""" - async_lock = _async_locks.get("graph_db_lock") if _is_multiprocess else None - return UnifiedLock( - lock=_graph_db_lock, - is_async=not _is_multiprocess, - name="graph_db_lock", - enable_logging=enable_logging, - async_lock=async_lock, - ) +# Workspace based storage_lock is implemented by get_storage_keyed_lock instead. +# Workspace based pipeline_status_lock is implemented by get_storage_keyed_lock instead. +# No need to implement graph_db_lock: +# data integrity is ensured by entity level keyed-lock and allowing only one process to hold pipeline at a time. def get_storage_keyed_lock( @@ -1193,14 +1176,11 @@ def initialize_share_data(workers: int = 1): _manager, \ _workers, \ _is_multiprocess, \ - _storage_lock, \ _lock_registry, \ _lock_registry_count, \ _lock_cleanup_data, \ _registry_guard, \ _internal_lock, \ - _pipeline_status_lock, \ - _graph_db_lock, \ _data_init_lock, \ _shared_dicts, \ _init_flags, \ @@ -1228,9 +1208,6 @@ def initialize_share_data(workers: int = 1): _lock_cleanup_data = _manager.dict() _registry_guard = _manager.RLock() _internal_lock = _manager.Lock() - _storage_lock = _manager.Lock() - _pipeline_status_lock = _manager.Lock() - _graph_db_lock = _manager.Lock() _data_init_lock = _manager.Lock() _shared_dicts = _manager.dict() _init_flags = _manager.dict() @@ -1241,8 +1218,6 @@ def initialize_share_data(workers: int = 1): # Initialize async locks for multiprocess mode _async_locks = { "internal_lock": asyncio.Lock(), - "storage_lock": asyncio.Lock(), - "pipeline_status_lock": asyncio.Lock(), "graph_db_lock": asyncio.Lock(), "data_init_lock": asyncio.Lock(), } @@ -1253,9 +1228,6 @@ def initialize_share_data(workers: int = 1): else: _is_multiprocess = False _internal_lock = asyncio.Lock() - _storage_lock = asyncio.Lock() - _pipeline_status_lock = asyncio.Lock() - _graph_db_lock = asyncio.Lock() _data_init_lock = asyncio.Lock() _shared_dicts = {} _init_flags = {} @@ -1273,29 +1245,19 @@ def initialize_share_data(workers: int = 1): _initialized = True -async def initialize_pipeline_status(workspace: str = ""): +async def initialize_pipeline_status(workspace: str | None = None): """ - Initialize pipeline namespace with default values. + Initialize pipeline_status share data with default values. + This function could be called before during FASTAPI lifespan for each worker. Args: - workspace: Optional workspace identifier for multi-tenant isolation. - If empty string, uses the default workspace set by - set_default_workspace(). If no default is set, uses - global "pipeline_status" namespace. - - This function is called during FASTAPI lifespan for each worker. + workspace: Optional workspace identifier for pipeline_status of specific workspace. + If None or empty string, uses the default workspace set by + set_default_workspace(). """ - # Backward compatibility: use default workspace if not provided - if not workspace: - workspace = get_default_workspace() - - # Construct namespace (following GraphDB pattern) - if workspace: - namespace = f"{workspace}:pipeline" - else: - namespace = "pipeline_status" # Global namespace for backward compatibility - - pipeline_namespace = await get_namespace_data(namespace, first_init=True) + pipeline_namespace = await get_namespace_data( + "pipeline_status", first_init=True, workspace=workspace + ) async with get_internal_lock(): # Check if already initialized by checking for required fields @@ -1318,12 +1280,14 @@ async def initialize_pipeline_status(workspace: str = ""): "history_messages": history_messages, # 使用共享列表对象 } ) + + final_namespace = get_final_namespace("pipeline_status", workspace) direct_log( - f"Process {os.getpid()} Pipeline namespace '{namespace}' initialized" + f"Process {os.getpid()} Pipeline namespace '{final_namespace}' initialized" ) -async def get_update_flag(namespace: str): +async def get_update_flag(namespace: str, workspace: str | None = None): """ Create a namespace's update flag for a workers. Returen the update flag to caller for referencing or reset. @@ -1332,14 +1296,16 @@ async def get_update_flag(namespace: str): if _update_flags is None: raise ValueError("Try to create namespace before Shared-Data is initialized") + final_namespace = get_final_namespace(namespace, workspace) + async with get_internal_lock(): - if namespace not in _update_flags: + if final_namespace not in _update_flags: if _is_multiprocess and _manager is not None: - _update_flags[namespace] = _manager.list() + _update_flags[final_namespace] = _manager.list() else: - _update_flags[namespace] = [] + _update_flags[final_namespace] = [] direct_log( - f"Process {os.getpid()} initialized updated flags for namespace: [{namespace}]" + f"Process {os.getpid()} initialized updated flags for namespace: [{final_namespace}]" ) if _is_multiprocess and _manager is not None: @@ -1352,39 +1318,43 @@ async def get_update_flag(namespace: str): new_update_flag = MutableBoolean(False) - _update_flags[namespace].append(new_update_flag) + _update_flags[final_namespace].append(new_update_flag) return new_update_flag -async def set_all_update_flags(namespace: str): +async def set_all_update_flags(namespace: str, workspace: str | None = None): """Set all update flag of namespace indicating all workers need to reload data from files""" global _update_flags if _update_flags is None: raise ValueError("Try to create namespace before Shared-Data is initialized") + final_namespace = get_final_namespace(namespace, workspace) + async with get_internal_lock(): - if namespace not in _update_flags: - raise ValueError(f"Namespace {namespace} not found in update flags") + if final_namespace not in _update_flags: + raise ValueError(f"Namespace {final_namespace} not found in update flags") # Update flags for both modes - for i in range(len(_update_flags[namespace])): - _update_flags[namespace][i].value = True + for i in range(len(_update_flags[final_namespace])): + _update_flags[final_namespace][i].value = True -async def clear_all_update_flags(namespace: str): +async def clear_all_update_flags(namespace: str, workspace: str | None = None): """Clear all update flag of namespace indicating all workers need to reload data from files""" global _update_flags if _update_flags is None: raise ValueError("Try to create namespace before Shared-Data is initialized") + final_namespace = get_final_namespace(namespace, workspace) + async with get_internal_lock(): - if namespace not in _update_flags: - raise ValueError(f"Namespace {namespace} not found in update flags") + if final_namespace not in _update_flags: + raise ValueError(f"Namespace {final_namespace} not found in update flags") # Update flags for both modes - for i in range(len(_update_flags[namespace])): - _update_flags[namespace][i].value = False + for i in range(len(_update_flags[final_namespace])): + _update_flags[final_namespace][i].value = False -async def get_all_update_flags_status() -> Dict[str, list]: +async def get_all_update_flags_status(workspace: str | None = None) -> Dict[str, list]: """ Get update flags status for all namespaces. @@ -1394,9 +1364,17 @@ async def get_all_update_flags_status() -> Dict[str, list]: if _update_flags is None: return {} + if workspace is None: + workspace = get_default_workspace + result = {} async with get_internal_lock(): for namespace, flags in _update_flags.items(): + namespace_split = namespace.split(":") + if workspace and not namespace_split[0] == workspace: + continue + if not workspace and namespace_split[0]: + continue worker_statuses = [] for flag in flags: if _is_multiprocess: @@ -1408,7 +1386,9 @@ async def get_all_update_flags_status() -> Dict[str, list]: return result -async def try_initialize_namespace(namespace: str) -> bool: +async def try_initialize_namespace( + namespace: str, workspace: str | None = None +) -> bool: """ Returns True if the current worker(process) gets initialization permission for loading data later. The worker does not get the permission is prohibited to load data from files. @@ -1418,57 +1398,76 @@ async def try_initialize_namespace(namespace: str) -> bool: if _init_flags is None: raise ValueError("Try to create nanmespace before Shared-Data is initialized") + final_namespace = get_final_namespace(namespace, workspace) + async with get_internal_lock(): - if namespace not in _init_flags: - _init_flags[namespace] = True + if final_namespace not in _init_flags: + _init_flags[final_namespace] = True direct_log( - f"Process {os.getpid()} ready to initialize storage namespace: [{namespace}]" + f"Process {os.getpid()} ready to initialize storage namespace: [{final_namespace}]" ) return True direct_log( - f"Process {os.getpid()} storage namespace already initialized: [{namespace}]" + f"Process {os.getpid()} storage namespace already initialized: [{final_namespace}]" ) return False async def get_namespace_data( - namespace: str, first_init: bool = False + namespace: str, first_init: bool = False, workspace: str | None = None ) -> Dict[str, Any]: """get the shared data reference for specific namespace Args: namespace: The namespace to retrieve - allow_create: If True, allows creation of the namespace if it doesn't exist. - Used internally by initialize_pipeline_status(). + first_init: If True, allows pipeline_status namespace to create namespace if it doesn't exist. + Prevent getting pipeline_status namespace without initialize_pipeline_status(). + This parameter is used internally by initialize_pipeline_status(). + workspace: Workspace identifier (may be empty string for global namespace) """ if _shared_dicts is None: direct_log( - f"Error: try to getnanmespace before it is initialized, pid={os.getpid()}", + f"Error: Try to getnanmespace before it is initialized, pid={os.getpid()}", level="ERROR", ) raise ValueError("Shared dictionaries not initialized") - async with get_internal_lock(): - if namespace not in _shared_dicts: - # Special handling for pipeline_status namespace - # Supports both global "pipeline_status" and workspace-specific "{workspace}:pipeline" - is_pipeline = namespace == "pipeline_status" or namespace.endswith( - ":pipeline" - ) + final_namespace = get_final_namespace(namespace, workspace) - if is_pipeline and not first_init: + async with get_internal_lock(): + if final_namespace not in _shared_dicts: + # Special handling for pipeline_status namespace + if final_namespace.endswith(":pipeline_status") and not first_init: # Check if pipeline_status should have been initialized but wasn't - # This helps users understand they need to call initialize_pipeline_status() - raise PipelineNotInitializedError(namespace) + # This helps users to call initialize_pipeline_status() before get_namespace_data() + raise PipelineNotInitializedError(final_namespace) # For other namespaces or when allow_create=True, create them dynamically if _is_multiprocess and _manager is not None: - _shared_dicts[namespace] = _manager.dict() + _shared_dicts[final_namespace] = _manager.dict() else: - _shared_dicts[namespace] = {} + _shared_dicts[final_namespace] = {} - return _shared_dicts[namespace] + return _shared_dicts[final_namespace] + + +def get_namespace_lock( + namespace: str, workspace: str | None = None, enable_logging: bool = False +) -> str: + """Get the lock key for a namespace. + + Args: + namespace: The namespace to get the lock key for. + workspace: Workspace identifier (may be empty string for global namespace) + + Returns: + str: The lock key for the namespace. + """ + final_namespace = get_final_namespace(namespace, workspace) + return get_storage_keyed_lock( + ["default_key"], namespace=final_namespace, enable_logging=enable_logging + ) def finalize_share_data(): @@ -1484,10 +1483,7 @@ def finalize_share_data(): global \ _manager, \ _is_multiprocess, \ - _storage_lock, \ _internal_lock, \ - _pipeline_status_lock, \ - _graph_db_lock, \ _data_init_lock, \ _shared_dicts, \ _init_flags, \ @@ -1552,10 +1548,7 @@ def finalize_share_data(): _is_multiprocess = None _shared_dicts = None _init_flags = None - _storage_lock = None _internal_lock = None - _pipeline_status_lock = None - _graph_db_lock = None _data_init_lock = None _update_flags = None _async_locks = None @@ -1563,21 +1556,23 @@ def finalize_share_data(): direct_log(f"Process {os.getpid()} storage data finalization complete") -def set_default_workspace(workspace: str): +def set_default_workspace(workspace: str | None = None): """ - Set default workspace for backward compatibility. + Set default workspace for namespace operations for backward compatibility. - This allows initialize_pipeline_status() to automatically use the correct - workspace when called without parameters, maintaining compatibility with - legacy code that doesn't pass workspace explicitly. + This allows get_namespace_data(),get_namespace_lock() or initialize_pipeline_status() to + automatically use the correct workspace when called without workspace parameters, + maintaining compatibility with legacy code that doesn't pass workspace explicitly. Args: workspace: Workspace identifier (may be empty string for global namespace) """ global _default_workspace + if workspace is None: + workspace = "" _default_workspace = workspace direct_log( - f"Default workspace set to: '{workspace}' (empty means global)", + f"Default workspace set to: '{_default_workspace}' (empty means global)", level="DEBUG", ) @@ -1587,7 +1582,7 @@ def get_default_workspace() -> str: Get default workspace for backward compatibility. Returns: - The default workspace string. Empty string means global namespace. + The default workspace string. Empty string means global namespace. None means not set. """ global _default_workspace - return _default_workspace if _default_workspace is not None else "" + return _default_workspace diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index dff80aad..cd32a78a 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -3,6 +3,7 @@ from __future__ import annotations import traceback import asyncio import configparser +import inspect import os import time import warnings @@ -12,6 +13,7 @@ from functools import partial from typing import ( Any, AsyncIterator, + Awaitable, Callable, Iterator, cast, @@ -20,6 +22,7 @@ from typing import ( Optional, List, Dict, + Union, ) from lightrag.prompt import PROMPTS from lightrag.exceptions import PipelineCancelledException @@ -61,10 +64,10 @@ from lightrag.kg import ( from lightrag.kg.shared_storage import ( get_namespace_data, - get_graph_db_lock, get_data_init_lock, - get_storage_keyed_lock, - initialize_pipeline_status, + get_default_workspace, + set_default_workspace, + get_namespace_lock, ) from lightrag.base import ( @@ -88,7 +91,7 @@ from lightrag.operate import ( merge_nodes_and_edges, kg_query, naive_query, - _rebuild_knowledge_from_chunks, + rebuild_knowledge_from_chunks, ) from lightrag.constants import GRAPH_FIELD_SEP from lightrag.utils import ( @@ -244,11 +247,13 @@ class LightRAG: int, int, ], - List[Dict[str, Any]], + Union[List[Dict[str, Any]], Awaitable[List[Dict[str, Any]]]], ] = field(default_factory=lambda: chunking_by_token_size) """ Custom chunking function for splitting text into chunks before processing. + The function can be either synchronous or asynchronous. + The function should take the following parameters: - `tokenizer`: A Tokenizer instance to use for tokenization. @@ -258,7 +263,8 @@ class LightRAG: - `chunk_token_size`: The maximum number of tokens per chunk. - `chunk_overlap_token_size`: The number of overlapping tokens between consecutive chunks. - The function should return a list of dictionaries, where each dictionary contains the following keys: + The function should return a list of dictionaries (or an awaitable that resolves to a list), + where each dictionary contains the following keys: - `tokens`: The number of tokens in the chunk. - `content`: The text content of the chunk. @@ -271,6 +277,9 @@ class LightRAG: embedding_func: EmbeddingFunc | None = field(default=None) """Function for computing text embeddings. Must be set before use.""" + embedding_token_limit: int | None = field(default=None, init=False) + """Token limit for embedding model. Set automatically from embedding_func.max_token_size in __post_init__.""" + embedding_batch_num: int = field(default=int(os.getenv("EMBEDDING_BATCH_NUM", 10))) """Batch size for embedding computations.""" @@ -514,6 +523,16 @@ class LightRAG: logger.debug(f"LightRAG init with param:\n {_print_config}\n") # Init Embedding + # Step 1: Capture max_token_size before applying decorator (decorator strips dataclass attributes) + embedding_max_token_size = None + if self.embedding_func and hasattr(self.embedding_func, "max_token_size"): + embedding_max_token_size = self.embedding_func.max_token_size + logger.debug( + f"Captured embedding max_token_size: {embedding_max_token_size}" + ) + self.embedding_token_limit = embedding_max_token_size + + # Step 2: Apply priority wrapper decorator self.embedding_func = priority_limit_async_func_call( self.embedding_func_max_async, llm_timeout=self.default_embedding_timeout, @@ -640,12 +659,11 @@ class LightRAG: async def initialize_storages(self): """Storage initialization must be called one by one to prevent deadlock""" if self._storages_status == StoragesStatus.CREATED: - # Set default workspace for backward compatibility - # This allows initialize_pipeline_status() called without parameters - # to use the correct workspace - from lightrag.kg.shared_storage import set_default_workspace - - set_default_workspace(self.workspace) + # Set the first initialized workspace will set the default workspace + # Allows namespace operation without specifying workspace for backward compatibility + default_workspace = get_default_workspace() + if default_workspace is None: + set_default_workspace(self.workspace) for storage in ( self.full_docs, @@ -718,7 +736,7 @@ class LightRAG: async def check_and_migrate_data(self): """Check if data migration is needed and perform migration if necessary""" - async with get_data_init_lock(enable_logging=True): + async with get_data_init_lock(): try: # Check if migration is needed: # 1. chunk_entity_relation_graph has entities and relations (count > 0) @@ -1581,22 +1599,12 @@ class LightRAG: """ # Get pipeline status shared data and lock - # Step 1: Get workspace - workspace = self.workspace - - # Step 2: Construct namespace (following GraphDB pattern) - namespace = f"{workspace}:pipeline" if workspace else "pipeline_status" - - # Step 3: Ensure initialization (on first access) - await initialize_pipeline_status(workspace) - - # Step 4: Get lock - pipeline_status_lock = get_storage_keyed_lock( - keys="status", namespace=namespace, enable_logging=False + pipeline_status = await get_namespace_data( + "pipeline_status", workspace=self.workspace + ) + pipeline_status_lock = get_namespace_lock( + "pipeline_status", workspace=self.workspace ) - - # Step 5: Get data - pipeline_status = await get_namespace_data(namespace) # Check if another process is already processing the queue async with pipeline_status_lock: @@ -1778,7 +1786,28 @@ class LightRAG: ) content = content_data["content"] - # Generate chunks from document + # Call chunking function, supporting both sync and async implementations + chunking_result = self.chunking_func( + self.tokenizer, + content, + split_by_character, + split_by_character_only, + self.chunk_overlap_token_size, + self.chunk_token_size, + ) + + # If result is awaitable, await to get actual result + if inspect.isawaitable(chunking_result): + chunking_result = await chunking_result + + # Validate return type + if not isinstance(chunking_result, (list, tuple)): + raise TypeError( + f"chunking_func must return a list or tuple of dicts, " + f"got {type(chunking_result)}" + ) + + # Build chunks dictionary chunks: dict[str, Any] = { compute_mdhash_id(dp["content"], prefix="chunk-"): { **dp, @@ -1786,14 +1815,7 @@ class LightRAG: "file_path": file_path, # Add file path to each chunk "llm_cache_list": [], # Initialize empty LLM cache list for each chunk } - for dp in self.chunking_func( - self.tokenizer, - content, - split_by_character, - split_by_character_only, - self.chunk_overlap_token_size, - self.chunk_token_size, - ) + for dp in chunking_result } if not chunks: @@ -1893,9 +1915,14 @@ class LightRAG: if task and not task.done(): task.cancel() - # Persistent llm cache + # Persistent llm cache with error handling if self.llm_response_cache: - await self.llm_response_cache.index_done_callback() + try: + await self.llm_response_cache.index_done_callback() + except Exception as persist_error: + logger.error( + f"Failed to persist LLM cache: {persist_error}" + ) # Record processing end time for failed case processing_end_time = int(time.time()) @@ -2015,9 +2042,14 @@ class LightRAG: error_msg ) - # Persistent llm cache + # Persistent llm cache with error handling if self.llm_response_cache: - await self.llm_response_cache.index_done_callback() + try: + await self.llm_response_cache.index_done_callback() + except Exception as persist_error: + logger.error( + f"Failed to persist LLM cache: {persist_error}" + ) # Record processing end time for failed case processing_end_time = int(time.time()) @@ -2924,22 +2956,12 @@ class LightRAG: doc_llm_cache_ids: list[str] = [] # Get pipeline status shared data and lock for status updates - # Step 1: Get workspace - workspace = self.workspace - - # Step 2: Construct namespace (following GraphDB pattern) - namespace = f"{workspace}:pipeline" if workspace else "pipeline_status" - - # Step 3: Ensure initialization (on first access) - await initialize_pipeline_status(workspace) - - # Step 4: Get lock - pipeline_status_lock = get_storage_keyed_lock( - keys="status", namespace=namespace, enable_logging=False + pipeline_status = await get_namespace_data( + "pipeline_status", workspace=self.workspace + ) + pipeline_status_lock = get_namespace_lock( + "pipeline_status", workspace=self.workspace ) - - # Step 5: Get data - pipeline_status = await get_namespace_data(namespace) async with pipeline_status_lock: log_message = f"Starting deletion process for document {doc_id}" @@ -3172,6 +3194,9 @@ class LightRAG: ] if not existing_sources: + # No chunk references means this entity should be deleted + entities_to_delete.add(node_label) + entity_chunk_updates[node_label] = [] continue remaining_sources = subtract_source_ids(existing_sources, chunk_ids) @@ -3193,6 +3218,7 @@ class LightRAG: # Process relationships for edge_data in affected_edges: + # source target is not in normalize order in graph db property src = edge_data.get("source") tgt = edge_data.get("target") @@ -3229,6 +3255,9 @@ class LightRAG: ] if not existing_sources: + # No chunk references means this relationship should be deleted + relationships_to_delete.add(edge_tuple) + relation_chunk_updates[edge_tuple] = [] continue remaining_sources = subtract_source_ids(existing_sources, chunk_ids) @@ -3254,38 +3283,31 @@ class LightRAG: if entity_chunk_updates and self.entity_chunks: entity_upsert_payload = {} - entity_delete_ids: set[str] = set() for entity_name, remaining in entity_chunk_updates.items(): if not remaining: - entity_delete_ids.add(entity_name) - else: - entity_upsert_payload[entity_name] = { - "chunk_ids": remaining, - "count": len(remaining), - "updated_at": current_time, - } - - if entity_delete_ids: - await self.entity_chunks.delete(list(entity_delete_ids)) + # Empty entities are deleted alongside graph nodes later + continue + entity_upsert_payload[entity_name] = { + "chunk_ids": remaining, + "count": len(remaining), + "updated_at": current_time, + } if entity_upsert_payload: await self.entity_chunks.upsert(entity_upsert_payload) if relation_chunk_updates and self.relation_chunks: relation_upsert_payload = {} - relation_delete_ids: set[str] = set() for edge_tuple, remaining in relation_chunk_updates.items(): - storage_key = make_relation_chunk_key(*edge_tuple) if not remaining: - relation_delete_ids.add(storage_key) - else: - relation_upsert_payload[storage_key] = { - "chunk_ids": remaining, - "count": len(remaining), - "updated_at": current_time, - } + # Empty relations are deleted alongside graph edges later + continue + storage_key = make_relation_chunk_key(*edge_tuple) + relation_upsert_payload[storage_key] = { + "chunk_ids": remaining, + "count": len(remaining), + "updated_at": current_time, + } - if relation_delete_ids: - await self.relation_chunks.delete(list(relation_delete_ids)) if relation_upsert_payload: await self.relation_chunks.upsert(relation_upsert_payload) @@ -3293,56 +3315,111 @@ class LightRAG: logger.error(f"Failed to process graph analysis results: {e}") raise Exception(f"Failed to process graph dependencies: {e}") from e - # Use graph database lock to prevent dirty read - graph_db_lock = get_graph_db_lock(enable_logging=False) - async with graph_db_lock: - # 5. Delete chunks from storage - if chunk_ids: - try: - await self.chunks_vdb.delete(chunk_ids) - await self.text_chunks.delete(chunk_ids) + # Data integrity is ensured by allowing only one process to hold pipeline at a time(no graph db lock is needed anymore) - async with pipeline_status_lock: - log_message = f"Successfully deleted {len(chunk_ids)} chunks from storage" - logger.info(log_message) - pipeline_status["latest_message"] = log_message - pipeline_status["history_messages"].append(log_message) + # 5. Delete chunks from storage + if chunk_ids: + try: + await self.chunks_vdb.delete(chunk_ids) + await self.text_chunks.delete(chunk_ids) - except Exception as e: - logger.error(f"Failed to delete chunks: {e}") - raise Exception(f"Failed to delete document chunks: {e}") from e + async with pipeline_status_lock: + log_message = ( + f"Successfully deleted {len(chunk_ids)} chunks from storage" + ) + logger.info(log_message) + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) - # 6. Delete entities that have no remaining sources - if entities_to_delete: - try: - # Delete from vector database - entity_vdb_ids = [ - compute_mdhash_id(entity, prefix="ent-") - for entity in entities_to_delete + except Exception as e: + logger.error(f"Failed to delete chunks: {e}") + raise Exception(f"Failed to delete document chunks: {e}") from e + + # 6. Delete relationships that have no remaining sources + if relationships_to_delete: + try: + # Delete from relation vdb + rel_ids_to_delete = [] + for src, tgt in relationships_to_delete: + rel_ids_to_delete.extend( + [ + compute_mdhash_id(src + tgt, prefix="rel-"), + compute_mdhash_id(tgt + src, prefix="rel-"), + ] + ) + await self.relationships_vdb.delete(rel_ids_to_delete) + + # Delete from graph + await self.chunk_entity_relation_graph.remove_edges( + list(relationships_to_delete) + ) + + # Delete from relation_chunks storage + if self.relation_chunks: + relation_storage_keys = [ + make_relation_chunk_key(src, tgt) + for src, tgt in relationships_to_delete ] - await self.entities_vdb.delete(entity_vdb_ids) + await self.relation_chunks.delete(relation_storage_keys) - # Delete from graph - await self.chunk_entity_relation_graph.remove_nodes( + async with pipeline_status_lock: + log_message = f"Successfully deleted {len(relationships_to_delete)} relations" + logger.info(log_message) + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) + + except Exception as e: + logger.error(f"Failed to delete relationships: {e}") + raise Exception(f"Failed to delete relationships: {e}") from e + + # 7. Delete entities that have no remaining sources + if entities_to_delete: + try: + # Batch get all edges for entities to avoid N+1 query problem + nodes_edges_dict = ( + await self.chunk_entity_relation_graph.get_nodes_edges_batch( list(entities_to_delete) ) + ) - async with pipeline_status_lock: - log_message = f"Successfully deleted {len(entities_to_delete)} entities" - logger.info(log_message) - pipeline_status["latest_message"] = log_message - pipeline_status["history_messages"].append(log_message) + # Debug: Check and log all edges before deleting nodes + edges_to_delete = set() + edges_still_exist = 0 - except Exception as e: - logger.error(f"Failed to delete entities: {e}") - raise Exception(f"Failed to delete entities: {e}") from e + for entity, edges in nodes_edges_dict.items(): + if edges: + for src, tgt in edges: + # Normalize edge representation (sorted for consistency) + edge_tuple = tuple(sorted((src, tgt))) + edges_to_delete.add(edge_tuple) - # 7. Delete relationships that have no remaining sources - if relationships_to_delete: - try: - # Delete from vector database + if ( + src in entities_to_delete + and tgt in entities_to_delete + ): + logger.warning( + f"Edge still exists: {src} <-> {tgt}" + ) + elif src in entities_to_delete: + logger.warning( + f"Edge still exists: {src} --> {tgt}" + ) + else: + logger.warning( + f"Edge still exists: {src} <-- {tgt}" + ) + edges_still_exist += 1 + + if edges_still_exist: + logger.warning( + f"⚠️ {edges_still_exist} entities still has edges before deletion" + ) + + # Clean residual edges from VDB and storage before deleting nodes + if edges_to_delete: + # Delete from relationships_vdb rel_ids_to_delete = [] - for src, tgt in relationships_to_delete: + for src, tgt in edges_to_delete: rel_ids_to_delete.extend( [ compute_mdhash_id(src + tgt, prefix="rel-"), @@ -3351,28 +3428,53 @@ class LightRAG: ) await self.relationships_vdb.delete(rel_ids_to_delete) - # Delete from graph - await self.chunk_entity_relation_graph.remove_edges( - list(relationships_to_delete) + # Delete from relation_chunks storage + if self.relation_chunks: + relation_storage_keys = [ + make_relation_chunk_key(src, tgt) + for src, tgt in edges_to_delete + ] + await self.relation_chunks.delete(relation_storage_keys) + + logger.info( + f"Cleaned {len(edges_to_delete)} residual edges from VDB and chunk-tracking storage" ) - async with pipeline_status_lock: - log_message = f"Successfully deleted {len(relationships_to_delete)} relations" - logger.info(log_message) - pipeline_status["latest_message"] = log_message - pipeline_status["history_messages"].append(log_message) + # Delete from graph (edges will be auto-deleted with nodes) + await self.chunk_entity_relation_graph.remove_nodes( + list(entities_to_delete) + ) - except Exception as e: - logger.error(f"Failed to delete relationships: {e}") - raise Exception(f"Failed to delete relationships: {e}") from e + # Delete from vector vdb + entity_vdb_ids = [ + compute_mdhash_id(entity, prefix="ent-") + for entity in entities_to_delete + ] + await self.entities_vdb.delete(entity_vdb_ids) - # Persist changes to graph database before releasing graph database lock - await self._insert_done() + # Delete from entity_chunks storage + if self.entity_chunks: + await self.entity_chunks.delete(list(entities_to_delete)) + + async with pipeline_status_lock: + log_message = ( + f"Successfully deleted {len(entities_to_delete)} entities" + ) + logger.info(log_message) + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) + + except Exception as e: + logger.error(f"Failed to delete entities: {e}") + raise Exception(f"Failed to delete entities: {e}") from e + + # Persist changes to graph database before entity and relationship rebuild + await self._insert_done() # 8. Rebuild entities and relationships from remaining chunks if entities_to_rebuild or relationships_to_rebuild: try: - await _rebuild_knowledge_from_chunks( + await rebuild_knowledge_from_chunks( entities_to_rebuild=entities_to_rebuild, relationships_to_rebuild=relationships_to_rebuild, knowledge_graph_inst=self.chunk_entity_relation_graph, @@ -3590,16 +3692,22 @@ class LightRAG: ) async def aedit_entity( - self, entity_name: str, updated_data: dict[str, str], allow_rename: bool = True + self, + entity_name: str, + updated_data: dict[str, str], + allow_rename: bool = True, + allow_merge: bool = False, ) -> dict[str, Any]: """Asynchronously edit entity information. Updates entity information in the knowledge graph and re-embeds the entity in the vector database. + Also synchronizes entity_chunks_storage and relation_chunks_storage to track chunk references. Args: entity_name: Name of the entity to edit updated_data: Dictionary containing updated attributes, e.g. {"description": "new description", "entity_type": "new type"} allow_rename: Whether to allow entity renaming, defaults to True + allow_merge: Whether to merge into an existing entity when renaming to an existing name Returns: Dictionary containing updated entity information @@ -3613,14 +3721,21 @@ class LightRAG: entity_name, updated_data, allow_rename, + allow_merge, + self.entity_chunks, + self.relation_chunks, ) def edit_entity( - self, entity_name: str, updated_data: dict[str, str], allow_rename: bool = True + self, + entity_name: str, + updated_data: dict[str, str], + allow_rename: bool = True, + allow_merge: bool = False, ) -> dict[str, Any]: loop = always_get_an_event_loop() return loop.run_until_complete( - self.aedit_entity(entity_name, updated_data, allow_rename) + self.aedit_entity(entity_name, updated_data, allow_rename, allow_merge) ) async def aedit_relation( @@ -3629,6 +3744,7 @@ class LightRAG: """Asynchronously edit relation information. Updates relation (edge) information in the knowledge graph and re-embeds the relation in the vector database. + Also synchronizes the relation_chunks_storage to track which chunks reference this relation. Args: source_entity: Name of the source entity @@ -3647,6 +3763,7 @@ class LightRAG: source_entity, target_entity, updated_data, + self.relation_chunks, ) def edit_relation( @@ -3758,6 +3875,8 @@ class LightRAG: target_entity, merge_strategy, target_entity_data, + self.entity_chunks, + self.relation_chunks, ) def merge_entities(