Merge branch 'main' into workspace-isolation
This commit is contained in:
commit
f1d8f18c80
1 changed files with 37 additions and 1 deletions
|
|
@ -353,6 +353,17 @@ class TaskState:
|
|||
|
||||
@dataclass
|
||||
class EmbeddingFunc:
|
||||
"""Embedding function wrapper with dimension validation
|
||||
This class wraps an embedding function to ensure that the output embeddings have the correct dimension.
|
||||
This class should not be wrapped multiple times.
|
||||
|
||||
Args:
|
||||
embedding_dim: Expected dimension of the embeddings
|
||||
func: The actual embedding function to wrap
|
||||
max_token_size: Optional token limit for the embedding model
|
||||
send_dimensions: Whether to inject embedding_dim as a keyword argument
|
||||
"""
|
||||
|
||||
embedding_dim: int
|
||||
func: callable
|
||||
max_token_size: int | None = None # Token limit for the embedding model
|
||||
|
|
@ -379,7 +390,32 @@ class EmbeddingFunc:
|
|||
# Inject embedding_dim from decorator
|
||||
kwargs["embedding_dim"] = self.embedding_dim
|
||||
|
||||
return await self.func(*args, **kwargs)
|
||||
# Call the actual embedding function
|
||||
result = await self.func(*args, **kwargs)
|
||||
|
||||
# Validate embedding dimensions using total element count
|
||||
total_elements = result.size # Total number of elements in the numpy array
|
||||
expected_dim = self.embedding_dim
|
||||
|
||||
# Check if total elements can be evenly divided by embedding_dim
|
||||
if total_elements % expected_dim != 0:
|
||||
raise ValueError(
|
||||
f"Embedding dimension mismatch detected: "
|
||||
f"total elements ({total_elements}) cannot be evenly divided by "
|
||||
f"expected dimension ({expected_dim}). "
|
||||
)
|
||||
|
||||
# Optional: Verify vector count matches input text count
|
||||
actual_vectors = total_elements // expected_dim
|
||||
if args and isinstance(args[0], (list, tuple)):
|
||||
expected_vectors = len(args[0])
|
||||
if actual_vectors != expected_vectors:
|
||||
raise ValueError(
|
||||
f"Vector count mismatch: "
|
||||
f"expected {expected_vectors} vectors but got {actual_vectors} vectors (from embedding result)."
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def compute_args_hash(*args: Any) -> str:
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue