diff --git a/lightrag/utils.py b/lightrag/utils.py index d653c1e3..8c9b7776 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -353,6 +353,17 @@ class TaskState: @dataclass class EmbeddingFunc: + """Embedding function wrapper with dimension validation + This class wraps an embedding function to ensure that the output embeddings have the correct dimension. + This class should not be wrapped multiple times. + + Args: + embedding_dim: Expected dimension of the embeddings + func: The actual embedding function to wrap + max_token_size: Optional token limit for the embedding model + send_dimensions: Whether to inject embedding_dim as a keyword argument + """ + embedding_dim: int func: callable max_token_size: int | None = None # Token limit for the embedding model @@ -379,7 +390,32 @@ class EmbeddingFunc: # Inject embedding_dim from decorator kwargs["embedding_dim"] = self.embedding_dim - return await self.func(*args, **kwargs) + # Call the actual embedding function + result = await self.func(*args, **kwargs) + + # Validate embedding dimensions using total element count + total_elements = result.size # Total number of elements in the numpy array + expected_dim = self.embedding_dim + + # Check if total elements can be evenly divided by embedding_dim + if total_elements % expected_dim != 0: + raise ValueError( + f"Embedding dimension mismatch detected: " + f"total elements ({total_elements}) cannot be evenly divided by " + f"expected dimension ({expected_dim}). " + ) + + # Optional: Verify vector count matches input text count + actual_vectors = total_elements // expected_dim + if args and isinstance(args[0], (list, tuple)): + expected_vectors = len(args[0]) + if actual_vectors != expected_vectors: + raise ValueError( + f"Vector count mismatch: " + f"expected {expected_vectors} vectors but got {actual_vectors} vectors (from embedding result)." + ) + + return result def compute_args_hash(*args: Any) -> str: