diff --git a/lightrag/llm/bedrock.py b/lightrag/llm/bedrock.py index 16737341..ccfbb4f7 100644 --- a/lightrag/llm/bedrock.py +++ b/lightrag/llm/bedrock.py @@ -16,6 +16,7 @@ from tenacity import ( ) import sys +from lightrag.utils import wrap_embedding_func_with_attrs if sys.version_info < (3, 9): from typing import AsyncIterator @@ -253,7 +254,7 @@ async def bedrock_complete( return result -# @wrap_embedding_func_with_attrs(embedding_dim=1024) +@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192) # @retry( # stop=stop_after_attempt(3), # wait=wait_exponential(multiplier=1, min=4, max=10), diff --git a/lightrag/llm/siliconcloud.py b/lightrag/llm/deprecated/siliconcloud.py similarity index 100% rename from lightrag/llm/siliconcloud.py rename to lightrag/llm/deprecated/siliconcloud.py diff --git a/lightrag/llm/gemini.py b/lightrag/llm/gemini.py index 983d6b9f..37ce7206 100644 --- a/lightrag/llm/gemini.py +++ b/lightrag/llm/gemini.py @@ -453,7 +453,7 @@ async def gemini_model_complete( ) -@wrap_embedding_func_with_attrs(embedding_dim=1536) +@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=2048) @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=60), diff --git a/lightrag/llm/hf.py b/lightrag/llm/hf.py index c33b1c7f..447f95c3 100644 --- a/lightrag/llm/hf.py +++ b/lightrag/llm/hf.py @@ -26,6 +26,7 @@ from lightrag.exceptions import ( ) import torch import numpy as np +from lightrag.utils import wrap_embedding_func_with_attrs os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -141,6 +142,7 @@ async def hf_model_complete( return result +@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192) async def hf_embed(texts: list[str], tokenizer, embed_model) -> np.ndarray: # Detect the appropriate device if torch.cuda.is_available(): diff --git a/lightrag/llm/jina.py b/lightrag/llm/jina.py index 70de5995..f61faadd 100644 --- a/lightrag/llm/jina.py +++ b/lightrag/llm/jina.py @@ -58,7 +58,7 @@ async def fetch_data(url, headers, data): return data_list -@wrap_embedding_func_with_attrs(embedding_dim=2048) +@wrap_embedding_func_with_attrs(embedding_dim=2048, max_token_size=8192) @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=60), diff --git a/lightrag/llm/llama_index_impl.py b/lightrag/llm/llama_index_impl.py index 38ec7cd1..c44e6c7a 100644 --- a/lightrag/llm/llama_index_impl.py +++ b/lightrag/llm/llama_index_impl.py @@ -174,7 +174,7 @@ async def llama_index_complete( return result -@wrap_embedding_func_with_attrs(embedding_dim=1536) +@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192) @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=60), diff --git a/lightrag/llm/lollms.py b/lightrag/llm/lollms.py index 9274dbfc..2f2a1dbf 100644 --- a/lightrag/llm/lollms.py +++ b/lightrag/llm/lollms.py @@ -26,6 +26,10 @@ from lightrag.exceptions import ( from typing import Union, List import numpy as np +from lightrag.utils import ( + wrap_embedding_func_with_attrs, +) + @retry( stop=stop_after_attempt(3), @@ -134,6 +138,7 @@ async def lollms_model_complete( ) +@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192) async def lollms_embed( texts: List[str], embed_model=None, base_url="http://localhost:9600", **kwargs ) -> np.ndarray: diff --git a/lightrag/llm/nvidia_openai.py b/lightrag/llm/nvidia_openai.py index 1cbab380..1ebaf3a6 100644 --- a/lightrag/llm/nvidia_openai.py +++ b/lightrag/llm/nvidia_openai.py @@ -33,7 +33,7 @@ from lightrag.utils import ( import numpy as np -@wrap_embedding_func_with_attrs(embedding_dim=2048) +@wrap_embedding_func_with_attrs(embedding_dim=2048, max_token_size=8192) @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=60), diff --git a/lightrag/llm/ollama.py b/lightrag/llm/ollama.py index 670351bc..e35dc293 100644 --- a/lightrag/llm/ollama.py +++ b/lightrag/llm/ollama.py @@ -25,7 +25,10 @@ from lightrag.api import __api_version__ import numpy as np from typing import Optional, Union -from lightrag.utils import logger +from lightrag.utils import ( + wrap_embedding_func_with_attrs, + logger, +) _OLLAMA_CLOUD_HOST = "https://ollama.com" @@ -169,6 +172,7 @@ async def ollama_model_complete( ) +@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192) async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray: api_key = kwargs.pop("api_key", None) if not api_key: diff --git a/lightrag/llm/openai.py b/lightrag/llm/openai.py index a2bbfa23..8c984e51 100644 --- a/lightrag/llm/openai.py +++ b/lightrag/llm/openai.py @@ -47,7 +47,7 @@ try: # Only enable Langfuse if both keys are configured if langfuse_public_key and langfuse_secret_key: - from langfuse.openai import AsyncOpenAI + from langfuse.openai import AsyncOpenAI # type: ignore[import-untyped] LANGFUSE_ENABLED = True logger.info("Langfuse observability enabled for OpenAI client") @@ -604,7 +604,7 @@ async def nvidia_openai_complete( return result -@wrap_embedding_func_with_attrs(embedding_dim=1536) +@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192) @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=60), diff --git a/lightrag/utils.py b/lightrag/utils.py index b78b7523..d653c1e3 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -355,7 +355,7 @@ class TaskState: class EmbeddingFunc: embedding_dim: int func: callable - max_token_size: int | None = None # deprecated keep it for compatible only + max_token_size: int | None = None # Token limit for the embedding model send_dimensions: bool = ( False # Control whether to send embedding_dim to the function )