diff --git a/lightrag/api/config.py b/lightrag/api/config.py index 95ab9f70..4f59d3c1 100644 --- a/lightrag/api/config.py +++ b/lightrag/api/config.py @@ -445,6 +445,11 @@ def parse_args() -> argparse.Namespace: "EMBEDDING_BATCH_NUM", DEFAULT_EMBEDDING_BATCH_NUM, int ) + # Embedding token limit configuration + args.embedding_token_limit = get_env_value( + "EMBEDDING_TOKEN_LIMIT", None, int, special_none=True + ) + ollama_server_infos.LIGHTRAG_NAME = args.simulated_model_name ollama_server_infos.LIGHTRAG_TAG = args.simulated_model_tag diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index e5a86b3e..adbc5f28 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -784,6 +784,11 @@ def create_app(args): send_dimensions=send_dimensions, ) + # Set max_token_size if EMBEDDING_TOKEN_LIMIT is provided + if args.embedding_token_limit is not None: + embedding_func.max_token_size = args.embedding_token_limit + logger.info(f"Set embedding max_token_size to {args.embedding_token_limit}") + # Configure rerank function based on args.rerank_bindingparameter rerank_model_func = None if args.rerank_binding != "null": diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index e540e3b7..72a4dc6d 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -276,6 +276,13 @@ class LightRAG: embedding_func: EmbeddingFunc | None = field(default=None) """Function for computing text embeddings. Must be set before use.""" + @property + def embedding_token_limit(self) -> int | None: + """Get the token limit for embedding model from embedding_func.""" + if self.embedding_func and hasattr(self.embedding_func, "max_token_size"): + return self.embedding_func.max_token_size + return None + embedding_batch_num: int = field(default=int(os.getenv("EMBEDDING_BATCH_NUM", 10))) """Batch size for embedding computations.""" diff --git a/lightrag/operate.py b/lightrag/operate.py index ae2be49e..858553b1 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -345,6 +345,20 @@ async def _summarize_descriptions( llm_response_cache=llm_response_cache, cache_type="summary", ) + + # Check summary token length against embedding limit + embedding_token_limit = global_config.get("embedding_token_limit") + if embedding_token_limit is not None and summary: + tokenizer = global_config["tokenizer"] + summary_token_count = len(tokenizer.encode(summary)) + threshold = int(embedding_token_limit * 0.9) + + if summary_token_count > threshold: + logger.warning( + f"Summary tokens ({summary_token_count}) exceeds 90% of embedding limit " + f"({embedding_token_limit}) for {description_type}: {description_name}" + ) + return summary