Merge pull request #2029 from danielaskdd/optimize-rag-object-creation
refac: Eliminate Conditional Imports and Simplify Initialization
This commit is contained in:
commit
3c0ce9e38d
1 changed files with 115 additions and 144 deletions
|
|
@ -238,24 +238,100 @@ def create_app(args):
|
||||||
# Create working directory if it doesn't exist
|
# Create working directory if it doesn't exist
|
||||||
Path(args.working_dir).mkdir(parents=True, exist_ok=True)
|
Path(args.working_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
if args.llm_binding == "lollms" or args.embedding_binding == "lollms":
|
def create_llm_model_func(binding: str):
|
||||||
from lightrag.llm.lollms import lollms_model_complete, lollms_embed
|
"""
|
||||||
if args.llm_binding == "ollama" or args.embedding_binding == "ollama":
|
Create LLM model function based on binding type.
|
||||||
from lightrag.llm.ollama import ollama_model_complete, ollama_embed
|
Uses lazy import to avoid unnecessary dependencies.
|
||||||
from lightrag.llm.binding_options import OllamaLLMOptions
|
"""
|
||||||
if args.llm_binding == "openai" or args.embedding_binding == "openai":
|
try:
|
||||||
from lightrag.llm.openai import openai_complete_if_cache, openai_embed
|
if binding == "lollms":
|
||||||
from lightrag.llm.binding_options import OpenAILLMOptions
|
from lightrag.llm.lollms import lollms_model_complete
|
||||||
if args.llm_binding == "azure_openai" or args.embedding_binding == "azure_openai":
|
|
||||||
from lightrag.llm.azure_openai import (
|
return lollms_model_complete
|
||||||
azure_openai_complete_if_cache,
|
elif binding == "ollama":
|
||||||
azure_openai_embed,
|
from lightrag.llm.ollama import ollama_model_complete
|
||||||
)
|
|
||||||
from lightrag.llm.binding_options import OpenAILLMOptions
|
return ollama_model_complete
|
||||||
if args.llm_binding == "aws_bedrock" or args.embedding_binding == "aws_bedrock":
|
elif binding == "aws_bedrock":
|
||||||
from lightrag.llm.bedrock import bedrock_complete_if_cache, bedrock_embed
|
return bedrock_model_complete # Already defined locally
|
||||||
if args.embedding_binding == "jina":
|
elif binding == "azure_openai":
|
||||||
from lightrag.llm.jina import jina_embed
|
return azure_openai_model_complete # Already defined locally
|
||||||
|
else: # openai and compatible
|
||||||
|
return openai_alike_model_complete # Already defined locally
|
||||||
|
except ImportError as e:
|
||||||
|
raise Exception(f"Failed to import {binding} LLM binding: {e}")
|
||||||
|
|
||||||
|
def create_llm_model_kwargs(binding: str, args, llm_timeout: int) -> dict:
|
||||||
|
"""
|
||||||
|
Create LLM model kwargs based on binding type.
|
||||||
|
Uses lazy import for binding-specific options.
|
||||||
|
"""
|
||||||
|
if binding in ["lollms", "ollama"]:
|
||||||
|
try:
|
||||||
|
from lightrag.llm.binding_options import OllamaLLMOptions
|
||||||
|
|
||||||
|
return {
|
||||||
|
"host": args.llm_binding_host,
|
||||||
|
"timeout": llm_timeout,
|
||||||
|
"options": OllamaLLMOptions.options_dict(args),
|
||||||
|
"api_key": args.llm_binding_api_key,
|
||||||
|
}
|
||||||
|
except ImportError as e:
|
||||||
|
raise Exception(f"Failed to import {binding} options: {e}")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def create_embedding_function_with_lazy_import(
|
||||||
|
binding, model, host, api_key, dimensions, args
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Create embedding function with lazy imports for all bindings.
|
||||||
|
Replaces the current create_embedding_function with full lazy import support.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def embedding_function(texts):
|
||||||
|
try:
|
||||||
|
if binding == "lollms":
|
||||||
|
from lightrag.llm.lollms import lollms_embed
|
||||||
|
|
||||||
|
return await lollms_embed(
|
||||||
|
texts, embed_model=model, host=host, api_key=api_key
|
||||||
|
)
|
||||||
|
elif binding == "ollama":
|
||||||
|
from lightrag.llm.binding_options import OllamaEmbeddingOptions
|
||||||
|
from lightrag.llm.ollama import ollama_embed
|
||||||
|
|
||||||
|
ollama_options = OllamaEmbeddingOptions.options_dict(args)
|
||||||
|
return await ollama_embed(
|
||||||
|
texts,
|
||||||
|
embed_model=model,
|
||||||
|
host=host,
|
||||||
|
api_key=api_key,
|
||||||
|
options=ollama_options,
|
||||||
|
)
|
||||||
|
elif binding == "azure_openai":
|
||||||
|
from lightrag.llm.azure_openai import azure_openai_embed
|
||||||
|
|
||||||
|
return await azure_openai_embed(texts, model=model, api_key=api_key)
|
||||||
|
elif binding == "aws_bedrock":
|
||||||
|
from lightrag.llm.bedrock import bedrock_embed
|
||||||
|
|
||||||
|
return await bedrock_embed(texts, model=model)
|
||||||
|
elif binding == "jina":
|
||||||
|
from lightrag.llm.jina import jina_embed
|
||||||
|
|
||||||
|
return await jina_embed(
|
||||||
|
texts, dimensions=dimensions, base_url=host, api_key=api_key
|
||||||
|
)
|
||||||
|
else: # openai and compatible
|
||||||
|
from lightrag.llm.openai import openai_embed
|
||||||
|
|
||||||
|
return await openai_embed(
|
||||||
|
texts, model=model, base_url=host, api_key=api_key
|
||||||
|
)
|
||||||
|
except ImportError as e:
|
||||||
|
raise Exception(f"Failed to import {binding} embedding: {e}")
|
||||||
|
|
||||||
|
return embedding_function
|
||||||
|
|
||||||
llm_timeout = get_env_value("LLM_TIMEOUT", DEFAULT_LLM_TIMEOUT, int)
|
llm_timeout = get_env_value("LLM_TIMEOUT", DEFAULT_LLM_TIMEOUT, int)
|
||||||
embedding_timeout = get_env_value(
|
embedding_timeout = get_env_value(
|
||||||
|
|
@ -269,6 +345,10 @@ def create_app(args):
|
||||||
keyword_extraction=False,
|
keyword_extraction=False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> str:
|
) -> str:
|
||||||
|
# Lazy import
|
||||||
|
from lightrag.llm.openai import openai_complete_if_cache
|
||||||
|
from lightrag.llm.binding_options import OpenAILLMOptions
|
||||||
|
|
||||||
keyword_extraction = kwargs.pop("keyword_extraction", None)
|
keyword_extraction = kwargs.pop("keyword_extraction", None)
|
||||||
if keyword_extraction:
|
if keyword_extraction:
|
||||||
kwargs["response_format"] = GPTKeywordExtractionFormat
|
kwargs["response_format"] = GPTKeywordExtractionFormat
|
||||||
|
|
@ -297,6 +377,10 @@ def create_app(args):
|
||||||
keyword_extraction=False,
|
keyword_extraction=False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> str:
|
) -> str:
|
||||||
|
# Lazy import
|
||||||
|
from lightrag.llm.azure_openai import azure_openai_complete_if_cache
|
||||||
|
from lightrag.llm.binding_options import OpenAILLMOptions
|
||||||
|
|
||||||
keyword_extraction = kwargs.pop("keyword_extraction", None)
|
keyword_extraction = kwargs.pop("keyword_extraction", None)
|
||||||
if keyword_extraction:
|
if keyword_extraction:
|
||||||
kwargs["response_format"] = GPTKeywordExtractionFormat
|
kwargs["response_format"] = GPTKeywordExtractionFormat
|
||||||
|
|
@ -326,6 +410,9 @@ def create_app(args):
|
||||||
keyword_extraction=False,
|
keyword_extraction=False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> str:
|
) -> str:
|
||||||
|
# Lazy import
|
||||||
|
from lightrag.llm.bedrock import bedrock_complete_if_cache
|
||||||
|
|
||||||
keyword_extraction = kwargs.pop("keyword_extraction", None)
|
keyword_extraction = kwargs.pop("keyword_extraction", None)
|
||||||
if keyword_extraction:
|
if keyword_extraction:
|
||||||
kwargs["response_format"] = GPTKeywordExtractionFormat
|
kwargs["response_format"] = GPTKeywordExtractionFormat
|
||||||
|
|
@ -343,80 +430,10 @@ def create_app(args):
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_embedding_function(binding, model, host, api_key, dimensions, args):
|
# Create embedding function with lazy imports
|
||||||
"""
|
|
||||||
Create embedding function with args object for dynamic option generation.
|
|
||||||
|
|
||||||
This approach completely avoids closure issues by capturing configuration
|
|
||||||
values as function parameters rather than through variable references.
|
|
||||||
The args object is used only for dynamic option generation when needed.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
binding: The embedding provider binding (lollms, ollama, etc.)
|
|
||||||
model: The embedding model name
|
|
||||||
host: The host URL for the embedding service
|
|
||||||
api_key: API key for authentication
|
|
||||||
dimensions: Embedding dimensions
|
|
||||||
args: Arguments object for dynamic option generation (only used when needed)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Async function that performs embedding based on the specified provider
|
|
||||||
"""
|
|
||||||
|
|
||||||
async def embedding_function(texts):
|
|
||||||
"""Embedding function with captured configuration parameters"""
|
|
||||||
if binding == "lollms":
|
|
||||||
return await lollms_embed(
|
|
||||||
texts,
|
|
||||||
embed_model=model,
|
|
||||||
host=host,
|
|
||||||
api_key=api_key,
|
|
||||||
)
|
|
||||||
elif binding == "ollama":
|
|
||||||
# Only import and generate ollama_options when actually needed
|
|
||||||
from lightrag.llm.binding_options import OllamaEmbeddingOptions
|
|
||||||
|
|
||||||
ollama_options = OllamaEmbeddingOptions.options_dict(args)
|
|
||||||
return await ollama_embed(
|
|
||||||
texts,
|
|
||||||
embed_model=model,
|
|
||||||
host=host,
|
|
||||||
api_key=api_key,
|
|
||||||
options=ollama_options,
|
|
||||||
)
|
|
||||||
elif binding == "azure_openai":
|
|
||||||
return await azure_openai_embed(
|
|
||||||
texts,
|
|
||||||
model=model,
|
|
||||||
api_key=api_key,
|
|
||||||
)
|
|
||||||
elif binding == "aws_bedrock":
|
|
||||||
return await bedrock_embed(
|
|
||||||
texts,
|
|
||||||
model=model,
|
|
||||||
)
|
|
||||||
elif binding == "jina":
|
|
||||||
return await jina_embed(
|
|
||||||
texts,
|
|
||||||
dimensions=dimensions,
|
|
||||||
base_url=host,
|
|
||||||
api_key=api_key,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Default to OpenAI-compatible embedding
|
|
||||||
return await openai_embed(
|
|
||||||
texts,
|
|
||||||
model=model,
|
|
||||||
base_url=host,
|
|
||||||
api_key=api_key,
|
|
||||||
)
|
|
||||||
|
|
||||||
return embedding_function
|
|
||||||
|
|
||||||
# Create embedding function with current configuration
|
|
||||||
embedding_func = EmbeddingFunc(
|
embedding_func = EmbeddingFunc(
|
||||||
embedding_dim=args.embedding_dim,
|
embedding_dim=args.embedding_dim,
|
||||||
func=create_embedding_function(
|
func=create_embedding_function_with_lazy_import(
|
||||||
binding=args.embedding_binding,
|
binding=args.embedding_binding,
|
||||||
model=args.embedding_model,
|
model=args.embedding_model,
|
||||||
host=args.embedding_binding_host,
|
host=args.embedding_binding_host,
|
||||||
|
|
@ -488,37 +505,20 @@ def create_app(args):
|
||||||
name=args.simulated_model_name, tag=args.simulated_model_tag
|
name=args.simulated_model_name, tag=args.simulated_model_tag
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize RAG
|
# Initialize RAG with unified configuration
|
||||||
if args.llm_binding in ["lollms", "ollama", "openai", "aws_bedrock"]:
|
try:
|
||||||
rag = LightRAG(
|
rag = LightRAG(
|
||||||
working_dir=args.working_dir,
|
working_dir=args.working_dir,
|
||||||
workspace=args.workspace,
|
workspace=args.workspace,
|
||||||
llm_model_func=(
|
llm_model_func=create_llm_model_func(args.llm_binding),
|
||||||
lollms_model_complete
|
|
||||||
if args.llm_binding == "lollms"
|
|
||||||
else (
|
|
||||||
ollama_model_complete
|
|
||||||
if args.llm_binding == "ollama"
|
|
||||||
else bedrock_model_complete
|
|
||||||
if args.llm_binding == "aws_bedrock"
|
|
||||||
else openai_alike_model_complete
|
|
||||||
)
|
|
||||||
),
|
|
||||||
llm_model_name=args.llm_model,
|
llm_model_name=args.llm_model,
|
||||||
llm_model_max_async=args.max_async,
|
llm_model_max_async=args.max_async,
|
||||||
summary_max_tokens=args.summary_max_tokens,
|
summary_max_tokens=args.summary_max_tokens,
|
||||||
summary_context_size=args.summary_context_size,
|
summary_context_size=args.summary_context_size,
|
||||||
chunk_token_size=int(args.chunk_size),
|
chunk_token_size=int(args.chunk_size),
|
||||||
chunk_overlap_token_size=int(args.chunk_overlap_size),
|
chunk_overlap_token_size=int(args.chunk_overlap_size),
|
||||||
llm_model_kwargs=(
|
llm_model_kwargs=create_llm_model_kwargs(
|
||||||
{
|
args.llm_binding, args, llm_timeout
|
||||||
"host": args.llm_binding_host,
|
|
||||||
"timeout": llm_timeout,
|
|
||||||
"options": OllamaLLMOptions.options_dict(args),
|
|
||||||
"api_key": args.llm_binding_api_key,
|
|
||||||
}
|
|
||||||
if args.llm_binding == "lollms" or args.llm_binding == "ollama"
|
|
||||||
else {}
|
|
||||||
),
|
),
|
||||||
embedding_func=embedding_func,
|
embedding_func=embedding_func,
|
||||||
default_llm_timeout=llm_timeout,
|
default_llm_timeout=llm_timeout,
|
||||||
|
|
@ -541,38 +541,9 @@ def create_app(args):
|
||||||
},
|
},
|
||||||
ollama_server_infos=ollama_server_infos,
|
ollama_server_infos=ollama_server_infos,
|
||||||
)
|
)
|
||||||
else: # azure_openai
|
except Exception as e:
|
||||||
rag = LightRAG(
|
logger.error(f"Failed to initialize LightRAG: {e}")
|
||||||
working_dir=args.working_dir,
|
raise
|
||||||
workspace=args.workspace,
|
|
||||||
llm_model_func=azure_openai_model_complete,
|
|
||||||
chunk_token_size=int(args.chunk_size),
|
|
||||||
chunk_overlap_token_size=int(args.chunk_overlap_size),
|
|
||||||
llm_model_name=args.llm_model,
|
|
||||||
llm_model_max_async=args.max_async,
|
|
||||||
summary_max_tokens=args.summary_max_tokens,
|
|
||||||
summary_context_size=args.summary_context_size,
|
|
||||||
embedding_func=embedding_func,
|
|
||||||
default_llm_timeout=llm_timeout,
|
|
||||||
default_embedding_timeout=embedding_timeout,
|
|
||||||
kv_storage=args.kv_storage,
|
|
||||||
graph_storage=args.graph_storage,
|
|
||||||
vector_storage=args.vector_storage,
|
|
||||||
doc_status_storage=args.doc_status_storage,
|
|
||||||
vector_db_storage_cls_kwargs={
|
|
||||||
"cosine_better_than_threshold": args.cosine_threshold
|
|
||||||
},
|
|
||||||
enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract,
|
|
||||||
enable_llm_cache=args.enable_llm_cache,
|
|
||||||
rerank_model_func=rerank_model_func,
|
|
||||||
max_parallel_insert=args.max_parallel_insert,
|
|
||||||
max_graph_nodes=args.max_graph_nodes,
|
|
||||||
addon_params={
|
|
||||||
"language": args.summary_language,
|
|
||||||
"entity_types": args.entity_types,
|
|
||||||
},
|
|
||||||
ollama_server_infos=ollama_server_infos,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add routes
|
# Add routes
|
||||||
app.include_router(
|
app.include_router(
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue