diff --git a/src/utils/embeddings.py b/src/utils/embeddings.py index f3c902e7..b0ec035f 100644 --- a/src/utils/embeddings.py +++ b/src/utils/embeddings.py @@ -10,6 +10,8 @@ def get_embedding_dimensions(model_name: str) -> int: # Check all model dictionaries all_models = {**OPENAI_EMBEDDING_DIMENSIONS, **OLLAMA_EMBEDDING_DIMENSIONS, **WATSONX_EMBEDDING_DIMENSIONS} + model_name = model_name.lower().strip().split(":")[0] + if model_name in all_models: dimensions = all_models[model_name] logger.info(f"Found dimensions for model '{model_name}': {dimensions}")