Merge pull request #2368 from danielaskdd/milvus-vector-batching

Refact: Add Embedding Dimension Validation in EmbeddingFunc
This commit is contained in:
Daniel.y 2025-11-17 12:38:22 +08:00 committed by GitHub
commit 8bb54833a7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -353,6 +353,17 @@ class TaskState:
@dataclass
class EmbeddingFunc:
"""Embedding function wrapper with dimension validation
This class wraps an embedding function to ensure that the output embeddings have the correct dimension.
This class should not be wrapped multiple times.
Args:
embedding_dim: Expected dimension of the embeddings
func: The actual embedding function to wrap
max_token_size: Optional token limit for the embedding model
send_dimensions: Whether to inject embedding_dim as a keyword argument
"""
embedding_dim: int
func: callable
max_token_size: int | None = None # Token limit for the embedding model
@ -379,7 +390,32 @@ class EmbeddingFunc:
# Inject embedding_dim from decorator
kwargs["embedding_dim"] = self.embedding_dim
return await self.func(*args, **kwargs)
# Call the actual embedding function
result = await self.func(*args, **kwargs)
# Validate embedding dimensions using total element count
total_elements = result.size # Total number of elements in the numpy array
expected_dim = self.embedding_dim
# Check if total elements can be evenly divided by embedding_dim
if total_elements % expected_dim != 0:
raise ValueError(
f"Embedding dimension mismatch detected: "
f"total elements ({total_elements}) cannot be evenly divided by "
f"expected dimension ({expected_dim}). "
)
# Optional: Verify vector count matches input text count
actual_vectors = total_elements // expected_dim
if args and isinstance(args[0], (list, tuple)):
expected_vectors = len(args[0])
if actual_vectors != expected_vectors:
raise ValueError(
f"Vector count mismatch: "
f"expected {expected_vectors} vectors but got {actual_vectors} vectors (from embedding result)."
)
return result
def compute_args_hash(*args: Any) -> str: