Merge pull request #2359 from danielaskdd/embedding-limit

Refact: Add Embedding Token Limit Configuration and Improve Error Handling
This commit is contained in:
Daniel.y 2025-11-15 01:27:26 +08:00 committed by GitHub
commit 3b76eea20b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 434 additions and 102 deletions

View file

@ -255,21 +255,23 @@ OLLAMA_LLM_NUM_CTX=32768
### For OpenAI: Set to 'true' to enable dynamic dimension adjustment
### For OpenAI: Set to 'false' (default) to disable sending dimension parameter
### Note: Automatically ignored for backends that don't support dimension parameter (e.g., Ollama)
# EMBEDDING_SEND_DIM=false
EMBEDDING_BINDING=ollama
EMBEDDING_MODEL=bge-m3:latest
EMBEDDING_DIM=1024
EMBEDDING_BINDING_API_KEY=your_api_key
# If LightRAG deployed in Docker uses host.docker.internal instead of localhost
EMBEDDING_BINDING_HOST=http://localhost:11434
### OpenAI compatible (VoyageAI embedding openai compatible)
# EMBEDDING_BINDING=openai
# EMBEDDING_MODEL=text-embedding-3-large
# EMBEDDING_DIM=3072
# EMBEDDING_BINDING_HOST=https://api.openai.com/v1
# Ollama embedding
# EMBEDDING_BINDING=ollama
# EMBEDDING_MODEL=bge-m3:latest
# EMBEDDING_DIM=1024
# EMBEDDING_BINDING_API_KEY=your_api_key
### If LightRAG deployed in Docker uses host.docker.internal instead of localhost
# EMBEDDING_BINDING_HOST=http://localhost:11434
### OpenAI compatible embedding
EMBEDDING_BINDING=openai
EMBEDDING_MODEL=text-embedding-3-large
EMBEDDING_DIM=3072
EMBEDDING_SEND_DIM=false
EMBEDDING_TOKEN_LIMIT=8192
EMBEDDING_BINDING_HOST=https://api.openai.com/v1
EMBEDDING_BINDING_API_KEY=your_api_key
### Optional for Azure
# AZURE_EMBEDDING_DEPLOYMENT=text-embedding-3-large
@ -277,6 +279,16 @@ EMBEDDING_BINDING_HOST=http://localhost:11434
# AZURE_EMBEDDING_ENDPOINT=your_endpoint
# AZURE_EMBEDDING_API_KEY=your_api_key
### Gemini embedding
# EMBEDDING_BINDING=gemini
# EMBEDDING_MODEL=gemini-embedding-001
# EMBEDDING_DIM=1536
# EMBEDDING_TOKEN_LIMIT=2048
# EMBEDDING_BINDING_HOST=https://generativelanguage.googleapis.com
# EMBEDDING_BINDING_API_KEY=your_api_key
### Gemini embedding requires sending dimension to server
# EMBEDDING_SEND_DIM=true
### Jina AI Embedding
# EMBEDDING_BINDING=jina
# EMBEDDING_BINDING_HOST=https://api.jina.ai/v1/embeddings

View file

@ -445,6 +445,11 @@ def parse_args() -> argparse.Namespace:
"EMBEDDING_BATCH_NUM", DEFAULT_EMBEDDING_BATCH_NUM, int
)
# Embedding token limit configuration
args.embedding_token_limit = get_env_value(
"EMBEDDING_TOKEN_LIMIT", None, int, special_none=True
)
ollama_server_infos.LIGHTRAG_NAME = args.simulated_model_name
ollama_server_infos.LIGHTRAG_TAG = args.simulated_model_tag

View file

@ -618,33 +618,108 @@ def create_app(args):
def create_optimized_embedding_function(
config_cache: LLMConfigCache, binding, model, host, api_key, args
):
) -> EmbeddingFunc:
"""
Create optimized embedding function with pre-processed configuration for applicable bindings.
Uses lazy imports for all bindings and avoids repeated configuration parsing.
Create optimized embedding function and return an EmbeddingFunc instance
with proper max_token_size inheritance from provider defaults.
This function:
1. Imports the provider embedding function
2. Extracts max_token_size and embedding_dim from provider if it's an EmbeddingFunc
3. Creates an optimized wrapper that calls the underlying function directly (avoiding double-wrapping)
4. Returns a properly configured EmbeddingFunc instance
"""
# Step 1: Import provider function and extract default attributes
provider_func = None
provider_max_token_size = None
provider_embedding_dim = None
try:
if binding == "openai":
from lightrag.llm.openai import openai_embed
provider_func = openai_embed
elif binding == "ollama":
from lightrag.llm.ollama import ollama_embed
provider_func = ollama_embed
elif binding == "gemini":
from lightrag.llm.gemini import gemini_embed
provider_func = gemini_embed
elif binding == "jina":
from lightrag.llm.jina import jina_embed
provider_func = jina_embed
elif binding == "azure_openai":
from lightrag.llm.azure_openai import azure_openai_embed
provider_func = azure_openai_embed
elif binding == "aws_bedrock":
from lightrag.llm.bedrock import bedrock_embed
provider_func = bedrock_embed
elif binding == "lollms":
from lightrag.llm.lollms import lollms_embed
provider_func = lollms_embed
# Extract attributes if provider is an EmbeddingFunc
if provider_func and isinstance(provider_func, EmbeddingFunc):
provider_max_token_size = provider_func.max_token_size
provider_embedding_dim = provider_func.embedding_dim
logger.debug(
f"Extracted from {binding} provider: "
f"max_token_size={provider_max_token_size}, "
f"embedding_dim={provider_embedding_dim}"
)
except ImportError as e:
logger.warning(f"Could not import provider function for {binding}: {e}")
# Step 2: Apply priority (user config > provider default)
# For max_token_size: explicit env var > provider default > None
final_max_token_size = args.embedding_token_limit or provider_max_token_size
# For embedding_dim: user config (always has value) takes priority
# Only use provider default if user config is explicitly None (which shouldn't happen)
final_embedding_dim = (
args.embedding_dim if args.embedding_dim else provider_embedding_dim
)
# Step 3: Create optimized embedding function (calls underlying function directly)
async def optimized_embedding_function(texts, embedding_dim=None):
try:
if binding == "lollms":
from lightrag.llm.lollms import lollms_embed
return await lollms_embed(
# Get real function, skip EmbeddingFunc wrapper if present
actual_func = (
lollms_embed.func
if isinstance(lollms_embed, EmbeddingFunc)
else lollms_embed
)
return await actual_func(
texts, embed_model=model, host=host, api_key=api_key
)
elif binding == "ollama":
from lightrag.llm.ollama import ollama_embed
# Use pre-processed configuration if available, otherwise fallback to dynamic parsing
# Get real function, skip EmbeddingFunc wrapper if present
actual_func = (
ollama_embed.func
if isinstance(ollama_embed, EmbeddingFunc)
else ollama_embed
)
# Use pre-processed configuration if available
if config_cache.ollama_embedding_options is not None:
ollama_options = config_cache.ollama_embedding_options
else:
# Fallback for cases where config cache wasn't initialized properly
from lightrag.llm.binding_options import OllamaEmbeddingOptions
ollama_options = OllamaEmbeddingOptions.options_dict(args)
return await ollama_embed(
return await actual_func(
texts,
embed_model=model,
host=host,
@ -654,15 +729,30 @@ def create_app(args):
elif binding == "azure_openai":
from lightrag.llm.azure_openai import azure_openai_embed
return await azure_openai_embed(texts, model=model, api_key=api_key)
actual_func = (
azure_openai_embed.func
if isinstance(azure_openai_embed, EmbeddingFunc)
else azure_openai_embed
)
return await actual_func(texts, model=model, api_key=api_key)
elif binding == "aws_bedrock":
from lightrag.llm.bedrock import bedrock_embed
return await bedrock_embed(texts, model=model)
actual_func = (
bedrock_embed.func
if isinstance(bedrock_embed, EmbeddingFunc)
else bedrock_embed
)
return await actual_func(texts, model=model)
elif binding == "jina":
from lightrag.llm.jina import jina_embed
return await jina_embed(
actual_func = (
jina_embed.func
if isinstance(jina_embed, EmbeddingFunc)
else jina_embed
)
return await actual_func(
texts,
embedding_dim=embedding_dim,
base_url=host,
@ -671,16 +761,21 @@ def create_app(args):
elif binding == "gemini":
from lightrag.llm.gemini import gemini_embed
# Use pre-processed configuration if available, otherwise fallback to dynamic parsing
actual_func = (
gemini_embed.func
if isinstance(gemini_embed, EmbeddingFunc)
else gemini_embed
)
# Use pre-processed configuration if available
if config_cache.gemini_embedding_options is not None:
gemini_options = config_cache.gemini_embedding_options
else:
# Fallback for cases where config cache wasn't initialized properly
from lightrag.llm.binding_options import GeminiEmbeddingOptions
gemini_options = GeminiEmbeddingOptions.options_dict(args)
return await gemini_embed(
return await actual_func(
texts,
model=model,
base_url=host,
@ -691,7 +786,12 @@ def create_app(args):
else: # openai and compatible
from lightrag.llm.openai import openai_embed
return await openai_embed(
actual_func = (
openai_embed.func
if isinstance(openai_embed, EmbeddingFunc)
else openai_embed
)
return await actual_func(
texts,
model=model,
base_url=host,
@ -701,7 +801,21 @@ def create_app(args):
except ImportError as e:
raise Exception(f"Failed to import {binding} embedding: {e}")
return optimized_embedding_function
# Step 4: Wrap in EmbeddingFunc and return
embedding_func_instance = EmbeddingFunc(
embedding_dim=final_embedding_dim,
func=optimized_embedding_function,
max_token_size=final_max_token_size,
send_dimensions=False, # Will be set later based on binding requirements
)
# Log final embedding configuration
logger.info(
f"Embedding config: binding={binding} model={model} "
f"embedding_dim={final_embedding_dim} max_token_size={final_max_token_size}"
)
return embedding_func_instance
llm_timeout = get_env_value("LLM_TIMEOUT", DEFAULT_LLM_TIMEOUT, int)
embedding_timeout = get_env_value(
@ -735,25 +849,24 @@ def create_app(args):
**kwargs,
)
# Create embedding function with optimized configuration
# Create embedding function with optimized configuration and max_token_size inheritance
import inspect
# Create the optimized embedding function
optimized_embedding_func = create_optimized_embedding_function(
# Create the EmbeddingFunc instance (now returns complete EmbeddingFunc with max_token_size)
embedding_func = create_optimized_embedding_function(
config_cache=config_cache,
binding=args.embedding_binding,
model=args.embedding_model,
host=args.embedding_binding_host,
api_key=args.embedding_binding_api_key,
args=args, # Pass args object for fallback option generation
args=args,
)
# Get embedding_send_dim from centralized configuration
embedding_send_dim = args.embedding_send_dim
# Check if the function signature has embedding_dim parameter
# Note: Since optimized_embedding_func is an async function, inspect its signature
sig = inspect.signature(optimized_embedding_func)
# Check if the underlying function signature has embedding_dim parameter
sig = inspect.signature(embedding_func.func)
has_embedding_dim_param = "embedding_dim" in sig.parameters
# Determine send_dimensions value based on binding type
@ -771,18 +884,27 @@ def create_app(args):
else:
dimension_control = "by not hasparam"
# Set send_dimensions on the EmbeddingFunc instance
embedding_func.send_dimensions = send_dimensions
logger.info(
f"Send embedding dimension: {send_dimensions} {dimension_control} "
f"(dimensions={args.embedding_dim}, has_param={has_embedding_dim_param}, "
f"(dimensions={embedding_func.embedding_dim}, has_param={has_embedding_dim_param}, "
f"binding={args.embedding_binding})"
)
# Create EmbeddingFunc with send_dimensions attribute
embedding_func = EmbeddingFunc(
embedding_dim=args.embedding_dim,
func=optimized_embedding_func,
send_dimensions=send_dimensions,
)
# Log max_token_size source
if embedding_func.max_token_size:
source = (
"env variable"
if args.embedding_token_limit
else f"{args.embedding_binding} provider default"
)
logger.info(
f"Embedding max_token_size: {embedding_func.max_token_size} (from {source})"
)
else:
logger.info("Embedding max_token_size: not set (90% token warning disabled)")
# Configure rerank function based on args.rerank_bindingparameter
rerank_model_func = None

View file

@ -276,6 +276,9 @@ class LightRAG:
embedding_func: EmbeddingFunc | None = field(default=None)
"""Function for computing text embeddings. Must be set before use."""
embedding_token_limit: int | None = field(default=None, init=False)
"""Token limit for embedding model. Set automatically from embedding_func.max_token_size in __post_init__."""
embedding_batch_num: int = field(default=int(os.getenv("EMBEDDING_BATCH_NUM", 10)))
"""Batch size for embedding computations."""
@ -519,6 +522,16 @@ class LightRAG:
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
# Init Embedding
# Step 1: Capture max_token_size before applying decorator (decorator strips dataclass attributes)
embedding_max_token_size = None
if self.embedding_func and hasattr(self.embedding_func, "max_token_size"):
embedding_max_token_size = self.embedding_func.max_token_size
logger.debug(
f"Captured embedding max_token_size: {embedding_max_token_size}"
)
self.embedding_token_limit = embedding_max_token_size
# Step 2: Apply priority wrapper decorator
self.embedding_func = priority_limit_async_func_call(
self.embedding_func_max_async,
llm_timeout=self.default_embedding_timeout,

View file

@ -1,6 +1,7 @@
import copy
import os
import json
import logging
import pipmaster as pm # Pipmaster for dynamic library install
@ -16,6 +17,7 @@ from tenacity import (
)
import sys
from lightrag.utils import wrap_embedding_func_with_attrs
if sys.version_info < (3, 9):
from typing import AsyncIterator
@ -23,21 +25,121 @@ else:
from collections.abc import AsyncIterator
from typing import Union
# Import botocore exceptions for proper exception handling
try:
from botocore.exceptions import (
ClientError,
ConnectionError as BotocoreConnectionError,
ReadTimeoutError,
)
except ImportError:
# If botocore is not installed, define placeholders
ClientError = Exception
BotocoreConnectionError = Exception
ReadTimeoutError = Exception
class BedrockError(Exception):
"""Generic error for issues related to Amazon Bedrock"""
class BedrockRateLimitError(BedrockError):
"""Error for rate limiting and throttling issues"""
class BedrockConnectionError(BedrockError):
"""Error for network and connection issues"""
class BedrockTimeoutError(BedrockError):
"""Error for timeout issues"""
def _set_env_if_present(key: str, value):
"""Set environment variable only if a non-empty value is provided."""
if value is not None and value != "":
os.environ[key] = value
def _handle_bedrock_exception(e: Exception, operation: str = "Bedrock API") -> None:
"""Convert AWS Bedrock exceptions to appropriate custom exceptions.
Args:
e: The exception to handle
operation: Description of the operation for error messages
Raises:
BedrockRateLimitError: For rate limiting and throttling issues (retryable)
BedrockConnectionError: For network and server issues (retryable)
BedrockTimeoutError: For timeout issues (retryable)
BedrockError: For validation and other non-retryable errors
"""
error_message = str(e)
# Handle botocore ClientError with specific error codes
if isinstance(e, ClientError):
error_code = e.response.get("Error", {}).get("Code", "")
error_msg = e.response.get("Error", {}).get("Message", error_message)
# Rate limiting and throttling errors (retryable)
if error_code in [
"ThrottlingException",
"ProvisionedThroughputExceededException",
]:
logging.error(f"{operation} rate limit error: {error_msg}")
raise BedrockRateLimitError(f"Rate limit error: {error_msg}")
# Server errors (retryable)
elif error_code in ["ServiceUnavailableException", "InternalServerException"]:
logging.error(f"{operation} connection error: {error_msg}")
raise BedrockConnectionError(f"Service error: {error_msg}")
# Check for 5xx HTTP status codes (retryable)
elif e.response.get("ResponseMetadata", {}).get("HTTPStatusCode", 0) >= 500:
logging.error(f"{operation} server error: {error_msg}")
raise BedrockConnectionError(f"Server error: {error_msg}")
# Validation and other client errors (non-retryable)
else:
logging.error(f"{operation} client error: {error_msg}")
raise BedrockError(f"Client error: {error_msg}")
# Connection errors (retryable)
elif isinstance(e, BotocoreConnectionError):
logging.error(f"{operation} connection error: {error_message}")
raise BedrockConnectionError(f"Connection error: {error_message}")
# Timeout errors (retryable)
elif isinstance(e, (ReadTimeoutError, TimeoutError)):
logging.error(f"{operation} timeout error: {error_message}")
raise BedrockTimeoutError(f"Timeout error: {error_message}")
# Custom Bedrock errors (already properly typed)
elif isinstance(
e,
(
BedrockRateLimitError,
BedrockConnectionError,
BedrockTimeoutError,
BedrockError,
),
):
raise
# Unknown errors (non-retryable)
else:
logging.error(f"{operation} unexpected error: {error_message}")
raise BedrockError(f"Unexpected error: {error_message}")
@retry(
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, max=60),
retry=retry_if_exception_type((BedrockError)),
wait=wait_exponential(multiplier=1, min=4, max=60),
retry=(
retry_if_exception_type(BedrockRateLimitError)
| retry_if_exception_type(BedrockConnectionError)
| retry_if_exception_type(BedrockTimeoutError)
),
)
async def bedrock_complete_if_cache(
model,
@ -158,9 +260,6 @@ async def bedrock_complete_if_cache(
break
except Exception as e:
# Log the specific error for debugging
logging.error(f"Bedrock streaming error: {e}")
# Try to clean up resources if possible
if (
iteration_started
@ -175,7 +274,8 @@ async def bedrock_complete_if_cache(
f"Failed to close Bedrock event stream: {close_error}"
)
raise BedrockError(f"Streaming error: {e}")
# Convert to appropriate exception type
_handle_bedrock_exception(e, "Bedrock streaming")
finally:
# Clean up the event stream
@ -231,10 +331,8 @@ async def bedrock_complete_if_cache(
return content
except Exception as e:
if isinstance(e, BedrockError):
raise
else:
raise BedrockError(f"Bedrock API error: {e}")
# Convert to appropriate exception type
_handle_bedrock_exception(e, "Bedrock converse")
# Generic Bedrock completion function
@ -253,12 +351,16 @@ async def bedrock_complete(
return result
# @wrap_embedding_func_with_attrs(embedding_dim=1024)
# @retry(
# stop=stop_after_attempt(3),
# wait=wait_exponential(multiplier=1, min=4, max=10),
# retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), # TODO: fix exceptions
# )
@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
@retry(
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, min=4, max=60),
retry=(
retry_if_exception_type(BedrockRateLimitError)
| retry_if_exception_type(BedrockConnectionError)
| retry_if_exception_type(BedrockTimeoutError)
),
)
async def bedrock_embed(
texts: list[str],
model: str = "amazon.titan-embed-text-v2:0",
@ -281,48 +383,101 @@ async def bedrock_embed(
async with session.client(
"bedrock-runtime", region_name=region
) as bedrock_async_client:
if (model_provider := model.split(".")[0]) == "amazon":
embed_texts = []
for text in texts:
if "v2" in model:
try:
if (model_provider := model.split(".")[0]) == "amazon":
embed_texts = []
for text in texts:
try:
if "v2" in model:
body = json.dumps(
{
"inputText": text,
# 'dimensions': embedding_dim,
"embeddingTypes": ["float"],
}
)
elif "v1" in model:
body = json.dumps({"inputText": text})
else:
raise BedrockError(f"Model {model} is not supported!")
response = await bedrock_async_client.invoke_model(
modelId=model,
body=body,
accept="application/json",
contentType="application/json",
)
response_body = await response.get("body").json()
# Validate response structure
if not response_body or "embedding" not in response_body:
raise BedrockError(
f"Invalid embedding response structure for text: {text[:50]}..."
)
embedding = response_body["embedding"]
if not embedding:
raise BedrockError(
f"Received empty embedding for text: {text[:50]}..."
)
embed_texts.append(embedding)
except Exception as e:
# Convert to appropriate exception type
_handle_bedrock_exception(
e, "Bedrock embedding (amazon, text chunk)"
)
elif model_provider == "cohere":
try:
body = json.dumps(
{
"inputText": text,
# 'dimensions': embedding_dim,
"embeddingTypes": ["float"],
"texts": texts,
"input_type": "search_document",
"truncate": "NONE",
}
)
elif "v1" in model:
body = json.dumps({"inputText": text})
else:
raise ValueError(f"Model {model} is not supported!")
response = await bedrock_async_client.invoke_model(
modelId=model,
body=body,
accept="application/json",
contentType="application/json",
response = await bedrock_async_client.invoke_model(
model=model,
body=body,
accept="application/json",
contentType="application/json",
)
response_body = json.loads(response.get("body").read())
# Validate response structure
if not response_body or "embeddings" not in response_body:
raise BedrockError(
"Invalid embedding response structure from Cohere"
)
embeddings = response_body["embeddings"]
if not embeddings or len(embeddings) != len(texts):
raise BedrockError(
f"Invalid embeddings count: expected {len(texts)}, got {len(embeddings) if embeddings else 0}"
)
embed_texts = embeddings
except Exception as e:
# Convert to appropriate exception type
_handle_bedrock_exception(e, "Bedrock embedding (cohere)")
else:
raise BedrockError(
f"Model provider '{model_provider}' is not supported!"
)
response_body = await response.get("body").json()
# Final validation
if not embed_texts:
raise BedrockError("No embeddings generated")
embed_texts.append(response_body["embedding"])
elif model_provider == "cohere":
body = json.dumps(
{"texts": texts, "input_type": "search_document", "truncate": "NONE"}
)
return np.array(embed_texts)
response = await bedrock_async_client.invoke_model(
model=model,
body=body,
accept="application/json",
contentType="application/json",
)
response_body = json.loads(response.get("body").read())
embed_texts = response_body["embeddings"]
else:
raise ValueError(f"Model provider '{model_provider}' is not supported!")
return np.array(embed_texts)
except Exception as e:
# Convert to appropriate exception type
_handle_bedrock_exception(e, "Bedrock embedding")

View file

@ -453,7 +453,7 @@ async def gemini_model_complete(
)
@wrap_embedding_func_with_attrs(embedding_dim=1536)
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=2048)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=60),

View file

@ -26,6 +26,7 @@ from lightrag.exceptions import (
)
import torch
import numpy as np
from lightrag.utils import wrap_embedding_func_with_attrs
os.environ["TOKENIZERS_PARALLELISM"] = "false"
@ -141,6 +142,7 @@ async def hf_model_complete(
return result
@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
async def hf_embed(texts: list[str], tokenizer, embed_model) -> np.ndarray:
# Detect the appropriate device
if torch.cuda.is_available():

View file

@ -58,7 +58,7 @@ async def fetch_data(url, headers, data):
return data_list
@wrap_embedding_func_with_attrs(embedding_dim=2048)
@wrap_embedding_func_with_attrs(embedding_dim=2048, max_token_size=8192)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=60),

View file

@ -174,7 +174,7 @@ async def llama_index_complete(
return result
@wrap_embedding_func_with_attrs(embedding_dim=1536)
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=60),

View file

@ -26,6 +26,10 @@ from lightrag.exceptions import (
from typing import Union, List
import numpy as np
from lightrag.utils import (
wrap_embedding_func_with_attrs,
)
@retry(
stop=stop_after_attempt(3),
@ -134,6 +138,7 @@ async def lollms_model_complete(
)
@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
async def lollms_embed(
texts: List[str], embed_model=None, base_url="http://localhost:9600", **kwargs
) -> np.ndarray:

View file

@ -33,7 +33,7 @@ from lightrag.utils import (
import numpy as np
@wrap_embedding_func_with_attrs(embedding_dim=2048)
@wrap_embedding_func_with_attrs(embedding_dim=2048, max_token_size=8192)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=60),

View file

@ -25,7 +25,10 @@ from lightrag.api import __api_version__
import numpy as np
from typing import Optional, Union
from lightrag.utils import logger
from lightrag.utils import (
wrap_embedding_func_with_attrs,
logger,
)
_OLLAMA_CLOUD_HOST = "https://ollama.com"
@ -169,6 +172,7 @@ async def ollama_model_complete(
)
@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray:
api_key = kwargs.pop("api_key", None)
if not api_key:

View file

@ -47,7 +47,7 @@ try:
# Only enable Langfuse if both keys are configured
if langfuse_public_key and langfuse_secret_key:
from langfuse.openai import AsyncOpenAI
from langfuse.openai import AsyncOpenAI # type: ignore[import-untyped]
LANGFUSE_ENABLED = True
logger.info("Langfuse observability enabled for OpenAI client")
@ -604,7 +604,7 @@ async def nvidia_openai_complete(
return result
@wrap_embedding_func_with_attrs(embedding_dim=1536)
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=60),

View file

@ -345,6 +345,20 @@ async def _summarize_descriptions(
llm_response_cache=llm_response_cache,
cache_type="summary",
)
# Check summary token length against embedding limit
embedding_token_limit = global_config.get("embedding_token_limit")
if embedding_token_limit is not None and summary:
tokenizer = global_config["tokenizer"]
summary_token_count = len(tokenizer.encode(summary))
threshold = int(embedding_token_limit * 0.9)
if summary_token_count > threshold:
logger.warning(
f"Summary tokens ({summary_token_count}) exceeds 90% of embedding limit "
f"({embedding_token_limit}) for {description_type}: {description_name}"
)
return summary

View file

@ -355,7 +355,7 @@ class TaskState:
class EmbeddingFunc:
embedding_dim: int
func: callable
max_token_size: int | None = None # deprecated keep it for compatible only
max_token_size: int | None = None # Token limit for the embedding model
send_dimensions: bool = (
False # Control whether to send embedding_dim to the function
)