From c13f9116d9cb319a8d08a4e84d21cb8dd105e9af Mon Sep 17 00:00:00 2001 From: yangdx Date: Mon, 17 Nov 2025 12:26:54 +0800 Subject: [PATCH 1/2] Add embedding dimension validation to EmbeddingFunc wrapper MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit • Validate total elements divisibility • Check vector count matches input count • Raise clear error messages on mismatch • Ensure embedding output correctness • Add docstring for EmbeddingFunc class --- lightrag/utils.py | 37 ++++++++++++++++++++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/lightrag/utils.py b/lightrag/utils.py index d653c1e3..aa04f338 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -353,6 +353,16 @@ class TaskState: @dataclass class EmbeddingFunc: + """Embedding function wrapper with dimension validation + This class wraps an embedding function to ensure that the output embeddings have the correct dimension. + This class should not be wrapped multiple times. + + Args: + embedding_dim: Expected dimension of the embeddings + func: The actual embedding function to wrap + max_token_size: Optional token limit for the embedding model + send_dimensions: Whether to inject embedding_dim as a keyword argument + """ embedding_dim: int func: callable max_token_size: int | None = None # Token limit for the embedding model @@ -379,7 +389,32 @@ class EmbeddingFunc: # Inject embedding_dim from decorator kwargs["embedding_dim"] = self.embedding_dim - return await self.func(*args, **kwargs) + # Call the actual embedding function + result = await self.func(*args, **kwargs) + + # Validate embedding dimensions using total element count + total_elements = result.size # Total number of elements in the numpy array + expected_dim = self.embedding_dim + + # Check if total elements can be evenly divided by embedding_dim + if total_elements % expected_dim != 0: + raise ValueError( + f"Embedding dimension mismatch detected: " + f"total elements ({total_elements}) cannot be evenly divided by " + f"expected dimension ({expected_dim}). " + ) + + # Optional: Verify vector count matches input text count + actual_vectors = total_elements // expected_dim + if args and isinstance(args[0], (list, tuple)): + expected_vectors = len(args[0]) + if actual_vectors != expected_vectors: + raise ValueError( + f"Vector count mismatch: " + f"expected {expected_vectors} vectors but got {actual_vectors} vectors (from embedding result)." + ) + + return result def compute_args_hash(*args: Any) -> str: From 90f52acf0cf915e75318d53798c84b3a76ec0477 Mon Sep 17 00:00:00 2001 From: yangdx Date: Mon, 17 Nov 2025 12:28:53 +0800 Subject: [PATCH 2/2] Fix linting --- lightrag/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/lightrag/utils.py b/lightrag/utils.py index aa04f338..8c9b7776 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -363,6 +363,7 @@ class EmbeddingFunc: max_token_size: Optional token limit for the embedding model send_dimensions: Whether to inject embedding_dim as a keyword argument """ + embedding_dim: int func: callable max_token_size: int | None = None # Token limit for the embedding model