cherry-pick 33a1482f
This commit is contained in:
parent
e11e30be0e
commit
cacea8ab56
5 changed files with 149 additions and 611 deletions
|
|
@ -233,6 +233,13 @@ OLLAMA_LLM_NUM_CTX=32768
|
|||
### EMBEDDING_BINDING_HOST: host only for Ollama, endpoint for other Embedding service
|
||||
#######################################################################################
|
||||
# EMBEDDING_TIMEOUT=30
|
||||
|
||||
### Control whether to send embedding_dim parameter to embedding API
|
||||
### Set to 'true' to enable dynamic dimension adjustment (only works for OpenAI and Jina)
|
||||
### Set to 'false' (default) to disable sending dimension parameter
|
||||
### Note: This is automatically ignored for backends that don't support dimension parameter
|
||||
# EMBEDDING_SEND_DIM=false
|
||||
|
||||
EMBEDDING_BINDING=ollama
|
||||
EMBEDDING_MODEL=bge-m3:latest
|
||||
EMBEDDING_DIM=1024
|
||||
|
|
|
|||
|
|
@ -56,8 +56,6 @@ from lightrag.api.routers.ollama_api import OllamaAPI
|
|||
from lightrag.utils import logger, set_verbose_debug
|
||||
from lightrag.kg.shared_storage import (
|
||||
get_namespace_data,
|
||||
get_default_workspace,
|
||||
# set_default_workspace,
|
||||
initialize_pipeline_status,
|
||||
cleanup_keyed_lock,
|
||||
finalize_share_data,
|
||||
|
|
@ -91,7 +89,6 @@ class LLMConfigCache:
|
|||
# Initialize configurations based on binding conditions
|
||||
self.openai_llm_options = None
|
||||
self.gemini_llm_options = None
|
||||
self.gemini_embedding_options = None
|
||||
self.ollama_llm_options = None
|
||||
self.ollama_embedding_options = None
|
||||
|
||||
|
|
@ -138,23 +135,6 @@ class LLMConfigCache:
|
|||
)
|
||||
self.ollama_embedding_options = {}
|
||||
|
||||
# Only initialize and log Gemini Embedding options when using Gemini Embedding binding
|
||||
if args.embedding_binding == "gemini":
|
||||
try:
|
||||
from lightrag.llm.binding_options import GeminiEmbeddingOptions
|
||||
|
||||
self.gemini_embedding_options = GeminiEmbeddingOptions.options_dict(
|
||||
args
|
||||
)
|
||||
logger.info(
|
||||
f"Gemini Embedding Options: {self.gemini_embedding_options}"
|
||||
)
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"GeminiEmbeddingOptions not available, using default configuration"
|
||||
)
|
||||
self.gemini_embedding_options = {}
|
||||
|
||||
|
||||
def check_frontend_build():
|
||||
"""Check if frontend is built and optionally check if source is up-to-date
|
||||
|
|
@ -316,7 +296,6 @@ def create_app(args):
|
|||
"azure_openai",
|
||||
"aws_bedrock",
|
||||
"jina",
|
||||
"gemini",
|
||||
]:
|
||||
raise Exception("embedding binding not supported")
|
||||
|
||||
|
|
@ -352,9 +331,8 @@ def create_app(args):
|
|||
|
||||
try:
|
||||
# Initialize database connections
|
||||
# set_default_workspace(rag.workspace) # comment this line to test auto default workspace setting in initialize_storages
|
||||
await rag.initialize_storages()
|
||||
await initialize_pipeline_status() # with default workspace
|
||||
await initialize_pipeline_status()
|
||||
|
||||
# Data migration regardless of storage implementation
|
||||
await rag.check_and_migrate_data()
|
||||
|
|
@ -455,29 +433,6 @@ def create_app(args):
|
|||
# Create combined auth dependency for all endpoints
|
||||
combined_auth = get_combined_auth_dependency(api_key)
|
||||
|
||||
def get_workspace_from_request(request: Request) -> str:
|
||||
"""
|
||||
Extract workspace from HTTP request header or use default.
|
||||
|
||||
This enables multi-workspace API support by checking the custom
|
||||
'LIGHTRAG-WORKSPACE' header. If not present, falls back to the
|
||||
server's default workspace configuration.
|
||||
|
||||
Args:
|
||||
request: FastAPI Request object
|
||||
|
||||
Returns:
|
||||
Workspace identifier (may be empty string for global namespace)
|
||||
"""
|
||||
# Check custom header first
|
||||
workspace = request.headers.get("LIGHTRAG-WORKSPACE", "").strip()
|
||||
|
||||
# Fall back to server default if header not provided
|
||||
if not workspace:
|
||||
workspace = args.workspace
|
||||
|
||||
return workspace
|
||||
|
||||
# Create working directory if it doesn't exist
|
||||
Path(args.working_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
|
@ -556,9 +511,7 @@ def create_app(args):
|
|||
|
||||
return optimized_azure_openai_model_complete
|
||||
|
||||
def create_optimized_gemini_llm_func(
|
||||
config_cache: LLMConfigCache, args, llm_timeout: int
|
||||
):
|
||||
def create_optimized_gemini_llm_func(config_cache: LLMConfigCache, args):
|
||||
"""Create optimized Gemini LLM function with cached configuration"""
|
||||
|
||||
async def optimized_gemini_model_complete(
|
||||
|
|
@ -573,8 +526,6 @@ def create_app(args):
|
|||
if history_messages is None:
|
||||
history_messages = []
|
||||
|
||||
# Use pre-processed configuration to avoid repeated parsing
|
||||
kwargs["timeout"] = llm_timeout
|
||||
if (
|
||||
config_cache.gemini_llm_options is not None
|
||||
and "generation_config" not in kwargs
|
||||
|
|
@ -616,7 +567,7 @@ def create_app(args):
|
|||
config_cache, args, llm_timeout
|
||||
)
|
||||
elif binding == "gemini":
|
||||
return create_optimized_gemini_llm_func(config_cache, args, llm_timeout)
|
||||
return create_optimized_gemini_llm_func(config_cache, args)
|
||||
else: # openai and compatible
|
||||
# Use optimized function with pre-processed configuration
|
||||
return create_optimized_openai_llm_func(config_cache, args, llm_timeout)
|
||||
|
|
@ -644,108 +595,33 @@ def create_app(args):
|
|||
|
||||
def create_optimized_embedding_function(
|
||||
config_cache: LLMConfigCache, binding, model, host, api_key, args
|
||||
) -> EmbeddingFunc:
|
||||
):
|
||||
"""
|
||||
Create optimized embedding function and return an EmbeddingFunc instance
|
||||
with proper max_token_size inheritance from provider defaults.
|
||||
|
||||
This function:
|
||||
1. Imports the provider embedding function
|
||||
2. Extracts max_token_size and embedding_dim from provider if it's an EmbeddingFunc
|
||||
3. Creates an optimized wrapper that calls the underlying function directly (avoiding double-wrapping)
|
||||
4. Returns a properly configured EmbeddingFunc instance
|
||||
Create optimized embedding function with pre-processed configuration for applicable bindings.
|
||||
Uses lazy imports for all bindings and avoids repeated configuration parsing.
|
||||
"""
|
||||
|
||||
# Step 1: Import provider function and extract default attributes
|
||||
provider_func = None
|
||||
provider_max_token_size = None
|
||||
provider_embedding_dim = None
|
||||
|
||||
try:
|
||||
if binding == "openai":
|
||||
from lightrag.llm.openai import openai_embed
|
||||
|
||||
provider_func = openai_embed
|
||||
elif binding == "ollama":
|
||||
from lightrag.llm.ollama import ollama_embed
|
||||
|
||||
provider_func = ollama_embed
|
||||
elif binding == "gemini":
|
||||
from lightrag.llm.gemini import gemini_embed
|
||||
|
||||
provider_func = gemini_embed
|
||||
elif binding == "jina":
|
||||
from lightrag.llm.jina import jina_embed
|
||||
|
||||
provider_func = jina_embed
|
||||
elif binding == "azure_openai":
|
||||
from lightrag.llm.azure_openai import azure_openai_embed
|
||||
|
||||
provider_func = azure_openai_embed
|
||||
elif binding == "aws_bedrock":
|
||||
from lightrag.llm.bedrock import bedrock_embed
|
||||
|
||||
provider_func = bedrock_embed
|
||||
elif binding == "lollms":
|
||||
from lightrag.llm.lollms import lollms_embed
|
||||
|
||||
provider_func = lollms_embed
|
||||
|
||||
# Extract attributes if provider is an EmbeddingFunc
|
||||
if provider_func and isinstance(provider_func, EmbeddingFunc):
|
||||
provider_max_token_size = provider_func.max_token_size
|
||||
provider_embedding_dim = provider_func.embedding_dim
|
||||
logger.debug(
|
||||
f"Extracted from {binding} provider: "
|
||||
f"max_token_size={provider_max_token_size}, "
|
||||
f"embedding_dim={provider_embedding_dim}"
|
||||
)
|
||||
except ImportError as e:
|
||||
logger.warning(f"Could not import provider function for {binding}: {e}")
|
||||
|
||||
# Step 2: Apply priority (user config > provider default)
|
||||
# For max_token_size: explicit env var > provider default > None
|
||||
final_max_token_size = args.embedding_token_limit or provider_max_token_size
|
||||
# For embedding_dim: user config (always has value) takes priority
|
||||
# Only use provider default if user config is explicitly None (which shouldn't happen)
|
||||
final_embedding_dim = (
|
||||
args.embedding_dim if args.embedding_dim else provider_embedding_dim
|
||||
)
|
||||
|
||||
# Step 3: Create optimized embedding function (calls underlying function directly)
|
||||
async def optimized_embedding_function(texts, embedding_dim=None):
|
||||
async def optimized_embedding_function(texts):
|
||||
try:
|
||||
if binding == "lollms":
|
||||
from lightrag.llm.lollms import lollms_embed
|
||||
|
||||
# Get real function, skip EmbeddingFunc wrapper if present
|
||||
actual_func = (
|
||||
lollms_embed.func
|
||||
if isinstance(lollms_embed, EmbeddingFunc)
|
||||
else lollms_embed
|
||||
)
|
||||
return await actual_func(
|
||||
return await lollms_embed(
|
||||
texts, embed_model=model, host=host, api_key=api_key
|
||||
)
|
||||
elif binding == "ollama":
|
||||
from lightrag.llm.ollama import ollama_embed
|
||||
|
||||
# Get real function, skip EmbeddingFunc wrapper if present
|
||||
actual_func = (
|
||||
ollama_embed.func
|
||||
if isinstance(ollama_embed, EmbeddingFunc)
|
||||
else ollama_embed
|
||||
)
|
||||
|
||||
# Use pre-processed configuration if available
|
||||
# Use pre-processed configuration if available, otherwise fallback to dynamic parsing
|
||||
if config_cache.ollama_embedding_options is not None:
|
||||
ollama_options = config_cache.ollama_embedding_options
|
||||
else:
|
||||
# Fallback for cases where config cache wasn't initialized properly
|
||||
from lightrag.llm.binding_options import OllamaEmbeddingOptions
|
||||
|
||||
ollama_options = OllamaEmbeddingOptions.options_dict(args)
|
||||
|
||||
return await actual_func(
|
||||
return await ollama_embed(
|
||||
texts,
|
||||
embed_model=model,
|
||||
host=host,
|
||||
|
|
@ -755,93 +631,27 @@ def create_app(args):
|
|||
elif binding == "azure_openai":
|
||||
from lightrag.llm.azure_openai import azure_openai_embed
|
||||
|
||||
actual_func = (
|
||||
azure_openai_embed.func
|
||||
if isinstance(azure_openai_embed, EmbeddingFunc)
|
||||
else azure_openai_embed
|
||||
)
|
||||
return await actual_func(texts, model=model, api_key=api_key)
|
||||
return await azure_openai_embed(texts, model=model, api_key=api_key)
|
||||
elif binding == "aws_bedrock":
|
||||
from lightrag.llm.bedrock import bedrock_embed
|
||||
|
||||
actual_func = (
|
||||
bedrock_embed.func
|
||||
if isinstance(bedrock_embed, EmbeddingFunc)
|
||||
else bedrock_embed
|
||||
)
|
||||
return await actual_func(texts, model=model)
|
||||
return await bedrock_embed(texts, model=model)
|
||||
elif binding == "jina":
|
||||
from lightrag.llm.jina import jina_embed
|
||||
|
||||
actual_func = (
|
||||
jina_embed.func
|
||||
if isinstance(jina_embed, EmbeddingFunc)
|
||||
else jina_embed
|
||||
)
|
||||
return await actual_func(
|
||||
texts,
|
||||
embedding_dim=embedding_dim,
|
||||
base_url=host,
|
||||
api_key=api_key,
|
||||
)
|
||||
elif binding == "gemini":
|
||||
from lightrag.llm.gemini import gemini_embed
|
||||
|
||||
actual_func = (
|
||||
gemini_embed.func
|
||||
if isinstance(gemini_embed, EmbeddingFunc)
|
||||
else gemini_embed
|
||||
)
|
||||
|
||||
# Use pre-processed configuration if available
|
||||
if config_cache.gemini_embedding_options is not None:
|
||||
gemini_options = config_cache.gemini_embedding_options
|
||||
else:
|
||||
from lightrag.llm.binding_options import GeminiEmbeddingOptions
|
||||
|
||||
gemini_options = GeminiEmbeddingOptions.options_dict(args)
|
||||
|
||||
return await actual_func(
|
||||
texts,
|
||||
model=model,
|
||||
base_url=host,
|
||||
api_key=api_key,
|
||||
embedding_dim=embedding_dim,
|
||||
task_type=gemini_options.get("task_type", "RETRIEVAL_DOCUMENT"),
|
||||
return await jina_embed(
|
||||
texts, base_url=host, api_key=api_key
|
||||
)
|
||||
else: # openai and compatible
|
||||
from lightrag.llm.openai import openai_embed
|
||||
|
||||
actual_func = (
|
||||
openai_embed.func
|
||||
if isinstance(openai_embed, EmbeddingFunc)
|
||||
else openai_embed
|
||||
)
|
||||
return await actual_func(
|
||||
texts,
|
||||
model=model,
|
||||
base_url=host,
|
||||
api_key=api_key,
|
||||
embedding_dim=embedding_dim,
|
||||
return await openai_embed(
|
||||
texts, model=model, base_url=host, api_key=api_key
|
||||
)
|
||||
except ImportError as e:
|
||||
raise Exception(f"Failed to import {binding} embedding: {e}")
|
||||
|
||||
# Step 4: Wrap in EmbeddingFunc and return
|
||||
embedding_func_instance = EmbeddingFunc(
|
||||
embedding_dim=final_embedding_dim,
|
||||
func=optimized_embedding_function,
|
||||
max_token_size=final_max_token_size,
|
||||
send_dimensions=False, # Will be set later based on binding requirements
|
||||
)
|
||||
|
||||
# Log final embedding configuration
|
||||
logger.info(
|
||||
f"Embedding config: binding={binding} model={model} "
|
||||
f"embedding_dim={final_embedding_dim} max_token_size={final_max_token_size}"
|
||||
)
|
||||
|
||||
return embedding_func_instance
|
||||
return optimized_embedding_function
|
||||
|
||||
llm_timeout = get_env_value("LLM_TIMEOUT", DEFAULT_LLM_TIMEOUT, int)
|
||||
embedding_timeout = get_env_value(
|
||||
|
|
@ -875,62 +685,45 @@ def create_app(args):
|
|||
**kwargs,
|
||||
)
|
||||
|
||||
# Create embedding function with optimized configuration and max_token_size inheritance
|
||||
# Create embedding function with optimized configuration
|
||||
import inspect
|
||||
|
||||
# Create the EmbeddingFunc instance (now returns complete EmbeddingFunc with max_token_size)
|
||||
embedding_func = create_optimized_embedding_function(
|
||||
|
||||
# Create the optimized embedding function
|
||||
optimized_embedding_func = create_optimized_embedding_function(
|
||||
config_cache=config_cache,
|
||||
binding=args.embedding_binding,
|
||||
model=args.embedding_model,
|
||||
host=args.embedding_binding_host,
|
||||
api_key=args.embedding_binding_api_key,
|
||||
args=args,
|
||||
args=args, # Pass args object for fallback option generation
|
||||
)
|
||||
|
||||
# Get embedding_send_dim from centralized configuration
|
||||
embedding_send_dim = args.embedding_send_dim
|
||||
|
||||
# Check if the underlying function signature has embedding_dim parameter
|
||||
sig = inspect.signature(embedding_func.func)
|
||||
has_embedding_dim_param = "embedding_dim" in sig.parameters
|
||||
|
||||
# Determine send_dimensions value based on binding type
|
||||
# Jina and Gemini REQUIRE dimension parameter (forced to True)
|
||||
# OpenAI and others: controlled by EMBEDDING_SEND_DIM environment variable
|
||||
if args.embedding_binding in ["jina", "gemini"]:
|
||||
# Jina and Gemini APIs require dimension parameter - always send it
|
||||
send_dimensions = has_embedding_dim_param
|
||||
dimension_control = f"forced by {args.embedding_binding.title()} API"
|
||||
else:
|
||||
# For OpenAI and other bindings, respect EMBEDDING_SEND_DIM setting
|
||||
send_dimensions = embedding_send_dim and has_embedding_dim_param
|
||||
if send_dimensions or not embedding_send_dim:
|
||||
dimension_control = "by env var"
|
||||
else:
|
||||
dimension_control = "by not hasparam"
|
||||
|
||||
# Set send_dimensions on the EmbeddingFunc instance
|
||||
embedding_func.send_dimensions = send_dimensions
|
||||
|
||||
|
||||
# Check environment variable for sending dimensions
|
||||
embedding_send_dim = os.getenv("EMBEDDING_SEND_DIM", "false").lower() == "true"
|
||||
|
||||
# Check if the function signature has embedding_dim parameter
|
||||
# Note: Since optimized_embedding_func is an async function, inspect its signature
|
||||
sig = inspect.signature(optimized_embedding_func)
|
||||
has_embedding_dim_param = 'embedding_dim' in sig.parameters
|
||||
|
||||
# Determine send_dimensions value
|
||||
# Only send dimensions if both conditions are met:
|
||||
# 1. EMBEDDING_SEND_DIM environment variable is true
|
||||
# 2. The function has embedding_dim parameter
|
||||
send_dimensions = embedding_send_dim and has_embedding_dim_param
|
||||
|
||||
logger.info(
|
||||
f"Send embedding dimension: {send_dimensions} {dimension_control} "
|
||||
f"(dimensions={embedding_func.embedding_dim}, has_param={has_embedding_dim_param}, "
|
||||
f"Embedding configuration: send_dimensions={send_dimensions} "
|
||||
f"(env_var={embedding_send_dim}, has_param={has_embedding_dim_param}, "
|
||||
f"binding={args.embedding_binding})"
|
||||
)
|
||||
|
||||
# Log max_token_size source
|
||||
if embedding_func.max_token_size:
|
||||
source = (
|
||||
"env variable"
|
||||
if args.embedding_token_limit
|
||||
else f"{args.embedding_binding} provider default"
|
||||
)
|
||||
logger.info(
|
||||
f"Embedding max_token_size: {embedding_func.max_token_size} (from {source})"
|
||||
)
|
||||
else:
|
||||
logger.info("Embedding max_token_size: not set (90% token warning disabled)")
|
||||
|
||||
# Create EmbeddingFunc with send_dimensions attribute
|
||||
embedding_func = EmbeddingFunc(
|
||||
embedding_dim=args.embedding_dim,
|
||||
func=optimized_embedding_func,
|
||||
send_dimensions=send_dimensions,
|
||||
)
|
||||
|
||||
# Configure rerank function based on args.rerank_bindingparameter
|
||||
rerank_model_func = None
|
||||
|
|
@ -970,27 +763,15 @@ def create_app(args):
|
|||
query: str, documents: list, top_n: int = None, extra_body: dict = None
|
||||
):
|
||||
"""Server rerank function with configuration from environment variables"""
|
||||
# Prepare kwargs for rerank function
|
||||
kwargs = {
|
||||
"query": query,
|
||||
"documents": documents,
|
||||
"top_n": top_n,
|
||||
"api_key": args.rerank_binding_api_key,
|
||||
"model": args.rerank_model,
|
||||
"base_url": args.rerank_binding_host,
|
||||
}
|
||||
|
||||
# Add Cohere-specific parameters if using cohere binding
|
||||
if args.rerank_binding == "cohere":
|
||||
# Enable chunking if configured (useful for models with token limits like ColBERT)
|
||||
kwargs["enable_chunking"] = (
|
||||
os.getenv("RERANK_ENABLE_CHUNKING", "false").lower() == "true"
|
||||
)
|
||||
kwargs["max_tokens_per_doc"] = int(
|
||||
os.getenv("RERANK_MAX_TOKENS_PER_DOC", "4096")
|
||||
)
|
||||
|
||||
return await selected_rerank_func(**kwargs, extra_body=extra_body)
|
||||
return await selected_rerank_func(
|
||||
query=query,
|
||||
documents=documents,
|
||||
top_n=top_n,
|
||||
api_key=args.rerank_binding_api_key,
|
||||
model=args.rerank_model,
|
||||
base_url=args.rerank_binding_host,
|
||||
extra_body=extra_body,
|
||||
)
|
||||
|
||||
rerank_model_func = server_rerank_func
|
||||
logger.info(
|
||||
|
|
@ -1151,10 +932,9 @@ def create_app(args):
|
|||
}
|
||||
|
||||
@app.get("/health", dependencies=[Depends(combined_auth)])
|
||||
async def get_status(request: Request):
|
||||
async def get_status():
|
||||
"""Get current system status"""
|
||||
try:
|
||||
default_workspace = get_default_workspace()
|
||||
pipeline_status = await get_namespace_data("pipeline_status")
|
||||
|
||||
if not auth_configured:
|
||||
|
|
@ -1186,7 +966,7 @@ def create_app(args):
|
|||
"vector_storage": args.vector_storage,
|
||||
"enable_llm_cache_for_extract": args.enable_llm_cache_for_extract,
|
||||
"enable_llm_cache": args.enable_llm_cache,
|
||||
"workspace": default_workspace,
|
||||
"workspace": args.workspace,
|
||||
"max_graph_nodes": args.max_graph_nodes,
|
||||
# Rerank configuration
|
||||
"enable_rerank": rerank_model_func is not None,
|
||||
|
|
@ -1375,12 +1155,6 @@ def check_and_install_dependencies():
|
|||
|
||||
|
||||
def main():
|
||||
# Explicitly initialize configuration for clarity
|
||||
# (The proxy will auto-initialize anyway, but this makes intent clear)
|
||||
from .config import initialize_config
|
||||
|
||||
initialize_config()
|
||||
|
||||
# Check if running under Gunicorn
|
||||
if "GUNICORN_CMD_ARGS" in os.environ:
|
||||
# If started with Gunicorn, return directly as Gunicorn will call get_application
|
||||
|
|
|
|||
|
|
@ -1,6 +1,4 @@
|
|||
import os
|
||||
from typing import Final
|
||||
|
||||
import pipmaster as pm # Pipmaster for dynamic library install
|
||||
|
||||
# install specific modules
|
||||
|
|
@ -21,9 +19,6 @@ from tenacity import (
|
|||
from lightrag.utils import wrap_embedding_func_with_attrs, logger
|
||||
|
||||
|
||||
DEFAULT_JINA_EMBED_DIM: Final[int] = 2048
|
||||
|
||||
|
||||
async def fetch_data(url, headers, data):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, headers=headers, json=data) as response:
|
||||
|
|
@ -63,7 +58,7 @@ async def fetch_data(url, headers, data):
|
|||
return data_list
|
||||
|
||||
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=DEFAULT_JINA_EMBED_DIM)
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=2048)
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||
|
|
@ -74,7 +69,7 @@ async def fetch_data(url, headers, data):
|
|||
)
|
||||
async def jina_embed(
|
||||
texts: list[str],
|
||||
embedding_dim: int | None = DEFAULT_JINA_EMBED_DIM,
|
||||
embedding_dim: int = 2048,
|
||||
late_chunking: bool = False,
|
||||
base_url: str = None,
|
||||
api_key: str = None,
|
||||
|
|
@ -100,10 +95,6 @@ async def jina_embed(
|
|||
aiohttp.ClientError: If there is a connection error with the Jina API.
|
||||
aiohttp.ClientResponseError: If the Jina API returns an error response.
|
||||
"""
|
||||
resolved_embedding_dim = (
|
||||
embedding_dim if embedding_dim is not None else DEFAULT_JINA_EMBED_DIM
|
||||
)
|
||||
|
||||
if api_key:
|
||||
os.environ["JINA_API_KEY"] = api_key
|
||||
|
||||
|
|
@ -118,7 +109,7 @@ async def jina_embed(
|
|||
data = {
|
||||
"model": "jina-embeddings-v4",
|
||||
"task": "text-matching",
|
||||
"dimensions": resolved_embedding_dim,
|
||||
"dimensions": embedding_dim,
|
||||
"embedding_type": "base64",
|
||||
"input": texts,
|
||||
}
|
||||
|
|
@ -128,7 +119,7 @@ async def jina_embed(
|
|||
data["late_chunking"] = late_chunking
|
||||
|
||||
logger.debug(
|
||||
f"Jina embedding request: {len(texts)} texts, dimensions: {resolved_embedding_dim}"
|
||||
f"Jina embedding request: {len(texts)} texts, dimensions: {embedding_dim}"
|
||||
)
|
||||
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@ if not pm.is_installed("openai"):
|
|||
pm.install("openai")
|
||||
|
||||
from openai import (
|
||||
AsyncOpenAI,
|
||||
APIConnectionError,
|
||||
RateLimitError,
|
||||
APITimeoutError,
|
||||
|
|
@ -27,6 +26,7 @@ from lightrag.utils import (
|
|||
safe_unicode_decode,
|
||||
logger,
|
||||
)
|
||||
|
||||
from lightrag.types import GPTKeywordExtractionFormat
|
||||
from lightrag.api import __api_version__
|
||||
|
||||
|
|
@ -36,6 +36,32 @@ from typing import Any, Union
|
|||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Try to import Langfuse for LLM observability (optional)
|
||||
# Falls back to standard OpenAI client if not available
|
||||
# Langfuse requires proper configuration to work correctly
|
||||
LANGFUSE_ENABLED = False
|
||||
try:
|
||||
# Check if required Langfuse environment variables are set
|
||||
langfuse_public_key = os.environ.get("LANGFUSE_PUBLIC_KEY")
|
||||
langfuse_secret_key = os.environ.get("LANGFUSE_SECRET_KEY")
|
||||
|
||||
# Only enable Langfuse if both keys are configured
|
||||
if langfuse_public_key and langfuse_secret_key:
|
||||
from langfuse.openai import AsyncOpenAI
|
||||
|
||||
LANGFUSE_ENABLED = True
|
||||
logger.info("Langfuse observability enabled for OpenAI client")
|
||||
else:
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
logger.debug(
|
||||
"Langfuse environment variables not configured, using standard OpenAI client"
|
||||
)
|
||||
except ImportError:
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
logger.debug("Langfuse not available, using standard OpenAI client")
|
||||
|
||||
# use the .env that is inside the current folder
|
||||
# allows to use different .env file for each lightrag instance
|
||||
# the OS environment variables take precedence over the .env file
|
||||
|
|
@ -370,18 +396,23 @@ async def openai_complete_if_cache(
|
|||
)
|
||||
|
||||
# Ensure resources are released even if no exception occurs
|
||||
if (
|
||||
iteration_started
|
||||
and hasattr(response, "aclose")
|
||||
and callable(getattr(response, "aclose", None))
|
||||
):
|
||||
try:
|
||||
await response.aclose()
|
||||
logger.debug("Successfully closed stream response")
|
||||
except Exception as close_error:
|
||||
logger.warning(
|
||||
f"Failed to close stream response in finally block: {close_error}"
|
||||
)
|
||||
# Note: Some wrapped clients (e.g., Langfuse) may not implement aclose() properly
|
||||
if iteration_started and hasattr(response, "aclose"):
|
||||
aclose_method = getattr(response, "aclose", None)
|
||||
if callable(aclose_method):
|
||||
try:
|
||||
await response.aclose()
|
||||
logger.debug("Successfully closed stream response")
|
||||
except (AttributeError, TypeError) as close_error:
|
||||
# Some wrapper objects may report hasattr(aclose) but fail when called
|
||||
# This is expected behavior for certain client wrappers
|
||||
logger.debug(
|
||||
f"Stream response cleanup not supported by client wrapper: {close_error}"
|
||||
)
|
||||
except Exception as close_error:
|
||||
logger.warning(
|
||||
f"Unexpected error during stream response cleanup: {close_error}"
|
||||
)
|
||||
|
||||
# This prevents resource leaks since the caller doesn't handle closing
|
||||
try:
|
||||
|
|
@ -578,6 +609,7 @@ async def openai_embed(
|
|||
model: str = "text-embedding-3-small",
|
||||
base_url: str | None = None,
|
||||
api_key: str | None = None,
|
||||
embedding_dim: int | None = None,
|
||||
client_configs: dict[str, Any] | None = None,
|
||||
token_tracker: Any | None = None,
|
||||
) -> np.ndarray:
|
||||
|
|
@ -588,6 +620,12 @@ async def openai_embed(
|
|||
model: The OpenAI embedding model to use.
|
||||
base_url: Optional base URL for the OpenAI API.
|
||||
api_key: Optional OpenAI API key. If None, uses the OPENAI_API_KEY environment variable.
|
||||
embedding_dim: Optional embedding dimension for dynamic dimension reduction.
|
||||
**IMPORTANT**: This parameter is automatically injected by the EmbeddingFunc wrapper.
|
||||
Do NOT manually pass this parameter when calling the function directly.
|
||||
The dimension is controlled by the @wrap_embedding_func_with_attrs decorator.
|
||||
Manually passing a different value will trigger a warning and be ignored.
|
||||
When provided (by EmbeddingFunc), it will be passed to the OpenAI API for dimension reduction.
|
||||
client_configs: Additional configuration options for the AsyncOpenAI client.
|
||||
These will override any default configurations but will be overridden by
|
||||
explicit parameters (api_key, base_url).
|
||||
|
|
@ -607,17 +645,27 @@ async def openai_embed(
|
|||
)
|
||||
|
||||
async with openai_async_client:
|
||||
response = await openai_async_client.embeddings.create(
|
||||
model=model, input=texts, encoding_format="base64"
|
||||
)
|
||||
|
||||
# Prepare API call parameters
|
||||
api_params = {
|
||||
"model": model,
|
||||
"input": texts,
|
||||
"encoding_format": "base64",
|
||||
}
|
||||
|
||||
# Add dimensions parameter only if embedding_dim is provided
|
||||
if embedding_dim is not None:
|
||||
api_params["dimensions"] = embedding_dim
|
||||
|
||||
# Make API call
|
||||
response = await openai_async_client.embeddings.create(**api_params)
|
||||
|
||||
if token_tracker and hasattr(response, "usage"):
|
||||
token_counts = {
|
||||
"prompt_tokens": getattr(response.usage, "prompt_tokens", 0),
|
||||
"total_tokens": getattr(response.usage, "total_tokens", 0),
|
||||
}
|
||||
token_tracker.add_usage(token_counts)
|
||||
|
||||
|
||||
return np.array(
|
||||
[
|
||||
np.array(dp.embedding, dtype=np.float32)
|
||||
|
|
|
|||
|
|
@ -1,8 +1,6 @@
|
|||
from __future__ import annotations
|
||||
import weakref
|
||||
|
||||
import sys
|
||||
|
||||
import asyncio
|
||||
import html
|
||||
import csv
|
||||
|
|
@ -42,35 +40,6 @@ from lightrag.constants import (
|
|||
SOURCE_IDS_LIMIT_METHOD_FIFO,
|
||||
)
|
||||
|
||||
# Precompile regex pattern for JSON sanitization (module-level, compiled once)
|
||||
_SURROGATE_PATTERN = re.compile(r"[\uD800-\uDFFF\uFFFE\uFFFF]")
|
||||
|
||||
|
||||
class SafeStreamHandler(logging.StreamHandler):
|
||||
"""StreamHandler that gracefully handles closed streams during shutdown.
|
||||
|
||||
This handler prevents "ValueError: I/O operation on closed file" errors
|
||||
that can occur when pytest or other test frameworks close stdout/stderr
|
||||
before Python's logging cleanup runs.
|
||||
"""
|
||||
|
||||
def flush(self):
|
||||
"""Flush the stream, ignoring errors if the stream is closed."""
|
||||
try:
|
||||
super().flush()
|
||||
except (ValueError, OSError):
|
||||
# Stream is closed or otherwise unavailable, silently ignore
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
"""Close the handler, ignoring errors if the stream is already closed."""
|
||||
try:
|
||||
super().close()
|
||||
except (ValueError, OSError):
|
||||
# Stream is closed or otherwise unavailable, silently ignore
|
||||
pass
|
||||
|
||||
|
||||
# Initialize logger with basic configuration
|
||||
logger = logging.getLogger("lightrag")
|
||||
logger.propagate = False # prevent log message send to root logger
|
||||
|
|
@ -78,7 +47,7 @@ logger.setLevel(logging.INFO)
|
|||
|
||||
# Add console handler if no handlers exist
|
||||
if not logger.handlers:
|
||||
console_handler = SafeStreamHandler()
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setLevel(logging.INFO)
|
||||
formatter = logging.Formatter("%(levelname)s: %(message)s")
|
||||
console_handler.setFormatter(formatter)
|
||||
|
|
@ -87,33 +56,6 @@ if not logger.handlers:
|
|||
# Set httpx logging level to WARNING
|
||||
logging.getLogger("httpx").setLevel(logging.WARNING)
|
||||
|
||||
|
||||
def _patch_ascii_colors_console_handler() -> None:
|
||||
"""Prevent ascii_colors from printing flush errors during interpreter exit."""
|
||||
|
||||
try:
|
||||
from ascii_colors import ConsoleHandler
|
||||
except ImportError:
|
||||
return
|
||||
|
||||
if getattr(ConsoleHandler, "_lightrag_patched", False):
|
||||
return
|
||||
|
||||
original_handle_error = ConsoleHandler.handle_error
|
||||
|
||||
def _safe_handle_error(self, message: str) -> None: # type: ignore[override]
|
||||
exc_type, _, _ = sys.exc_info()
|
||||
if exc_type in (ValueError, OSError) and "close" in message.lower():
|
||||
return
|
||||
original_handle_error(self, message)
|
||||
|
||||
ConsoleHandler.handle_error = _safe_handle_error # type: ignore[assignment]
|
||||
ConsoleHandler._lightrag_patched = True # type: ignore[attr-defined]
|
||||
|
||||
|
||||
_patch_ascii_colors_console_handler()
|
||||
|
||||
|
||||
# Global import for pypinyin with startup-time logging
|
||||
try:
|
||||
import pypinyin
|
||||
|
|
@ -341,8 +283,8 @@ def setup_logger(
|
|||
logger_instance.handlers = [] # Clear existing handlers
|
||||
logger_instance.propagate = False
|
||||
|
||||
# Add console handler with safe stream handling
|
||||
console_handler = SafeStreamHandler()
|
||||
# Add console handler
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setFormatter(simple_formatter)
|
||||
console_handler.setLevel(level)
|
||||
logger_instance.addHandler(console_handler)
|
||||
|
|
@ -408,69 +350,28 @@ class TaskState:
|
|||
|
||||
@dataclass
|
||||
class EmbeddingFunc:
|
||||
"""Embedding function wrapper with dimension validation
|
||||
This class wraps an embedding function to ensure that the output embeddings have the correct dimension.
|
||||
This class should not be wrapped multiple times.
|
||||
|
||||
Args:
|
||||
embedding_dim: Expected dimension of the embeddings
|
||||
func: The actual embedding function to wrap
|
||||
max_token_size: Optional token limit for the embedding model
|
||||
send_dimensions: Whether to inject embedding_dim as a keyword argument
|
||||
"""
|
||||
|
||||
embedding_dim: int
|
||||
func: callable
|
||||
max_token_size: int | None = None # Token limit for the embedding model
|
||||
send_dimensions: bool = (
|
||||
False # Control whether to send embedding_dim to the function
|
||||
)
|
||||
max_token_size: int | None = None # deprecated keep it for compatible only
|
||||
send_dimensions: bool = False # Control whether to send embedding_dim to the function
|
||||
|
||||
async def __call__(self, *args, **kwargs) -> np.ndarray:
|
||||
# Only inject embedding_dim when send_dimensions is True
|
||||
if self.send_dimensions:
|
||||
# Check if user provided embedding_dim parameter
|
||||
if "embedding_dim" in kwargs:
|
||||
user_provided_dim = kwargs["embedding_dim"]
|
||||
if 'embedding_dim' in kwargs:
|
||||
user_provided_dim = kwargs['embedding_dim']
|
||||
# If user's value differs from class attribute, output warning
|
||||
if (
|
||||
user_provided_dim is not None
|
||||
and user_provided_dim != self.embedding_dim
|
||||
):
|
||||
if user_provided_dim is not None and user_provided_dim != self.embedding_dim:
|
||||
logger.warning(
|
||||
f"Ignoring user-provided embedding_dim={user_provided_dim}, "
|
||||
f"using declared embedding_dim={self.embedding_dim} from decorator"
|
||||
)
|
||||
|
||||
|
||||
# Inject embedding_dim from decorator
|
||||
kwargs["embedding_dim"] = self.embedding_dim
|
||||
|
||||
# Call the actual embedding function
|
||||
result = await self.func(*args, **kwargs)
|
||||
|
||||
# Validate embedding dimensions using total element count
|
||||
total_elements = result.size # Total number of elements in the numpy array
|
||||
expected_dim = self.embedding_dim
|
||||
|
||||
# Check if total elements can be evenly divided by embedding_dim
|
||||
if total_elements % expected_dim != 0:
|
||||
raise ValueError(
|
||||
f"Embedding dimension mismatch detected: "
|
||||
f"total elements ({total_elements}) cannot be evenly divided by "
|
||||
f"expected dimension ({expected_dim}). "
|
||||
)
|
||||
|
||||
# Optional: Verify vector count matches input text count
|
||||
actual_vectors = total_elements // expected_dim
|
||||
if args and isinstance(args[0], (list, tuple)):
|
||||
expected_vectors = len(args[0])
|
||||
if actual_vectors != expected_vectors:
|
||||
raise ValueError(
|
||||
f"Vector count mismatch: "
|
||||
f"expected {expected_vectors} vectors but got {actual_vectors} vectors (from embedding result)."
|
||||
)
|
||||
|
||||
return result
|
||||
kwargs['embedding_dim'] = self.embedding_dim
|
||||
|
||||
return await self.func(*args, **kwargs)
|
||||
|
||||
|
||||
def compute_args_hash(*args: Any) -> str:
|
||||
|
|
@ -1005,76 +906,7 @@ def priority_limit_async_func_call(
|
|||
|
||||
|
||||
def wrap_embedding_func_with_attrs(**kwargs):
|
||||
"""Decorator to add embedding dimension and token limit attributes to embedding functions.
|
||||
|
||||
This decorator wraps an async embedding function and returns an EmbeddingFunc instance
|
||||
that automatically handles dimension parameter injection and attribute management.
|
||||
|
||||
WARNING: DO NOT apply this decorator to wrapper functions that call other
|
||||
decorated embedding functions. This will cause double decoration and parameter
|
||||
injection conflicts.
|
||||
|
||||
Correct usage patterns:
|
||||
|
||||
1. Direct implementation (decorated):
|
||||
```python
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=1536)
|
||||
async def my_embed(texts, embedding_dim=None):
|
||||
# Direct implementation
|
||||
return embeddings
|
||||
```
|
||||
|
||||
2. Wrapper calling decorated function (DO NOT decorate wrapper):
|
||||
```python
|
||||
# my_embed is already decorated above
|
||||
|
||||
async def my_wrapper(texts, **kwargs): # ❌ DO NOT decorate this!
|
||||
# Must call .func to access unwrapped implementation
|
||||
return await my_embed.func(texts, **kwargs)
|
||||
```
|
||||
|
||||
3. Wrapper calling decorated function (properly decorated):
|
||||
```python
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=1536)
|
||||
async def my_wrapper(texts, **kwargs): # ✅ Can decorate if calling .func
|
||||
# Calling .func avoids double decoration
|
||||
return await my_embed.func(texts, **kwargs)
|
||||
```
|
||||
|
||||
The decorated function becomes an EmbeddingFunc instance with:
|
||||
- embedding_dim: The embedding dimension
|
||||
- max_token_size: Maximum token limit (optional)
|
||||
- func: The original unwrapped function (access via .func)
|
||||
- __call__: Wrapper that injects embedding_dim parameter
|
||||
|
||||
Double decoration causes:
|
||||
- Double injection of embedding_dim parameter
|
||||
- Incorrect parameter passing to the underlying implementation
|
||||
- Runtime errors due to parameter conflicts
|
||||
|
||||
Args:
|
||||
embedding_dim: The dimension of embedding vectors
|
||||
max_token_size: Maximum number of tokens (optional)
|
||||
send_dimensions: Whether to inject embedding_dim as a keyword argument (optional)
|
||||
|
||||
Returns:
|
||||
A decorator that wraps the function as an EmbeddingFunc instance
|
||||
|
||||
Example of correct wrapper implementation:
|
||||
```python
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
|
||||
@retry(...)
|
||||
async def openai_embed(texts, ...):
|
||||
# Base implementation
|
||||
pass
|
||||
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=1536) # Note: No @retry here!
|
||||
async def azure_openai_embed(texts, ...):
|
||||
# CRITICAL: Call .func to access unwrapped function
|
||||
return await openai_embed.func(texts, ...) # ✅ Correct
|
||||
# return await openai_embed(texts, ...) # ❌ Wrong - double decoration!
|
||||
```
|
||||
"""
|
||||
"""Wrap a function with attributes"""
|
||||
|
||||
def final_decro(func) -> EmbeddingFunc:
|
||||
new_func = EmbeddingFunc(**kwargs, func=func)
|
||||
|
|
@ -1090,123 +922,9 @@ def load_json(file_name):
|
|||
return json.load(f)
|
||||
|
||||
|
||||
def _sanitize_string_for_json(text: str) -> str:
|
||||
"""Remove characters that cannot be encoded in UTF-8 for JSON serialization.
|
||||
|
||||
Uses regex for optimal performance with zero-copy optimization for clean strings.
|
||||
Fast detection path for clean strings (99% of cases) with efficient removal for dirty strings.
|
||||
|
||||
Args:
|
||||
text: String to sanitize
|
||||
|
||||
Returns:
|
||||
Original string if clean (zero-copy), sanitized string if dirty
|
||||
"""
|
||||
if not text:
|
||||
return text
|
||||
|
||||
# Fast path: Check if sanitization is needed using C-level regex search
|
||||
if not _SURROGATE_PATTERN.search(text):
|
||||
return text # Zero-copy for clean strings - most common case
|
||||
|
||||
# Slow path: Remove problematic characters using C-level regex substitution
|
||||
return _SURROGATE_PATTERN.sub("", text)
|
||||
|
||||
|
||||
class SanitizingJSONEncoder(json.JSONEncoder):
|
||||
"""
|
||||
Custom JSON encoder that sanitizes data during serialization.
|
||||
|
||||
This encoder cleans strings during the encoding process without creating
|
||||
a full copy of the data structure, making it memory-efficient for large datasets.
|
||||
"""
|
||||
|
||||
def encode(self, o):
|
||||
"""Override encode method to handle simple string cases"""
|
||||
if isinstance(o, str):
|
||||
return json.encoder.encode_basestring(_sanitize_string_for_json(o))
|
||||
return super().encode(o)
|
||||
|
||||
def iterencode(self, o, _one_shot=False):
|
||||
"""
|
||||
Override iterencode to sanitize strings during serialization.
|
||||
This is the core method that handles complex nested structures.
|
||||
"""
|
||||
# Preprocess: sanitize all strings in the object
|
||||
sanitized = self._sanitize_for_encoding(o)
|
||||
|
||||
# Call parent's iterencode with sanitized data
|
||||
for chunk in super().iterencode(sanitized, _one_shot):
|
||||
yield chunk
|
||||
|
||||
def _sanitize_for_encoding(self, obj):
|
||||
"""
|
||||
Recursively sanitize strings in an object.
|
||||
Creates new objects only when necessary to avoid deep copies.
|
||||
|
||||
Args:
|
||||
obj: Object to sanitize
|
||||
|
||||
Returns:
|
||||
Sanitized object with cleaned strings
|
||||
"""
|
||||
if isinstance(obj, str):
|
||||
return _sanitize_string_for_json(obj)
|
||||
|
||||
elif isinstance(obj, dict):
|
||||
# Create new dict with sanitized keys and values
|
||||
new_dict = {}
|
||||
for k, v in obj.items():
|
||||
clean_k = _sanitize_string_for_json(k) if isinstance(k, str) else k
|
||||
clean_v = self._sanitize_for_encoding(v)
|
||||
new_dict[clean_k] = clean_v
|
||||
return new_dict
|
||||
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
# Sanitize list/tuple elements
|
||||
cleaned = [self._sanitize_for_encoding(item) for item in obj]
|
||||
return type(obj)(cleaned) if isinstance(obj, tuple) else cleaned
|
||||
|
||||
else:
|
||||
# Numbers, booleans, None, etc. remain unchanged
|
||||
return obj
|
||||
|
||||
|
||||
def write_json(json_obj, file_name):
|
||||
"""
|
||||
Write JSON data to file with optimized sanitization strategy.
|
||||
|
||||
This function uses a two-stage approach:
|
||||
1. Fast path: Try direct serialization (works for clean data ~99% of time)
|
||||
2. Slow path: Use custom encoder that sanitizes during serialization
|
||||
|
||||
The custom encoder approach avoids creating a deep copy of the data,
|
||||
making it memory-efficient. When sanitization occurs, the caller should
|
||||
reload the cleaned data from the file to update shared memory.
|
||||
|
||||
Args:
|
||||
json_obj: Object to serialize (may be a shallow copy from shared memory)
|
||||
file_name: Output file path
|
||||
|
||||
Returns:
|
||||
bool: True if sanitization was applied (caller should reload data),
|
||||
False if direct write succeeded (no reload needed)
|
||||
"""
|
||||
try:
|
||||
# Strategy 1: Fast path - try direct serialization
|
||||
with open(file_name, "w", encoding="utf-8") as f:
|
||||
json.dump(json_obj, f, indent=2, ensure_ascii=False)
|
||||
return False # No sanitization needed, no reload required
|
||||
|
||||
except (UnicodeEncodeError, UnicodeDecodeError) as e:
|
||||
logger.debug(f"Direct JSON write failed, using sanitizing encoder: {e}")
|
||||
|
||||
# Strategy 2: Use custom encoder (sanitizes during serialization, zero memory copy)
|
||||
with open(file_name, "w", encoding="utf-8") as f:
|
||||
json.dump(json_obj, f, indent=2, ensure_ascii=False, cls=SanitizingJSONEncoder)
|
||||
|
||||
logger.info(f"JSON sanitization applied during write: {file_name}")
|
||||
return True # Sanitization applied, reload recommended
|
||||
json.dump(json_obj, f, indent=2, ensure_ascii=False)
|
||||
|
||||
|
||||
class TokenizerInterface(Protocol):
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue