Merge pull request #2359 from danielaskdd/embedding-limit
Refact: Add Embedding Token Limit Configuration and Improve Error Handling
This commit is contained in:
commit
3b76eea20b
16 changed files with 434 additions and 102 deletions
38
env.example
38
env.example
|
|
@ -255,21 +255,23 @@ OLLAMA_LLM_NUM_CTX=32768
|
|||
### For OpenAI: Set to 'true' to enable dynamic dimension adjustment
|
||||
### For OpenAI: Set to 'false' (default) to disable sending dimension parameter
|
||||
### Note: Automatically ignored for backends that don't support dimension parameter (e.g., Ollama)
|
||||
# EMBEDDING_SEND_DIM=false
|
||||
|
||||
EMBEDDING_BINDING=ollama
|
||||
EMBEDDING_MODEL=bge-m3:latest
|
||||
EMBEDDING_DIM=1024
|
||||
EMBEDDING_BINDING_API_KEY=your_api_key
|
||||
# If LightRAG deployed in Docker uses host.docker.internal instead of localhost
|
||||
EMBEDDING_BINDING_HOST=http://localhost:11434
|
||||
|
||||
### OpenAI compatible (VoyageAI embedding openai compatible)
|
||||
# EMBEDDING_BINDING=openai
|
||||
# EMBEDDING_MODEL=text-embedding-3-large
|
||||
# EMBEDDING_DIM=3072
|
||||
# EMBEDDING_BINDING_HOST=https://api.openai.com/v1
|
||||
# Ollama embedding
|
||||
# EMBEDDING_BINDING=ollama
|
||||
# EMBEDDING_MODEL=bge-m3:latest
|
||||
# EMBEDDING_DIM=1024
|
||||
# EMBEDDING_BINDING_API_KEY=your_api_key
|
||||
### If LightRAG deployed in Docker uses host.docker.internal instead of localhost
|
||||
# EMBEDDING_BINDING_HOST=http://localhost:11434
|
||||
|
||||
### OpenAI compatible embedding
|
||||
EMBEDDING_BINDING=openai
|
||||
EMBEDDING_MODEL=text-embedding-3-large
|
||||
EMBEDDING_DIM=3072
|
||||
EMBEDDING_SEND_DIM=false
|
||||
EMBEDDING_TOKEN_LIMIT=8192
|
||||
EMBEDDING_BINDING_HOST=https://api.openai.com/v1
|
||||
EMBEDDING_BINDING_API_KEY=your_api_key
|
||||
|
||||
### Optional for Azure
|
||||
# AZURE_EMBEDDING_DEPLOYMENT=text-embedding-3-large
|
||||
|
|
@ -277,6 +279,16 @@ EMBEDDING_BINDING_HOST=http://localhost:11434
|
|||
# AZURE_EMBEDDING_ENDPOINT=your_endpoint
|
||||
# AZURE_EMBEDDING_API_KEY=your_api_key
|
||||
|
||||
### Gemini embedding
|
||||
# EMBEDDING_BINDING=gemini
|
||||
# EMBEDDING_MODEL=gemini-embedding-001
|
||||
# EMBEDDING_DIM=1536
|
||||
# EMBEDDING_TOKEN_LIMIT=2048
|
||||
# EMBEDDING_BINDING_HOST=https://generativelanguage.googleapis.com
|
||||
# EMBEDDING_BINDING_API_KEY=your_api_key
|
||||
### Gemini embedding requires sending dimension to server
|
||||
# EMBEDDING_SEND_DIM=true
|
||||
|
||||
### Jina AI Embedding
|
||||
# EMBEDDING_BINDING=jina
|
||||
# EMBEDDING_BINDING_HOST=https://api.jina.ai/v1/embeddings
|
||||
|
|
|
|||
|
|
@ -445,6 +445,11 @@ def parse_args() -> argparse.Namespace:
|
|||
"EMBEDDING_BATCH_NUM", DEFAULT_EMBEDDING_BATCH_NUM, int
|
||||
)
|
||||
|
||||
# Embedding token limit configuration
|
||||
args.embedding_token_limit = get_env_value(
|
||||
"EMBEDDING_TOKEN_LIMIT", None, int, special_none=True
|
||||
)
|
||||
|
||||
ollama_server_infos.LIGHTRAG_NAME = args.simulated_model_name
|
||||
ollama_server_infos.LIGHTRAG_TAG = args.simulated_model_tag
|
||||
|
||||
|
|
|
|||
|
|
@ -618,33 +618,108 @@ def create_app(args):
|
|||
|
||||
def create_optimized_embedding_function(
|
||||
config_cache: LLMConfigCache, binding, model, host, api_key, args
|
||||
):
|
||||
) -> EmbeddingFunc:
|
||||
"""
|
||||
Create optimized embedding function with pre-processed configuration for applicable bindings.
|
||||
Uses lazy imports for all bindings and avoids repeated configuration parsing.
|
||||
Create optimized embedding function and return an EmbeddingFunc instance
|
||||
with proper max_token_size inheritance from provider defaults.
|
||||
|
||||
This function:
|
||||
1. Imports the provider embedding function
|
||||
2. Extracts max_token_size and embedding_dim from provider if it's an EmbeddingFunc
|
||||
3. Creates an optimized wrapper that calls the underlying function directly (avoiding double-wrapping)
|
||||
4. Returns a properly configured EmbeddingFunc instance
|
||||
"""
|
||||
|
||||
# Step 1: Import provider function and extract default attributes
|
||||
provider_func = None
|
||||
provider_max_token_size = None
|
||||
provider_embedding_dim = None
|
||||
|
||||
try:
|
||||
if binding == "openai":
|
||||
from lightrag.llm.openai import openai_embed
|
||||
|
||||
provider_func = openai_embed
|
||||
elif binding == "ollama":
|
||||
from lightrag.llm.ollama import ollama_embed
|
||||
|
||||
provider_func = ollama_embed
|
||||
elif binding == "gemini":
|
||||
from lightrag.llm.gemini import gemini_embed
|
||||
|
||||
provider_func = gemini_embed
|
||||
elif binding == "jina":
|
||||
from lightrag.llm.jina import jina_embed
|
||||
|
||||
provider_func = jina_embed
|
||||
elif binding == "azure_openai":
|
||||
from lightrag.llm.azure_openai import azure_openai_embed
|
||||
|
||||
provider_func = azure_openai_embed
|
||||
elif binding == "aws_bedrock":
|
||||
from lightrag.llm.bedrock import bedrock_embed
|
||||
|
||||
provider_func = bedrock_embed
|
||||
elif binding == "lollms":
|
||||
from lightrag.llm.lollms import lollms_embed
|
||||
|
||||
provider_func = lollms_embed
|
||||
|
||||
# Extract attributes if provider is an EmbeddingFunc
|
||||
if provider_func and isinstance(provider_func, EmbeddingFunc):
|
||||
provider_max_token_size = provider_func.max_token_size
|
||||
provider_embedding_dim = provider_func.embedding_dim
|
||||
logger.debug(
|
||||
f"Extracted from {binding} provider: "
|
||||
f"max_token_size={provider_max_token_size}, "
|
||||
f"embedding_dim={provider_embedding_dim}"
|
||||
)
|
||||
except ImportError as e:
|
||||
logger.warning(f"Could not import provider function for {binding}: {e}")
|
||||
|
||||
# Step 2: Apply priority (user config > provider default)
|
||||
# For max_token_size: explicit env var > provider default > None
|
||||
final_max_token_size = args.embedding_token_limit or provider_max_token_size
|
||||
# For embedding_dim: user config (always has value) takes priority
|
||||
# Only use provider default if user config is explicitly None (which shouldn't happen)
|
||||
final_embedding_dim = (
|
||||
args.embedding_dim if args.embedding_dim else provider_embedding_dim
|
||||
)
|
||||
|
||||
# Step 3: Create optimized embedding function (calls underlying function directly)
|
||||
async def optimized_embedding_function(texts, embedding_dim=None):
|
||||
try:
|
||||
if binding == "lollms":
|
||||
from lightrag.llm.lollms import lollms_embed
|
||||
|
||||
return await lollms_embed(
|
||||
# Get real function, skip EmbeddingFunc wrapper if present
|
||||
actual_func = (
|
||||
lollms_embed.func
|
||||
if isinstance(lollms_embed, EmbeddingFunc)
|
||||
else lollms_embed
|
||||
)
|
||||
return await actual_func(
|
||||
texts, embed_model=model, host=host, api_key=api_key
|
||||
)
|
||||
elif binding == "ollama":
|
||||
from lightrag.llm.ollama import ollama_embed
|
||||
|
||||
# Use pre-processed configuration if available, otherwise fallback to dynamic parsing
|
||||
# Get real function, skip EmbeddingFunc wrapper if present
|
||||
actual_func = (
|
||||
ollama_embed.func
|
||||
if isinstance(ollama_embed, EmbeddingFunc)
|
||||
else ollama_embed
|
||||
)
|
||||
|
||||
# Use pre-processed configuration if available
|
||||
if config_cache.ollama_embedding_options is not None:
|
||||
ollama_options = config_cache.ollama_embedding_options
|
||||
else:
|
||||
# Fallback for cases where config cache wasn't initialized properly
|
||||
from lightrag.llm.binding_options import OllamaEmbeddingOptions
|
||||
|
||||
ollama_options = OllamaEmbeddingOptions.options_dict(args)
|
||||
|
||||
return await ollama_embed(
|
||||
return await actual_func(
|
||||
texts,
|
||||
embed_model=model,
|
||||
host=host,
|
||||
|
|
@ -654,15 +729,30 @@ def create_app(args):
|
|||
elif binding == "azure_openai":
|
||||
from lightrag.llm.azure_openai import azure_openai_embed
|
||||
|
||||
return await azure_openai_embed(texts, model=model, api_key=api_key)
|
||||
actual_func = (
|
||||
azure_openai_embed.func
|
||||
if isinstance(azure_openai_embed, EmbeddingFunc)
|
||||
else azure_openai_embed
|
||||
)
|
||||
return await actual_func(texts, model=model, api_key=api_key)
|
||||
elif binding == "aws_bedrock":
|
||||
from lightrag.llm.bedrock import bedrock_embed
|
||||
|
||||
return await bedrock_embed(texts, model=model)
|
||||
actual_func = (
|
||||
bedrock_embed.func
|
||||
if isinstance(bedrock_embed, EmbeddingFunc)
|
||||
else bedrock_embed
|
||||
)
|
||||
return await actual_func(texts, model=model)
|
||||
elif binding == "jina":
|
||||
from lightrag.llm.jina import jina_embed
|
||||
|
||||
return await jina_embed(
|
||||
actual_func = (
|
||||
jina_embed.func
|
||||
if isinstance(jina_embed, EmbeddingFunc)
|
||||
else jina_embed
|
||||
)
|
||||
return await actual_func(
|
||||
texts,
|
||||
embedding_dim=embedding_dim,
|
||||
base_url=host,
|
||||
|
|
@ -671,16 +761,21 @@ def create_app(args):
|
|||
elif binding == "gemini":
|
||||
from lightrag.llm.gemini import gemini_embed
|
||||
|
||||
# Use pre-processed configuration if available, otherwise fallback to dynamic parsing
|
||||
actual_func = (
|
||||
gemini_embed.func
|
||||
if isinstance(gemini_embed, EmbeddingFunc)
|
||||
else gemini_embed
|
||||
)
|
||||
|
||||
# Use pre-processed configuration if available
|
||||
if config_cache.gemini_embedding_options is not None:
|
||||
gemini_options = config_cache.gemini_embedding_options
|
||||
else:
|
||||
# Fallback for cases where config cache wasn't initialized properly
|
||||
from lightrag.llm.binding_options import GeminiEmbeddingOptions
|
||||
|
||||
gemini_options = GeminiEmbeddingOptions.options_dict(args)
|
||||
|
||||
return await gemini_embed(
|
||||
return await actual_func(
|
||||
texts,
|
||||
model=model,
|
||||
base_url=host,
|
||||
|
|
@ -691,7 +786,12 @@ def create_app(args):
|
|||
else: # openai and compatible
|
||||
from lightrag.llm.openai import openai_embed
|
||||
|
||||
return await openai_embed(
|
||||
actual_func = (
|
||||
openai_embed.func
|
||||
if isinstance(openai_embed, EmbeddingFunc)
|
||||
else openai_embed
|
||||
)
|
||||
return await actual_func(
|
||||
texts,
|
||||
model=model,
|
||||
base_url=host,
|
||||
|
|
@ -701,7 +801,21 @@ def create_app(args):
|
|||
except ImportError as e:
|
||||
raise Exception(f"Failed to import {binding} embedding: {e}")
|
||||
|
||||
return optimized_embedding_function
|
||||
# Step 4: Wrap in EmbeddingFunc and return
|
||||
embedding_func_instance = EmbeddingFunc(
|
||||
embedding_dim=final_embedding_dim,
|
||||
func=optimized_embedding_function,
|
||||
max_token_size=final_max_token_size,
|
||||
send_dimensions=False, # Will be set later based on binding requirements
|
||||
)
|
||||
|
||||
# Log final embedding configuration
|
||||
logger.info(
|
||||
f"Embedding config: binding={binding} model={model} "
|
||||
f"embedding_dim={final_embedding_dim} max_token_size={final_max_token_size}"
|
||||
)
|
||||
|
||||
return embedding_func_instance
|
||||
|
||||
llm_timeout = get_env_value("LLM_TIMEOUT", DEFAULT_LLM_TIMEOUT, int)
|
||||
embedding_timeout = get_env_value(
|
||||
|
|
@ -735,25 +849,24 @@ def create_app(args):
|
|||
**kwargs,
|
||||
)
|
||||
|
||||
# Create embedding function with optimized configuration
|
||||
# Create embedding function with optimized configuration and max_token_size inheritance
|
||||
import inspect
|
||||
|
||||
# Create the optimized embedding function
|
||||
optimized_embedding_func = create_optimized_embedding_function(
|
||||
# Create the EmbeddingFunc instance (now returns complete EmbeddingFunc with max_token_size)
|
||||
embedding_func = create_optimized_embedding_function(
|
||||
config_cache=config_cache,
|
||||
binding=args.embedding_binding,
|
||||
model=args.embedding_model,
|
||||
host=args.embedding_binding_host,
|
||||
api_key=args.embedding_binding_api_key,
|
||||
args=args, # Pass args object for fallback option generation
|
||||
args=args,
|
||||
)
|
||||
|
||||
# Get embedding_send_dim from centralized configuration
|
||||
embedding_send_dim = args.embedding_send_dim
|
||||
|
||||
# Check if the function signature has embedding_dim parameter
|
||||
# Note: Since optimized_embedding_func is an async function, inspect its signature
|
||||
sig = inspect.signature(optimized_embedding_func)
|
||||
# Check if the underlying function signature has embedding_dim parameter
|
||||
sig = inspect.signature(embedding_func.func)
|
||||
has_embedding_dim_param = "embedding_dim" in sig.parameters
|
||||
|
||||
# Determine send_dimensions value based on binding type
|
||||
|
|
@ -771,18 +884,27 @@ def create_app(args):
|
|||
else:
|
||||
dimension_control = "by not hasparam"
|
||||
|
||||
# Set send_dimensions on the EmbeddingFunc instance
|
||||
embedding_func.send_dimensions = send_dimensions
|
||||
|
||||
logger.info(
|
||||
f"Send embedding dimension: {send_dimensions} {dimension_control} "
|
||||
f"(dimensions={args.embedding_dim}, has_param={has_embedding_dim_param}, "
|
||||
f"(dimensions={embedding_func.embedding_dim}, has_param={has_embedding_dim_param}, "
|
||||
f"binding={args.embedding_binding})"
|
||||
)
|
||||
|
||||
# Create EmbeddingFunc with send_dimensions attribute
|
||||
embedding_func = EmbeddingFunc(
|
||||
embedding_dim=args.embedding_dim,
|
||||
func=optimized_embedding_func,
|
||||
send_dimensions=send_dimensions,
|
||||
)
|
||||
# Log max_token_size source
|
||||
if embedding_func.max_token_size:
|
||||
source = (
|
||||
"env variable"
|
||||
if args.embedding_token_limit
|
||||
else f"{args.embedding_binding} provider default"
|
||||
)
|
||||
logger.info(
|
||||
f"Embedding max_token_size: {embedding_func.max_token_size} (from {source})"
|
||||
)
|
||||
else:
|
||||
logger.info("Embedding max_token_size: not set (90% token warning disabled)")
|
||||
|
||||
# Configure rerank function based on args.rerank_bindingparameter
|
||||
rerank_model_func = None
|
||||
|
|
|
|||
|
|
@ -276,6 +276,9 @@ class LightRAG:
|
|||
embedding_func: EmbeddingFunc | None = field(default=None)
|
||||
"""Function for computing text embeddings. Must be set before use."""
|
||||
|
||||
embedding_token_limit: int | None = field(default=None, init=False)
|
||||
"""Token limit for embedding model. Set automatically from embedding_func.max_token_size in __post_init__."""
|
||||
|
||||
embedding_batch_num: int = field(default=int(os.getenv("EMBEDDING_BATCH_NUM", 10)))
|
||||
"""Batch size for embedding computations."""
|
||||
|
||||
|
|
@ -519,6 +522,16 @@ class LightRAG:
|
|||
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
|
||||
|
||||
# Init Embedding
|
||||
# Step 1: Capture max_token_size before applying decorator (decorator strips dataclass attributes)
|
||||
embedding_max_token_size = None
|
||||
if self.embedding_func and hasattr(self.embedding_func, "max_token_size"):
|
||||
embedding_max_token_size = self.embedding_func.max_token_size
|
||||
logger.debug(
|
||||
f"Captured embedding max_token_size: {embedding_max_token_size}"
|
||||
)
|
||||
self.embedding_token_limit = embedding_max_token_size
|
||||
|
||||
# Step 2: Apply priority wrapper decorator
|
||||
self.embedding_func = priority_limit_async_func_call(
|
||||
self.embedding_func_max_async,
|
||||
llm_timeout=self.default_embedding_timeout,
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import copy
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
|
||||
import pipmaster as pm # Pipmaster for dynamic library install
|
||||
|
||||
|
|
@ -16,6 +17,7 @@ from tenacity import (
|
|||
)
|
||||
|
||||
import sys
|
||||
from lightrag.utils import wrap_embedding_func_with_attrs
|
||||
|
||||
if sys.version_info < (3, 9):
|
||||
from typing import AsyncIterator
|
||||
|
|
@ -23,21 +25,121 @@ else:
|
|||
from collections.abc import AsyncIterator
|
||||
from typing import Union
|
||||
|
||||
# Import botocore exceptions for proper exception handling
|
||||
try:
|
||||
from botocore.exceptions import (
|
||||
ClientError,
|
||||
ConnectionError as BotocoreConnectionError,
|
||||
ReadTimeoutError,
|
||||
)
|
||||
except ImportError:
|
||||
# If botocore is not installed, define placeholders
|
||||
ClientError = Exception
|
||||
BotocoreConnectionError = Exception
|
||||
ReadTimeoutError = Exception
|
||||
|
||||
|
||||
class BedrockError(Exception):
|
||||
"""Generic error for issues related to Amazon Bedrock"""
|
||||
|
||||
|
||||
class BedrockRateLimitError(BedrockError):
|
||||
"""Error for rate limiting and throttling issues"""
|
||||
|
||||
|
||||
class BedrockConnectionError(BedrockError):
|
||||
"""Error for network and connection issues"""
|
||||
|
||||
|
||||
class BedrockTimeoutError(BedrockError):
|
||||
"""Error for timeout issues"""
|
||||
|
||||
|
||||
def _set_env_if_present(key: str, value):
|
||||
"""Set environment variable only if a non-empty value is provided."""
|
||||
if value is not None and value != "":
|
||||
os.environ[key] = value
|
||||
|
||||
|
||||
def _handle_bedrock_exception(e: Exception, operation: str = "Bedrock API") -> None:
|
||||
"""Convert AWS Bedrock exceptions to appropriate custom exceptions.
|
||||
|
||||
Args:
|
||||
e: The exception to handle
|
||||
operation: Description of the operation for error messages
|
||||
|
||||
Raises:
|
||||
BedrockRateLimitError: For rate limiting and throttling issues (retryable)
|
||||
BedrockConnectionError: For network and server issues (retryable)
|
||||
BedrockTimeoutError: For timeout issues (retryable)
|
||||
BedrockError: For validation and other non-retryable errors
|
||||
"""
|
||||
error_message = str(e)
|
||||
|
||||
# Handle botocore ClientError with specific error codes
|
||||
if isinstance(e, ClientError):
|
||||
error_code = e.response.get("Error", {}).get("Code", "")
|
||||
error_msg = e.response.get("Error", {}).get("Message", error_message)
|
||||
|
||||
# Rate limiting and throttling errors (retryable)
|
||||
if error_code in [
|
||||
"ThrottlingException",
|
||||
"ProvisionedThroughputExceededException",
|
||||
]:
|
||||
logging.error(f"{operation} rate limit error: {error_msg}")
|
||||
raise BedrockRateLimitError(f"Rate limit error: {error_msg}")
|
||||
|
||||
# Server errors (retryable)
|
||||
elif error_code in ["ServiceUnavailableException", "InternalServerException"]:
|
||||
logging.error(f"{operation} connection error: {error_msg}")
|
||||
raise BedrockConnectionError(f"Service error: {error_msg}")
|
||||
|
||||
# Check for 5xx HTTP status codes (retryable)
|
||||
elif e.response.get("ResponseMetadata", {}).get("HTTPStatusCode", 0) >= 500:
|
||||
logging.error(f"{operation} server error: {error_msg}")
|
||||
raise BedrockConnectionError(f"Server error: {error_msg}")
|
||||
|
||||
# Validation and other client errors (non-retryable)
|
||||
else:
|
||||
logging.error(f"{operation} client error: {error_msg}")
|
||||
raise BedrockError(f"Client error: {error_msg}")
|
||||
|
||||
# Connection errors (retryable)
|
||||
elif isinstance(e, BotocoreConnectionError):
|
||||
logging.error(f"{operation} connection error: {error_message}")
|
||||
raise BedrockConnectionError(f"Connection error: {error_message}")
|
||||
|
||||
# Timeout errors (retryable)
|
||||
elif isinstance(e, (ReadTimeoutError, TimeoutError)):
|
||||
logging.error(f"{operation} timeout error: {error_message}")
|
||||
raise BedrockTimeoutError(f"Timeout error: {error_message}")
|
||||
|
||||
# Custom Bedrock errors (already properly typed)
|
||||
elif isinstance(
|
||||
e,
|
||||
(
|
||||
BedrockRateLimitError,
|
||||
BedrockConnectionError,
|
||||
BedrockTimeoutError,
|
||||
BedrockError,
|
||||
),
|
||||
):
|
||||
raise
|
||||
|
||||
# Unknown errors (non-retryable)
|
||||
else:
|
||||
logging.error(f"{operation} unexpected error: {error_message}")
|
||||
raise BedrockError(f"Unexpected error: {error_message}")
|
||||
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(5),
|
||||
wait=wait_exponential(multiplier=1, max=60),
|
||||
retry=retry_if_exception_type((BedrockError)),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||
retry=(
|
||||
retry_if_exception_type(BedrockRateLimitError)
|
||||
| retry_if_exception_type(BedrockConnectionError)
|
||||
| retry_if_exception_type(BedrockTimeoutError)
|
||||
),
|
||||
)
|
||||
async def bedrock_complete_if_cache(
|
||||
model,
|
||||
|
|
@ -158,9 +260,6 @@ async def bedrock_complete_if_cache(
|
|||
break
|
||||
|
||||
except Exception as e:
|
||||
# Log the specific error for debugging
|
||||
logging.error(f"Bedrock streaming error: {e}")
|
||||
|
||||
# Try to clean up resources if possible
|
||||
if (
|
||||
iteration_started
|
||||
|
|
@ -175,7 +274,8 @@ async def bedrock_complete_if_cache(
|
|||
f"Failed to close Bedrock event stream: {close_error}"
|
||||
)
|
||||
|
||||
raise BedrockError(f"Streaming error: {e}")
|
||||
# Convert to appropriate exception type
|
||||
_handle_bedrock_exception(e, "Bedrock streaming")
|
||||
|
||||
finally:
|
||||
# Clean up the event stream
|
||||
|
|
@ -231,10 +331,8 @@ async def bedrock_complete_if_cache(
|
|||
return content
|
||||
|
||||
except Exception as e:
|
||||
if isinstance(e, BedrockError):
|
||||
raise
|
||||
else:
|
||||
raise BedrockError(f"Bedrock API error: {e}")
|
||||
# Convert to appropriate exception type
|
||||
_handle_bedrock_exception(e, "Bedrock converse")
|
||||
|
||||
|
||||
# Generic Bedrock completion function
|
||||
|
|
@ -253,12 +351,16 @@ async def bedrock_complete(
|
|||
return result
|
||||
|
||||
|
||||
# @wrap_embedding_func_with_attrs(embedding_dim=1024)
|
||||
# @retry(
|
||||
# stop=stop_after_attempt(3),
|
||||
# wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
# retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), # TODO: fix exceptions
|
||||
# )
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
|
||||
@retry(
|
||||
stop=stop_after_attempt(5),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||
retry=(
|
||||
retry_if_exception_type(BedrockRateLimitError)
|
||||
| retry_if_exception_type(BedrockConnectionError)
|
||||
| retry_if_exception_type(BedrockTimeoutError)
|
||||
),
|
||||
)
|
||||
async def bedrock_embed(
|
||||
texts: list[str],
|
||||
model: str = "amazon.titan-embed-text-v2:0",
|
||||
|
|
@ -281,48 +383,101 @@ async def bedrock_embed(
|
|||
async with session.client(
|
||||
"bedrock-runtime", region_name=region
|
||||
) as bedrock_async_client:
|
||||
if (model_provider := model.split(".")[0]) == "amazon":
|
||||
embed_texts = []
|
||||
for text in texts:
|
||||
if "v2" in model:
|
||||
try:
|
||||
if (model_provider := model.split(".")[0]) == "amazon":
|
||||
embed_texts = []
|
||||
for text in texts:
|
||||
try:
|
||||
if "v2" in model:
|
||||
body = json.dumps(
|
||||
{
|
||||
"inputText": text,
|
||||
# 'dimensions': embedding_dim,
|
||||
"embeddingTypes": ["float"],
|
||||
}
|
||||
)
|
||||
elif "v1" in model:
|
||||
body = json.dumps({"inputText": text})
|
||||
else:
|
||||
raise BedrockError(f"Model {model} is not supported!")
|
||||
|
||||
response = await bedrock_async_client.invoke_model(
|
||||
modelId=model,
|
||||
body=body,
|
||||
accept="application/json",
|
||||
contentType="application/json",
|
||||
)
|
||||
|
||||
response_body = await response.get("body").json()
|
||||
|
||||
# Validate response structure
|
||||
if not response_body or "embedding" not in response_body:
|
||||
raise BedrockError(
|
||||
f"Invalid embedding response structure for text: {text[:50]}..."
|
||||
)
|
||||
|
||||
embedding = response_body["embedding"]
|
||||
if not embedding:
|
||||
raise BedrockError(
|
||||
f"Received empty embedding for text: {text[:50]}..."
|
||||
)
|
||||
|
||||
embed_texts.append(embedding)
|
||||
|
||||
except Exception as e:
|
||||
# Convert to appropriate exception type
|
||||
_handle_bedrock_exception(
|
||||
e, "Bedrock embedding (amazon, text chunk)"
|
||||
)
|
||||
|
||||
elif model_provider == "cohere":
|
||||
try:
|
||||
body = json.dumps(
|
||||
{
|
||||
"inputText": text,
|
||||
# 'dimensions': embedding_dim,
|
||||
"embeddingTypes": ["float"],
|
||||
"texts": texts,
|
||||
"input_type": "search_document",
|
||||
"truncate": "NONE",
|
||||
}
|
||||
)
|
||||
elif "v1" in model:
|
||||
body = json.dumps({"inputText": text})
|
||||
else:
|
||||
raise ValueError(f"Model {model} is not supported!")
|
||||
|
||||
response = await bedrock_async_client.invoke_model(
|
||||
modelId=model,
|
||||
body=body,
|
||||
accept="application/json",
|
||||
contentType="application/json",
|
||||
response = await bedrock_async_client.invoke_model(
|
||||
model=model,
|
||||
body=body,
|
||||
accept="application/json",
|
||||
contentType="application/json",
|
||||
)
|
||||
|
||||
response_body = json.loads(response.get("body").read())
|
||||
|
||||
# Validate response structure
|
||||
if not response_body or "embeddings" not in response_body:
|
||||
raise BedrockError(
|
||||
"Invalid embedding response structure from Cohere"
|
||||
)
|
||||
|
||||
embeddings = response_body["embeddings"]
|
||||
if not embeddings or len(embeddings) != len(texts):
|
||||
raise BedrockError(
|
||||
f"Invalid embeddings count: expected {len(texts)}, got {len(embeddings) if embeddings else 0}"
|
||||
)
|
||||
|
||||
embed_texts = embeddings
|
||||
|
||||
except Exception as e:
|
||||
# Convert to appropriate exception type
|
||||
_handle_bedrock_exception(e, "Bedrock embedding (cohere)")
|
||||
|
||||
else:
|
||||
raise BedrockError(
|
||||
f"Model provider '{model_provider}' is not supported!"
|
||||
)
|
||||
|
||||
response_body = await response.get("body").json()
|
||||
# Final validation
|
||||
if not embed_texts:
|
||||
raise BedrockError("No embeddings generated")
|
||||
|
||||
embed_texts.append(response_body["embedding"])
|
||||
elif model_provider == "cohere":
|
||||
body = json.dumps(
|
||||
{"texts": texts, "input_type": "search_document", "truncate": "NONE"}
|
||||
)
|
||||
return np.array(embed_texts)
|
||||
|
||||
response = await bedrock_async_client.invoke_model(
|
||||
model=model,
|
||||
body=body,
|
||||
accept="application/json",
|
||||
contentType="application/json",
|
||||
)
|
||||
|
||||
response_body = json.loads(response.get("body").read())
|
||||
|
||||
embed_texts = response_body["embeddings"]
|
||||
else:
|
||||
raise ValueError(f"Model provider '{model_provider}' is not supported!")
|
||||
|
||||
return np.array(embed_texts)
|
||||
except Exception as e:
|
||||
# Convert to appropriate exception type
|
||||
_handle_bedrock_exception(e, "Bedrock embedding")
|
||||
|
|
|
|||
|
|
@ -453,7 +453,7 @@ async def gemini_model_complete(
|
|||
)
|
||||
|
||||
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=1536)
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=2048)
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ from lightrag.exceptions import (
|
|||
)
|
||||
import torch
|
||||
import numpy as np
|
||||
from lightrag.utils import wrap_embedding_func_with_attrs
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
|
|
@ -141,6 +142,7 @@ async def hf_model_complete(
|
|||
return result
|
||||
|
||||
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
|
||||
async def hf_embed(texts: list[str], tokenizer, embed_model) -> np.ndarray:
|
||||
# Detect the appropriate device
|
||||
if torch.cuda.is_available():
|
||||
|
|
|
|||
|
|
@ -58,7 +58,7 @@ async def fetch_data(url, headers, data):
|
|||
return data_list
|
||||
|
||||
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=2048)
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=2048, max_token_size=8192)
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||
|
|
|
|||
|
|
@ -174,7 +174,7 @@ async def llama_index_complete(
|
|||
return result
|
||||
|
||||
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=1536)
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||
|
|
|
|||
|
|
@ -26,6 +26,10 @@ from lightrag.exceptions import (
|
|||
from typing import Union, List
|
||||
import numpy as np
|
||||
|
||||
from lightrag.utils import (
|
||||
wrap_embedding_func_with_attrs,
|
||||
)
|
||||
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
|
|
@ -134,6 +138,7 @@ async def lollms_model_complete(
|
|||
)
|
||||
|
||||
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
|
||||
async def lollms_embed(
|
||||
texts: List[str], embed_model=None, base_url="http://localhost:9600", **kwargs
|
||||
) -> np.ndarray:
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ from lightrag.utils import (
|
|||
import numpy as np
|
||||
|
||||
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=2048)
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=2048, max_token_size=8192)
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||
|
|
|
|||
|
|
@ -25,7 +25,10 @@ from lightrag.api import __api_version__
|
|||
|
||||
import numpy as np
|
||||
from typing import Optional, Union
|
||||
from lightrag.utils import logger
|
||||
from lightrag.utils import (
|
||||
wrap_embedding_func_with_attrs,
|
||||
logger,
|
||||
)
|
||||
|
||||
|
||||
_OLLAMA_CLOUD_HOST = "https://ollama.com"
|
||||
|
|
@ -169,6 +172,7 @@ async def ollama_model_complete(
|
|||
)
|
||||
|
||||
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
|
||||
async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray:
|
||||
api_key = kwargs.pop("api_key", None)
|
||||
if not api_key:
|
||||
|
|
|
|||
|
|
@ -47,7 +47,7 @@ try:
|
|||
|
||||
# Only enable Langfuse if both keys are configured
|
||||
if langfuse_public_key and langfuse_secret_key:
|
||||
from langfuse.openai import AsyncOpenAI
|
||||
from langfuse.openai import AsyncOpenAI # type: ignore[import-untyped]
|
||||
|
||||
LANGFUSE_ENABLED = True
|
||||
logger.info("Langfuse observability enabled for OpenAI client")
|
||||
|
|
@ -604,7 +604,7 @@ async def nvidia_openai_complete(
|
|||
return result
|
||||
|
||||
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=1536)
|
||||
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||
|
|
|
|||
|
|
@ -345,6 +345,20 @@ async def _summarize_descriptions(
|
|||
llm_response_cache=llm_response_cache,
|
||||
cache_type="summary",
|
||||
)
|
||||
|
||||
# Check summary token length against embedding limit
|
||||
embedding_token_limit = global_config.get("embedding_token_limit")
|
||||
if embedding_token_limit is not None and summary:
|
||||
tokenizer = global_config["tokenizer"]
|
||||
summary_token_count = len(tokenizer.encode(summary))
|
||||
threshold = int(embedding_token_limit * 0.9)
|
||||
|
||||
if summary_token_count > threshold:
|
||||
logger.warning(
|
||||
f"Summary tokens ({summary_token_count}) exceeds 90% of embedding limit "
|
||||
f"({embedding_token_limit}) for {description_type}: {description_name}"
|
||||
)
|
||||
|
||||
return summary
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -355,7 +355,7 @@ class TaskState:
|
|||
class EmbeddingFunc:
|
||||
embedding_dim: int
|
||||
func: callable
|
||||
max_token_size: int | None = None # deprecated keep it for compatible only
|
||||
max_token_size: int | None = None # Token limit for the embedding model
|
||||
send_dimensions: bool = (
|
||||
False # Control whether to send embedding_dim to the function
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue