cherry-pick 33a1482f
This commit is contained in:
parent
e11e30be0e
commit
cacea8ab56
5 changed files with 149 additions and 611 deletions
|
|
@ -233,6 +233,13 @@ OLLAMA_LLM_NUM_CTX=32768
|
||||||
### EMBEDDING_BINDING_HOST: host only for Ollama, endpoint for other Embedding service
|
### EMBEDDING_BINDING_HOST: host only for Ollama, endpoint for other Embedding service
|
||||||
#######################################################################################
|
#######################################################################################
|
||||||
# EMBEDDING_TIMEOUT=30
|
# EMBEDDING_TIMEOUT=30
|
||||||
|
|
||||||
|
### Control whether to send embedding_dim parameter to embedding API
|
||||||
|
### Set to 'true' to enable dynamic dimension adjustment (only works for OpenAI and Jina)
|
||||||
|
### Set to 'false' (default) to disable sending dimension parameter
|
||||||
|
### Note: This is automatically ignored for backends that don't support dimension parameter
|
||||||
|
# EMBEDDING_SEND_DIM=false
|
||||||
|
|
||||||
EMBEDDING_BINDING=ollama
|
EMBEDDING_BINDING=ollama
|
||||||
EMBEDDING_MODEL=bge-m3:latest
|
EMBEDDING_MODEL=bge-m3:latest
|
||||||
EMBEDDING_DIM=1024
|
EMBEDDING_DIM=1024
|
||||||
|
|
|
||||||
|
|
@ -56,8 +56,6 @@ from lightrag.api.routers.ollama_api import OllamaAPI
|
||||||
from lightrag.utils import logger, set_verbose_debug
|
from lightrag.utils import logger, set_verbose_debug
|
||||||
from lightrag.kg.shared_storage import (
|
from lightrag.kg.shared_storage import (
|
||||||
get_namespace_data,
|
get_namespace_data,
|
||||||
get_default_workspace,
|
|
||||||
# set_default_workspace,
|
|
||||||
initialize_pipeline_status,
|
initialize_pipeline_status,
|
||||||
cleanup_keyed_lock,
|
cleanup_keyed_lock,
|
||||||
finalize_share_data,
|
finalize_share_data,
|
||||||
|
|
@ -91,7 +89,6 @@ class LLMConfigCache:
|
||||||
# Initialize configurations based on binding conditions
|
# Initialize configurations based on binding conditions
|
||||||
self.openai_llm_options = None
|
self.openai_llm_options = None
|
||||||
self.gemini_llm_options = None
|
self.gemini_llm_options = None
|
||||||
self.gemini_embedding_options = None
|
|
||||||
self.ollama_llm_options = None
|
self.ollama_llm_options = None
|
||||||
self.ollama_embedding_options = None
|
self.ollama_embedding_options = None
|
||||||
|
|
||||||
|
|
@ -138,23 +135,6 @@ class LLMConfigCache:
|
||||||
)
|
)
|
||||||
self.ollama_embedding_options = {}
|
self.ollama_embedding_options = {}
|
||||||
|
|
||||||
# Only initialize and log Gemini Embedding options when using Gemini Embedding binding
|
|
||||||
if args.embedding_binding == "gemini":
|
|
||||||
try:
|
|
||||||
from lightrag.llm.binding_options import GeminiEmbeddingOptions
|
|
||||||
|
|
||||||
self.gemini_embedding_options = GeminiEmbeddingOptions.options_dict(
|
|
||||||
args
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
f"Gemini Embedding Options: {self.gemini_embedding_options}"
|
|
||||||
)
|
|
||||||
except ImportError:
|
|
||||||
logger.warning(
|
|
||||||
"GeminiEmbeddingOptions not available, using default configuration"
|
|
||||||
)
|
|
||||||
self.gemini_embedding_options = {}
|
|
||||||
|
|
||||||
|
|
||||||
def check_frontend_build():
|
def check_frontend_build():
|
||||||
"""Check if frontend is built and optionally check if source is up-to-date
|
"""Check if frontend is built and optionally check if source is up-to-date
|
||||||
|
|
@ -316,7 +296,6 @@ def create_app(args):
|
||||||
"azure_openai",
|
"azure_openai",
|
||||||
"aws_bedrock",
|
"aws_bedrock",
|
||||||
"jina",
|
"jina",
|
||||||
"gemini",
|
|
||||||
]:
|
]:
|
||||||
raise Exception("embedding binding not supported")
|
raise Exception("embedding binding not supported")
|
||||||
|
|
||||||
|
|
@ -352,9 +331,8 @@ def create_app(args):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Initialize database connections
|
# Initialize database connections
|
||||||
# set_default_workspace(rag.workspace) # comment this line to test auto default workspace setting in initialize_storages
|
|
||||||
await rag.initialize_storages()
|
await rag.initialize_storages()
|
||||||
await initialize_pipeline_status() # with default workspace
|
await initialize_pipeline_status()
|
||||||
|
|
||||||
# Data migration regardless of storage implementation
|
# Data migration regardless of storage implementation
|
||||||
await rag.check_and_migrate_data()
|
await rag.check_and_migrate_data()
|
||||||
|
|
@ -455,29 +433,6 @@ def create_app(args):
|
||||||
# Create combined auth dependency for all endpoints
|
# Create combined auth dependency for all endpoints
|
||||||
combined_auth = get_combined_auth_dependency(api_key)
|
combined_auth = get_combined_auth_dependency(api_key)
|
||||||
|
|
||||||
def get_workspace_from_request(request: Request) -> str:
|
|
||||||
"""
|
|
||||||
Extract workspace from HTTP request header or use default.
|
|
||||||
|
|
||||||
This enables multi-workspace API support by checking the custom
|
|
||||||
'LIGHTRAG-WORKSPACE' header. If not present, falls back to the
|
|
||||||
server's default workspace configuration.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request: FastAPI Request object
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Workspace identifier (may be empty string for global namespace)
|
|
||||||
"""
|
|
||||||
# Check custom header first
|
|
||||||
workspace = request.headers.get("LIGHTRAG-WORKSPACE", "").strip()
|
|
||||||
|
|
||||||
# Fall back to server default if header not provided
|
|
||||||
if not workspace:
|
|
||||||
workspace = args.workspace
|
|
||||||
|
|
||||||
return workspace
|
|
||||||
|
|
||||||
# Create working directory if it doesn't exist
|
# Create working directory if it doesn't exist
|
||||||
Path(args.working_dir).mkdir(parents=True, exist_ok=True)
|
Path(args.working_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
|
@ -556,9 +511,7 @@ def create_app(args):
|
||||||
|
|
||||||
return optimized_azure_openai_model_complete
|
return optimized_azure_openai_model_complete
|
||||||
|
|
||||||
def create_optimized_gemini_llm_func(
|
def create_optimized_gemini_llm_func(config_cache: LLMConfigCache, args):
|
||||||
config_cache: LLMConfigCache, args, llm_timeout: int
|
|
||||||
):
|
|
||||||
"""Create optimized Gemini LLM function with cached configuration"""
|
"""Create optimized Gemini LLM function with cached configuration"""
|
||||||
|
|
||||||
async def optimized_gemini_model_complete(
|
async def optimized_gemini_model_complete(
|
||||||
|
|
@ -573,8 +526,6 @@ def create_app(args):
|
||||||
if history_messages is None:
|
if history_messages is None:
|
||||||
history_messages = []
|
history_messages = []
|
||||||
|
|
||||||
# Use pre-processed configuration to avoid repeated parsing
|
|
||||||
kwargs["timeout"] = llm_timeout
|
|
||||||
if (
|
if (
|
||||||
config_cache.gemini_llm_options is not None
|
config_cache.gemini_llm_options is not None
|
||||||
and "generation_config" not in kwargs
|
and "generation_config" not in kwargs
|
||||||
|
|
@ -616,7 +567,7 @@ def create_app(args):
|
||||||
config_cache, args, llm_timeout
|
config_cache, args, llm_timeout
|
||||||
)
|
)
|
||||||
elif binding == "gemini":
|
elif binding == "gemini":
|
||||||
return create_optimized_gemini_llm_func(config_cache, args, llm_timeout)
|
return create_optimized_gemini_llm_func(config_cache, args)
|
||||||
else: # openai and compatible
|
else: # openai and compatible
|
||||||
# Use optimized function with pre-processed configuration
|
# Use optimized function with pre-processed configuration
|
||||||
return create_optimized_openai_llm_func(config_cache, args, llm_timeout)
|
return create_optimized_openai_llm_func(config_cache, args, llm_timeout)
|
||||||
|
|
@ -644,108 +595,33 @@ def create_app(args):
|
||||||
|
|
||||||
def create_optimized_embedding_function(
|
def create_optimized_embedding_function(
|
||||||
config_cache: LLMConfigCache, binding, model, host, api_key, args
|
config_cache: LLMConfigCache, binding, model, host, api_key, args
|
||||||
) -> EmbeddingFunc:
|
):
|
||||||
"""
|
"""
|
||||||
Create optimized embedding function and return an EmbeddingFunc instance
|
Create optimized embedding function with pre-processed configuration for applicable bindings.
|
||||||
with proper max_token_size inheritance from provider defaults.
|
Uses lazy imports for all bindings and avoids repeated configuration parsing.
|
||||||
|
|
||||||
This function:
|
|
||||||
1. Imports the provider embedding function
|
|
||||||
2. Extracts max_token_size and embedding_dim from provider if it's an EmbeddingFunc
|
|
||||||
3. Creates an optimized wrapper that calls the underlying function directly (avoiding double-wrapping)
|
|
||||||
4. Returns a properly configured EmbeddingFunc instance
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Step 1: Import provider function and extract default attributes
|
async def optimized_embedding_function(texts):
|
||||||
provider_func = None
|
|
||||||
provider_max_token_size = None
|
|
||||||
provider_embedding_dim = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
if binding == "openai":
|
|
||||||
from lightrag.llm.openai import openai_embed
|
|
||||||
|
|
||||||
provider_func = openai_embed
|
|
||||||
elif binding == "ollama":
|
|
||||||
from lightrag.llm.ollama import ollama_embed
|
|
||||||
|
|
||||||
provider_func = ollama_embed
|
|
||||||
elif binding == "gemini":
|
|
||||||
from lightrag.llm.gemini import gemini_embed
|
|
||||||
|
|
||||||
provider_func = gemini_embed
|
|
||||||
elif binding == "jina":
|
|
||||||
from lightrag.llm.jina import jina_embed
|
|
||||||
|
|
||||||
provider_func = jina_embed
|
|
||||||
elif binding == "azure_openai":
|
|
||||||
from lightrag.llm.azure_openai import azure_openai_embed
|
|
||||||
|
|
||||||
provider_func = azure_openai_embed
|
|
||||||
elif binding == "aws_bedrock":
|
|
||||||
from lightrag.llm.bedrock import bedrock_embed
|
|
||||||
|
|
||||||
provider_func = bedrock_embed
|
|
||||||
elif binding == "lollms":
|
|
||||||
from lightrag.llm.lollms import lollms_embed
|
|
||||||
|
|
||||||
provider_func = lollms_embed
|
|
||||||
|
|
||||||
# Extract attributes if provider is an EmbeddingFunc
|
|
||||||
if provider_func and isinstance(provider_func, EmbeddingFunc):
|
|
||||||
provider_max_token_size = provider_func.max_token_size
|
|
||||||
provider_embedding_dim = provider_func.embedding_dim
|
|
||||||
logger.debug(
|
|
||||||
f"Extracted from {binding} provider: "
|
|
||||||
f"max_token_size={provider_max_token_size}, "
|
|
||||||
f"embedding_dim={provider_embedding_dim}"
|
|
||||||
)
|
|
||||||
except ImportError as e:
|
|
||||||
logger.warning(f"Could not import provider function for {binding}: {e}")
|
|
||||||
|
|
||||||
# Step 2: Apply priority (user config > provider default)
|
|
||||||
# For max_token_size: explicit env var > provider default > None
|
|
||||||
final_max_token_size = args.embedding_token_limit or provider_max_token_size
|
|
||||||
# For embedding_dim: user config (always has value) takes priority
|
|
||||||
# Only use provider default if user config is explicitly None (which shouldn't happen)
|
|
||||||
final_embedding_dim = (
|
|
||||||
args.embedding_dim if args.embedding_dim else provider_embedding_dim
|
|
||||||
)
|
|
||||||
|
|
||||||
# Step 3: Create optimized embedding function (calls underlying function directly)
|
|
||||||
async def optimized_embedding_function(texts, embedding_dim=None):
|
|
||||||
try:
|
try:
|
||||||
if binding == "lollms":
|
if binding == "lollms":
|
||||||
from lightrag.llm.lollms import lollms_embed
|
from lightrag.llm.lollms import lollms_embed
|
||||||
|
|
||||||
# Get real function, skip EmbeddingFunc wrapper if present
|
return await lollms_embed(
|
||||||
actual_func = (
|
|
||||||
lollms_embed.func
|
|
||||||
if isinstance(lollms_embed, EmbeddingFunc)
|
|
||||||
else lollms_embed
|
|
||||||
)
|
|
||||||
return await actual_func(
|
|
||||||
texts, embed_model=model, host=host, api_key=api_key
|
texts, embed_model=model, host=host, api_key=api_key
|
||||||
)
|
)
|
||||||
elif binding == "ollama":
|
elif binding == "ollama":
|
||||||
from lightrag.llm.ollama import ollama_embed
|
from lightrag.llm.ollama import ollama_embed
|
||||||
|
|
||||||
# Get real function, skip EmbeddingFunc wrapper if present
|
# Use pre-processed configuration if available, otherwise fallback to dynamic parsing
|
||||||
actual_func = (
|
|
||||||
ollama_embed.func
|
|
||||||
if isinstance(ollama_embed, EmbeddingFunc)
|
|
||||||
else ollama_embed
|
|
||||||
)
|
|
||||||
|
|
||||||
# Use pre-processed configuration if available
|
|
||||||
if config_cache.ollama_embedding_options is not None:
|
if config_cache.ollama_embedding_options is not None:
|
||||||
ollama_options = config_cache.ollama_embedding_options
|
ollama_options = config_cache.ollama_embedding_options
|
||||||
else:
|
else:
|
||||||
|
# Fallback for cases where config cache wasn't initialized properly
|
||||||
from lightrag.llm.binding_options import OllamaEmbeddingOptions
|
from lightrag.llm.binding_options import OllamaEmbeddingOptions
|
||||||
|
|
||||||
ollama_options = OllamaEmbeddingOptions.options_dict(args)
|
ollama_options = OllamaEmbeddingOptions.options_dict(args)
|
||||||
|
|
||||||
return await actual_func(
|
return await ollama_embed(
|
||||||
texts,
|
texts,
|
||||||
embed_model=model,
|
embed_model=model,
|
||||||
host=host,
|
host=host,
|
||||||
|
|
@ -755,93 +631,27 @@ def create_app(args):
|
||||||
elif binding == "azure_openai":
|
elif binding == "azure_openai":
|
||||||
from lightrag.llm.azure_openai import azure_openai_embed
|
from lightrag.llm.azure_openai import azure_openai_embed
|
||||||
|
|
||||||
actual_func = (
|
return await azure_openai_embed(texts, model=model, api_key=api_key)
|
||||||
azure_openai_embed.func
|
|
||||||
if isinstance(azure_openai_embed, EmbeddingFunc)
|
|
||||||
else azure_openai_embed
|
|
||||||
)
|
|
||||||
return await actual_func(texts, model=model, api_key=api_key)
|
|
||||||
elif binding == "aws_bedrock":
|
elif binding == "aws_bedrock":
|
||||||
from lightrag.llm.bedrock import bedrock_embed
|
from lightrag.llm.bedrock import bedrock_embed
|
||||||
|
|
||||||
actual_func = (
|
return await bedrock_embed(texts, model=model)
|
||||||
bedrock_embed.func
|
|
||||||
if isinstance(bedrock_embed, EmbeddingFunc)
|
|
||||||
else bedrock_embed
|
|
||||||
)
|
|
||||||
return await actual_func(texts, model=model)
|
|
||||||
elif binding == "jina":
|
elif binding == "jina":
|
||||||
from lightrag.llm.jina import jina_embed
|
from lightrag.llm.jina import jina_embed
|
||||||
|
|
||||||
actual_func = (
|
return await jina_embed(
|
||||||
jina_embed.func
|
texts, base_url=host, api_key=api_key
|
||||||
if isinstance(jina_embed, EmbeddingFunc)
|
|
||||||
else jina_embed
|
|
||||||
)
|
|
||||||
return await actual_func(
|
|
||||||
texts,
|
|
||||||
embedding_dim=embedding_dim,
|
|
||||||
base_url=host,
|
|
||||||
api_key=api_key,
|
|
||||||
)
|
|
||||||
elif binding == "gemini":
|
|
||||||
from lightrag.llm.gemini import gemini_embed
|
|
||||||
|
|
||||||
actual_func = (
|
|
||||||
gemini_embed.func
|
|
||||||
if isinstance(gemini_embed, EmbeddingFunc)
|
|
||||||
else gemini_embed
|
|
||||||
)
|
|
||||||
|
|
||||||
# Use pre-processed configuration if available
|
|
||||||
if config_cache.gemini_embedding_options is not None:
|
|
||||||
gemini_options = config_cache.gemini_embedding_options
|
|
||||||
else:
|
|
||||||
from lightrag.llm.binding_options import GeminiEmbeddingOptions
|
|
||||||
|
|
||||||
gemini_options = GeminiEmbeddingOptions.options_dict(args)
|
|
||||||
|
|
||||||
return await actual_func(
|
|
||||||
texts,
|
|
||||||
model=model,
|
|
||||||
base_url=host,
|
|
||||||
api_key=api_key,
|
|
||||||
embedding_dim=embedding_dim,
|
|
||||||
task_type=gemini_options.get("task_type", "RETRIEVAL_DOCUMENT"),
|
|
||||||
)
|
)
|
||||||
else: # openai and compatible
|
else: # openai and compatible
|
||||||
from lightrag.llm.openai import openai_embed
|
from lightrag.llm.openai import openai_embed
|
||||||
|
|
||||||
actual_func = (
|
return await openai_embed(
|
||||||
openai_embed.func
|
texts, model=model, base_url=host, api_key=api_key
|
||||||
if isinstance(openai_embed, EmbeddingFunc)
|
|
||||||
else openai_embed
|
|
||||||
)
|
|
||||||
return await actual_func(
|
|
||||||
texts,
|
|
||||||
model=model,
|
|
||||||
base_url=host,
|
|
||||||
api_key=api_key,
|
|
||||||
embedding_dim=embedding_dim,
|
|
||||||
)
|
)
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
raise Exception(f"Failed to import {binding} embedding: {e}")
|
raise Exception(f"Failed to import {binding} embedding: {e}")
|
||||||
|
|
||||||
# Step 4: Wrap in EmbeddingFunc and return
|
return optimized_embedding_function
|
||||||
embedding_func_instance = EmbeddingFunc(
|
|
||||||
embedding_dim=final_embedding_dim,
|
|
||||||
func=optimized_embedding_function,
|
|
||||||
max_token_size=final_max_token_size,
|
|
||||||
send_dimensions=False, # Will be set later based on binding requirements
|
|
||||||
)
|
|
||||||
|
|
||||||
# Log final embedding configuration
|
|
||||||
logger.info(
|
|
||||||
f"Embedding config: binding={binding} model={model} "
|
|
||||||
f"embedding_dim={final_embedding_dim} max_token_size={final_max_token_size}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return embedding_func_instance
|
|
||||||
|
|
||||||
llm_timeout = get_env_value("LLM_TIMEOUT", DEFAULT_LLM_TIMEOUT, int)
|
llm_timeout = get_env_value("LLM_TIMEOUT", DEFAULT_LLM_TIMEOUT, int)
|
||||||
embedding_timeout = get_env_value(
|
embedding_timeout = get_env_value(
|
||||||
|
|
@ -875,62 +685,45 @@ def create_app(args):
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create embedding function with optimized configuration and max_token_size inheritance
|
# Create embedding function with optimized configuration
|
||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
# Create the EmbeddingFunc instance (now returns complete EmbeddingFunc with max_token_size)
|
# Create the optimized embedding function
|
||||||
embedding_func = create_optimized_embedding_function(
|
optimized_embedding_func = create_optimized_embedding_function(
|
||||||
config_cache=config_cache,
|
config_cache=config_cache,
|
||||||
binding=args.embedding_binding,
|
binding=args.embedding_binding,
|
||||||
model=args.embedding_model,
|
model=args.embedding_model,
|
||||||
host=args.embedding_binding_host,
|
host=args.embedding_binding_host,
|
||||||
api_key=args.embedding_binding_api_key,
|
api_key=args.embedding_binding_api_key,
|
||||||
args=args,
|
args=args, # Pass args object for fallback option generation
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get embedding_send_dim from centralized configuration
|
# Check environment variable for sending dimensions
|
||||||
embedding_send_dim = args.embedding_send_dim
|
embedding_send_dim = os.getenv("EMBEDDING_SEND_DIM", "false").lower() == "true"
|
||||||
|
|
||||||
# Check if the underlying function signature has embedding_dim parameter
|
# Check if the function signature has embedding_dim parameter
|
||||||
sig = inspect.signature(embedding_func.func)
|
# Note: Since optimized_embedding_func is an async function, inspect its signature
|
||||||
has_embedding_dim_param = "embedding_dim" in sig.parameters
|
sig = inspect.signature(optimized_embedding_func)
|
||||||
|
has_embedding_dim_param = 'embedding_dim' in sig.parameters
|
||||||
# Determine send_dimensions value based on binding type
|
|
||||||
# Jina and Gemini REQUIRE dimension parameter (forced to True)
|
# Determine send_dimensions value
|
||||||
# OpenAI and others: controlled by EMBEDDING_SEND_DIM environment variable
|
# Only send dimensions if both conditions are met:
|
||||||
if args.embedding_binding in ["jina", "gemini"]:
|
# 1. EMBEDDING_SEND_DIM environment variable is true
|
||||||
# Jina and Gemini APIs require dimension parameter - always send it
|
# 2. The function has embedding_dim parameter
|
||||||
send_dimensions = has_embedding_dim_param
|
send_dimensions = embedding_send_dim and has_embedding_dim_param
|
||||||
dimension_control = f"forced by {args.embedding_binding.title()} API"
|
|
||||||
else:
|
|
||||||
# For OpenAI and other bindings, respect EMBEDDING_SEND_DIM setting
|
|
||||||
send_dimensions = embedding_send_dim and has_embedding_dim_param
|
|
||||||
if send_dimensions or not embedding_send_dim:
|
|
||||||
dimension_control = "by env var"
|
|
||||||
else:
|
|
||||||
dimension_control = "by not hasparam"
|
|
||||||
|
|
||||||
# Set send_dimensions on the EmbeddingFunc instance
|
|
||||||
embedding_func.send_dimensions = send_dimensions
|
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Send embedding dimension: {send_dimensions} {dimension_control} "
|
f"Embedding configuration: send_dimensions={send_dimensions} "
|
||||||
f"(dimensions={embedding_func.embedding_dim}, has_param={has_embedding_dim_param}, "
|
f"(env_var={embedding_send_dim}, has_param={has_embedding_dim_param}, "
|
||||||
f"binding={args.embedding_binding})"
|
f"binding={args.embedding_binding})"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Log max_token_size source
|
# Create EmbeddingFunc with send_dimensions attribute
|
||||||
if embedding_func.max_token_size:
|
embedding_func = EmbeddingFunc(
|
||||||
source = (
|
embedding_dim=args.embedding_dim,
|
||||||
"env variable"
|
func=optimized_embedding_func,
|
||||||
if args.embedding_token_limit
|
send_dimensions=send_dimensions,
|
||||||
else f"{args.embedding_binding} provider default"
|
)
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
f"Embedding max_token_size: {embedding_func.max_token_size} (from {source})"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.info("Embedding max_token_size: not set (90% token warning disabled)")
|
|
||||||
|
|
||||||
# Configure rerank function based on args.rerank_bindingparameter
|
# Configure rerank function based on args.rerank_bindingparameter
|
||||||
rerank_model_func = None
|
rerank_model_func = None
|
||||||
|
|
@ -970,27 +763,15 @@ def create_app(args):
|
||||||
query: str, documents: list, top_n: int = None, extra_body: dict = None
|
query: str, documents: list, top_n: int = None, extra_body: dict = None
|
||||||
):
|
):
|
||||||
"""Server rerank function with configuration from environment variables"""
|
"""Server rerank function with configuration from environment variables"""
|
||||||
# Prepare kwargs for rerank function
|
return await selected_rerank_func(
|
||||||
kwargs = {
|
query=query,
|
||||||
"query": query,
|
documents=documents,
|
||||||
"documents": documents,
|
top_n=top_n,
|
||||||
"top_n": top_n,
|
api_key=args.rerank_binding_api_key,
|
||||||
"api_key": args.rerank_binding_api_key,
|
model=args.rerank_model,
|
||||||
"model": args.rerank_model,
|
base_url=args.rerank_binding_host,
|
||||||
"base_url": args.rerank_binding_host,
|
extra_body=extra_body,
|
||||||
}
|
)
|
||||||
|
|
||||||
# Add Cohere-specific parameters if using cohere binding
|
|
||||||
if args.rerank_binding == "cohere":
|
|
||||||
# Enable chunking if configured (useful for models with token limits like ColBERT)
|
|
||||||
kwargs["enable_chunking"] = (
|
|
||||||
os.getenv("RERANK_ENABLE_CHUNKING", "false").lower() == "true"
|
|
||||||
)
|
|
||||||
kwargs["max_tokens_per_doc"] = int(
|
|
||||||
os.getenv("RERANK_MAX_TOKENS_PER_DOC", "4096")
|
|
||||||
)
|
|
||||||
|
|
||||||
return await selected_rerank_func(**kwargs, extra_body=extra_body)
|
|
||||||
|
|
||||||
rerank_model_func = server_rerank_func
|
rerank_model_func = server_rerank_func
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|
@ -1151,10 +932,9 @@ def create_app(args):
|
||||||
}
|
}
|
||||||
|
|
||||||
@app.get("/health", dependencies=[Depends(combined_auth)])
|
@app.get("/health", dependencies=[Depends(combined_auth)])
|
||||||
async def get_status(request: Request):
|
async def get_status():
|
||||||
"""Get current system status"""
|
"""Get current system status"""
|
||||||
try:
|
try:
|
||||||
default_workspace = get_default_workspace()
|
|
||||||
pipeline_status = await get_namespace_data("pipeline_status")
|
pipeline_status = await get_namespace_data("pipeline_status")
|
||||||
|
|
||||||
if not auth_configured:
|
if not auth_configured:
|
||||||
|
|
@ -1186,7 +966,7 @@ def create_app(args):
|
||||||
"vector_storage": args.vector_storage,
|
"vector_storage": args.vector_storage,
|
||||||
"enable_llm_cache_for_extract": args.enable_llm_cache_for_extract,
|
"enable_llm_cache_for_extract": args.enable_llm_cache_for_extract,
|
||||||
"enable_llm_cache": args.enable_llm_cache,
|
"enable_llm_cache": args.enable_llm_cache,
|
||||||
"workspace": default_workspace,
|
"workspace": args.workspace,
|
||||||
"max_graph_nodes": args.max_graph_nodes,
|
"max_graph_nodes": args.max_graph_nodes,
|
||||||
# Rerank configuration
|
# Rerank configuration
|
||||||
"enable_rerank": rerank_model_func is not None,
|
"enable_rerank": rerank_model_func is not None,
|
||||||
|
|
@ -1375,12 +1155,6 @@ def check_and_install_dependencies():
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# Explicitly initialize configuration for clarity
|
|
||||||
# (The proxy will auto-initialize anyway, but this makes intent clear)
|
|
||||||
from .config import initialize_config
|
|
||||||
|
|
||||||
initialize_config()
|
|
||||||
|
|
||||||
# Check if running under Gunicorn
|
# Check if running under Gunicorn
|
||||||
if "GUNICORN_CMD_ARGS" in os.environ:
|
if "GUNICORN_CMD_ARGS" in os.environ:
|
||||||
# If started with Gunicorn, return directly as Gunicorn will call get_application
|
# If started with Gunicorn, return directly as Gunicorn will call get_application
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,4 @@
|
||||||
import os
|
import os
|
||||||
from typing import Final
|
|
||||||
|
|
||||||
import pipmaster as pm # Pipmaster for dynamic library install
|
import pipmaster as pm # Pipmaster for dynamic library install
|
||||||
|
|
||||||
# install specific modules
|
# install specific modules
|
||||||
|
|
@ -21,9 +19,6 @@ from tenacity import (
|
||||||
from lightrag.utils import wrap_embedding_func_with_attrs, logger
|
from lightrag.utils import wrap_embedding_func_with_attrs, logger
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_JINA_EMBED_DIM: Final[int] = 2048
|
|
||||||
|
|
||||||
|
|
||||||
async def fetch_data(url, headers, data):
|
async def fetch_data(url, headers, data):
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
async with session.post(url, headers=headers, json=data) as response:
|
async with session.post(url, headers=headers, json=data) as response:
|
||||||
|
|
@ -63,7 +58,7 @@ async def fetch_data(url, headers, data):
|
||||||
return data_list
|
return data_list
|
||||||
|
|
||||||
|
|
||||||
@wrap_embedding_func_with_attrs(embedding_dim=DEFAULT_JINA_EMBED_DIM)
|
@wrap_embedding_func_with_attrs(embedding_dim=2048)
|
||||||
@retry(
|
@retry(
|
||||||
stop=stop_after_attempt(3),
|
stop=stop_after_attempt(3),
|
||||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||||
|
|
@ -74,7 +69,7 @@ async def fetch_data(url, headers, data):
|
||||||
)
|
)
|
||||||
async def jina_embed(
|
async def jina_embed(
|
||||||
texts: list[str],
|
texts: list[str],
|
||||||
embedding_dim: int | None = DEFAULT_JINA_EMBED_DIM,
|
embedding_dim: int = 2048,
|
||||||
late_chunking: bool = False,
|
late_chunking: bool = False,
|
||||||
base_url: str = None,
|
base_url: str = None,
|
||||||
api_key: str = None,
|
api_key: str = None,
|
||||||
|
|
@ -100,10 +95,6 @@ async def jina_embed(
|
||||||
aiohttp.ClientError: If there is a connection error with the Jina API.
|
aiohttp.ClientError: If there is a connection error with the Jina API.
|
||||||
aiohttp.ClientResponseError: If the Jina API returns an error response.
|
aiohttp.ClientResponseError: If the Jina API returns an error response.
|
||||||
"""
|
"""
|
||||||
resolved_embedding_dim = (
|
|
||||||
embedding_dim if embedding_dim is not None else DEFAULT_JINA_EMBED_DIM
|
|
||||||
)
|
|
||||||
|
|
||||||
if api_key:
|
if api_key:
|
||||||
os.environ["JINA_API_KEY"] = api_key
|
os.environ["JINA_API_KEY"] = api_key
|
||||||
|
|
||||||
|
|
@ -118,7 +109,7 @@ async def jina_embed(
|
||||||
data = {
|
data = {
|
||||||
"model": "jina-embeddings-v4",
|
"model": "jina-embeddings-v4",
|
||||||
"task": "text-matching",
|
"task": "text-matching",
|
||||||
"dimensions": resolved_embedding_dim,
|
"dimensions": embedding_dim,
|
||||||
"embedding_type": "base64",
|
"embedding_type": "base64",
|
||||||
"input": texts,
|
"input": texts,
|
||||||
}
|
}
|
||||||
|
|
@ -128,7 +119,7 @@ async def jina_embed(
|
||||||
data["late_chunking"] = late_chunking
|
data["late_chunking"] = late_chunking
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Jina embedding request: {len(texts)} texts, dimensions: {resolved_embedding_dim}"
|
f"Jina embedding request: {len(texts)} texts, dimensions: {embedding_dim}"
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,6 @@ if not pm.is_installed("openai"):
|
||||||
pm.install("openai")
|
pm.install("openai")
|
||||||
|
|
||||||
from openai import (
|
from openai import (
|
||||||
AsyncOpenAI,
|
|
||||||
APIConnectionError,
|
APIConnectionError,
|
||||||
RateLimitError,
|
RateLimitError,
|
||||||
APITimeoutError,
|
APITimeoutError,
|
||||||
|
|
@ -27,6 +26,7 @@ from lightrag.utils import (
|
||||||
safe_unicode_decode,
|
safe_unicode_decode,
|
||||||
logger,
|
logger,
|
||||||
)
|
)
|
||||||
|
|
||||||
from lightrag.types import GPTKeywordExtractionFormat
|
from lightrag.types import GPTKeywordExtractionFormat
|
||||||
from lightrag.api import __api_version__
|
from lightrag.api import __api_version__
|
||||||
|
|
||||||
|
|
@ -36,6 +36,32 @@ from typing import Any, Union
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
# Try to import Langfuse for LLM observability (optional)
|
||||||
|
# Falls back to standard OpenAI client if not available
|
||||||
|
# Langfuse requires proper configuration to work correctly
|
||||||
|
LANGFUSE_ENABLED = False
|
||||||
|
try:
|
||||||
|
# Check if required Langfuse environment variables are set
|
||||||
|
langfuse_public_key = os.environ.get("LANGFUSE_PUBLIC_KEY")
|
||||||
|
langfuse_secret_key = os.environ.get("LANGFUSE_SECRET_KEY")
|
||||||
|
|
||||||
|
# Only enable Langfuse if both keys are configured
|
||||||
|
if langfuse_public_key and langfuse_secret_key:
|
||||||
|
from langfuse.openai import AsyncOpenAI
|
||||||
|
|
||||||
|
LANGFUSE_ENABLED = True
|
||||||
|
logger.info("Langfuse observability enabled for OpenAI client")
|
||||||
|
else:
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"Langfuse environment variables not configured, using standard OpenAI client"
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
|
logger.debug("Langfuse not available, using standard OpenAI client")
|
||||||
|
|
||||||
# use the .env that is inside the current folder
|
# use the .env that is inside the current folder
|
||||||
# allows to use different .env file for each lightrag instance
|
# allows to use different .env file for each lightrag instance
|
||||||
# the OS environment variables take precedence over the .env file
|
# the OS environment variables take precedence over the .env file
|
||||||
|
|
@ -370,18 +396,23 @@ async def openai_complete_if_cache(
|
||||||
)
|
)
|
||||||
|
|
||||||
# Ensure resources are released even if no exception occurs
|
# Ensure resources are released even if no exception occurs
|
||||||
if (
|
# Note: Some wrapped clients (e.g., Langfuse) may not implement aclose() properly
|
||||||
iteration_started
|
if iteration_started and hasattr(response, "aclose"):
|
||||||
and hasattr(response, "aclose")
|
aclose_method = getattr(response, "aclose", None)
|
||||||
and callable(getattr(response, "aclose", None))
|
if callable(aclose_method):
|
||||||
):
|
try:
|
||||||
try:
|
await response.aclose()
|
||||||
await response.aclose()
|
logger.debug("Successfully closed stream response")
|
||||||
logger.debug("Successfully closed stream response")
|
except (AttributeError, TypeError) as close_error:
|
||||||
except Exception as close_error:
|
# Some wrapper objects may report hasattr(aclose) but fail when called
|
||||||
logger.warning(
|
# This is expected behavior for certain client wrappers
|
||||||
f"Failed to close stream response in finally block: {close_error}"
|
logger.debug(
|
||||||
)
|
f"Stream response cleanup not supported by client wrapper: {close_error}"
|
||||||
|
)
|
||||||
|
except Exception as close_error:
|
||||||
|
logger.warning(
|
||||||
|
f"Unexpected error during stream response cleanup: {close_error}"
|
||||||
|
)
|
||||||
|
|
||||||
# This prevents resource leaks since the caller doesn't handle closing
|
# This prevents resource leaks since the caller doesn't handle closing
|
||||||
try:
|
try:
|
||||||
|
|
@ -578,6 +609,7 @@ async def openai_embed(
|
||||||
model: str = "text-embedding-3-small",
|
model: str = "text-embedding-3-small",
|
||||||
base_url: str | None = None,
|
base_url: str | None = None,
|
||||||
api_key: str | None = None,
|
api_key: str | None = None,
|
||||||
|
embedding_dim: int | None = None,
|
||||||
client_configs: dict[str, Any] | None = None,
|
client_configs: dict[str, Any] | None = None,
|
||||||
token_tracker: Any | None = None,
|
token_tracker: Any | None = None,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
|
|
@ -588,6 +620,12 @@ async def openai_embed(
|
||||||
model: The OpenAI embedding model to use.
|
model: The OpenAI embedding model to use.
|
||||||
base_url: Optional base URL for the OpenAI API.
|
base_url: Optional base URL for the OpenAI API.
|
||||||
api_key: Optional OpenAI API key. If None, uses the OPENAI_API_KEY environment variable.
|
api_key: Optional OpenAI API key. If None, uses the OPENAI_API_KEY environment variable.
|
||||||
|
embedding_dim: Optional embedding dimension for dynamic dimension reduction.
|
||||||
|
**IMPORTANT**: This parameter is automatically injected by the EmbeddingFunc wrapper.
|
||||||
|
Do NOT manually pass this parameter when calling the function directly.
|
||||||
|
The dimension is controlled by the @wrap_embedding_func_with_attrs decorator.
|
||||||
|
Manually passing a different value will trigger a warning and be ignored.
|
||||||
|
When provided (by EmbeddingFunc), it will be passed to the OpenAI API for dimension reduction.
|
||||||
client_configs: Additional configuration options for the AsyncOpenAI client.
|
client_configs: Additional configuration options for the AsyncOpenAI client.
|
||||||
These will override any default configurations but will be overridden by
|
These will override any default configurations but will be overridden by
|
||||||
explicit parameters (api_key, base_url).
|
explicit parameters (api_key, base_url).
|
||||||
|
|
@ -607,17 +645,27 @@ async def openai_embed(
|
||||||
)
|
)
|
||||||
|
|
||||||
async with openai_async_client:
|
async with openai_async_client:
|
||||||
response = await openai_async_client.embeddings.create(
|
# Prepare API call parameters
|
||||||
model=model, input=texts, encoding_format="base64"
|
api_params = {
|
||||||
)
|
"model": model,
|
||||||
|
"input": texts,
|
||||||
|
"encoding_format": "base64",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add dimensions parameter only if embedding_dim is provided
|
||||||
|
if embedding_dim is not None:
|
||||||
|
api_params["dimensions"] = embedding_dim
|
||||||
|
|
||||||
|
# Make API call
|
||||||
|
response = await openai_async_client.embeddings.create(**api_params)
|
||||||
|
|
||||||
if token_tracker and hasattr(response, "usage"):
|
if token_tracker and hasattr(response, "usage"):
|
||||||
token_counts = {
|
token_counts = {
|
||||||
"prompt_tokens": getattr(response.usage, "prompt_tokens", 0),
|
"prompt_tokens": getattr(response.usage, "prompt_tokens", 0),
|
||||||
"total_tokens": getattr(response.usage, "total_tokens", 0),
|
"total_tokens": getattr(response.usage, "total_tokens", 0),
|
||||||
}
|
}
|
||||||
token_tracker.add_usage(token_counts)
|
token_tracker.add_usage(token_counts)
|
||||||
|
|
||||||
return np.array(
|
return np.array(
|
||||||
[
|
[
|
||||||
np.array(dp.embedding, dtype=np.float32)
|
np.array(dp.embedding, dtype=np.float32)
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,6 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
import weakref
|
import weakref
|
||||||
|
|
||||||
import sys
|
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import html
|
import html
|
||||||
import csv
|
import csv
|
||||||
|
|
@ -42,35 +40,6 @@ from lightrag.constants import (
|
||||||
SOURCE_IDS_LIMIT_METHOD_FIFO,
|
SOURCE_IDS_LIMIT_METHOD_FIFO,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Precompile regex pattern for JSON sanitization (module-level, compiled once)
|
|
||||||
_SURROGATE_PATTERN = re.compile(r"[\uD800-\uDFFF\uFFFE\uFFFF]")
|
|
||||||
|
|
||||||
|
|
||||||
class SafeStreamHandler(logging.StreamHandler):
|
|
||||||
"""StreamHandler that gracefully handles closed streams during shutdown.
|
|
||||||
|
|
||||||
This handler prevents "ValueError: I/O operation on closed file" errors
|
|
||||||
that can occur when pytest or other test frameworks close stdout/stderr
|
|
||||||
before Python's logging cleanup runs.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def flush(self):
|
|
||||||
"""Flush the stream, ignoring errors if the stream is closed."""
|
|
||||||
try:
|
|
||||||
super().flush()
|
|
||||||
except (ValueError, OSError):
|
|
||||||
# Stream is closed or otherwise unavailable, silently ignore
|
|
||||||
pass
|
|
||||||
|
|
||||||
def close(self):
|
|
||||||
"""Close the handler, ignoring errors if the stream is already closed."""
|
|
||||||
try:
|
|
||||||
super().close()
|
|
||||||
except (ValueError, OSError):
|
|
||||||
# Stream is closed or otherwise unavailable, silently ignore
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
# Initialize logger with basic configuration
|
# Initialize logger with basic configuration
|
||||||
logger = logging.getLogger("lightrag")
|
logger = logging.getLogger("lightrag")
|
||||||
logger.propagate = False # prevent log message send to root logger
|
logger.propagate = False # prevent log message send to root logger
|
||||||
|
|
@ -78,7 +47,7 @@ logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
# Add console handler if no handlers exist
|
# Add console handler if no handlers exist
|
||||||
if not logger.handlers:
|
if not logger.handlers:
|
||||||
console_handler = SafeStreamHandler()
|
console_handler = logging.StreamHandler()
|
||||||
console_handler.setLevel(logging.INFO)
|
console_handler.setLevel(logging.INFO)
|
||||||
formatter = logging.Formatter("%(levelname)s: %(message)s")
|
formatter = logging.Formatter("%(levelname)s: %(message)s")
|
||||||
console_handler.setFormatter(formatter)
|
console_handler.setFormatter(formatter)
|
||||||
|
|
@ -87,33 +56,6 @@ if not logger.handlers:
|
||||||
# Set httpx logging level to WARNING
|
# Set httpx logging level to WARNING
|
||||||
logging.getLogger("httpx").setLevel(logging.WARNING)
|
logging.getLogger("httpx").setLevel(logging.WARNING)
|
||||||
|
|
||||||
|
|
||||||
def _patch_ascii_colors_console_handler() -> None:
|
|
||||||
"""Prevent ascii_colors from printing flush errors during interpreter exit."""
|
|
||||||
|
|
||||||
try:
|
|
||||||
from ascii_colors import ConsoleHandler
|
|
||||||
except ImportError:
|
|
||||||
return
|
|
||||||
|
|
||||||
if getattr(ConsoleHandler, "_lightrag_patched", False):
|
|
||||||
return
|
|
||||||
|
|
||||||
original_handle_error = ConsoleHandler.handle_error
|
|
||||||
|
|
||||||
def _safe_handle_error(self, message: str) -> None: # type: ignore[override]
|
|
||||||
exc_type, _, _ = sys.exc_info()
|
|
||||||
if exc_type in (ValueError, OSError) and "close" in message.lower():
|
|
||||||
return
|
|
||||||
original_handle_error(self, message)
|
|
||||||
|
|
||||||
ConsoleHandler.handle_error = _safe_handle_error # type: ignore[assignment]
|
|
||||||
ConsoleHandler._lightrag_patched = True # type: ignore[attr-defined]
|
|
||||||
|
|
||||||
|
|
||||||
_patch_ascii_colors_console_handler()
|
|
||||||
|
|
||||||
|
|
||||||
# Global import for pypinyin with startup-time logging
|
# Global import for pypinyin with startup-time logging
|
||||||
try:
|
try:
|
||||||
import pypinyin
|
import pypinyin
|
||||||
|
|
@ -341,8 +283,8 @@ def setup_logger(
|
||||||
logger_instance.handlers = [] # Clear existing handlers
|
logger_instance.handlers = [] # Clear existing handlers
|
||||||
logger_instance.propagate = False
|
logger_instance.propagate = False
|
||||||
|
|
||||||
# Add console handler with safe stream handling
|
# Add console handler
|
||||||
console_handler = SafeStreamHandler()
|
console_handler = logging.StreamHandler()
|
||||||
console_handler.setFormatter(simple_formatter)
|
console_handler.setFormatter(simple_formatter)
|
||||||
console_handler.setLevel(level)
|
console_handler.setLevel(level)
|
||||||
logger_instance.addHandler(console_handler)
|
logger_instance.addHandler(console_handler)
|
||||||
|
|
@ -408,69 +350,28 @@ class TaskState:
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class EmbeddingFunc:
|
class EmbeddingFunc:
|
||||||
"""Embedding function wrapper with dimension validation
|
|
||||||
This class wraps an embedding function to ensure that the output embeddings have the correct dimension.
|
|
||||||
This class should not be wrapped multiple times.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
embedding_dim: Expected dimension of the embeddings
|
|
||||||
func: The actual embedding function to wrap
|
|
||||||
max_token_size: Optional token limit for the embedding model
|
|
||||||
send_dimensions: Whether to inject embedding_dim as a keyword argument
|
|
||||||
"""
|
|
||||||
|
|
||||||
embedding_dim: int
|
embedding_dim: int
|
||||||
func: callable
|
func: callable
|
||||||
max_token_size: int | None = None # Token limit for the embedding model
|
max_token_size: int | None = None # deprecated keep it for compatible only
|
||||||
send_dimensions: bool = (
|
send_dimensions: bool = False # Control whether to send embedding_dim to the function
|
||||||
False # Control whether to send embedding_dim to the function
|
|
||||||
)
|
|
||||||
|
|
||||||
async def __call__(self, *args, **kwargs) -> np.ndarray:
|
async def __call__(self, *args, **kwargs) -> np.ndarray:
|
||||||
# Only inject embedding_dim when send_dimensions is True
|
# Only inject embedding_dim when send_dimensions is True
|
||||||
if self.send_dimensions:
|
if self.send_dimensions:
|
||||||
# Check if user provided embedding_dim parameter
|
# Check if user provided embedding_dim parameter
|
||||||
if "embedding_dim" in kwargs:
|
if 'embedding_dim' in kwargs:
|
||||||
user_provided_dim = kwargs["embedding_dim"]
|
user_provided_dim = kwargs['embedding_dim']
|
||||||
# If user's value differs from class attribute, output warning
|
# If user's value differs from class attribute, output warning
|
||||||
if (
|
if user_provided_dim is not None and user_provided_dim != self.embedding_dim:
|
||||||
user_provided_dim is not None
|
|
||||||
and user_provided_dim != self.embedding_dim
|
|
||||||
):
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Ignoring user-provided embedding_dim={user_provided_dim}, "
|
f"Ignoring user-provided embedding_dim={user_provided_dim}, "
|
||||||
f"using declared embedding_dim={self.embedding_dim} from decorator"
|
f"using declared embedding_dim={self.embedding_dim} from decorator"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Inject embedding_dim from decorator
|
# Inject embedding_dim from decorator
|
||||||
kwargs["embedding_dim"] = self.embedding_dim
|
kwargs['embedding_dim'] = self.embedding_dim
|
||||||
|
|
||||||
# Call the actual embedding function
|
return await self.func(*args, **kwargs)
|
||||||
result = await self.func(*args, **kwargs)
|
|
||||||
|
|
||||||
# Validate embedding dimensions using total element count
|
|
||||||
total_elements = result.size # Total number of elements in the numpy array
|
|
||||||
expected_dim = self.embedding_dim
|
|
||||||
|
|
||||||
# Check if total elements can be evenly divided by embedding_dim
|
|
||||||
if total_elements % expected_dim != 0:
|
|
||||||
raise ValueError(
|
|
||||||
f"Embedding dimension mismatch detected: "
|
|
||||||
f"total elements ({total_elements}) cannot be evenly divided by "
|
|
||||||
f"expected dimension ({expected_dim}). "
|
|
||||||
)
|
|
||||||
|
|
||||||
# Optional: Verify vector count matches input text count
|
|
||||||
actual_vectors = total_elements // expected_dim
|
|
||||||
if args and isinstance(args[0], (list, tuple)):
|
|
||||||
expected_vectors = len(args[0])
|
|
||||||
if actual_vectors != expected_vectors:
|
|
||||||
raise ValueError(
|
|
||||||
f"Vector count mismatch: "
|
|
||||||
f"expected {expected_vectors} vectors but got {actual_vectors} vectors (from embedding result)."
|
|
||||||
)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def compute_args_hash(*args: Any) -> str:
|
def compute_args_hash(*args: Any) -> str:
|
||||||
|
|
@ -1005,76 +906,7 @@ def priority_limit_async_func_call(
|
||||||
|
|
||||||
|
|
||||||
def wrap_embedding_func_with_attrs(**kwargs):
|
def wrap_embedding_func_with_attrs(**kwargs):
|
||||||
"""Decorator to add embedding dimension and token limit attributes to embedding functions.
|
"""Wrap a function with attributes"""
|
||||||
|
|
||||||
This decorator wraps an async embedding function and returns an EmbeddingFunc instance
|
|
||||||
that automatically handles dimension parameter injection and attribute management.
|
|
||||||
|
|
||||||
WARNING: DO NOT apply this decorator to wrapper functions that call other
|
|
||||||
decorated embedding functions. This will cause double decoration and parameter
|
|
||||||
injection conflicts.
|
|
||||||
|
|
||||||
Correct usage patterns:
|
|
||||||
|
|
||||||
1. Direct implementation (decorated):
|
|
||||||
```python
|
|
||||||
@wrap_embedding_func_with_attrs(embedding_dim=1536)
|
|
||||||
async def my_embed(texts, embedding_dim=None):
|
|
||||||
# Direct implementation
|
|
||||||
return embeddings
|
|
||||||
```
|
|
||||||
|
|
||||||
2. Wrapper calling decorated function (DO NOT decorate wrapper):
|
|
||||||
```python
|
|
||||||
# my_embed is already decorated above
|
|
||||||
|
|
||||||
async def my_wrapper(texts, **kwargs): # ❌ DO NOT decorate this!
|
|
||||||
# Must call .func to access unwrapped implementation
|
|
||||||
return await my_embed.func(texts, **kwargs)
|
|
||||||
```
|
|
||||||
|
|
||||||
3. Wrapper calling decorated function (properly decorated):
|
|
||||||
```python
|
|
||||||
@wrap_embedding_func_with_attrs(embedding_dim=1536)
|
|
||||||
async def my_wrapper(texts, **kwargs): # ✅ Can decorate if calling .func
|
|
||||||
# Calling .func avoids double decoration
|
|
||||||
return await my_embed.func(texts, **kwargs)
|
|
||||||
```
|
|
||||||
|
|
||||||
The decorated function becomes an EmbeddingFunc instance with:
|
|
||||||
- embedding_dim: The embedding dimension
|
|
||||||
- max_token_size: Maximum token limit (optional)
|
|
||||||
- func: The original unwrapped function (access via .func)
|
|
||||||
- __call__: Wrapper that injects embedding_dim parameter
|
|
||||||
|
|
||||||
Double decoration causes:
|
|
||||||
- Double injection of embedding_dim parameter
|
|
||||||
- Incorrect parameter passing to the underlying implementation
|
|
||||||
- Runtime errors due to parameter conflicts
|
|
||||||
|
|
||||||
Args:
|
|
||||||
embedding_dim: The dimension of embedding vectors
|
|
||||||
max_token_size: Maximum number of tokens (optional)
|
|
||||||
send_dimensions: Whether to inject embedding_dim as a keyword argument (optional)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A decorator that wraps the function as an EmbeddingFunc instance
|
|
||||||
|
|
||||||
Example of correct wrapper implementation:
|
|
||||||
```python
|
|
||||||
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
|
|
||||||
@retry(...)
|
|
||||||
async def openai_embed(texts, ...):
|
|
||||||
# Base implementation
|
|
||||||
pass
|
|
||||||
|
|
||||||
@wrap_embedding_func_with_attrs(embedding_dim=1536) # Note: No @retry here!
|
|
||||||
async def azure_openai_embed(texts, ...):
|
|
||||||
# CRITICAL: Call .func to access unwrapped function
|
|
||||||
return await openai_embed.func(texts, ...) # ✅ Correct
|
|
||||||
# return await openai_embed(texts, ...) # ❌ Wrong - double decoration!
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
|
|
||||||
def final_decro(func) -> EmbeddingFunc:
|
def final_decro(func) -> EmbeddingFunc:
|
||||||
new_func = EmbeddingFunc(**kwargs, func=func)
|
new_func = EmbeddingFunc(**kwargs, func=func)
|
||||||
|
|
@ -1090,123 +922,9 @@ def load_json(file_name):
|
||||||
return json.load(f)
|
return json.load(f)
|
||||||
|
|
||||||
|
|
||||||
def _sanitize_string_for_json(text: str) -> str:
|
|
||||||
"""Remove characters that cannot be encoded in UTF-8 for JSON serialization.
|
|
||||||
|
|
||||||
Uses regex for optimal performance with zero-copy optimization for clean strings.
|
|
||||||
Fast detection path for clean strings (99% of cases) with efficient removal for dirty strings.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text: String to sanitize
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Original string if clean (zero-copy), sanitized string if dirty
|
|
||||||
"""
|
|
||||||
if not text:
|
|
||||||
return text
|
|
||||||
|
|
||||||
# Fast path: Check if sanitization is needed using C-level regex search
|
|
||||||
if not _SURROGATE_PATTERN.search(text):
|
|
||||||
return text # Zero-copy for clean strings - most common case
|
|
||||||
|
|
||||||
# Slow path: Remove problematic characters using C-level regex substitution
|
|
||||||
return _SURROGATE_PATTERN.sub("", text)
|
|
||||||
|
|
||||||
|
|
||||||
class SanitizingJSONEncoder(json.JSONEncoder):
|
|
||||||
"""
|
|
||||||
Custom JSON encoder that sanitizes data during serialization.
|
|
||||||
|
|
||||||
This encoder cleans strings during the encoding process without creating
|
|
||||||
a full copy of the data structure, making it memory-efficient for large datasets.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def encode(self, o):
|
|
||||||
"""Override encode method to handle simple string cases"""
|
|
||||||
if isinstance(o, str):
|
|
||||||
return json.encoder.encode_basestring(_sanitize_string_for_json(o))
|
|
||||||
return super().encode(o)
|
|
||||||
|
|
||||||
def iterencode(self, o, _one_shot=False):
|
|
||||||
"""
|
|
||||||
Override iterencode to sanitize strings during serialization.
|
|
||||||
This is the core method that handles complex nested structures.
|
|
||||||
"""
|
|
||||||
# Preprocess: sanitize all strings in the object
|
|
||||||
sanitized = self._sanitize_for_encoding(o)
|
|
||||||
|
|
||||||
# Call parent's iterencode with sanitized data
|
|
||||||
for chunk in super().iterencode(sanitized, _one_shot):
|
|
||||||
yield chunk
|
|
||||||
|
|
||||||
def _sanitize_for_encoding(self, obj):
|
|
||||||
"""
|
|
||||||
Recursively sanitize strings in an object.
|
|
||||||
Creates new objects only when necessary to avoid deep copies.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
obj: Object to sanitize
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Sanitized object with cleaned strings
|
|
||||||
"""
|
|
||||||
if isinstance(obj, str):
|
|
||||||
return _sanitize_string_for_json(obj)
|
|
||||||
|
|
||||||
elif isinstance(obj, dict):
|
|
||||||
# Create new dict with sanitized keys and values
|
|
||||||
new_dict = {}
|
|
||||||
for k, v in obj.items():
|
|
||||||
clean_k = _sanitize_string_for_json(k) if isinstance(k, str) else k
|
|
||||||
clean_v = self._sanitize_for_encoding(v)
|
|
||||||
new_dict[clean_k] = clean_v
|
|
||||||
return new_dict
|
|
||||||
|
|
||||||
elif isinstance(obj, (list, tuple)):
|
|
||||||
# Sanitize list/tuple elements
|
|
||||||
cleaned = [self._sanitize_for_encoding(item) for item in obj]
|
|
||||||
return type(obj)(cleaned) if isinstance(obj, tuple) else cleaned
|
|
||||||
|
|
||||||
else:
|
|
||||||
# Numbers, booleans, None, etc. remain unchanged
|
|
||||||
return obj
|
|
||||||
|
|
||||||
|
|
||||||
def write_json(json_obj, file_name):
|
def write_json(json_obj, file_name):
|
||||||
"""
|
|
||||||
Write JSON data to file with optimized sanitization strategy.
|
|
||||||
|
|
||||||
This function uses a two-stage approach:
|
|
||||||
1. Fast path: Try direct serialization (works for clean data ~99% of time)
|
|
||||||
2. Slow path: Use custom encoder that sanitizes during serialization
|
|
||||||
|
|
||||||
The custom encoder approach avoids creating a deep copy of the data,
|
|
||||||
making it memory-efficient. When sanitization occurs, the caller should
|
|
||||||
reload the cleaned data from the file to update shared memory.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
json_obj: Object to serialize (may be a shallow copy from shared memory)
|
|
||||||
file_name: Output file path
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if sanitization was applied (caller should reload data),
|
|
||||||
False if direct write succeeded (no reload needed)
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Strategy 1: Fast path - try direct serialization
|
|
||||||
with open(file_name, "w", encoding="utf-8") as f:
|
|
||||||
json.dump(json_obj, f, indent=2, ensure_ascii=False)
|
|
||||||
return False # No sanitization needed, no reload required
|
|
||||||
|
|
||||||
except (UnicodeEncodeError, UnicodeDecodeError) as e:
|
|
||||||
logger.debug(f"Direct JSON write failed, using sanitizing encoder: {e}")
|
|
||||||
|
|
||||||
# Strategy 2: Use custom encoder (sanitizes during serialization, zero memory copy)
|
|
||||||
with open(file_name, "w", encoding="utf-8") as f:
|
with open(file_name, "w", encoding="utf-8") as f:
|
||||||
json.dump(json_obj, f, indent=2, ensure_ascii=False, cls=SanitizingJSONEncoder)
|
json.dump(json_obj, f, indent=2, ensure_ascii=False)
|
||||||
|
|
||||||
logger.info(f"JSON sanitization applied during write: {file_name}")
|
|
||||||
return True # Sanitization applied, reload recommended
|
|
||||||
|
|
||||||
|
|
||||||
class TokenizerInterface(Protocol):
|
class TokenizerInterface(Protocol):
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue