diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index b29e39b2..8de03283 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -56,8 +56,7 @@ from lightrag.api.routers.ollama_api import OllamaAPI from lightrag.utils import logger, set_verbose_debug from lightrag.kg.shared_storage import ( get_namespace_data, - get_default_workspace, - # set_default_workspace, + initialize_pipeline_status, cleanup_keyed_lock, finalize_share_data, ) @@ -351,8 +350,8 @@ def create_app(args): try: # Initialize database connections - # Note: initialize_storages() now auto-initializes pipeline_status for rag.workspace await rag.initialize_storages() + await initialize_pipeline_status() # Data migration regardless of storage implementation await rag.check_and_migrate_data() @@ -453,7 +452,7 @@ def create_app(args): # Create combined auth dependency for all endpoints combined_auth = get_combined_auth_dependency(api_key) - def get_workspace_from_request(request: Request) -> str | None: + def get_workspace_from_request(request: Request) -> str: """ Extract workspace from HTTP request header or use default. @@ -470,8 +469,9 @@ def create_app(args): # Check custom header first workspace = request.headers.get("LIGHTRAG-WORKSPACE", "").strip() + # Fall back to server default if header not provided if not workspace: - workspace = None + workspace = args.workspace return workspace @@ -641,108 +641,33 @@ def create_app(args): def create_optimized_embedding_function( config_cache: LLMConfigCache, binding, model, host, api_key, args - ) -> EmbeddingFunc: + ): """ - Create optimized embedding function and return an EmbeddingFunc instance - with proper max_token_size inheritance from provider defaults. - - This function: - 1. Imports the provider embedding function - 2. Extracts max_token_size and embedding_dim from provider if it's an EmbeddingFunc - 3. Creates an optimized wrapper that calls the underlying function directly (avoiding double-wrapping) - 4. Returns a properly configured EmbeddingFunc instance + Create optimized embedding function with pre-processed configuration for applicable bindings. + Uses lazy imports for all bindings and avoids repeated configuration parsing. """ - # Step 1: Import provider function and extract default attributes - provider_func = None - provider_max_token_size = None - provider_embedding_dim = None - - try: - if binding == "openai": - from lightrag.llm.openai import openai_embed - - provider_func = openai_embed - elif binding == "ollama": - from lightrag.llm.ollama import ollama_embed - - provider_func = ollama_embed - elif binding == "gemini": - from lightrag.llm.gemini import gemini_embed - - provider_func = gemini_embed - elif binding == "jina": - from lightrag.llm.jina import jina_embed - - provider_func = jina_embed - elif binding == "azure_openai": - from lightrag.llm.azure_openai import azure_openai_embed - - provider_func = azure_openai_embed - elif binding == "aws_bedrock": - from lightrag.llm.bedrock import bedrock_embed - - provider_func = bedrock_embed - elif binding == "lollms": - from lightrag.llm.lollms import lollms_embed - - provider_func = lollms_embed - - # Extract attributes if provider is an EmbeddingFunc - if provider_func and isinstance(provider_func, EmbeddingFunc): - provider_max_token_size = provider_func.max_token_size - provider_embedding_dim = provider_func.embedding_dim - logger.debug( - f"Extracted from {binding} provider: " - f"max_token_size={provider_max_token_size}, " - f"embedding_dim={provider_embedding_dim}" - ) - except ImportError as e: - logger.warning(f"Could not import provider function for {binding}: {e}") - - # Step 2: Apply priority (user config > provider default) - # For max_token_size: explicit env var > provider default > None - final_max_token_size = args.embedding_token_limit or provider_max_token_size - # For embedding_dim: user config (always has value) takes priority - # Only use provider default if user config is explicitly None (which shouldn't happen) - final_embedding_dim = ( - args.embedding_dim if args.embedding_dim else provider_embedding_dim - ) - - # Step 3: Create optimized embedding function (calls underlying function directly) async def optimized_embedding_function(texts, embedding_dim=None): try: if binding == "lollms": from lightrag.llm.lollms import lollms_embed - # Get real function, skip EmbeddingFunc wrapper if present - actual_func = ( - lollms_embed.func - if isinstance(lollms_embed, EmbeddingFunc) - else lollms_embed - ) - return await actual_func( + return await lollms_embed( texts, embed_model=model, host=host, api_key=api_key ) elif binding == "ollama": from lightrag.llm.ollama import ollama_embed - # Get real function, skip EmbeddingFunc wrapper if present - actual_func = ( - ollama_embed.func - if isinstance(ollama_embed, EmbeddingFunc) - else ollama_embed - ) - - # Use pre-processed configuration if available + # Use pre-processed configuration if available, otherwise fallback to dynamic parsing if config_cache.ollama_embedding_options is not None: ollama_options = config_cache.ollama_embedding_options else: + # Fallback for cases where config cache wasn't initialized properly from lightrag.llm.binding_options import OllamaEmbeddingOptions ollama_options = OllamaEmbeddingOptions.options_dict(args) - return await actual_func( + return await ollama_embed( texts, embed_model=model, host=host, @@ -752,30 +677,15 @@ def create_app(args): elif binding == "azure_openai": from lightrag.llm.azure_openai import azure_openai_embed - actual_func = ( - azure_openai_embed.func - if isinstance(azure_openai_embed, EmbeddingFunc) - else azure_openai_embed - ) - return await actual_func(texts, model=model, api_key=api_key) + return await azure_openai_embed(texts, model=model, api_key=api_key) elif binding == "aws_bedrock": from lightrag.llm.bedrock import bedrock_embed - actual_func = ( - bedrock_embed.func - if isinstance(bedrock_embed, EmbeddingFunc) - else bedrock_embed - ) - return await actual_func(texts, model=model) + return await bedrock_embed(texts, model=model) elif binding == "jina": from lightrag.llm.jina import jina_embed - actual_func = ( - jina_embed.func - if isinstance(jina_embed, EmbeddingFunc) - else jina_embed - ) - return await actual_func( + return await jina_embed( texts, embedding_dim=embedding_dim, base_url=host, @@ -784,21 +694,16 @@ def create_app(args): elif binding == "gemini": from lightrag.llm.gemini import gemini_embed - actual_func = ( - gemini_embed.func - if isinstance(gemini_embed, EmbeddingFunc) - else gemini_embed - ) - - # Use pre-processed configuration if available + # Use pre-processed configuration if available, otherwise fallback to dynamic parsing if config_cache.gemini_embedding_options is not None: gemini_options = config_cache.gemini_embedding_options else: + # Fallback for cases where config cache wasn't initialized properly from lightrag.llm.binding_options import GeminiEmbeddingOptions gemini_options = GeminiEmbeddingOptions.options_dict(args) - return await actual_func( + return await gemini_embed( texts, model=model, base_url=host, @@ -809,12 +714,7 @@ def create_app(args): else: # openai and compatible from lightrag.llm.openai import openai_embed - actual_func = ( - openai_embed.func - if isinstance(openai_embed, EmbeddingFunc) - else openai_embed - ) - return await actual_func( + return await openai_embed( texts, model=model, base_url=host, @@ -824,21 +724,7 @@ def create_app(args): except ImportError as e: raise Exception(f"Failed to import {binding} embedding: {e}") - # Step 4: Wrap in EmbeddingFunc and return - embedding_func_instance = EmbeddingFunc( - embedding_dim=final_embedding_dim, - func=optimized_embedding_function, - max_token_size=final_max_token_size, - send_dimensions=False, # Will be set later based on binding requirements - ) - - # Log final embedding configuration - logger.info( - f"Embedding config: binding={binding} model={model} " - f"embedding_dim={final_embedding_dim} max_token_size={final_max_token_size}" - ) - - return embedding_func_instance + return optimized_embedding_function llm_timeout = get_env_value("LLM_TIMEOUT", DEFAULT_LLM_TIMEOUT, int) embedding_timeout = get_env_value( @@ -872,24 +758,25 @@ def create_app(args): **kwargs, ) - # Create embedding function with optimized configuration and max_token_size inheritance + # Create embedding function with optimized configuration import inspect - # Create the EmbeddingFunc instance (now returns complete EmbeddingFunc with max_token_size) - embedding_func = create_optimized_embedding_function( + # Create the optimized embedding function + optimized_embedding_func = create_optimized_embedding_function( config_cache=config_cache, binding=args.embedding_binding, model=args.embedding_model, host=args.embedding_binding_host, api_key=args.embedding_binding_api_key, - args=args, + args=args, # Pass args object for fallback option generation ) # Get embedding_send_dim from centralized configuration embedding_send_dim = args.embedding_send_dim - # Check if the underlying function signature has embedding_dim parameter - sig = inspect.signature(embedding_func.func) + # Check if the function signature has embedding_dim parameter + # Note: Since optimized_embedding_func is an async function, inspect its signature + sig = inspect.signature(optimized_embedding_func) has_embedding_dim_param = "embedding_dim" in sig.parameters # Determine send_dimensions value based on binding type @@ -907,27 +794,18 @@ def create_app(args): else: dimension_control = "by not hasparam" - # Set send_dimensions on the EmbeddingFunc instance - embedding_func.send_dimensions = send_dimensions - logger.info( f"Send embedding dimension: {send_dimensions} {dimension_control} " - f"(dimensions={embedding_func.embedding_dim}, has_param={has_embedding_dim_param}, " + f"(dimensions={args.embedding_dim}, has_param={has_embedding_dim_param}, " f"binding={args.embedding_binding})" ) - # Log max_token_size source - if embedding_func.max_token_size: - source = ( - "env variable" - if args.embedding_token_limit - else f"{args.embedding_binding} provider default" - ) - logger.info( - f"Embedding max_token_size: {embedding_func.max_token_size} (from {source})" - ) - else: - logger.info("Embedding max_token_size: not set (90% token warning disabled)") + # Create EmbeddingFunc with send_dimensions attribute + embedding_func = EmbeddingFunc( + embedding_dim=args.embedding_dim, + func=optimized_embedding_func, + send_dimensions=send_dimensions, + ) # Configure rerank function based on args.rerank_bindingparameter rerank_model_func = None @@ -1139,13 +1017,14 @@ def create_app(args): async def get_status(request: Request): """Get current system status""" try: + # Extract workspace from request header or use default workspace = get_workspace_from_request(request) - default_workspace = get_default_workspace() - if workspace is None: - workspace = default_workspace - pipeline_status = await get_namespace_data( - "pipeline_status", workspace=workspace - ) + + # Construct namespace (following GraphDB pattern) + namespace = f"{workspace}:pipeline" if workspace else "pipeline_status" + + # Get workspace-specific pipeline status + pipeline_status = await get_namespace_data(namespace) if not auth_configured: auth_mode = "disabled" @@ -1176,7 +1055,8 @@ def create_app(args): "vector_storage": args.vector_storage, "enable_llm_cache_for_extract": args.enable_llm_cache_for_extract, "enable_llm_cache": args.enable_llm_cache, - "workspace": default_workspace, + "workspace": workspace, + "default_workspace": args.workspace, "max_graph_nodes": args.max_graph_nodes, # Rerank configuration "enable_rerank": rerank_model_func is not None, @@ -1365,12 +1245,6 @@ def check_and_install_dependencies(): def main(): - # Explicitly initialize configuration for clarity - # (The proxy will auto-initialize anyway, but this makes intent clear) - from .config import initialize_config - - initialize_config() - # Check if running under Gunicorn if "GUNICORN_CMD_ARGS" in os.environ: # If started with Gunicorn, return directly as Gunicorn will call get_application diff --git a/lightrag/kg/shared_storage.py b/lightrag/kg/shared_storage.py index 8f34e64c..0d55db3d 100644 --- a/lightrag/kg/shared_storage.py +++ b/lightrag/kg/shared_storage.py @@ -75,6 +75,9 @@ _last_mp_cleanup_time: Optional[float] = None _initialized = None +# Default workspace for backward compatibility +_default_workspace: Optional[str] = None + # shared data for storage across processes _shared_dicts: Optional[Dict[str, Any]] = None _init_flags: Optional[Dict[str, bool]] = None # namespace -> initialized @@ -1276,15 +1279,21 @@ async def initialize_pipeline_status(workspace: str = ""): Args: workspace: Optional workspace identifier for multi-tenant isolation. - Empty string (default) uses global "pipeline_status" namespace. + If empty string, uses the default workspace set by + set_default_workspace(). If no default is set, uses + global "pipeline_status" namespace. This function is called during FASTAPI lifespan for each worker. """ + # Backward compatibility: use default workspace if not provided + if not workspace: + workspace = get_default_workspace() + # Construct namespace (following GraphDB pattern) if workspace: namespace = f"{workspace}:pipeline" else: - namespace = "pipeline_status" # Backward compatibility + namespace = "pipeline_status" # Global namespace for backward compatibility pipeline_namespace = await get_namespace_data(namespace, first_init=True) @@ -1552,3 +1561,33 @@ def finalize_share_data(): _async_locks = None direct_log(f"Process {os.getpid()} storage data finalization complete") + + +def set_default_workspace(workspace: str): + """ + Set default workspace for backward compatibility. + + This allows initialize_pipeline_status() to automatically use the correct + workspace when called without parameters, maintaining compatibility with + legacy code that doesn't pass workspace explicitly. + + Args: + workspace: Workspace identifier (may be empty string for global namespace) + """ + global _default_workspace + _default_workspace = workspace + direct_log( + f"Default workspace set to: '{workspace}' (empty means global)", + level="DEBUG", + ) + + +def get_default_workspace() -> str: + """ + Get default workspace for backward compatibility. + + Returns: + The default workspace string. Empty string means global namespace. + """ + global _default_workspace + return _default_workspace if _default_workspace is not None else "" diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 4b3fc3a6..dff80aad 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -640,6 +640,13 @@ class LightRAG: async def initialize_storages(self): """Storage initialization must be called one by one to prevent deadlock""" if self._storages_status == StoragesStatus.CREATED: + # Set default workspace for backward compatibility + # This allows initialize_pipeline_status() called without parameters + # to use the correct workspace + from lightrag.kg.shared_storage import set_default_workspace + + set_default_workspace(self.workspace) + for storage in ( self.full_docs, self.text_chunks,