Fix workspace isolation for pipeline status across all operations

- Fix final_namespace error in get_namespace_data()
- Fix get_workspace_from_request return type
- Add workspace param to pipeline status calls

(cherry picked from commit 52c812b9a0)
This commit is contained in:
yangdx 2025-11-17 03:45:51 +08:00 committed by Raphaël MANSUY
parent fe1576943f
commit dfab175c16
4 changed files with 573 additions and 367 deletions

View file

@ -56,6 +56,8 @@ from lightrag.api.routers.ollama_api import OllamaAPI
from lightrag.utils import logger, set_verbose_debug from lightrag.utils import logger, set_verbose_debug
from lightrag.kg.shared_storage import ( from lightrag.kg.shared_storage import (
get_namespace_data, get_namespace_data,
get_default_workspace,
# set_default_workspace,
initialize_pipeline_status, initialize_pipeline_status,
cleanup_keyed_lock, cleanup_keyed_lock,
finalize_share_data, finalize_share_data,
@ -350,8 +352,9 @@ def create_app(args):
try: try:
# Initialize database connections # Initialize database connections
# set_default_workspace(rag.workspace) # comment this line to test auto default workspace setting in initialize_storages
await rag.initialize_storages() await rag.initialize_storages()
await initialize_pipeline_status() await initialize_pipeline_status() # with default workspace
# Data migration regardless of storage implementation # Data migration regardless of storage implementation
await rag.check_and_migrate_data() await rag.check_and_migrate_data()
@ -452,7 +455,7 @@ def create_app(args):
# Create combined auth dependency for all endpoints # Create combined auth dependency for all endpoints
combined_auth = get_combined_auth_dependency(api_key) combined_auth = get_combined_auth_dependency(api_key)
def get_workspace_from_request(request: Request) -> str: def get_workspace_from_request(request: Request) -> str | None:
""" """
Extract workspace from HTTP request header or use default. Extract workspace from HTTP request header or use default.
@ -469,9 +472,8 @@ def create_app(args):
# Check custom header first # Check custom header first
workspace = request.headers.get("LIGHTRAG-WORKSPACE", "").strip() workspace = request.headers.get("LIGHTRAG-WORKSPACE", "").strip()
# Fall back to server default if header not provided
if not workspace: if not workspace:
workspace = args.workspace workspace = None
return workspace return workspace
@ -641,33 +643,108 @@ def create_app(args):
def create_optimized_embedding_function( def create_optimized_embedding_function(
config_cache: LLMConfigCache, binding, model, host, api_key, args config_cache: LLMConfigCache, binding, model, host, api_key, args
): ) -> EmbeddingFunc:
""" """
Create optimized embedding function with pre-processed configuration for applicable bindings. Create optimized embedding function and return an EmbeddingFunc instance
Uses lazy imports for all bindings and avoids repeated configuration parsing. with proper max_token_size inheritance from provider defaults.
This function:
1. Imports the provider embedding function
2. Extracts max_token_size and embedding_dim from provider if it's an EmbeddingFunc
3. Creates an optimized wrapper that calls the underlying function directly (avoiding double-wrapping)
4. Returns a properly configured EmbeddingFunc instance
""" """
# Step 1: Import provider function and extract default attributes
provider_func = None
provider_max_token_size = None
provider_embedding_dim = None
try:
if binding == "openai":
from lightrag.llm.openai import openai_embed
provider_func = openai_embed
elif binding == "ollama":
from lightrag.llm.ollama import ollama_embed
provider_func = ollama_embed
elif binding == "gemini":
from lightrag.llm.gemini import gemini_embed
provider_func = gemini_embed
elif binding == "jina":
from lightrag.llm.jina import jina_embed
provider_func = jina_embed
elif binding == "azure_openai":
from lightrag.llm.azure_openai import azure_openai_embed
provider_func = azure_openai_embed
elif binding == "aws_bedrock":
from lightrag.llm.bedrock import bedrock_embed
provider_func = bedrock_embed
elif binding == "lollms":
from lightrag.llm.lollms import lollms_embed
provider_func = lollms_embed
# Extract attributes if provider is an EmbeddingFunc
if provider_func and isinstance(provider_func, EmbeddingFunc):
provider_max_token_size = provider_func.max_token_size
provider_embedding_dim = provider_func.embedding_dim
logger.debug(
f"Extracted from {binding} provider: "
f"max_token_size={provider_max_token_size}, "
f"embedding_dim={provider_embedding_dim}"
)
except ImportError as e:
logger.warning(f"Could not import provider function for {binding}: {e}")
# Step 2: Apply priority (user config > provider default)
# For max_token_size: explicit env var > provider default > None
final_max_token_size = args.embedding_token_limit or provider_max_token_size
# For embedding_dim: user config (always has value) takes priority
# Only use provider default if user config is explicitly None (which shouldn't happen)
final_embedding_dim = (
args.embedding_dim if args.embedding_dim else provider_embedding_dim
)
# Step 3: Create optimized embedding function (calls underlying function directly)
async def optimized_embedding_function(texts, embedding_dim=None): async def optimized_embedding_function(texts, embedding_dim=None):
try: try:
if binding == "lollms": if binding == "lollms":
from lightrag.llm.lollms import lollms_embed from lightrag.llm.lollms import lollms_embed
return await lollms_embed( # Get real function, skip EmbeddingFunc wrapper if present
actual_func = (
lollms_embed.func
if isinstance(lollms_embed, EmbeddingFunc)
else lollms_embed
)
return await actual_func(
texts, embed_model=model, host=host, api_key=api_key texts, embed_model=model, host=host, api_key=api_key
) )
elif binding == "ollama": elif binding == "ollama":
from lightrag.llm.ollama import ollama_embed from lightrag.llm.ollama import ollama_embed
# Use pre-processed configuration if available, otherwise fallback to dynamic parsing # Get real function, skip EmbeddingFunc wrapper if present
actual_func = (
ollama_embed.func
if isinstance(ollama_embed, EmbeddingFunc)
else ollama_embed
)
# Use pre-processed configuration if available
if config_cache.ollama_embedding_options is not None: if config_cache.ollama_embedding_options is not None:
ollama_options = config_cache.ollama_embedding_options ollama_options = config_cache.ollama_embedding_options
else: else:
# Fallback for cases where config cache wasn't initialized properly
from lightrag.llm.binding_options import OllamaEmbeddingOptions from lightrag.llm.binding_options import OllamaEmbeddingOptions
ollama_options = OllamaEmbeddingOptions.options_dict(args) ollama_options = OllamaEmbeddingOptions.options_dict(args)
return await ollama_embed( return await actual_func(
texts, texts,
embed_model=model, embed_model=model,
host=host, host=host,
@ -677,15 +754,30 @@ def create_app(args):
elif binding == "azure_openai": elif binding == "azure_openai":
from lightrag.llm.azure_openai import azure_openai_embed from lightrag.llm.azure_openai import azure_openai_embed
return await azure_openai_embed(texts, model=model, api_key=api_key) actual_func = (
azure_openai_embed.func
if isinstance(azure_openai_embed, EmbeddingFunc)
else azure_openai_embed
)
return await actual_func(texts, model=model, api_key=api_key)
elif binding == "aws_bedrock": elif binding == "aws_bedrock":
from lightrag.llm.bedrock import bedrock_embed from lightrag.llm.bedrock import bedrock_embed
return await bedrock_embed(texts, model=model) actual_func = (
bedrock_embed.func
if isinstance(bedrock_embed, EmbeddingFunc)
else bedrock_embed
)
return await actual_func(texts, model=model)
elif binding == "jina": elif binding == "jina":
from lightrag.llm.jina import jina_embed from lightrag.llm.jina import jina_embed
return await jina_embed( actual_func = (
jina_embed.func
if isinstance(jina_embed, EmbeddingFunc)
else jina_embed
)
return await actual_func(
texts, texts,
embedding_dim=embedding_dim, embedding_dim=embedding_dim,
base_url=host, base_url=host,
@ -694,16 +786,21 @@ def create_app(args):
elif binding == "gemini": elif binding == "gemini":
from lightrag.llm.gemini import gemini_embed from lightrag.llm.gemini import gemini_embed
# Use pre-processed configuration if available, otherwise fallback to dynamic parsing actual_func = (
gemini_embed.func
if isinstance(gemini_embed, EmbeddingFunc)
else gemini_embed
)
# Use pre-processed configuration if available
if config_cache.gemini_embedding_options is not None: if config_cache.gemini_embedding_options is not None:
gemini_options = config_cache.gemini_embedding_options gemini_options = config_cache.gemini_embedding_options
else: else:
# Fallback for cases where config cache wasn't initialized properly
from lightrag.llm.binding_options import GeminiEmbeddingOptions from lightrag.llm.binding_options import GeminiEmbeddingOptions
gemini_options = GeminiEmbeddingOptions.options_dict(args) gemini_options = GeminiEmbeddingOptions.options_dict(args)
return await gemini_embed( return await actual_func(
texts, texts,
model=model, model=model,
base_url=host, base_url=host,
@ -714,7 +811,12 @@ def create_app(args):
else: # openai and compatible else: # openai and compatible
from lightrag.llm.openai import openai_embed from lightrag.llm.openai import openai_embed
return await openai_embed( actual_func = (
openai_embed.func
if isinstance(openai_embed, EmbeddingFunc)
else openai_embed
)
return await actual_func(
texts, texts,
model=model, model=model,
base_url=host, base_url=host,
@ -724,7 +826,21 @@ def create_app(args):
except ImportError as e: except ImportError as e:
raise Exception(f"Failed to import {binding} embedding: {e}") raise Exception(f"Failed to import {binding} embedding: {e}")
return optimized_embedding_function # Step 4: Wrap in EmbeddingFunc and return
embedding_func_instance = EmbeddingFunc(
embedding_dim=final_embedding_dim,
func=optimized_embedding_function,
max_token_size=final_max_token_size,
send_dimensions=False, # Will be set later based on binding requirements
)
# Log final embedding configuration
logger.info(
f"Embedding config: binding={binding} model={model} "
f"embedding_dim={final_embedding_dim} max_token_size={final_max_token_size}"
)
return embedding_func_instance
llm_timeout = get_env_value("LLM_TIMEOUT", DEFAULT_LLM_TIMEOUT, int) llm_timeout = get_env_value("LLM_TIMEOUT", DEFAULT_LLM_TIMEOUT, int)
embedding_timeout = get_env_value( embedding_timeout = get_env_value(
@ -758,25 +874,24 @@ def create_app(args):
**kwargs, **kwargs,
) )
# Create embedding function with optimized configuration # Create embedding function with optimized configuration and max_token_size inheritance
import inspect import inspect
# Create the optimized embedding function # Create the EmbeddingFunc instance (now returns complete EmbeddingFunc with max_token_size)
optimized_embedding_func = create_optimized_embedding_function( embedding_func = create_optimized_embedding_function(
config_cache=config_cache, config_cache=config_cache,
binding=args.embedding_binding, binding=args.embedding_binding,
model=args.embedding_model, model=args.embedding_model,
host=args.embedding_binding_host, host=args.embedding_binding_host,
api_key=args.embedding_binding_api_key, api_key=args.embedding_binding_api_key,
args=args, # Pass args object for fallback option generation args=args,
) )
# Get embedding_send_dim from centralized configuration # Get embedding_send_dim from centralized configuration
embedding_send_dim = args.embedding_send_dim embedding_send_dim = args.embedding_send_dim
# Check if the function signature has embedding_dim parameter # Check if the underlying function signature has embedding_dim parameter
# Note: Since optimized_embedding_func is an async function, inspect its signature sig = inspect.signature(embedding_func.func)
sig = inspect.signature(optimized_embedding_func)
has_embedding_dim_param = "embedding_dim" in sig.parameters has_embedding_dim_param = "embedding_dim" in sig.parameters
# Determine send_dimensions value based on binding type # Determine send_dimensions value based on binding type
@ -794,18 +909,27 @@ def create_app(args):
else: else:
dimension_control = "by not hasparam" dimension_control = "by not hasparam"
# Set send_dimensions on the EmbeddingFunc instance
embedding_func.send_dimensions = send_dimensions
logger.info( logger.info(
f"Send embedding dimension: {send_dimensions} {dimension_control} " f"Send embedding dimension: {send_dimensions} {dimension_control} "
f"(dimensions={args.embedding_dim}, has_param={has_embedding_dim_param}, " f"(dimensions={embedding_func.embedding_dim}, has_param={has_embedding_dim_param}, "
f"binding={args.embedding_binding})" f"binding={args.embedding_binding})"
) )
# Create EmbeddingFunc with send_dimensions attribute # Log max_token_size source
embedding_func = EmbeddingFunc( if embedding_func.max_token_size:
embedding_dim=args.embedding_dim, source = (
func=optimized_embedding_func, "env variable"
send_dimensions=send_dimensions, if args.embedding_token_limit
) else f"{args.embedding_binding} provider default"
)
logger.info(
f"Embedding max_token_size: {embedding_func.max_token_size} (from {source})"
)
else:
logger.info("Embedding max_token_size: not set (90% token warning disabled)")
# Configure rerank function based on args.rerank_bindingparameter # Configure rerank function based on args.rerank_bindingparameter
rerank_model_func = None rerank_model_func = None
@ -1017,14 +1141,13 @@ def create_app(args):
async def get_status(request: Request): async def get_status(request: Request):
"""Get current system status""" """Get current system status"""
try: try:
# Extract workspace from request header or use default
workspace = get_workspace_from_request(request) workspace = get_workspace_from_request(request)
default_workspace = get_default_workspace()
# Construct namespace (following GraphDB pattern) if workspace is None:
namespace = f"{workspace}:pipeline" if workspace else "pipeline_status" workspace = default_workspace
pipeline_status = await get_namespace_data(
# Get workspace-specific pipeline status "pipeline_status", workspace=workspace
pipeline_status = await get_namespace_data(namespace) )
if not auth_configured: if not auth_configured:
auth_mode = "disabled" auth_mode = "disabled"
@ -1055,8 +1178,7 @@ def create_app(args):
"vector_storage": args.vector_storage, "vector_storage": args.vector_storage,
"enable_llm_cache_for_extract": args.enable_llm_cache_for_extract, "enable_llm_cache_for_extract": args.enable_llm_cache_for_extract,
"enable_llm_cache": args.enable_llm_cache, "enable_llm_cache": args.enable_llm_cache,
"workspace": workspace, "workspace": default_workspace,
"default_workspace": args.workspace,
"max_graph_nodes": args.max_graph_nodes, "max_graph_nodes": args.max_graph_nodes,
# Rerank configuration # Rerank configuration
"enable_rerank": rerank_model_func is not None, "enable_rerank": rerank_model_func is not None,
@ -1245,6 +1367,12 @@ def check_and_install_dependencies():
def main(): def main():
# Explicitly initialize configuration for clarity
# (The proxy will auto-initialize anyway, but this makes intent clear)
from .config import initialize_config
initialize_config()
# Check if running under Gunicorn # Check if running under Gunicorn
if "GUNICORN_CMD_ARGS" in os.environ: if "GUNICORN_CMD_ARGS" in os.environ:
# If started with Gunicorn, return directly as Gunicorn will call get_application # If started with Gunicorn, return directly as Gunicorn will call get_application

View file

@ -1641,26 +1641,15 @@ async def background_delete_documents(
"""Background task to delete multiple documents""" """Background task to delete multiple documents"""
from lightrag.kg.shared_storage import ( from lightrag.kg.shared_storage import (
get_namespace_data, get_namespace_data,
get_storage_keyed_lock, get_namespace_lock,
initialize_pipeline_status,
) )
# Step 1: Get workspace pipeline_status = await get_namespace_data(
workspace = rag.workspace "pipeline_status", workspace=rag.workspace
)
# Step 2: Construct namespace pipeline_status_lock = get_namespace_lock(
namespace = f"{workspace}:pipeline" if workspace else "pipeline_status" "pipeline_status", workspace=rag.workspace
# Step 3: Ensure initialization
await initialize_pipeline_status(workspace)
# Step 4: Get lock
pipeline_status_lock = get_storage_keyed_lock(
keys="status", namespace=namespace, enable_logging=False
) )
# Step 5: Get data
pipeline_status = await get_namespace_data(namespace)
total_docs = len(doc_ids) total_docs = len(doc_ids)
successful_deletions = [] successful_deletions = []
@ -2149,27 +2138,16 @@ def create_document_routes(
""" """
from lightrag.kg.shared_storage import ( from lightrag.kg.shared_storage import (
get_namespace_data, get_namespace_data,
get_storage_keyed_lock, get_namespace_lock,
initialize_pipeline_status,
) )
# Get pipeline status and lock # Get pipeline status and lock
# Step 1: Get workspace pipeline_status = await get_namespace_data(
workspace = rag.workspace "pipeline_status", workspace=rag.workspace
)
# Step 2: Construct namespace pipeline_status_lock = get_namespace_lock(
namespace = f"{workspace}:pipeline" if workspace else "pipeline_status" "pipeline_status", workspace=rag.workspace
# Step 3: Ensure initialization
await initialize_pipeline_status(workspace)
# Step 4: Get lock
pipeline_status_lock = get_storage_keyed_lock(
keys="status", namespace=namespace, enable_logging=False
) )
# Step 5: Get data
pipeline_status = await get_namespace_data(namespace)
# Check and set status with lock # Check and set status with lock
async with pipeline_status_lock: async with pipeline_status_lock:
@ -2360,15 +2338,16 @@ def create_document_routes(
try: try:
from lightrag.kg.shared_storage import ( from lightrag.kg.shared_storage import (
get_namespace_data, get_namespace_data,
get_namespace_lock,
get_all_update_flags_status, get_all_update_flags_status,
initialize_pipeline_status,
) )
# Get workspace-specific pipeline status pipeline_status = await get_namespace_data(
workspace = rag.workspace "pipeline_status", workspace=rag.workspace
namespace = f"{workspace}:pipeline" if workspace else "pipeline_status" )
await initialize_pipeline_status(workspace) pipeline_status_lock = get_namespace_lock(
pipeline_status = await get_namespace_data(namespace) "pipeline_status", workspace=rag.workspace
)
# Get update flags status for all namespaces # Get update flags status for all namespaces
update_status = await get_all_update_flags_status() update_status = await get_all_update_flags_status()
@ -2385,8 +2364,9 @@ def create_document_routes(
processed_flags.append(bool(flag)) processed_flags.append(bool(flag))
processed_update_status[namespace] = processed_flags processed_update_status[namespace] = processed_flags
# Convert to regular dict if it's a Manager.dict async with pipeline_status_lock:
status_dict = dict(pipeline_status) # Convert to regular dict if it's a Manager.dict
status_dict = dict(pipeline_status)
# Add processed update_status to the status dictionary # Add processed update_status to the status dictionary
status_dict["update_status"] = processed_update_status status_dict["update_status"] = processed_update_status
@ -2575,20 +2555,15 @@ def create_document_routes(
try: try:
from lightrag.kg.shared_storage import ( from lightrag.kg.shared_storage import (
get_namespace_data, get_namespace_data,
get_storage_keyed_lock, get_namespace_lock,
initialize_pipeline_status,
) )
# Get workspace-specific pipeline status pipeline_status = await get_namespace_data(
workspace = rag.workspace "pipeline_status", workspace=rag.workspace
namespace = f"{workspace}:pipeline" if workspace else "pipeline_status" )
await initialize_pipeline_status(workspace) pipeline_status_lock = get_namespace_lock(
"pipeline_status", workspace=rag.workspace
# Use workspace-aware lock to check busy flag
pipeline_status_lock = get_storage_keyed_lock(
keys="status", namespace=namespace, enable_logging=False
) )
pipeline_status = await get_namespace_data(namespace)
# Check if pipeline is busy with proper lock # Check if pipeline is busy with proper lock
async with pipeline_status_lock: async with pipeline_status_lock:
@ -2993,26 +2968,15 @@ def create_document_routes(
try: try:
from lightrag.kg.shared_storage import ( from lightrag.kg.shared_storage import (
get_namespace_data, get_namespace_data,
get_storage_keyed_lock, get_namespace_lock,
initialize_pipeline_status,
) )
# Step 1: Get workspace pipeline_status = await get_namespace_data(
workspace = rag.workspace "pipeline_status", workspace=rag.workspace
)
# Step 2: Construct namespace pipeline_status_lock = get_namespace_lock(
namespace = f"{workspace}:pipeline" if workspace else "pipeline_status" "pipeline_status", workspace=rag.workspace
# Step 3: Ensure initialization
await initialize_pipeline_status(workspace)
# Step 4: Get lock
pipeline_status_lock = get_storage_keyed_lock(
keys="status", namespace=namespace, enable_logging=False
) )
# Step 5: Get data
pipeline_status = await get_namespace_data(namespace)
async with pipeline_status_lock: async with pipeline_status_lock:
if not pipeline_status.get("busy", False): if not pipeline_status.get("busy", False):

View file

@ -84,10 +84,7 @@ _init_flags: Optional[Dict[str, bool]] = None # namespace -> initialized
_update_flags: Optional[Dict[str, bool]] = None # namespace -> updated _update_flags: Optional[Dict[str, bool]] = None # namespace -> updated
# locks for mutex access # locks for mutex access
_storage_lock: Optional[LockType] = None
_internal_lock: Optional[LockType] = None _internal_lock: Optional[LockType] = None
_pipeline_status_lock: Optional[LockType] = None
_graph_db_lock: Optional[LockType] = None
_data_init_lock: Optional[LockType] = None _data_init_lock: Optional[LockType] = None
# Manager for all keyed locks # Manager for all keyed locks
_storage_keyed_lock: Optional["KeyedUnifiedLock"] = None _storage_keyed_lock: Optional["KeyedUnifiedLock"] = None
@ -98,6 +95,22 @@ _async_locks: Optional[Dict[str, asyncio.Lock]] = None
_debug_n_locks_acquired: int = 0 _debug_n_locks_acquired: int = 0
def get_final_namespace(namespace: str, workspace: str | None = None):
global _default_workspace
if workspace is None:
workspace = _default_workspace
if workspace is None:
direct_log(
f"Error: Invoke namespace operation without workspace, pid={os.getpid()}",
level="ERROR",
)
raise ValueError("Invoke namespace operation without workspace")
final_namespace = f"{workspace}:{namespace}" if workspace else f"{namespace}"
return final_namespace
def inc_debug_n_locks_acquired(): def inc_debug_n_locks_acquired():
global _debug_n_locks_acquired global _debug_n_locks_acquired
if DEBUG_LOCKS: if DEBUG_LOCKS:
@ -1056,40 +1069,10 @@ def get_internal_lock(enable_logging: bool = False) -> UnifiedLock:
) )
def get_storage_lock(enable_logging: bool = False) -> UnifiedLock: # Workspace based storage_lock is implemented by get_storage_keyed_lock instead.
"""return unified storage lock for data consistency""" # Workspace based pipeline_status_lock is implemented by get_storage_keyed_lock instead.
async_lock = _async_locks.get("storage_lock") if _is_multiprocess else None # No need to implement graph_db_lock:
return UnifiedLock( # data integrity is ensured by entity level keyed-lock and allowing only one process to hold pipeline at a time.
lock=_storage_lock,
is_async=not _is_multiprocess,
name="storage_lock",
enable_logging=enable_logging,
async_lock=async_lock,
)
def get_pipeline_status_lock(enable_logging: bool = False) -> UnifiedLock:
"""return unified storage lock for data consistency"""
async_lock = _async_locks.get("pipeline_status_lock") if _is_multiprocess else None
return UnifiedLock(
lock=_pipeline_status_lock,
is_async=not _is_multiprocess,
name="pipeline_status_lock",
enable_logging=enable_logging,
async_lock=async_lock,
)
def get_graph_db_lock(enable_logging: bool = False) -> UnifiedLock:
"""return unified graph database lock for ensuring atomic operations"""
async_lock = _async_locks.get("graph_db_lock") if _is_multiprocess else None
return UnifiedLock(
lock=_graph_db_lock,
is_async=not _is_multiprocess,
name="graph_db_lock",
enable_logging=enable_logging,
async_lock=async_lock,
)
def get_storage_keyed_lock( def get_storage_keyed_lock(
@ -1193,14 +1176,11 @@ def initialize_share_data(workers: int = 1):
_manager, \ _manager, \
_workers, \ _workers, \
_is_multiprocess, \ _is_multiprocess, \
_storage_lock, \
_lock_registry, \ _lock_registry, \
_lock_registry_count, \ _lock_registry_count, \
_lock_cleanup_data, \ _lock_cleanup_data, \
_registry_guard, \ _registry_guard, \
_internal_lock, \ _internal_lock, \
_pipeline_status_lock, \
_graph_db_lock, \
_data_init_lock, \ _data_init_lock, \
_shared_dicts, \ _shared_dicts, \
_init_flags, \ _init_flags, \
@ -1228,9 +1208,6 @@ def initialize_share_data(workers: int = 1):
_lock_cleanup_data = _manager.dict() _lock_cleanup_data = _manager.dict()
_registry_guard = _manager.RLock() _registry_guard = _manager.RLock()
_internal_lock = _manager.Lock() _internal_lock = _manager.Lock()
_storage_lock = _manager.Lock()
_pipeline_status_lock = _manager.Lock()
_graph_db_lock = _manager.Lock()
_data_init_lock = _manager.Lock() _data_init_lock = _manager.Lock()
_shared_dicts = _manager.dict() _shared_dicts = _manager.dict()
_init_flags = _manager.dict() _init_flags = _manager.dict()
@ -1241,8 +1218,6 @@ def initialize_share_data(workers: int = 1):
# Initialize async locks for multiprocess mode # Initialize async locks for multiprocess mode
_async_locks = { _async_locks = {
"internal_lock": asyncio.Lock(), "internal_lock": asyncio.Lock(),
"storage_lock": asyncio.Lock(),
"pipeline_status_lock": asyncio.Lock(),
"graph_db_lock": asyncio.Lock(), "graph_db_lock": asyncio.Lock(),
"data_init_lock": asyncio.Lock(), "data_init_lock": asyncio.Lock(),
} }
@ -1253,9 +1228,6 @@ def initialize_share_data(workers: int = 1):
else: else:
_is_multiprocess = False _is_multiprocess = False
_internal_lock = asyncio.Lock() _internal_lock = asyncio.Lock()
_storage_lock = asyncio.Lock()
_pipeline_status_lock = asyncio.Lock()
_graph_db_lock = asyncio.Lock()
_data_init_lock = asyncio.Lock() _data_init_lock = asyncio.Lock()
_shared_dicts = {} _shared_dicts = {}
_init_flags = {} _init_flags = {}
@ -1273,29 +1245,19 @@ def initialize_share_data(workers: int = 1):
_initialized = True _initialized = True
async def initialize_pipeline_status(workspace: str = ""): async def initialize_pipeline_status(workspace: str | None = None):
""" """
Initialize pipeline namespace with default values. Initialize pipeline_status share data with default values.
This function could be called before during FASTAPI lifespan for each worker.
Args: Args:
workspace: Optional workspace identifier for multi-tenant isolation. workspace: Optional workspace identifier for pipeline_status of specific workspace.
If empty string, uses the default workspace set by If None or empty string, uses the default workspace set by
set_default_workspace(). If no default is set, uses set_default_workspace().
global "pipeline_status" namespace.
This function is called during FASTAPI lifespan for each worker.
""" """
# Backward compatibility: use default workspace if not provided pipeline_namespace = await get_namespace_data(
if not workspace: "pipeline_status", first_init=True, workspace=workspace
workspace = get_default_workspace() )
# Construct namespace (following GraphDB pattern)
if workspace:
namespace = f"{workspace}:pipeline"
else:
namespace = "pipeline_status" # Global namespace for backward compatibility
pipeline_namespace = await get_namespace_data(namespace, first_init=True)
async with get_internal_lock(): async with get_internal_lock():
# Check if already initialized by checking for required fields # Check if already initialized by checking for required fields
@ -1318,12 +1280,14 @@ async def initialize_pipeline_status(workspace: str = ""):
"history_messages": history_messages, # 使用共享列表对象 "history_messages": history_messages, # 使用共享列表对象
} }
) )
final_namespace = get_final_namespace("pipeline_status", workspace)
direct_log( direct_log(
f"Process {os.getpid()} Pipeline namespace '{namespace}' initialized" f"Process {os.getpid()} Pipeline namespace '{final_namespace}' initialized"
) )
async def get_update_flag(namespace: str): async def get_update_flag(namespace: str, workspace: str | None = None):
""" """
Create a namespace's update flag for a workers. Create a namespace's update flag for a workers.
Returen the update flag to caller for referencing or reset. Returen the update flag to caller for referencing or reset.
@ -1332,14 +1296,16 @@ async def get_update_flag(namespace: str):
if _update_flags is None: if _update_flags is None:
raise ValueError("Try to create namespace before Shared-Data is initialized") raise ValueError("Try to create namespace before Shared-Data is initialized")
final_namespace = get_final_namespace(namespace, workspace)
async with get_internal_lock(): async with get_internal_lock():
if namespace not in _update_flags: if final_namespace not in _update_flags:
if _is_multiprocess and _manager is not None: if _is_multiprocess and _manager is not None:
_update_flags[namespace] = _manager.list() _update_flags[final_namespace] = _manager.list()
else: else:
_update_flags[namespace] = [] _update_flags[final_namespace] = []
direct_log( direct_log(
f"Process {os.getpid()} initialized updated flags for namespace: [{namespace}]" f"Process {os.getpid()} initialized updated flags for namespace: [{final_namespace}]"
) )
if _is_multiprocess and _manager is not None: if _is_multiprocess and _manager is not None:
@ -1352,39 +1318,43 @@ async def get_update_flag(namespace: str):
new_update_flag = MutableBoolean(False) new_update_flag = MutableBoolean(False)
_update_flags[namespace].append(new_update_flag) _update_flags[final_namespace].append(new_update_flag)
return new_update_flag return new_update_flag
async def set_all_update_flags(namespace: str): async def set_all_update_flags(namespace: str, workspace: str | None = None):
"""Set all update flag of namespace indicating all workers need to reload data from files""" """Set all update flag of namespace indicating all workers need to reload data from files"""
global _update_flags global _update_flags
if _update_flags is None: if _update_flags is None:
raise ValueError("Try to create namespace before Shared-Data is initialized") raise ValueError("Try to create namespace before Shared-Data is initialized")
final_namespace = get_final_namespace(namespace, workspace)
async with get_internal_lock(): async with get_internal_lock():
if namespace not in _update_flags: if final_namespace not in _update_flags:
raise ValueError(f"Namespace {namespace} not found in update flags") raise ValueError(f"Namespace {final_namespace} not found in update flags")
# Update flags for both modes # Update flags for both modes
for i in range(len(_update_flags[namespace])): for i in range(len(_update_flags[final_namespace])):
_update_flags[namespace][i].value = True _update_flags[final_namespace][i].value = True
async def clear_all_update_flags(namespace: str): async def clear_all_update_flags(namespace: str, workspace: str | None = None):
"""Clear all update flag of namespace indicating all workers need to reload data from files""" """Clear all update flag of namespace indicating all workers need to reload data from files"""
global _update_flags global _update_flags
if _update_flags is None: if _update_flags is None:
raise ValueError("Try to create namespace before Shared-Data is initialized") raise ValueError("Try to create namespace before Shared-Data is initialized")
final_namespace = get_final_namespace(namespace, workspace)
async with get_internal_lock(): async with get_internal_lock():
if namespace not in _update_flags: if final_namespace not in _update_flags:
raise ValueError(f"Namespace {namespace} not found in update flags") raise ValueError(f"Namespace {final_namespace} not found in update flags")
# Update flags for both modes # Update flags for both modes
for i in range(len(_update_flags[namespace])): for i in range(len(_update_flags[final_namespace])):
_update_flags[namespace][i].value = False _update_flags[final_namespace][i].value = False
async def get_all_update_flags_status() -> Dict[str, list]: async def get_all_update_flags_status(workspace: str | None = None) -> Dict[str, list]:
""" """
Get update flags status for all namespaces. Get update flags status for all namespaces.
@ -1394,9 +1364,17 @@ async def get_all_update_flags_status() -> Dict[str, list]:
if _update_flags is None: if _update_flags is None:
return {} return {}
if workspace is None:
workspace = get_default_workspace
result = {} result = {}
async with get_internal_lock(): async with get_internal_lock():
for namespace, flags in _update_flags.items(): for namespace, flags in _update_flags.items():
namespace_split = namespace.split(":")
if workspace and not namespace_split[0] == workspace:
continue
if not workspace and namespace_split[0]:
continue
worker_statuses = [] worker_statuses = []
for flag in flags: for flag in flags:
if _is_multiprocess: if _is_multiprocess:
@ -1408,7 +1386,9 @@ async def get_all_update_flags_status() -> Dict[str, list]:
return result return result
async def try_initialize_namespace(namespace: str) -> bool: async def try_initialize_namespace(
namespace: str, workspace: str | None = None
) -> bool:
""" """
Returns True if the current worker(process) gets initialization permission for loading data later. Returns True if the current worker(process) gets initialization permission for loading data later.
The worker does not get the permission is prohibited to load data from files. The worker does not get the permission is prohibited to load data from files.
@ -1418,57 +1398,76 @@ async def try_initialize_namespace(namespace: str) -> bool:
if _init_flags is None: if _init_flags is None:
raise ValueError("Try to create nanmespace before Shared-Data is initialized") raise ValueError("Try to create nanmespace before Shared-Data is initialized")
final_namespace = get_final_namespace(namespace, workspace)
async with get_internal_lock(): async with get_internal_lock():
if namespace not in _init_flags: if final_namespace not in _init_flags:
_init_flags[namespace] = True _init_flags[final_namespace] = True
direct_log( direct_log(
f"Process {os.getpid()} ready to initialize storage namespace: [{namespace}]" f"Process {os.getpid()} ready to initialize storage namespace: [{final_namespace}]"
) )
return True return True
direct_log( direct_log(
f"Process {os.getpid()} storage namespace already initialized: [{namespace}]" f"Process {os.getpid()} storage namespace already initialized: [{final_namespace}]"
) )
return False return False
async def get_namespace_data( async def get_namespace_data(
namespace: str, first_init: bool = False namespace: str, first_init: bool = False, workspace: str | None = None
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""get the shared data reference for specific namespace """get the shared data reference for specific namespace
Args: Args:
namespace: The namespace to retrieve namespace: The namespace to retrieve
allow_create: If True, allows creation of the namespace if it doesn't exist. first_init: If True, allows pipeline_status namespace to create namespace if it doesn't exist.
Used internally by initialize_pipeline_status(). Prevent getting pipeline_status namespace without initialize_pipeline_status().
This parameter is used internally by initialize_pipeline_status().
workspace: Workspace identifier (may be empty string for global namespace)
""" """
if _shared_dicts is None: if _shared_dicts is None:
direct_log( direct_log(
f"Error: try to getnanmespace before it is initialized, pid={os.getpid()}", f"Error: Try to getnanmespace before it is initialized, pid={os.getpid()}",
level="ERROR", level="ERROR",
) )
raise ValueError("Shared dictionaries not initialized") raise ValueError("Shared dictionaries not initialized")
async with get_internal_lock(): final_namespace = get_final_namespace(namespace, workspace)
if namespace not in _shared_dicts:
# Special handling for pipeline_status namespace
# Supports both global "pipeline_status" and workspace-specific "{workspace}:pipeline"
is_pipeline = namespace == "pipeline_status" or namespace.endswith(
":pipeline"
)
if is_pipeline and not first_init: async with get_internal_lock():
if final_namespace not in _shared_dicts:
# Special handling for pipeline_status namespace
if final_namespace.endswith(":pipeline_status") and not first_init:
# Check if pipeline_status should have been initialized but wasn't # Check if pipeline_status should have been initialized but wasn't
# This helps users understand they need to call initialize_pipeline_status() # This helps users to call initialize_pipeline_status() before get_namespace_data()
raise PipelineNotInitializedError(namespace) raise PipelineNotInitializedError(final_namespace)
# For other namespaces or when allow_create=True, create them dynamically # For other namespaces or when allow_create=True, create them dynamically
if _is_multiprocess and _manager is not None: if _is_multiprocess and _manager is not None:
_shared_dicts[namespace] = _manager.dict() _shared_dicts[final_namespace] = _manager.dict()
else: else:
_shared_dicts[namespace] = {} _shared_dicts[final_namespace] = {}
return _shared_dicts[namespace] return _shared_dicts[final_namespace]
def get_namespace_lock(
namespace: str, workspace: str | None = None, enable_logging: bool = False
) -> str:
"""Get the lock key for a namespace.
Args:
namespace: The namespace to get the lock key for.
workspace: Workspace identifier (may be empty string for global namespace)
Returns:
str: The lock key for the namespace.
"""
final_namespace = get_final_namespace(namespace, workspace)
return get_storage_keyed_lock(
["default_key"], namespace=final_namespace, enable_logging=enable_logging
)
def finalize_share_data(): def finalize_share_data():
@ -1484,10 +1483,7 @@ def finalize_share_data():
global \ global \
_manager, \ _manager, \
_is_multiprocess, \ _is_multiprocess, \
_storage_lock, \
_internal_lock, \ _internal_lock, \
_pipeline_status_lock, \
_graph_db_lock, \
_data_init_lock, \ _data_init_lock, \
_shared_dicts, \ _shared_dicts, \
_init_flags, \ _init_flags, \
@ -1552,10 +1548,7 @@ def finalize_share_data():
_is_multiprocess = None _is_multiprocess = None
_shared_dicts = None _shared_dicts = None
_init_flags = None _init_flags = None
_storage_lock = None
_internal_lock = None _internal_lock = None
_pipeline_status_lock = None
_graph_db_lock = None
_data_init_lock = None _data_init_lock = None
_update_flags = None _update_flags = None
_async_locks = None _async_locks = None
@ -1563,21 +1556,23 @@ def finalize_share_data():
direct_log(f"Process {os.getpid()} storage data finalization complete") direct_log(f"Process {os.getpid()} storage data finalization complete")
def set_default_workspace(workspace: str): def set_default_workspace(workspace: str | None = None):
""" """
Set default workspace for backward compatibility. Set default workspace for namespace operations for backward compatibility.
This allows initialize_pipeline_status() to automatically use the correct This allows get_namespace_data(),get_namespace_lock() or initialize_pipeline_status() to
workspace when called without parameters, maintaining compatibility with automatically use the correct workspace when called without workspace parameters,
legacy code that doesn't pass workspace explicitly. maintaining compatibility with legacy code that doesn't pass workspace explicitly.
Args: Args:
workspace: Workspace identifier (may be empty string for global namespace) workspace: Workspace identifier (may be empty string for global namespace)
""" """
global _default_workspace global _default_workspace
if workspace is None:
workspace = ""
_default_workspace = workspace _default_workspace = workspace
direct_log( direct_log(
f"Default workspace set to: '{workspace}' (empty means global)", f"Default workspace set to: '{_default_workspace}' (empty means global)",
level="DEBUG", level="DEBUG",
) )
@ -1587,7 +1582,7 @@ def get_default_workspace() -> str:
Get default workspace for backward compatibility. Get default workspace for backward compatibility.
Returns: Returns:
The default workspace string. Empty string means global namespace. The default workspace string. Empty string means global namespace. None means not set.
""" """
global _default_workspace global _default_workspace
return _default_workspace if _default_workspace is not None else "" return _default_workspace

View file

@ -3,6 +3,7 @@ from __future__ import annotations
import traceback import traceback
import asyncio import asyncio
import configparser import configparser
import inspect
import os import os
import time import time
import warnings import warnings
@ -12,6 +13,7 @@ from functools import partial
from typing import ( from typing import (
Any, Any,
AsyncIterator, AsyncIterator,
Awaitable,
Callable, Callable,
Iterator, Iterator,
cast, cast,
@ -20,6 +22,7 @@ from typing import (
Optional, Optional,
List, List,
Dict, Dict,
Union,
) )
from lightrag.prompt import PROMPTS from lightrag.prompt import PROMPTS
from lightrag.exceptions import PipelineCancelledException from lightrag.exceptions import PipelineCancelledException
@ -61,10 +64,10 @@ from lightrag.kg import (
from lightrag.kg.shared_storage import ( from lightrag.kg.shared_storage import (
get_namespace_data, get_namespace_data,
get_graph_db_lock,
get_data_init_lock, get_data_init_lock,
get_storage_keyed_lock, get_default_workspace,
initialize_pipeline_status, set_default_workspace,
get_namespace_lock,
) )
from lightrag.base import ( from lightrag.base import (
@ -88,7 +91,7 @@ from lightrag.operate import (
merge_nodes_and_edges, merge_nodes_and_edges,
kg_query, kg_query,
naive_query, naive_query,
_rebuild_knowledge_from_chunks, rebuild_knowledge_from_chunks,
) )
from lightrag.constants import GRAPH_FIELD_SEP from lightrag.constants import GRAPH_FIELD_SEP
from lightrag.utils import ( from lightrag.utils import (
@ -244,11 +247,13 @@ class LightRAG:
int, int,
int, int,
], ],
List[Dict[str, Any]], Union[List[Dict[str, Any]], Awaitable[List[Dict[str, Any]]]],
] = field(default_factory=lambda: chunking_by_token_size) ] = field(default_factory=lambda: chunking_by_token_size)
""" """
Custom chunking function for splitting text into chunks before processing. Custom chunking function for splitting text into chunks before processing.
The function can be either synchronous or asynchronous.
The function should take the following parameters: The function should take the following parameters:
- `tokenizer`: A Tokenizer instance to use for tokenization. - `tokenizer`: A Tokenizer instance to use for tokenization.
@ -258,7 +263,8 @@ class LightRAG:
- `chunk_token_size`: The maximum number of tokens per chunk. - `chunk_token_size`: The maximum number of tokens per chunk.
- `chunk_overlap_token_size`: The number of overlapping tokens between consecutive chunks. - `chunk_overlap_token_size`: The number of overlapping tokens between consecutive chunks.
The function should return a list of dictionaries, where each dictionary contains the following keys: The function should return a list of dictionaries (or an awaitable that resolves to a list),
where each dictionary contains the following keys:
- `tokens`: The number of tokens in the chunk. - `tokens`: The number of tokens in the chunk.
- `content`: The text content of the chunk. - `content`: The text content of the chunk.
@ -271,6 +277,9 @@ class LightRAG:
embedding_func: EmbeddingFunc | None = field(default=None) embedding_func: EmbeddingFunc | None = field(default=None)
"""Function for computing text embeddings. Must be set before use.""" """Function for computing text embeddings. Must be set before use."""
embedding_token_limit: int | None = field(default=None, init=False)
"""Token limit for embedding model. Set automatically from embedding_func.max_token_size in __post_init__."""
embedding_batch_num: int = field(default=int(os.getenv("EMBEDDING_BATCH_NUM", 10))) embedding_batch_num: int = field(default=int(os.getenv("EMBEDDING_BATCH_NUM", 10)))
"""Batch size for embedding computations.""" """Batch size for embedding computations."""
@ -514,6 +523,16 @@ class LightRAG:
logger.debug(f"LightRAG init with param:\n {_print_config}\n") logger.debug(f"LightRAG init with param:\n {_print_config}\n")
# Init Embedding # Init Embedding
# Step 1: Capture max_token_size before applying decorator (decorator strips dataclass attributes)
embedding_max_token_size = None
if self.embedding_func and hasattr(self.embedding_func, "max_token_size"):
embedding_max_token_size = self.embedding_func.max_token_size
logger.debug(
f"Captured embedding max_token_size: {embedding_max_token_size}"
)
self.embedding_token_limit = embedding_max_token_size
# Step 2: Apply priority wrapper decorator
self.embedding_func = priority_limit_async_func_call( self.embedding_func = priority_limit_async_func_call(
self.embedding_func_max_async, self.embedding_func_max_async,
llm_timeout=self.default_embedding_timeout, llm_timeout=self.default_embedding_timeout,
@ -640,12 +659,11 @@ class LightRAG:
async def initialize_storages(self): async def initialize_storages(self):
"""Storage initialization must be called one by one to prevent deadlock""" """Storage initialization must be called one by one to prevent deadlock"""
if self._storages_status == StoragesStatus.CREATED: if self._storages_status == StoragesStatus.CREATED:
# Set default workspace for backward compatibility # Set the first initialized workspace will set the default workspace
# This allows initialize_pipeline_status() called without parameters # Allows namespace operation without specifying workspace for backward compatibility
# to use the correct workspace default_workspace = get_default_workspace()
from lightrag.kg.shared_storage import set_default_workspace if default_workspace is None:
set_default_workspace(self.workspace)
set_default_workspace(self.workspace)
for storage in ( for storage in (
self.full_docs, self.full_docs,
@ -718,7 +736,7 @@ class LightRAG:
async def check_and_migrate_data(self): async def check_and_migrate_data(self):
"""Check if data migration is needed and perform migration if necessary""" """Check if data migration is needed and perform migration if necessary"""
async with get_data_init_lock(enable_logging=True): async with get_data_init_lock():
try: try:
# Check if migration is needed: # Check if migration is needed:
# 1. chunk_entity_relation_graph has entities and relations (count > 0) # 1. chunk_entity_relation_graph has entities and relations (count > 0)
@ -1581,22 +1599,12 @@ class LightRAG:
""" """
# Get pipeline status shared data and lock # Get pipeline status shared data and lock
# Step 1: Get workspace pipeline_status = await get_namespace_data(
workspace = self.workspace "pipeline_status", workspace=self.workspace
)
# Step 2: Construct namespace (following GraphDB pattern) pipeline_status_lock = get_namespace_lock(
namespace = f"{workspace}:pipeline" if workspace else "pipeline_status" "pipeline_status", workspace=self.workspace
# Step 3: Ensure initialization (on first access)
await initialize_pipeline_status(workspace)
# Step 4: Get lock
pipeline_status_lock = get_storage_keyed_lock(
keys="status", namespace=namespace, enable_logging=False
) )
# Step 5: Get data
pipeline_status = await get_namespace_data(namespace)
# Check if another process is already processing the queue # Check if another process is already processing the queue
async with pipeline_status_lock: async with pipeline_status_lock:
@ -1778,7 +1786,28 @@ class LightRAG:
) )
content = content_data["content"] content = content_data["content"]
# Generate chunks from document # Call chunking function, supporting both sync and async implementations
chunking_result = self.chunking_func(
self.tokenizer,
content,
split_by_character,
split_by_character_only,
self.chunk_overlap_token_size,
self.chunk_token_size,
)
# If result is awaitable, await to get actual result
if inspect.isawaitable(chunking_result):
chunking_result = await chunking_result
# Validate return type
if not isinstance(chunking_result, (list, tuple)):
raise TypeError(
f"chunking_func must return a list or tuple of dicts, "
f"got {type(chunking_result)}"
)
# Build chunks dictionary
chunks: dict[str, Any] = { chunks: dict[str, Any] = {
compute_mdhash_id(dp["content"], prefix="chunk-"): { compute_mdhash_id(dp["content"], prefix="chunk-"): {
**dp, **dp,
@ -1786,14 +1815,7 @@ class LightRAG:
"file_path": file_path, # Add file path to each chunk "file_path": file_path, # Add file path to each chunk
"llm_cache_list": [], # Initialize empty LLM cache list for each chunk "llm_cache_list": [], # Initialize empty LLM cache list for each chunk
} }
for dp in self.chunking_func( for dp in chunking_result
self.tokenizer,
content,
split_by_character,
split_by_character_only,
self.chunk_overlap_token_size,
self.chunk_token_size,
)
} }
if not chunks: if not chunks:
@ -1893,9 +1915,14 @@ class LightRAG:
if task and not task.done(): if task and not task.done():
task.cancel() task.cancel()
# Persistent llm cache # Persistent llm cache with error handling
if self.llm_response_cache: if self.llm_response_cache:
await self.llm_response_cache.index_done_callback() try:
await self.llm_response_cache.index_done_callback()
except Exception as persist_error:
logger.error(
f"Failed to persist LLM cache: {persist_error}"
)
# Record processing end time for failed case # Record processing end time for failed case
processing_end_time = int(time.time()) processing_end_time = int(time.time())
@ -2015,9 +2042,14 @@ class LightRAG:
error_msg error_msg
) )
# Persistent llm cache # Persistent llm cache with error handling
if self.llm_response_cache: if self.llm_response_cache:
await self.llm_response_cache.index_done_callback() try:
await self.llm_response_cache.index_done_callback()
except Exception as persist_error:
logger.error(
f"Failed to persist LLM cache: {persist_error}"
)
# Record processing end time for failed case # Record processing end time for failed case
processing_end_time = int(time.time()) processing_end_time = int(time.time())
@ -2924,22 +2956,12 @@ class LightRAG:
doc_llm_cache_ids: list[str] = [] doc_llm_cache_ids: list[str] = []
# Get pipeline status shared data and lock for status updates # Get pipeline status shared data and lock for status updates
# Step 1: Get workspace pipeline_status = await get_namespace_data(
workspace = self.workspace "pipeline_status", workspace=self.workspace
)
# Step 2: Construct namespace (following GraphDB pattern) pipeline_status_lock = get_namespace_lock(
namespace = f"{workspace}:pipeline" if workspace else "pipeline_status" "pipeline_status", workspace=self.workspace
# Step 3: Ensure initialization (on first access)
await initialize_pipeline_status(workspace)
# Step 4: Get lock
pipeline_status_lock = get_storage_keyed_lock(
keys="status", namespace=namespace, enable_logging=False
) )
# Step 5: Get data
pipeline_status = await get_namespace_data(namespace)
async with pipeline_status_lock: async with pipeline_status_lock:
log_message = f"Starting deletion process for document {doc_id}" log_message = f"Starting deletion process for document {doc_id}"
@ -3172,6 +3194,9 @@ class LightRAG:
] ]
if not existing_sources: if not existing_sources:
# No chunk references means this entity should be deleted
entities_to_delete.add(node_label)
entity_chunk_updates[node_label] = []
continue continue
remaining_sources = subtract_source_ids(existing_sources, chunk_ids) remaining_sources = subtract_source_ids(existing_sources, chunk_ids)
@ -3193,6 +3218,7 @@ class LightRAG:
# Process relationships # Process relationships
for edge_data in affected_edges: for edge_data in affected_edges:
# source target is not in normalize order in graph db property
src = edge_data.get("source") src = edge_data.get("source")
tgt = edge_data.get("target") tgt = edge_data.get("target")
@ -3229,6 +3255,9 @@ class LightRAG:
] ]
if not existing_sources: if not existing_sources:
# No chunk references means this relationship should be deleted
relationships_to_delete.add(edge_tuple)
relation_chunk_updates[edge_tuple] = []
continue continue
remaining_sources = subtract_source_ids(existing_sources, chunk_ids) remaining_sources = subtract_source_ids(existing_sources, chunk_ids)
@ -3254,38 +3283,31 @@ class LightRAG:
if entity_chunk_updates and self.entity_chunks: if entity_chunk_updates and self.entity_chunks:
entity_upsert_payload = {} entity_upsert_payload = {}
entity_delete_ids: set[str] = set()
for entity_name, remaining in entity_chunk_updates.items(): for entity_name, remaining in entity_chunk_updates.items():
if not remaining: if not remaining:
entity_delete_ids.add(entity_name) # Empty entities are deleted alongside graph nodes later
else: continue
entity_upsert_payload[entity_name] = { entity_upsert_payload[entity_name] = {
"chunk_ids": remaining, "chunk_ids": remaining,
"count": len(remaining), "count": len(remaining),
"updated_at": current_time, "updated_at": current_time,
} }
if entity_delete_ids:
await self.entity_chunks.delete(list(entity_delete_ids))
if entity_upsert_payload: if entity_upsert_payload:
await self.entity_chunks.upsert(entity_upsert_payload) await self.entity_chunks.upsert(entity_upsert_payload)
if relation_chunk_updates and self.relation_chunks: if relation_chunk_updates and self.relation_chunks:
relation_upsert_payload = {} relation_upsert_payload = {}
relation_delete_ids: set[str] = set()
for edge_tuple, remaining in relation_chunk_updates.items(): for edge_tuple, remaining in relation_chunk_updates.items():
storage_key = make_relation_chunk_key(*edge_tuple)
if not remaining: if not remaining:
relation_delete_ids.add(storage_key) # Empty relations are deleted alongside graph edges later
else: continue
relation_upsert_payload[storage_key] = { storage_key = make_relation_chunk_key(*edge_tuple)
"chunk_ids": remaining, relation_upsert_payload[storage_key] = {
"count": len(remaining), "chunk_ids": remaining,
"updated_at": current_time, "count": len(remaining),
} "updated_at": current_time,
}
if relation_delete_ids:
await self.relation_chunks.delete(list(relation_delete_ids))
if relation_upsert_payload: if relation_upsert_payload:
await self.relation_chunks.upsert(relation_upsert_payload) await self.relation_chunks.upsert(relation_upsert_payload)
@ -3293,56 +3315,111 @@ class LightRAG:
logger.error(f"Failed to process graph analysis results: {e}") logger.error(f"Failed to process graph analysis results: {e}")
raise Exception(f"Failed to process graph dependencies: {e}") from e raise Exception(f"Failed to process graph dependencies: {e}") from e
# Use graph database lock to prevent dirty read # Data integrity is ensured by allowing only one process to hold pipeline at a timeno graph db lock is needed anymore)
graph_db_lock = get_graph_db_lock(enable_logging=False)
async with graph_db_lock:
# 5. Delete chunks from storage
if chunk_ids:
try:
await self.chunks_vdb.delete(chunk_ids)
await self.text_chunks.delete(chunk_ids)
async with pipeline_status_lock: # 5. Delete chunks from storage
log_message = f"Successfully deleted {len(chunk_ids)} chunks from storage" if chunk_ids:
logger.info(log_message) try:
pipeline_status["latest_message"] = log_message await self.chunks_vdb.delete(chunk_ids)
pipeline_status["history_messages"].append(log_message) await self.text_chunks.delete(chunk_ids)
except Exception as e: async with pipeline_status_lock:
logger.error(f"Failed to delete chunks: {e}") log_message = (
raise Exception(f"Failed to delete document chunks: {e}") from e f"Successfully deleted {len(chunk_ids)} chunks from storage"
)
logger.info(log_message)
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
# 6. Delete entities that have no remaining sources except Exception as e:
if entities_to_delete: logger.error(f"Failed to delete chunks: {e}")
try: raise Exception(f"Failed to delete document chunks: {e}") from e
# Delete from vector database
entity_vdb_ids = [ # 6. Delete relationships that have no remaining sources
compute_mdhash_id(entity, prefix="ent-") if relationships_to_delete:
for entity in entities_to_delete try:
# Delete from relation vdb
rel_ids_to_delete = []
for src, tgt in relationships_to_delete:
rel_ids_to_delete.extend(
[
compute_mdhash_id(src + tgt, prefix="rel-"),
compute_mdhash_id(tgt + src, prefix="rel-"),
]
)
await self.relationships_vdb.delete(rel_ids_to_delete)
# Delete from graph
await self.chunk_entity_relation_graph.remove_edges(
list(relationships_to_delete)
)
# Delete from relation_chunks storage
if self.relation_chunks:
relation_storage_keys = [
make_relation_chunk_key(src, tgt)
for src, tgt in relationships_to_delete
] ]
await self.entities_vdb.delete(entity_vdb_ids) await self.relation_chunks.delete(relation_storage_keys)
# Delete from graph async with pipeline_status_lock:
await self.chunk_entity_relation_graph.remove_nodes( log_message = f"Successfully deleted {len(relationships_to_delete)} relations"
logger.info(log_message)
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
except Exception as e:
logger.error(f"Failed to delete relationships: {e}")
raise Exception(f"Failed to delete relationships: {e}") from e
# 7. Delete entities that have no remaining sources
if entities_to_delete:
try:
# Batch get all edges for entities to avoid N+1 query problem
nodes_edges_dict = (
await self.chunk_entity_relation_graph.get_nodes_edges_batch(
list(entities_to_delete) list(entities_to_delete)
) )
)
async with pipeline_status_lock: # Debug: Check and log all edges before deleting nodes
log_message = f"Successfully deleted {len(entities_to_delete)} entities" edges_to_delete = set()
logger.info(log_message) edges_still_exist = 0
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
except Exception as e: for entity, edges in nodes_edges_dict.items():
logger.error(f"Failed to delete entities: {e}") if edges:
raise Exception(f"Failed to delete entities: {e}") from e for src, tgt in edges:
# Normalize edge representation (sorted for consistency)
edge_tuple = tuple(sorted((src, tgt)))
edges_to_delete.add(edge_tuple)
# 7. Delete relationships that have no remaining sources if (
if relationships_to_delete: src in entities_to_delete
try: and tgt in entities_to_delete
# Delete from vector database ):
logger.warning(
f"Edge still exists: {src} <-> {tgt}"
)
elif src in entities_to_delete:
logger.warning(
f"Edge still exists: {src} --> {tgt}"
)
else:
logger.warning(
f"Edge still exists: {src} <-- {tgt}"
)
edges_still_exist += 1
if edges_still_exist:
logger.warning(
f"⚠️ {edges_still_exist} entities still has edges before deletion"
)
# Clean residual edges from VDB and storage before deleting nodes
if edges_to_delete:
# Delete from relationships_vdb
rel_ids_to_delete = [] rel_ids_to_delete = []
for src, tgt in relationships_to_delete: for src, tgt in edges_to_delete:
rel_ids_to_delete.extend( rel_ids_to_delete.extend(
[ [
compute_mdhash_id(src + tgt, prefix="rel-"), compute_mdhash_id(src + tgt, prefix="rel-"),
@ -3351,28 +3428,53 @@ class LightRAG:
) )
await self.relationships_vdb.delete(rel_ids_to_delete) await self.relationships_vdb.delete(rel_ids_to_delete)
# Delete from graph # Delete from relation_chunks storage
await self.chunk_entity_relation_graph.remove_edges( if self.relation_chunks:
list(relationships_to_delete) relation_storage_keys = [
make_relation_chunk_key(src, tgt)
for src, tgt in edges_to_delete
]
await self.relation_chunks.delete(relation_storage_keys)
logger.info(
f"Cleaned {len(edges_to_delete)} residual edges from VDB and chunk-tracking storage"
) )
async with pipeline_status_lock: # Delete from graph (edges will be auto-deleted with nodes)
log_message = f"Successfully deleted {len(relationships_to_delete)} relations" await self.chunk_entity_relation_graph.remove_nodes(
logger.info(log_message) list(entities_to_delete)
pipeline_status["latest_message"] = log_message )
pipeline_status["history_messages"].append(log_message)
except Exception as e: # Delete from vector vdb
logger.error(f"Failed to delete relationships: {e}") entity_vdb_ids = [
raise Exception(f"Failed to delete relationships: {e}") from e compute_mdhash_id(entity, prefix="ent-")
for entity in entities_to_delete
]
await self.entities_vdb.delete(entity_vdb_ids)
# Persist changes to graph database before releasing graph database lock # Delete from entity_chunks storage
await self._insert_done() if self.entity_chunks:
await self.entity_chunks.delete(list(entities_to_delete))
async with pipeline_status_lock:
log_message = (
f"Successfully deleted {len(entities_to_delete)} entities"
)
logger.info(log_message)
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
except Exception as e:
logger.error(f"Failed to delete entities: {e}")
raise Exception(f"Failed to delete entities: {e}") from e
# Persist changes to graph database before entity and relationship rebuild
await self._insert_done()
# 8. Rebuild entities and relationships from remaining chunks # 8. Rebuild entities and relationships from remaining chunks
if entities_to_rebuild or relationships_to_rebuild: if entities_to_rebuild or relationships_to_rebuild:
try: try:
await _rebuild_knowledge_from_chunks( await rebuild_knowledge_from_chunks(
entities_to_rebuild=entities_to_rebuild, entities_to_rebuild=entities_to_rebuild,
relationships_to_rebuild=relationships_to_rebuild, relationships_to_rebuild=relationships_to_rebuild,
knowledge_graph_inst=self.chunk_entity_relation_graph, knowledge_graph_inst=self.chunk_entity_relation_graph,
@ -3590,16 +3692,22 @@ class LightRAG:
) )
async def aedit_entity( async def aedit_entity(
self, entity_name: str, updated_data: dict[str, str], allow_rename: bool = True self,
entity_name: str,
updated_data: dict[str, str],
allow_rename: bool = True,
allow_merge: bool = False,
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Asynchronously edit entity information. """Asynchronously edit entity information.
Updates entity information in the knowledge graph and re-embeds the entity in the vector database. Updates entity information in the knowledge graph and re-embeds the entity in the vector database.
Also synchronizes entity_chunks_storage and relation_chunks_storage to track chunk references.
Args: Args:
entity_name: Name of the entity to edit entity_name: Name of the entity to edit
updated_data: Dictionary containing updated attributes, e.g. {"description": "new description", "entity_type": "new type"} updated_data: Dictionary containing updated attributes, e.g. {"description": "new description", "entity_type": "new type"}
allow_rename: Whether to allow entity renaming, defaults to True allow_rename: Whether to allow entity renaming, defaults to True
allow_merge: Whether to merge into an existing entity when renaming to an existing name
Returns: Returns:
Dictionary containing updated entity information Dictionary containing updated entity information
@ -3613,14 +3721,21 @@ class LightRAG:
entity_name, entity_name,
updated_data, updated_data,
allow_rename, allow_rename,
allow_merge,
self.entity_chunks,
self.relation_chunks,
) )
def edit_entity( def edit_entity(
self, entity_name: str, updated_data: dict[str, str], allow_rename: bool = True self,
entity_name: str,
updated_data: dict[str, str],
allow_rename: bool = True,
allow_merge: bool = False,
) -> dict[str, Any]: ) -> dict[str, Any]:
loop = always_get_an_event_loop() loop = always_get_an_event_loop()
return loop.run_until_complete( return loop.run_until_complete(
self.aedit_entity(entity_name, updated_data, allow_rename) self.aedit_entity(entity_name, updated_data, allow_rename, allow_merge)
) )
async def aedit_relation( async def aedit_relation(
@ -3629,6 +3744,7 @@ class LightRAG:
"""Asynchronously edit relation information. """Asynchronously edit relation information.
Updates relation (edge) information in the knowledge graph and re-embeds the relation in the vector database. Updates relation (edge) information in the knowledge graph and re-embeds the relation in the vector database.
Also synchronizes the relation_chunks_storage to track which chunks reference this relation.
Args: Args:
source_entity: Name of the source entity source_entity: Name of the source entity
@ -3647,6 +3763,7 @@ class LightRAG:
source_entity, source_entity,
target_entity, target_entity,
updated_data, updated_data,
self.relation_chunks,
) )
def edit_relation( def edit_relation(
@ -3758,6 +3875,8 @@ class LightRAG:
target_entity, target_entity,
merge_strategy, merge_strategy,
target_entity_data, target_entity_data,
self.entity_chunks,
self.relation_chunks,
) )
def merge_entities( def merge_entities(