Fix workspace isolation for pipeline status across all operations
- Fix final_namespace error in get_namespace_data()
- Fix get_workspace_from_request return type
- Add workspace param to pipeline status calls
(cherry picked from commit 52c812b9a0)
This commit is contained in:
parent
fe1576943f
commit
dfab175c16
4 changed files with 573 additions and 367 deletions
|
|
@ -56,6 +56,8 @@ from lightrag.api.routers.ollama_api import OllamaAPI
|
|||
from lightrag.utils import logger, set_verbose_debug
|
||||
from lightrag.kg.shared_storage import (
|
||||
get_namespace_data,
|
||||
get_default_workspace,
|
||||
# set_default_workspace,
|
||||
initialize_pipeline_status,
|
||||
cleanup_keyed_lock,
|
||||
finalize_share_data,
|
||||
|
|
@ -350,8 +352,9 @@ def create_app(args):
|
|||
|
||||
try:
|
||||
# Initialize database connections
|
||||
# set_default_workspace(rag.workspace) # comment this line to test auto default workspace setting in initialize_storages
|
||||
await rag.initialize_storages()
|
||||
await initialize_pipeline_status()
|
||||
await initialize_pipeline_status() # with default workspace
|
||||
|
||||
# Data migration regardless of storage implementation
|
||||
await rag.check_and_migrate_data()
|
||||
|
|
@ -452,7 +455,7 @@ def create_app(args):
|
|||
# Create combined auth dependency for all endpoints
|
||||
combined_auth = get_combined_auth_dependency(api_key)
|
||||
|
||||
def get_workspace_from_request(request: Request) -> str:
|
||||
def get_workspace_from_request(request: Request) -> str | None:
|
||||
"""
|
||||
Extract workspace from HTTP request header or use default.
|
||||
|
||||
|
|
@ -469,9 +472,8 @@ def create_app(args):
|
|||
# Check custom header first
|
||||
workspace = request.headers.get("LIGHTRAG-WORKSPACE", "").strip()
|
||||
|
||||
# Fall back to server default if header not provided
|
||||
if not workspace:
|
||||
workspace = args.workspace
|
||||
workspace = None
|
||||
|
||||
return workspace
|
||||
|
||||
|
|
@ -641,33 +643,108 @@ def create_app(args):
|
|||
|
||||
def create_optimized_embedding_function(
|
||||
config_cache: LLMConfigCache, binding, model, host, api_key, args
|
||||
):
|
||||
) -> EmbeddingFunc:
|
||||
"""
|
||||
Create optimized embedding function with pre-processed configuration for applicable bindings.
|
||||
Uses lazy imports for all bindings and avoids repeated configuration parsing.
|
||||
Create optimized embedding function and return an EmbeddingFunc instance
|
||||
with proper max_token_size inheritance from provider defaults.
|
||||
|
||||
This function:
|
||||
1. Imports the provider embedding function
|
||||
2. Extracts max_token_size and embedding_dim from provider if it's an EmbeddingFunc
|
||||
3. Creates an optimized wrapper that calls the underlying function directly (avoiding double-wrapping)
|
||||
4. Returns a properly configured EmbeddingFunc instance
|
||||
"""
|
||||
|
||||
# Step 1: Import provider function and extract default attributes
|
||||
provider_func = None
|
||||
provider_max_token_size = None
|
||||
provider_embedding_dim = None
|
||||
|
||||
try:
|
||||
if binding == "openai":
|
||||
from lightrag.llm.openai import openai_embed
|
||||
|
||||
provider_func = openai_embed
|
||||
elif binding == "ollama":
|
||||
from lightrag.llm.ollama import ollama_embed
|
||||
|
||||
provider_func = ollama_embed
|
||||
elif binding == "gemini":
|
||||
from lightrag.llm.gemini import gemini_embed
|
||||
|
||||
provider_func = gemini_embed
|
||||
elif binding == "jina":
|
||||
from lightrag.llm.jina import jina_embed
|
||||
|
||||
provider_func = jina_embed
|
||||
elif binding == "azure_openai":
|
||||
from lightrag.llm.azure_openai import azure_openai_embed
|
||||
|
||||
provider_func = azure_openai_embed
|
||||
elif binding == "aws_bedrock":
|
||||
from lightrag.llm.bedrock import bedrock_embed
|
||||
|
||||
provider_func = bedrock_embed
|
||||
elif binding == "lollms":
|
||||
from lightrag.llm.lollms import lollms_embed
|
||||
|
||||
provider_func = lollms_embed
|
||||
|
||||
# Extract attributes if provider is an EmbeddingFunc
|
||||
if provider_func and isinstance(provider_func, EmbeddingFunc):
|
||||
provider_max_token_size = provider_func.max_token_size
|
||||
provider_embedding_dim = provider_func.embedding_dim
|
||||
logger.debug(
|
||||
f"Extracted from {binding} provider: "
|
||||
f"max_token_size={provider_max_token_size}, "
|
||||
f"embedding_dim={provider_embedding_dim}"
|
||||
)
|
||||
except ImportError as e:
|
||||
logger.warning(f"Could not import provider function for {binding}: {e}")
|
||||
|
||||
# Step 2: Apply priority (user config > provider default)
|
||||
# For max_token_size: explicit env var > provider default > None
|
||||
final_max_token_size = args.embedding_token_limit or provider_max_token_size
|
||||
# For embedding_dim: user config (always has value) takes priority
|
||||
# Only use provider default if user config is explicitly None (which shouldn't happen)
|
||||
final_embedding_dim = (
|
||||
args.embedding_dim if args.embedding_dim else provider_embedding_dim
|
||||
)
|
||||
|
||||
# Step 3: Create optimized embedding function (calls underlying function directly)
|
||||
async def optimized_embedding_function(texts, embedding_dim=None):
|
||||
try:
|
||||
if binding == "lollms":
|
||||
from lightrag.llm.lollms import lollms_embed
|
||||
|
||||
return await lollms_embed(
|
||||
# Get real function, skip EmbeddingFunc wrapper if present
|
||||
actual_func = (
|
||||
lollms_embed.func
|
||||
if isinstance(lollms_embed, EmbeddingFunc)
|
||||
else lollms_embed
|
||||
)
|
||||
return await actual_func(
|
||||
texts, embed_model=model, host=host, api_key=api_key
|
||||
)
|
||||
elif binding == "ollama":
|
||||
from lightrag.llm.ollama import ollama_embed
|
||||
|
||||
# Use pre-processed configuration if available, otherwise fallback to dynamic parsing
|
||||
# Get real function, skip EmbeddingFunc wrapper if present
|
||||
actual_func = (
|
||||
ollama_embed.func
|
||||
if isinstance(ollama_embed, EmbeddingFunc)
|
||||
else ollama_embed
|
||||
)
|
||||
|
||||
# Use pre-processed configuration if available
|
||||
if config_cache.ollama_embedding_options is not None:
|
||||
ollama_options = config_cache.ollama_embedding_options
|
||||
else:
|
||||
# Fallback for cases where config cache wasn't initialized properly
|
||||
from lightrag.llm.binding_options import OllamaEmbeddingOptions
|
||||
|
||||
ollama_options = OllamaEmbeddingOptions.options_dict(args)
|
||||
|
||||
return await ollama_embed(
|
||||
return await actual_func(
|
||||
texts,
|
||||
embed_model=model,
|
||||
host=host,
|
||||
|
|
@ -677,15 +754,30 @@ def create_app(args):
|
|||
elif binding == "azure_openai":
|
||||
from lightrag.llm.azure_openai import azure_openai_embed
|
||||
|
||||
return await azure_openai_embed(texts, model=model, api_key=api_key)
|
||||
actual_func = (
|
||||
azure_openai_embed.func
|
||||
if isinstance(azure_openai_embed, EmbeddingFunc)
|
||||
else azure_openai_embed
|
||||
)
|
||||
return await actual_func(texts, model=model, api_key=api_key)
|
||||
elif binding == "aws_bedrock":
|
||||
from lightrag.llm.bedrock import bedrock_embed
|
||||
|
||||
return await bedrock_embed(texts, model=model)
|
||||
actual_func = (
|
||||
bedrock_embed.func
|
||||
if isinstance(bedrock_embed, EmbeddingFunc)
|
||||
else bedrock_embed
|
||||
)
|
||||
return await actual_func(texts, model=model)
|
||||
elif binding == "jina":
|
||||
from lightrag.llm.jina import jina_embed
|
||||
|
||||
return await jina_embed(
|
||||
actual_func = (
|
||||
jina_embed.func
|
||||
if isinstance(jina_embed, EmbeddingFunc)
|
||||
else jina_embed
|
||||
)
|
||||
return await actual_func(
|
||||
texts,
|
||||
embedding_dim=embedding_dim,
|
||||
base_url=host,
|
||||
|
|
@ -694,16 +786,21 @@ def create_app(args):
|
|||
elif binding == "gemini":
|
||||
from lightrag.llm.gemini import gemini_embed
|
||||
|
||||
# Use pre-processed configuration if available, otherwise fallback to dynamic parsing
|
||||
actual_func = (
|
||||
gemini_embed.func
|
||||
if isinstance(gemini_embed, EmbeddingFunc)
|
||||
else gemini_embed
|
||||
)
|
||||
|
||||
# Use pre-processed configuration if available
|
||||
if config_cache.gemini_embedding_options is not None:
|
||||
gemini_options = config_cache.gemini_embedding_options
|
||||
else:
|
||||
# Fallback for cases where config cache wasn't initialized properly
|
||||
from lightrag.llm.binding_options import GeminiEmbeddingOptions
|
||||
|
||||
gemini_options = GeminiEmbeddingOptions.options_dict(args)
|
||||
|
||||
return await gemini_embed(
|
||||
return await actual_func(
|
||||
texts,
|
||||
model=model,
|
||||
base_url=host,
|
||||
|
|
@ -714,7 +811,12 @@ def create_app(args):
|
|||
else: # openai and compatible
|
||||
from lightrag.llm.openai import openai_embed
|
||||
|
||||
return await openai_embed(
|
||||
actual_func = (
|
||||
openai_embed.func
|
||||
if isinstance(openai_embed, EmbeddingFunc)
|
||||
else openai_embed
|
||||
)
|
||||
return await actual_func(
|
||||
texts,
|
||||
model=model,
|
||||
base_url=host,
|
||||
|
|
@ -724,7 +826,21 @@ def create_app(args):
|
|||
except ImportError as e:
|
||||
raise Exception(f"Failed to import {binding} embedding: {e}")
|
||||
|
||||
return optimized_embedding_function
|
||||
# Step 4: Wrap in EmbeddingFunc and return
|
||||
embedding_func_instance = EmbeddingFunc(
|
||||
embedding_dim=final_embedding_dim,
|
||||
func=optimized_embedding_function,
|
||||
max_token_size=final_max_token_size,
|
||||
send_dimensions=False, # Will be set later based on binding requirements
|
||||
)
|
||||
|
||||
# Log final embedding configuration
|
||||
logger.info(
|
||||
f"Embedding config: binding={binding} model={model} "
|
||||
f"embedding_dim={final_embedding_dim} max_token_size={final_max_token_size}"
|
||||
)
|
||||
|
||||
return embedding_func_instance
|
||||
|
||||
llm_timeout = get_env_value("LLM_TIMEOUT", DEFAULT_LLM_TIMEOUT, int)
|
||||
embedding_timeout = get_env_value(
|
||||
|
|
@ -758,25 +874,24 @@ def create_app(args):
|
|||
**kwargs,
|
||||
)
|
||||
|
||||
# Create embedding function with optimized configuration
|
||||
# Create embedding function with optimized configuration and max_token_size inheritance
|
||||
import inspect
|
||||
|
||||
# Create the optimized embedding function
|
||||
optimized_embedding_func = create_optimized_embedding_function(
|
||||
# Create the EmbeddingFunc instance (now returns complete EmbeddingFunc with max_token_size)
|
||||
embedding_func = create_optimized_embedding_function(
|
||||
config_cache=config_cache,
|
||||
binding=args.embedding_binding,
|
||||
model=args.embedding_model,
|
||||
host=args.embedding_binding_host,
|
||||
api_key=args.embedding_binding_api_key,
|
||||
args=args, # Pass args object for fallback option generation
|
||||
args=args,
|
||||
)
|
||||
|
||||
# Get embedding_send_dim from centralized configuration
|
||||
embedding_send_dim = args.embedding_send_dim
|
||||
|
||||
# Check if the function signature has embedding_dim parameter
|
||||
# Note: Since optimized_embedding_func is an async function, inspect its signature
|
||||
sig = inspect.signature(optimized_embedding_func)
|
||||
# Check if the underlying function signature has embedding_dim parameter
|
||||
sig = inspect.signature(embedding_func.func)
|
||||
has_embedding_dim_param = "embedding_dim" in sig.parameters
|
||||
|
||||
# Determine send_dimensions value based on binding type
|
||||
|
|
@ -794,18 +909,27 @@ def create_app(args):
|
|||
else:
|
||||
dimension_control = "by not hasparam"
|
||||
|
||||
# Set send_dimensions on the EmbeddingFunc instance
|
||||
embedding_func.send_dimensions = send_dimensions
|
||||
|
||||
logger.info(
|
||||
f"Send embedding dimension: {send_dimensions} {dimension_control} "
|
||||
f"(dimensions={args.embedding_dim}, has_param={has_embedding_dim_param}, "
|
||||
f"(dimensions={embedding_func.embedding_dim}, has_param={has_embedding_dim_param}, "
|
||||
f"binding={args.embedding_binding})"
|
||||
)
|
||||
|
||||
# Create EmbeddingFunc with send_dimensions attribute
|
||||
embedding_func = EmbeddingFunc(
|
||||
embedding_dim=args.embedding_dim,
|
||||
func=optimized_embedding_func,
|
||||
send_dimensions=send_dimensions,
|
||||
)
|
||||
# Log max_token_size source
|
||||
if embedding_func.max_token_size:
|
||||
source = (
|
||||
"env variable"
|
||||
if args.embedding_token_limit
|
||||
else f"{args.embedding_binding} provider default"
|
||||
)
|
||||
logger.info(
|
||||
f"Embedding max_token_size: {embedding_func.max_token_size} (from {source})"
|
||||
)
|
||||
else:
|
||||
logger.info("Embedding max_token_size: not set (90% token warning disabled)")
|
||||
|
||||
# Configure rerank function based on args.rerank_bindingparameter
|
||||
rerank_model_func = None
|
||||
|
|
@ -1017,14 +1141,13 @@ def create_app(args):
|
|||
async def get_status(request: Request):
|
||||
"""Get current system status"""
|
||||
try:
|
||||
# Extract workspace from request header or use default
|
||||
workspace = get_workspace_from_request(request)
|
||||
|
||||
# Construct namespace (following GraphDB pattern)
|
||||
namespace = f"{workspace}:pipeline" if workspace else "pipeline_status"
|
||||
|
||||
# Get workspace-specific pipeline status
|
||||
pipeline_status = await get_namespace_data(namespace)
|
||||
default_workspace = get_default_workspace()
|
||||
if workspace is None:
|
||||
workspace = default_workspace
|
||||
pipeline_status = await get_namespace_data(
|
||||
"pipeline_status", workspace=workspace
|
||||
)
|
||||
|
||||
if not auth_configured:
|
||||
auth_mode = "disabled"
|
||||
|
|
@ -1055,8 +1178,7 @@ def create_app(args):
|
|||
"vector_storage": args.vector_storage,
|
||||
"enable_llm_cache_for_extract": args.enable_llm_cache_for_extract,
|
||||
"enable_llm_cache": args.enable_llm_cache,
|
||||
"workspace": workspace,
|
||||
"default_workspace": args.workspace,
|
||||
"workspace": default_workspace,
|
||||
"max_graph_nodes": args.max_graph_nodes,
|
||||
# Rerank configuration
|
||||
"enable_rerank": rerank_model_func is not None,
|
||||
|
|
@ -1245,6 +1367,12 @@ def check_and_install_dependencies():
|
|||
|
||||
|
||||
def main():
|
||||
# Explicitly initialize configuration for clarity
|
||||
# (The proxy will auto-initialize anyway, but this makes intent clear)
|
||||
from .config import initialize_config
|
||||
|
||||
initialize_config()
|
||||
|
||||
# Check if running under Gunicorn
|
||||
if "GUNICORN_CMD_ARGS" in os.environ:
|
||||
# If started with Gunicorn, return directly as Gunicorn will call get_application
|
||||
|
|
|
|||
|
|
@ -1641,26 +1641,15 @@ async def background_delete_documents(
|
|||
"""Background task to delete multiple documents"""
|
||||
from lightrag.kg.shared_storage import (
|
||||
get_namespace_data,
|
||||
get_storage_keyed_lock,
|
||||
initialize_pipeline_status,
|
||||
get_namespace_lock,
|
||||
)
|
||||
|
||||
# Step 1: Get workspace
|
||||
workspace = rag.workspace
|
||||
|
||||
# Step 2: Construct namespace
|
||||
namespace = f"{workspace}:pipeline" if workspace else "pipeline_status"
|
||||
|
||||
# Step 3: Ensure initialization
|
||||
await initialize_pipeline_status(workspace)
|
||||
|
||||
# Step 4: Get lock
|
||||
pipeline_status_lock = get_storage_keyed_lock(
|
||||
keys="status", namespace=namespace, enable_logging=False
|
||||
pipeline_status = await get_namespace_data(
|
||||
"pipeline_status", workspace=rag.workspace
|
||||
)
|
||||
pipeline_status_lock = get_namespace_lock(
|
||||
"pipeline_status", workspace=rag.workspace
|
||||
)
|
||||
|
||||
# Step 5: Get data
|
||||
pipeline_status = await get_namespace_data(namespace)
|
||||
|
||||
total_docs = len(doc_ids)
|
||||
successful_deletions = []
|
||||
|
|
@ -2149,27 +2138,16 @@ def create_document_routes(
|
|||
"""
|
||||
from lightrag.kg.shared_storage import (
|
||||
get_namespace_data,
|
||||
get_storage_keyed_lock,
|
||||
initialize_pipeline_status,
|
||||
get_namespace_lock,
|
||||
)
|
||||
|
||||
# Get pipeline status and lock
|
||||
# Step 1: Get workspace
|
||||
workspace = rag.workspace
|
||||
|
||||
# Step 2: Construct namespace
|
||||
namespace = f"{workspace}:pipeline" if workspace else "pipeline_status"
|
||||
|
||||
# Step 3: Ensure initialization
|
||||
await initialize_pipeline_status(workspace)
|
||||
|
||||
# Step 4: Get lock
|
||||
pipeline_status_lock = get_storage_keyed_lock(
|
||||
keys="status", namespace=namespace, enable_logging=False
|
||||
pipeline_status = await get_namespace_data(
|
||||
"pipeline_status", workspace=rag.workspace
|
||||
)
|
||||
pipeline_status_lock = get_namespace_lock(
|
||||
"pipeline_status", workspace=rag.workspace
|
||||
)
|
||||
|
||||
# Step 5: Get data
|
||||
pipeline_status = await get_namespace_data(namespace)
|
||||
|
||||
# Check and set status with lock
|
||||
async with pipeline_status_lock:
|
||||
|
|
@ -2360,15 +2338,16 @@ def create_document_routes(
|
|||
try:
|
||||
from lightrag.kg.shared_storage import (
|
||||
get_namespace_data,
|
||||
get_namespace_lock,
|
||||
get_all_update_flags_status,
|
||||
initialize_pipeline_status,
|
||||
)
|
||||
|
||||
# Get workspace-specific pipeline status
|
||||
workspace = rag.workspace
|
||||
namespace = f"{workspace}:pipeline" if workspace else "pipeline_status"
|
||||
await initialize_pipeline_status(workspace)
|
||||
pipeline_status = await get_namespace_data(namespace)
|
||||
pipeline_status = await get_namespace_data(
|
||||
"pipeline_status", workspace=rag.workspace
|
||||
)
|
||||
pipeline_status_lock = get_namespace_lock(
|
||||
"pipeline_status", workspace=rag.workspace
|
||||
)
|
||||
|
||||
# Get update flags status for all namespaces
|
||||
update_status = await get_all_update_flags_status()
|
||||
|
|
@ -2385,8 +2364,9 @@ def create_document_routes(
|
|||
processed_flags.append(bool(flag))
|
||||
processed_update_status[namespace] = processed_flags
|
||||
|
||||
# Convert to regular dict if it's a Manager.dict
|
||||
status_dict = dict(pipeline_status)
|
||||
async with pipeline_status_lock:
|
||||
# Convert to regular dict if it's a Manager.dict
|
||||
status_dict = dict(pipeline_status)
|
||||
|
||||
# Add processed update_status to the status dictionary
|
||||
status_dict["update_status"] = processed_update_status
|
||||
|
|
@ -2575,20 +2555,15 @@ def create_document_routes(
|
|||
try:
|
||||
from lightrag.kg.shared_storage import (
|
||||
get_namespace_data,
|
||||
get_storage_keyed_lock,
|
||||
initialize_pipeline_status,
|
||||
get_namespace_lock,
|
||||
)
|
||||
|
||||
# Get workspace-specific pipeline status
|
||||
workspace = rag.workspace
|
||||
namespace = f"{workspace}:pipeline" if workspace else "pipeline_status"
|
||||
await initialize_pipeline_status(workspace)
|
||||
|
||||
# Use workspace-aware lock to check busy flag
|
||||
pipeline_status_lock = get_storage_keyed_lock(
|
||||
keys="status", namespace=namespace, enable_logging=False
|
||||
pipeline_status = await get_namespace_data(
|
||||
"pipeline_status", workspace=rag.workspace
|
||||
)
|
||||
pipeline_status_lock = get_namespace_lock(
|
||||
"pipeline_status", workspace=rag.workspace
|
||||
)
|
||||
pipeline_status = await get_namespace_data(namespace)
|
||||
|
||||
# Check if pipeline is busy with proper lock
|
||||
async with pipeline_status_lock:
|
||||
|
|
@ -2993,26 +2968,15 @@ def create_document_routes(
|
|||
try:
|
||||
from lightrag.kg.shared_storage import (
|
||||
get_namespace_data,
|
||||
get_storage_keyed_lock,
|
||||
initialize_pipeline_status,
|
||||
get_namespace_lock,
|
||||
)
|
||||
|
||||
# Step 1: Get workspace
|
||||
workspace = rag.workspace
|
||||
|
||||
# Step 2: Construct namespace
|
||||
namespace = f"{workspace}:pipeline" if workspace else "pipeline_status"
|
||||
|
||||
# Step 3: Ensure initialization
|
||||
await initialize_pipeline_status(workspace)
|
||||
|
||||
# Step 4: Get lock
|
||||
pipeline_status_lock = get_storage_keyed_lock(
|
||||
keys="status", namespace=namespace, enable_logging=False
|
||||
pipeline_status = await get_namespace_data(
|
||||
"pipeline_status", workspace=rag.workspace
|
||||
)
|
||||
pipeline_status_lock = get_namespace_lock(
|
||||
"pipeline_status", workspace=rag.workspace
|
||||
)
|
||||
|
||||
# Step 5: Get data
|
||||
pipeline_status = await get_namespace_data(namespace)
|
||||
|
||||
async with pipeline_status_lock:
|
||||
if not pipeline_status.get("busy", False):
|
||||
|
|
|
|||
|
|
@ -84,10 +84,7 @@ _init_flags: Optional[Dict[str, bool]] = None # namespace -> initialized
|
|||
_update_flags: Optional[Dict[str, bool]] = None # namespace -> updated
|
||||
|
||||
# locks for mutex access
|
||||
_storage_lock: Optional[LockType] = None
|
||||
_internal_lock: Optional[LockType] = None
|
||||
_pipeline_status_lock: Optional[LockType] = None
|
||||
_graph_db_lock: Optional[LockType] = None
|
||||
_data_init_lock: Optional[LockType] = None
|
||||
# Manager for all keyed locks
|
||||
_storage_keyed_lock: Optional["KeyedUnifiedLock"] = None
|
||||
|
|
@ -98,6 +95,22 @@ _async_locks: Optional[Dict[str, asyncio.Lock]] = None
|
|||
_debug_n_locks_acquired: int = 0
|
||||
|
||||
|
||||
def get_final_namespace(namespace: str, workspace: str | None = None):
|
||||
global _default_workspace
|
||||
if workspace is None:
|
||||
workspace = _default_workspace
|
||||
|
||||
if workspace is None:
|
||||
direct_log(
|
||||
f"Error: Invoke namespace operation without workspace, pid={os.getpid()}",
|
||||
level="ERROR",
|
||||
)
|
||||
raise ValueError("Invoke namespace operation without workspace")
|
||||
|
||||
final_namespace = f"{workspace}:{namespace}" if workspace else f"{namespace}"
|
||||
return final_namespace
|
||||
|
||||
|
||||
def inc_debug_n_locks_acquired():
|
||||
global _debug_n_locks_acquired
|
||||
if DEBUG_LOCKS:
|
||||
|
|
@ -1056,40 +1069,10 @@ def get_internal_lock(enable_logging: bool = False) -> UnifiedLock:
|
|||
)
|
||||
|
||||
|
||||
def get_storage_lock(enable_logging: bool = False) -> UnifiedLock:
|
||||
"""return unified storage lock for data consistency"""
|
||||
async_lock = _async_locks.get("storage_lock") if _is_multiprocess else None
|
||||
return UnifiedLock(
|
||||
lock=_storage_lock,
|
||||
is_async=not _is_multiprocess,
|
||||
name="storage_lock",
|
||||
enable_logging=enable_logging,
|
||||
async_lock=async_lock,
|
||||
)
|
||||
|
||||
|
||||
def get_pipeline_status_lock(enable_logging: bool = False) -> UnifiedLock:
|
||||
"""return unified storage lock for data consistency"""
|
||||
async_lock = _async_locks.get("pipeline_status_lock") if _is_multiprocess else None
|
||||
return UnifiedLock(
|
||||
lock=_pipeline_status_lock,
|
||||
is_async=not _is_multiprocess,
|
||||
name="pipeline_status_lock",
|
||||
enable_logging=enable_logging,
|
||||
async_lock=async_lock,
|
||||
)
|
||||
|
||||
|
||||
def get_graph_db_lock(enable_logging: bool = False) -> UnifiedLock:
|
||||
"""return unified graph database lock for ensuring atomic operations"""
|
||||
async_lock = _async_locks.get("graph_db_lock") if _is_multiprocess else None
|
||||
return UnifiedLock(
|
||||
lock=_graph_db_lock,
|
||||
is_async=not _is_multiprocess,
|
||||
name="graph_db_lock",
|
||||
enable_logging=enable_logging,
|
||||
async_lock=async_lock,
|
||||
)
|
||||
# Workspace based storage_lock is implemented by get_storage_keyed_lock instead.
|
||||
# Workspace based pipeline_status_lock is implemented by get_storage_keyed_lock instead.
|
||||
# No need to implement graph_db_lock:
|
||||
# data integrity is ensured by entity level keyed-lock and allowing only one process to hold pipeline at a time.
|
||||
|
||||
|
||||
def get_storage_keyed_lock(
|
||||
|
|
@ -1193,14 +1176,11 @@ def initialize_share_data(workers: int = 1):
|
|||
_manager, \
|
||||
_workers, \
|
||||
_is_multiprocess, \
|
||||
_storage_lock, \
|
||||
_lock_registry, \
|
||||
_lock_registry_count, \
|
||||
_lock_cleanup_data, \
|
||||
_registry_guard, \
|
||||
_internal_lock, \
|
||||
_pipeline_status_lock, \
|
||||
_graph_db_lock, \
|
||||
_data_init_lock, \
|
||||
_shared_dicts, \
|
||||
_init_flags, \
|
||||
|
|
@ -1228,9 +1208,6 @@ def initialize_share_data(workers: int = 1):
|
|||
_lock_cleanup_data = _manager.dict()
|
||||
_registry_guard = _manager.RLock()
|
||||
_internal_lock = _manager.Lock()
|
||||
_storage_lock = _manager.Lock()
|
||||
_pipeline_status_lock = _manager.Lock()
|
||||
_graph_db_lock = _manager.Lock()
|
||||
_data_init_lock = _manager.Lock()
|
||||
_shared_dicts = _manager.dict()
|
||||
_init_flags = _manager.dict()
|
||||
|
|
@ -1241,8 +1218,6 @@ def initialize_share_data(workers: int = 1):
|
|||
# Initialize async locks for multiprocess mode
|
||||
_async_locks = {
|
||||
"internal_lock": asyncio.Lock(),
|
||||
"storage_lock": asyncio.Lock(),
|
||||
"pipeline_status_lock": asyncio.Lock(),
|
||||
"graph_db_lock": asyncio.Lock(),
|
||||
"data_init_lock": asyncio.Lock(),
|
||||
}
|
||||
|
|
@ -1253,9 +1228,6 @@ def initialize_share_data(workers: int = 1):
|
|||
else:
|
||||
_is_multiprocess = False
|
||||
_internal_lock = asyncio.Lock()
|
||||
_storage_lock = asyncio.Lock()
|
||||
_pipeline_status_lock = asyncio.Lock()
|
||||
_graph_db_lock = asyncio.Lock()
|
||||
_data_init_lock = asyncio.Lock()
|
||||
_shared_dicts = {}
|
||||
_init_flags = {}
|
||||
|
|
@ -1273,29 +1245,19 @@ def initialize_share_data(workers: int = 1):
|
|||
_initialized = True
|
||||
|
||||
|
||||
async def initialize_pipeline_status(workspace: str = ""):
|
||||
async def initialize_pipeline_status(workspace: str | None = None):
|
||||
"""
|
||||
Initialize pipeline namespace with default values.
|
||||
Initialize pipeline_status share data with default values.
|
||||
This function could be called before during FASTAPI lifespan for each worker.
|
||||
|
||||
Args:
|
||||
workspace: Optional workspace identifier for multi-tenant isolation.
|
||||
If empty string, uses the default workspace set by
|
||||
set_default_workspace(). If no default is set, uses
|
||||
global "pipeline_status" namespace.
|
||||
|
||||
This function is called during FASTAPI lifespan for each worker.
|
||||
workspace: Optional workspace identifier for pipeline_status of specific workspace.
|
||||
If None or empty string, uses the default workspace set by
|
||||
set_default_workspace().
|
||||
"""
|
||||
# Backward compatibility: use default workspace if not provided
|
||||
if not workspace:
|
||||
workspace = get_default_workspace()
|
||||
|
||||
# Construct namespace (following GraphDB pattern)
|
||||
if workspace:
|
||||
namespace = f"{workspace}:pipeline"
|
||||
else:
|
||||
namespace = "pipeline_status" # Global namespace for backward compatibility
|
||||
|
||||
pipeline_namespace = await get_namespace_data(namespace, first_init=True)
|
||||
pipeline_namespace = await get_namespace_data(
|
||||
"pipeline_status", first_init=True, workspace=workspace
|
||||
)
|
||||
|
||||
async with get_internal_lock():
|
||||
# Check if already initialized by checking for required fields
|
||||
|
|
@ -1318,12 +1280,14 @@ async def initialize_pipeline_status(workspace: str = ""):
|
|||
"history_messages": history_messages, # 使用共享列表对象
|
||||
}
|
||||
)
|
||||
|
||||
final_namespace = get_final_namespace("pipeline_status", workspace)
|
||||
direct_log(
|
||||
f"Process {os.getpid()} Pipeline namespace '{namespace}' initialized"
|
||||
f"Process {os.getpid()} Pipeline namespace '{final_namespace}' initialized"
|
||||
)
|
||||
|
||||
|
||||
async def get_update_flag(namespace: str):
|
||||
async def get_update_flag(namespace: str, workspace: str | None = None):
|
||||
"""
|
||||
Create a namespace's update flag for a workers.
|
||||
Returen the update flag to caller for referencing or reset.
|
||||
|
|
@ -1332,14 +1296,16 @@ async def get_update_flag(namespace: str):
|
|||
if _update_flags is None:
|
||||
raise ValueError("Try to create namespace before Shared-Data is initialized")
|
||||
|
||||
final_namespace = get_final_namespace(namespace, workspace)
|
||||
|
||||
async with get_internal_lock():
|
||||
if namespace not in _update_flags:
|
||||
if final_namespace not in _update_flags:
|
||||
if _is_multiprocess and _manager is not None:
|
||||
_update_flags[namespace] = _manager.list()
|
||||
_update_flags[final_namespace] = _manager.list()
|
||||
else:
|
||||
_update_flags[namespace] = []
|
||||
_update_flags[final_namespace] = []
|
||||
direct_log(
|
||||
f"Process {os.getpid()} initialized updated flags for namespace: [{namespace}]"
|
||||
f"Process {os.getpid()} initialized updated flags for namespace: [{final_namespace}]"
|
||||
)
|
||||
|
||||
if _is_multiprocess and _manager is not None:
|
||||
|
|
@ -1352,39 +1318,43 @@ async def get_update_flag(namespace: str):
|
|||
|
||||
new_update_flag = MutableBoolean(False)
|
||||
|
||||
_update_flags[namespace].append(new_update_flag)
|
||||
_update_flags[final_namespace].append(new_update_flag)
|
||||
return new_update_flag
|
||||
|
||||
|
||||
async def set_all_update_flags(namespace: str):
|
||||
async def set_all_update_flags(namespace: str, workspace: str | None = None):
|
||||
"""Set all update flag of namespace indicating all workers need to reload data from files"""
|
||||
global _update_flags
|
||||
if _update_flags is None:
|
||||
raise ValueError("Try to create namespace before Shared-Data is initialized")
|
||||
|
||||
final_namespace = get_final_namespace(namespace, workspace)
|
||||
|
||||
async with get_internal_lock():
|
||||
if namespace not in _update_flags:
|
||||
raise ValueError(f"Namespace {namespace} not found in update flags")
|
||||
if final_namespace not in _update_flags:
|
||||
raise ValueError(f"Namespace {final_namespace} not found in update flags")
|
||||
# Update flags for both modes
|
||||
for i in range(len(_update_flags[namespace])):
|
||||
_update_flags[namespace][i].value = True
|
||||
for i in range(len(_update_flags[final_namespace])):
|
||||
_update_flags[final_namespace][i].value = True
|
||||
|
||||
|
||||
async def clear_all_update_flags(namespace: str):
|
||||
async def clear_all_update_flags(namespace: str, workspace: str | None = None):
|
||||
"""Clear all update flag of namespace indicating all workers need to reload data from files"""
|
||||
global _update_flags
|
||||
if _update_flags is None:
|
||||
raise ValueError("Try to create namespace before Shared-Data is initialized")
|
||||
|
||||
final_namespace = get_final_namespace(namespace, workspace)
|
||||
|
||||
async with get_internal_lock():
|
||||
if namespace not in _update_flags:
|
||||
raise ValueError(f"Namespace {namespace} not found in update flags")
|
||||
if final_namespace not in _update_flags:
|
||||
raise ValueError(f"Namespace {final_namespace} not found in update flags")
|
||||
# Update flags for both modes
|
||||
for i in range(len(_update_flags[namespace])):
|
||||
_update_flags[namespace][i].value = False
|
||||
for i in range(len(_update_flags[final_namespace])):
|
||||
_update_flags[final_namespace][i].value = False
|
||||
|
||||
|
||||
async def get_all_update_flags_status() -> Dict[str, list]:
|
||||
async def get_all_update_flags_status(workspace: str | None = None) -> Dict[str, list]:
|
||||
"""
|
||||
Get update flags status for all namespaces.
|
||||
|
||||
|
|
@ -1394,9 +1364,17 @@ async def get_all_update_flags_status() -> Dict[str, list]:
|
|||
if _update_flags is None:
|
||||
return {}
|
||||
|
||||
if workspace is None:
|
||||
workspace = get_default_workspace
|
||||
|
||||
result = {}
|
||||
async with get_internal_lock():
|
||||
for namespace, flags in _update_flags.items():
|
||||
namespace_split = namespace.split(":")
|
||||
if workspace and not namespace_split[0] == workspace:
|
||||
continue
|
||||
if not workspace and namespace_split[0]:
|
||||
continue
|
||||
worker_statuses = []
|
||||
for flag in flags:
|
||||
if _is_multiprocess:
|
||||
|
|
@ -1408,7 +1386,9 @@ async def get_all_update_flags_status() -> Dict[str, list]:
|
|||
return result
|
||||
|
||||
|
||||
async def try_initialize_namespace(namespace: str) -> bool:
|
||||
async def try_initialize_namespace(
|
||||
namespace: str, workspace: str | None = None
|
||||
) -> bool:
|
||||
"""
|
||||
Returns True if the current worker(process) gets initialization permission for loading data later.
|
||||
The worker does not get the permission is prohibited to load data from files.
|
||||
|
|
@ -1418,57 +1398,76 @@ async def try_initialize_namespace(namespace: str) -> bool:
|
|||
if _init_flags is None:
|
||||
raise ValueError("Try to create nanmespace before Shared-Data is initialized")
|
||||
|
||||
final_namespace = get_final_namespace(namespace, workspace)
|
||||
|
||||
async with get_internal_lock():
|
||||
if namespace not in _init_flags:
|
||||
_init_flags[namespace] = True
|
||||
if final_namespace not in _init_flags:
|
||||
_init_flags[final_namespace] = True
|
||||
direct_log(
|
||||
f"Process {os.getpid()} ready to initialize storage namespace: [{namespace}]"
|
||||
f"Process {os.getpid()} ready to initialize storage namespace: [{final_namespace}]"
|
||||
)
|
||||
return True
|
||||
direct_log(
|
||||
f"Process {os.getpid()} storage namespace already initialized: [{namespace}]"
|
||||
f"Process {os.getpid()} storage namespace already initialized: [{final_namespace}]"
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
async def get_namespace_data(
|
||||
namespace: str, first_init: bool = False
|
||||
namespace: str, first_init: bool = False, workspace: str | None = None
|
||||
) -> Dict[str, Any]:
|
||||
"""get the shared data reference for specific namespace
|
||||
|
||||
Args:
|
||||
namespace: The namespace to retrieve
|
||||
allow_create: If True, allows creation of the namespace if it doesn't exist.
|
||||
Used internally by initialize_pipeline_status().
|
||||
first_init: If True, allows pipeline_status namespace to create namespace if it doesn't exist.
|
||||
Prevent getting pipeline_status namespace without initialize_pipeline_status().
|
||||
This parameter is used internally by initialize_pipeline_status().
|
||||
workspace: Workspace identifier (may be empty string for global namespace)
|
||||
"""
|
||||
if _shared_dicts is None:
|
||||
direct_log(
|
||||
f"Error: try to getnanmespace before it is initialized, pid={os.getpid()}",
|
||||
f"Error: Try to getnanmespace before it is initialized, pid={os.getpid()}",
|
||||
level="ERROR",
|
||||
)
|
||||
raise ValueError("Shared dictionaries not initialized")
|
||||
|
||||
async with get_internal_lock():
|
||||
if namespace not in _shared_dicts:
|
||||
# Special handling for pipeline_status namespace
|
||||
# Supports both global "pipeline_status" and workspace-specific "{workspace}:pipeline"
|
||||
is_pipeline = namespace == "pipeline_status" or namespace.endswith(
|
||||
":pipeline"
|
||||
)
|
||||
final_namespace = get_final_namespace(namespace, workspace)
|
||||
|
||||
if is_pipeline and not first_init:
|
||||
async with get_internal_lock():
|
||||
if final_namespace not in _shared_dicts:
|
||||
# Special handling for pipeline_status namespace
|
||||
if final_namespace.endswith(":pipeline_status") and not first_init:
|
||||
# Check if pipeline_status should have been initialized but wasn't
|
||||
# This helps users understand they need to call initialize_pipeline_status()
|
||||
raise PipelineNotInitializedError(namespace)
|
||||
# This helps users to call initialize_pipeline_status() before get_namespace_data()
|
||||
raise PipelineNotInitializedError(final_namespace)
|
||||
|
||||
# For other namespaces or when allow_create=True, create them dynamically
|
||||
if _is_multiprocess and _manager is not None:
|
||||
_shared_dicts[namespace] = _manager.dict()
|
||||
_shared_dicts[final_namespace] = _manager.dict()
|
||||
else:
|
||||
_shared_dicts[namespace] = {}
|
||||
_shared_dicts[final_namespace] = {}
|
||||
|
||||
return _shared_dicts[namespace]
|
||||
return _shared_dicts[final_namespace]
|
||||
|
||||
|
||||
def get_namespace_lock(
|
||||
namespace: str, workspace: str | None = None, enable_logging: bool = False
|
||||
) -> str:
|
||||
"""Get the lock key for a namespace.
|
||||
|
||||
Args:
|
||||
namespace: The namespace to get the lock key for.
|
||||
workspace: Workspace identifier (may be empty string for global namespace)
|
||||
|
||||
Returns:
|
||||
str: The lock key for the namespace.
|
||||
"""
|
||||
final_namespace = get_final_namespace(namespace, workspace)
|
||||
return get_storage_keyed_lock(
|
||||
["default_key"], namespace=final_namespace, enable_logging=enable_logging
|
||||
)
|
||||
|
||||
|
||||
def finalize_share_data():
|
||||
|
|
@ -1484,10 +1483,7 @@ def finalize_share_data():
|
|||
global \
|
||||
_manager, \
|
||||
_is_multiprocess, \
|
||||
_storage_lock, \
|
||||
_internal_lock, \
|
||||
_pipeline_status_lock, \
|
||||
_graph_db_lock, \
|
||||
_data_init_lock, \
|
||||
_shared_dicts, \
|
||||
_init_flags, \
|
||||
|
|
@ -1552,10 +1548,7 @@ def finalize_share_data():
|
|||
_is_multiprocess = None
|
||||
_shared_dicts = None
|
||||
_init_flags = None
|
||||
_storage_lock = None
|
||||
_internal_lock = None
|
||||
_pipeline_status_lock = None
|
||||
_graph_db_lock = None
|
||||
_data_init_lock = None
|
||||
_update_flags = None
|
||||
_async_locks = None
|
||||
|
|
@ -1563,21 +1556,23 @@ def finalize_share_data():
|
|||
direct_log(f"Process {os.getpid()} storage data finalization complete")
|
||||
|
||||
|
||||
def set_default_workspace(workspace: str):
|
||||
def set_default_workspace(workspace: str | None = None):
|
||||
"""
|
||||
Set default workspace for backward compatibility.
|
||||
Set default workspace for namespace operations for backward compatibility.
|
||||
|
||||
This allows initialize_pipeline_status() to automatically use the correct
|
||||
workspace when called without parameters, maintaining compatibility with
|
||||
legacy code that doesn't pass workspace explicitly.
|
||||
This allows get_namespace_data(),get_namespace_lock() or initialize_pipeline_status() to
|
||||
automatically use the correct workspace when called without workspace parameters,
|
||||
maintaining compatibility with legacy code that doesn't pass workspace explicitly.
|
||||
|
||||
Args:
|
||||
workspace: Workspace identifier (may be empty string for global namespace)
|
||||
"""
|
||||
global _default_workspace
|
||||
if workspace is None:
|
||||
workspace = ""
|
||||
_default_workspace = workspace
|
||||
direct_log(
|
||||
f"Default workspace set to: '{workspace}' (empty means global)",
|
||||
f"Default workspace set to: '{_default_workspace}' (empty means global)",
|
||||
level="DEBUG",
|
||||
)
|
||||
|
||||
|
|
@ -1587,7 +1582,7 @@ def get_default_workspace() -> str:
|
|||
Get default workspace for backward compatibility.
|
||||
|
||||
Returns:
|
||||
The default workspace string. Empty string means global namespace.
|
||||
The default workspace string. Empty string means global namespace. None means not set.
|
||||
"""
|
||||
global _default_workspace
|
||||
return _default_workspace if _default_workspace is not None else ""
|
||||
return _default_workspace
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ from __future__ import annotations
|
|||
import traceback
|
||||
import asyncio
|
||||
import configparser
|
||||
import inspect
|
||||
import os
|
||||
import time
|
||||
import warnings
|
||||
|
|
@ -12,6 +13,7 @@ from functools import partial
|
|||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Iterator,
|
||||
cast,
|
||||
|
|
@ -20,6 +22,7 @@ from typing import (
|
|||
Optional,
|
||||
List,
|
||||
Dict,
|
||||
Union,
|
||||
)
|
||||
from lightrag.prompt import PROMPTS
|
||||
from lightrag.exceptions import PipelineCancelledException
|
||||
|
|
@ -61,10 +64,10 @@ from lightrag.kg import (
|
|||
|
||||
from lightrag.kg.shared_storage import (
|
||||
get_namespace_data,
|
||||
get_graph_db_lock,
|
||||
get_data_init_lock,
|
||||
get_storage_keyed_lock,
|
||||
initialize_pipeline_status,
|
||||
get_default_workspace,
|
||||
set_default_workspace,
|
||||
get_namespace_lock,
|
||||
)
|
||||
|
||||
from lightrag.base import (
|
||||
|
|
@ -88,7 +91,7 @@ from lightrag.operate import (
|
|||
merge_nodes_and_edges,
|
||||
kg_query,
|
||||
naive_query,
|
||||
_rebuild_knowledge_from_chunks,
|
||||
rebuild_knowledge_from_chunks,
|
||||
)
|
||||
from lightrag.constants import GRAPH_FIELD_SEP
|
||||
from lightrag.utils import (
|
||||
|
|
@ -244,11 +247,13 @@ class LightRAG:
|
|||
int,
|
||||
int,
|
||||
],
|
||||
List[Dict[str, Any]],
|
||||
Union[List[Dict[str, Any]], Awaitable[List[Dict[str, Any]]]],
|
||||
] = field(default_factory=lambda: chunking_by_token_size)
|
||||
"""
|
||||
Custom chunking function for splitting text into chunks before processing.
|
||||
|
||||
The function can be either synchronous or asynchronous.
|
||||
|
||||
The function should take the following parameters:
|
||||
|
||||
- `tokenizer`: A Tokenizer instance to use for tokenization.
|
||||
|
|
@ -258,7 +263,8 @@ class LightRAG:
|
|||
- `chunk_token_size`: The maximum number of tokens per chunk.
|
||||
- `chunk_overlap_token_size`: The number of overlapping tokens between consecutive chunks.
|
||||
|
||||
The function should return a list of dictionaries, where each dictionary contains the following keys:
|
||||
The function should return a list of dictionaries (or an awaitable that resolves to a list),
|
||||
where each dictionary contains the following keys:
|
||||
- `tokens`: The number of tokens in the chunk.
|
||||
- `content`: The text content of the chunk.
|
||||
|
||||
|
|
@ -271,6 +277,9 @@ class LightRAG:
|
|||
embedding_func: EmbeddingFunc | None = field(default=None)
|
||||
"""Function for computing text embeddings. Must be set before use."""
|
||||
|
||||
embedding_token_limit: int | None = field(default=None, init=False)
|
||||
"""Token limit for embedding model. Set automatically from embedding_func.max_token_size in __post_init__."""
|
||||
|
||||
embedding_batch_num: int = field(default=int(os.getenv("EMBEDDING_BATCH_NUM", 10)))
|
||||
"""Batch size for embedding computations."""
|
||||
|
||||
|
|
@ -514,6 +523,16 @@ class LightRAG:
|
|||
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
|
||||
|
||||
# Init Embedding
|
||||
# Step 1: Capture max_token_size before applying decorator (decorator strips dataclass attributes)
|
||||
embedding_max_token_size = None
|
||||
if self.embedding_func and hasattr(self.embedding_func, "max_token_size"):
|
||||
embedding_max_token_size = self.embedding_func.max_token_size
|
||||
logger.debug(
|
||||
f"Captured embedding max_token_size: {embedding_max_token_size}"
|
||||
)
|
||||
self.embedding_token_limit = embedding_max_token_size
|
||||
|
||||
# Step 2: Apply priority wrapper decorator
|
||||
self.embedding_func = priority_limit_async_func_call(
|
||||
self.embedding_func_max_async,
|
||||
llm_timeout=self.default_embedding_timeout,
|
||||
|
|
@ -640,12 +659,11 @@ class LightRAG:
|
|||
async def initialize_storages(self):
|
||||
"""Storage initialization must be called one by one to prevent deadlock"""
|
||||
if self._storages_status == StoragesStatus.CREATED:
|
||||
# Set default workspace for backward compatibility
|
||||
# This allows initialize_pipeline_status() called without parameters
|
||||
# to use the correct workspace
|
||||
from lightrag.kg.shared_storage import set_default_workspace
|
||||
|
||||
set_default_workspace(self.workspace)
|
||||
# Set the first initialized workspace will set the default workspace
|
||||
# Allows namespace operation without specifying workspace for backward compatibility
|
||||
default_workspace = get_default_workspace()
|
||||
if default_workspace is None:
|
||||
set_default_workspace(self.workspace)
|
||||
|
||||
for storage in (
|
||||
self.full_docs,
|
||||
|
|
@ -718,7 +736,7 @@ class LightRAG:
|
|||
|
||||
async def check_and_migrate_data(self):
|
||||
"""Check if data migration is needed and perform migration if necessary"""
|
||||
async with get_data_init_lock(enable_logging=True):
|
||||
async with get_data_init_lock():
|
||||
try:
|
||||
# Check if migration is needed:
|
||||
# 1. chunk_entity_relation_graph has entities and relations (count > 0)
|
||||
|
|
@ -1581,22 +1599,12 @@ class LightRAG:
|
|||
"""
|
||||
|
||||
# Get pipeline status shared data and lock
|
||||
# Step 1: Get workspace
|
||||
workspace = self.workspace
|
||||
|
||||
# Step 2: Construct namespace (following GraphDB pattern)
|
||||
namespace = f"{workspace}:pipeline" if workspace else "pipeline_status"
|
||||
|
||||
# Step 3: Ensure initialization (on first access)
|
||||
await initialize_pipeline_status(workspace)
|
||||
|
||||
# Step 4: Get lock
|
||||
pipeline_status_lock = get_storage_keyed_lock(
|
||||
keys="status", namespace=namespace, enable_logging=False
|
||||
pipeline_status = await get_namespace_data(
|
||||
"pipeline_status", workspace=self.workspace
|
||||
)
|
||||
pipeline_status_lock = get_namespace_lock(
|
||||
"pipeline_status", workspace=self.workspace
|
||||
)
|
||||
|
||||
# Step 5: Get data
|
||||
pipeline_status = await get_namespace_data(namespace)
|
||||
|
||||
# Check if another process is already processing the queue
|
||||
async with pipeline_status_lock:
|
||||
|
|
@ -1778,7 +1786,28 @@ class LightRAG:
|
|||
)
|
||||
content = content_data["content"]
|
||||
|
||||
# Generate chunks from document
|
||||
# Call chunking function, supporting both sync and async implementations
|
||||
chunking_result = self.chunking_func(
|
||||
self.tokenizer,
|
||||
content,
|
||||
split_by_character,
|
||||
split_by_character_only,
|
||||
self.chunk_overlap_token_size,
|
||||
self.chunk_token_size,
|
||||
)
|
||||
|
||||
# If result is awaitable, await to get actual result
|
||||
if inspect.isawaitable(chunking_result):
|
||||
chunking_result = await chunking_result
|
||||
|
||||
# Validate return type
|
||||
if not isinstance(chunking_result, (list, tuple)):
|
||||
raise TypeError(
|
||||
f"chunking_func must return a list or tuple of dicts, "
|
||||
f"got {type(chunking_result)}"
|
||||
)
|
||||
|
||||
# Build chunks dictionary
|
||||
chunks: dict[str, Any] = {
|
||||
compute_mdhash_id(dp["content"], prefix="chunk-"): {
|
||||
**dp,
|
||||
|
|
@ -1786,14 +1815,7 @@ class LightRAG:
|
|||
"file_path": file_path, # Add file path to each chunk
|
||||
"llm_cache_list": [], # Initialize empty LLM cache list for each chunk
|
||||
}
|
||||
for dp in self.chunking_func(
|
||||
self.tokenizer,
|
||||
content,
|
||||
split_by_character,
|
||||
split_by_character_only,
|
||||
self.chunk_overlap_token_size,
|
||||
self.chunk_token_size,
|
||||
)
|
||||
for dp in chunking_result
|
||||
}
|
||||
|
||||
if not chunks:
|
||||
|
|
@ -1893,9 +1915,14 @@ class LightRAG:
|
|||
if task and not task.done():
|
||||
task.cancel()
|
||||
|
||||
# Persistent llm cache
|
||||
# Persistent llm cache with error handling
|
||||
if self.llm_response_cache:
|
||||
await self.llm_response_cache.index_done_callback()
|
||||
try:
|
||||
await self.llm_response_cache.index_done_callback()
|
||||
except Exception as persist_error:
|
||||
logger.error(
|
||||
f"Failed to persist LLM cache: {persist_error}"
|
||||
)
|
||||
|
||||
# Record processing end time for failed case
|
||||
processing_end_time = int(time.time())
|
||||
|
|
@ -2015,9 +2042,14 @@ class LightRAG:
|
|||
error_msg
|
||||
)
|
||||
|
||||
# Persistent llm cache
|
||||
# Persistent llm cache with error handling
|
||||
if self.llm_response_cache:
|
||||
await self.llm_response_cache.index_done_callback()
|
||||
try:
|
||||
await self.llm_response_cache.index_done_callback()
|
||||
except Exception as persist_error:
|
||||
logger.error(
|
||||
f"Failed to persist LLM cache: {persist_error}"
|
||||
)
|
||||
|
||||
# Record processing end time for failed case
|
||||
processing_end_time = int(time.time())
|
||||
|
|
@ -2924,22 +2956,12 @@ class LightRAG:
|
|||
doc_llm_cache_ids: list[str] = []
|
||||
|
||||
# Get pipeline status shared data and lock for status updates
|
||||
# Step 1: Get workspace
|
||||
workspace = self.workspace
|
||||
|
||||
# Step 2: Construct namespace (following GraphDB pattern)
|
||||
namespace = f"{workspace}:pipeline" if workspace else "pipeline_status"
|
||||
|
||||
# Step 3: Ensure initialization (on first access)
|
||||
await initialize_pipeline_status(workspace)
|
||||
|
||||
# Step 4: Get lock
|
||||
pipeline_status_lock = get_storage_keyed_lock(
|
||||
keys="status", namespace=namespace, enable_logging=False
|
||||
pipeline_status = await get_namespace_data(
|
||||
"pipeline_status", workspace=self.workspace
|
||||
)
|
||||
pipeline_status_lock = get_namespace_lock(
|
||||
"pipeline_status", workspace=self.workspace
|
||||
)
|
||||
|
||||
# Step 5: Get data
|
||||
pipeline_status = await get_namespace_data(namespace)
|
||||
|
||||
async with pipeline_status_lock:
|
||||
log_message = f"Starting deletion process for document {doc_id}"
|
||||
|
|
@ -3172,6 +3194,9 @@ class LightRAG:
|
|||
]
|
||||
|
||||
if not existing_sources:
|
||||
# No chunk references means this entity should be deleted
|
||||
entities_to_delete.add(node_label)
|
||||
entity_chunk_updates[node_label] = []
|
||||
continue
|
||||
|
||||
remaining_sources = subtract_source_ids(existing_sources, chunk_ids)
|
||||
|
|
@ -3193,6 +3218,7 @@ class LightRAG:
|
|||
|
||||
# Process relationships
|
||||
for edge_data in affected_edges:
|
||||
# source target is not in normalize order in graph db property
|
||||
src = edge_data.get("source")
|
||||
tgt = edge_data.get("target")
|
||||
|
||||
|
|
@ -3229,6 +3255,9 @@ class LightRAG:
|
|||
]
|
||||
|
||||
if not existing_sources:
|
||||
# No chunk references means this relationship should be deleted
|
||||
relationships_to_delete.add(edge_tuple)
|
||||
relation_chunk_updates[edge_tuple] = []
|
||||
continue
|
||||
|
||||
remaining_sources = subtract_source_ids(existing_sources, chunk_ids)
|
||||
|
|
@ -3254,38 +3283,31 @@ class LightRAG:
|
|||
|
||||
if entity_chunk_updates and self.entity_chunks:
|
||||
entity_upsert_payload = {}
|
||||
entity_delete_ids: set[str] = set()
|
||||
for entity_name, remaining in entity_chunk_updates.items():
|
||||
if not remaining:
|
||||
entity_delete_ids.add(entity_name)
|
||||
else:
|
||||
entity_upsert_payload[entity_name] = {
|
||||
"chunk_ids": remaining,
|
||||
"count": len(remaining),
|
||||
"updated_at": current_time,
|
||||
}
|
||||
|
||||
if entity_delete_ids:
|
||||
await self.entity_chunks.delete(list(entity_delete_ids))
|
||||
# Empty entities are deleted alongside graph nodes later
|
||||
continue
|
||||
entity_upsert_payload[entity_name] = {
|
||||
"chunk_ids": remaining,
|
||||
"count": len(remaining),
|
||||
"updated_at": current_time,
|
||||
}
|
||||
if entity_upsert_payload:
|
||||
await self.entity_chunks.upsert(entity_upsert_payload)
|
||||
|
||||
if relation_chunk_updates and self.relation_chunks:
|
||||
relation_upsert_payload = {}
|
||||
relation_delete_ids: set[str] = set()
|
||||
for edge_tuple, remaining in relation_chunk_updates.items():
|
||||
storage_key = make_relation_chunk_key(*edge_tuple)
|
||||
if not remaining:
|
||||
relation_delete_ids.add(storage_key)
|
||||
else:
|
||||
relation_upsert_payload[storage_key] = {
|
||||
"chunk_ids": remaining,
|
||||
"count": len(remaining),
|
||||
"updated_at": current_time,
|
||||
}
|
||||
# Empty relations are deleted alongside graph edges later
|
||||
continue
|
||||
storage_key = make_relation_chunk_key(*edge_tuple)
|
||||
relation_upsert_payload[storage_key] = {
|
||||
"chunk_ids": remaining,
|
||||
"count": len(remaining),
|
||||
"updated_at": current_time,
|
||||
}
|
||||
|
||||
if relation_delete_ids:
|
||||
await self.relation_chunks.delete(list(relation_delete_ids))
|
||||
if relation_upsert_payload:
|
||||
await self.relation_chunks.upsert(relation_upsert_payload)
|
||||
|
||||
|
|
@ -3293,56 +3315,111 @@ class LightRAG:
|
|||
logger.error(f"Failed to process graph analysis results: {e}")
|
||||
raise Exception(f"Failed to process graph dependencies: {e}") from e
|
||||
|
||||
# Use graph database lock to prevent dirty read
|
||||
graph_db_lock = get_graph_db_lock(enable_logging=False)
|
||||
async with graph_db_lock:
|
||||
# 5. Delete chunks from storage
|
||||
if chunk_ids:
|
||||
try:
|
||||
await self.chunks_vdb.delete(chunk_ids)
|
||||
await self.text_chunks.delete(chunk_ids)
|
||||
# Data integrity is ensured by allowing only one process to hold pipeline at a time(no graph db lock is needed anymore)
|
||||
|
||||
async with pipeline_status_lock:
|
||||
log_message = f"Successfully deleted {len(chunk_ids)} chunks from storage"
|
||||
logger.info(log_message)
|
||||
pipeline_status["latest_message"] = log_message
|
||||
pipeline_status["history_messages"].append(log_message)
|
||||
# 5. Delete chunks from storage
|
||||
if chunk_ids:
|
||||
try:
|
||||
await self.chunks_vdb.delete(chunk_ids)
|
||||
await self.text_chunks.delete(chunk_ids)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete chunks: {e}")
|
||||
raise Exception(f"Failed to delete document chunks: {e}") from e
|
||||
async with pipeline_status_lock:
|
||||
log_message = (
|
||||
f"Successfully deleted {len(chunk_ids)} chunks from storage"
|
||||
)
|
||||
logger.info(log_message)
|
||||
pipeline_status["latest_message"] = log_message
|
||||
pipeline_status["history_messages"].append(log_message)
|
||||
|
||||
# 6. Delete entities that have no remaining sources
|
||||
if entities_to_delete:
|
||||
try:
|
||||
# Delete from vector database
|
||||
entity_vdb_ids = [
|
||||
compute_mdhash_id(entity, prefix="ent-")
|
||||
for entity in entities_to_delete
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete chunks: {e}")
|
||||
raise Exception(f"Failed to delete document chunks: {e}") from e
|
||||
|
||||
# 6. Delete relationships that have no remaining sources
|
||||
if relationships_to_delete:
|
||||
try:
|
||||
# Delete from relation vdb
|
||||
rel_ids_to_delete = []
|
||||
for src, tgt in relationships_to_delete:
|
||||
rel_ids_to_delete.extend(
|
||||
[
|
||||
compute_mdhash_id(src + tgt, prefix="rel-"),
|
||||
compute_mdhash_id(tgt + src, prefix="rel-"),
|
||||
]
|
||||
)
|
||||
await self.relationships_vdb.delete(rel_ids_to_delete)
|
||||
|
||||
# Delete from graph
|
||||
await self.chunk_entity_relation_graph.remove_edges(
|
||||
list(relationships_to_delete)
|
||||
)
|
||||
|
||||
# Delete from relation_chunks storage
|
||||
if self.relation_chunks:
|
||||
relation_storage_keys = [
|
||||
make_relation_chunk_key(src, tgt)
|
||||
for src, tgt in relationships_to_delete
|
||||
]
|
||||
await self.entities_vdb.delete(entity_vdb_ids)
|
||||
await self.relation_chunks.delete(relation_storage_keys)
|
||||
|
||||
# Delete from graph
|
||||
await self.chunk_entity_relation_graph.remove_nodes(
|
||||
async with pipeline_status_lock:
|
||||
log_message = f"Successfully deleted {len(relationships_to_delete)} relations"
|
||||
logger.info(log_message)
|
||||
pipeline_status["latest_message"] = log_message
|
||||
pipeline_status["history_messages"].append(log_message)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete relationships: {e}")
|
||||
raise Exception(f"Failed to delete relationships: {e}") from e
|
||||
|
||||
# 7. Delete entities that have no remaining sources
|
||||
if entities_to_delete:
|
||||
try:
|
||||
# Batch get all edges for entities to avoid N+1 query problem
|
||||
nodes_edges_dict = (
|
||||
await self.chunk_entity_relation_graph.get_nodes_edges_batch(
|
||||
list(entities_to_delete)
|
||||
)
|
||||
)
|
||||
|
||||
async with pipeline_status_lock:
|
||||
log_message = f"Successfully deleted {len(entities_to_delete)} entities"
|
||||
logger.info(log_message)
|
||||
pipeline_status["latest_message"] = log_message
|
||||
pipeline_status["history_messages"].append(log_message)
|
||||
# Debug: Check and log all edges before deleting nodes
|
||||
edges_to_delete = set()
|
||||
edges_still_exist = 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete entities: {e}")
|
||||
raise Exception(f"Failed to delete entities: {e}") from e
|
||||
for entity, edges in nodes_edges_dict.items():
|
||||
if edges:
|
||||
for src, tgt in edges:
|
||||
# Normalize edge representation (sorted for consistency)
|
||||
edge_tuple = tuple(sorted((src, tgt)))
|
||||
edges_to_delete.add(edge_tuple)
|
||||
|
||||
# 7. Delete relationships that have no remaining sources
|
||||
if relationships_to_delete:
|
||||
try:
|
||||
# Delete from vector database
|
||||
if (
|
||||
src in entities_to_delete
|
||||
and tgt in entities_to_delete
|
||||
):
|
||||
logger.warning(
|
||||
f"Edge still exists: {src} <-> {tgt}"
|
||||
)
|
||||
elif src in entities_to_delete:
|
||||
logger.warning(
|
||||
f"Edge still exists: {src} --> {tgt}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Edge still exists: {src} <-- {tgt}"
|
||||
)
|
||||
edges_still_exist += 1
|
||||
|
||||
if edges_still_exist:
|
||||
logger.warning(
|
||||
f"⚠️ {edges_still_exist} entities still has edges before deletion"
|
||||
)
|
||||
|
||||
# Clean residual edges from VDB and storage before deleting nodes
|
||||
if edges_to_delete:
|
||||
# Delete from relationships_vdb
|
||||
rel_ids_to_delete = []
|
||||
for src, tgt in relationships_to_delete:
|
||||
for src, tgt in edges_to_delete:
|
||||
rel_ids_to_delete.extend(
|
||||
[
|
||||
compute_mdhash_id(src + tgt, prefix="rel-"),
|
||||
|
|
@ -3351,28 +3428,53 @@ class LightRAG:
|
|||
)
|
||||
await self.relationships_vdb.delete(rel_ids_to_delete)
|
||||
|
||||
# Delete from graph
|
||||
await self.chunk_entity_relation_graph.remove_edges(
|
||||
list(relationships_to_delete)
|
||||
# Delete from relation_chunks storage
|
||||
if self.relation_chunks:
|
||||
relation_storage_keys = [
|
||||
make_relation_chunk_key(src, tgt)
|
||||
for src, tgt in edges_to_delete
|
||||
]
|
||||
await self.relation_chunks.delete(relation_storage_keys)
|
||||
|
||||
logger.info(
|
||||
f"Cleaned {len(edges_to_delete)} residual edges from VDB and chunk-tracking storage"
|
||||
)
|
||||
|
||||
async with pipeline_status_lock:
|
||||
log_message = f"Successfully deleted {len(relationships_to_delete)} relations"
|
||||
logger.info(log_message)
|
||||
pipeline_status["latest_message"] = log_message
|
||||
pipeline_status["history_messages"].append(log_message)
|
||||
# Delete from graph (edges will be auto-deleted with nodes)
|
||||
await self.chunk_entity_relation_graph.remove_nodes(
|
||||
list(entities_to_delete)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete relationships: {e}")
|
||||
raise Exception(f"Failed to delete relationships: {e}") from e
|
||||
# Delete from vector vdb
|
||||
entity_vdb_ids = [
|
||||
compute_mdhash_id(entity, prefix="ent-")
|
||||
for entity in entities_to_delete
|
||||
]
|
||||
await self.entities_vdb.delete(entity_vdb_ids)
|
||||
|
||||
# Persist changes to graph database before releasing graph database lock
|
||||
await self._insert_done()
|
||||
# Delete from entity_chunks storage
|
||||
if self.entity_chunks:
|
||||
await self.entity_chunks.delete(list(entities_to_delete))
|
||||
|
||||
async with pipeline_status_lock:
|
||||
log_message = (
|
||||
f"Successfully deleted {len(entities_to_delete)} entities"
|
||||
)
|
||||
logger.info(log_message)
|
||||
pipeline_status["latest_message"] = log_message
|
||||
pipeline_status["history_messages"].append(log_message)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete entities: {e}")
|
||||
raise Exception(f"Failed to delete entities: {e}") from e
|
||||
|
||||
# Persist changes to graph database before entity and relationship rebuild
|
||||
await self._insert_done()
|
||||
|
||||
# 8. Rebuild entities and relationships from remaining chunks
|
||||
if entities_to_rebuild or relationships_to_rebuild:
|
||||
try:
|
||||
await _rebuild_knowledge_from_chunks(
|
||||
await rebuild_knowledge_from_chunks(
|
||||
entities_to_rebuild=entities_to_rebuild,
|
||||
relationships_to_rebuild=relationships_to_rebuild,
|
||||
knowledge_graph_inst=self.chunk_entity_relation_graph,
|
||||
|
|
@ -3590,16 +3692,22 @@ class LightRAG:
|
|||
)
|
||||
|
||||
async def aedit_entity(
|
||||
self, entity_name: str, updated_data: dict[str, str], allow_rename: bool = True
|
||||
self,
|
||||
entity_name: str,
|
||||
updated_data: dict[str, str],
|
||||
allow_rename: bool = True,
|
||||
allow_merge: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""Asynchronously edit entity information.
|
||||
|
||||
Updates entity information in the knowledge graph and re-embeds the entity in the vector database.
|
||||
Also synchronizes entity_chunks_storage and relation_chunks_storage to track chunk references.
|
||||
|
||||
Args:
|
||||
entity_name: Name of the entity to edit
|
||||
updated_data: Dictionary containing updated attributes, e.g. {"description": "new description", "entity_type": "new type"}
|
||||
allow_rename: Whether to allow entity renaming, defaults to True
|
||||
allow_merge: Whether to merge into an existing entity when renaming to an existing name
|
||||
|
||||
Returns:
|
||||
Dictionary containing updated entity information
|
||||
|
|
@ -3613,14 +3721,21 @@ class LightRAG:
|
|||
entity_name,
|
||||
updated_data,
|
||||
allow_rename,
|
||||
allow_merge,
|
||||
self.entity_chunks,
|
||||
self.relation_chunks,
|
||||
)
|
||||
|
||||
def edit_entity(
|
||||
self, entity_name: str, updated_data: dict[str, str], allow_rename: bool = True
|
||||
self,
|
||||
entity_name: str,
|
||||
updated_data: dict[str, str],
|
||||
allow_rename: bool = True,
|
||||
allow_merge: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
loop = always_get_an_event_loop()
|
||||
return loop.run_until_complete(
|
||||
self.aedit_entity(entity_name, updated_data, allow_rename)
|
||||
self.aedit_entity(entity_name, updated_data, allow_rename, allow_merge)
|
||||
)
|
||||
|
||||
async def aedit_relation(
|
||||
|
|
@ -3629,6 +3744,7 @@ class LightRAG:
|
|||
"""Asynchronously edit relation information.
|
||||
|
||||
Updates relation (edge) information in the knowledge graph and re-embeds the relation in the vector database.
|
||||
Also synchronizes the relation_chunks_storage to track which chunks reference this relation.
|
||||
|
||||
Args:
|
||||
source_entity: Name of the source entity
|
||||
|
|
@ -3647,6 +3763,7 @@ class LightRAG:
|
|||
source_entity,
|
||||
target_entity,
|
||||
updated_data,
|
||||
self.relation_chunks,
|
||||
)
|
||||
|
||||
def edit_relation(
|
||||
|
|
@ -3758,6 +3875,8 @@ class LightRAG:
|
|||
target_entity,
|
||||
merge_strategy,
|
||||
target_entity_data,
|
||||
self.entity_chunks,
|
||||
self.relation_chunks,
|
||||
)
|
||||
|
||||
def merge_entities(
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue