diff --git a/cognee/infrastructure/databases/vector/embeddings/config.py b/cognee/infrastructure/databases/vector/embeddings/config.py index 733548f89..13f4092bc 100644 --- a/cognee/infrastructure/databases/vector/embeddings/config.py +++ b/cognee/infrastructure/databases/vector/embeddings/config.py @@ -14,6 +14,21 @@ class EmbeddingConfig(BaseSettings): huggingface_tokenizer: Optional[str] = None model_config = SettingsConfigDict(env_file=".env", extra="allow") + def to_dict(self) -> dict: + """ + Serialize all embedding configuration settings to a dictionary. + """ + return { + "embedding_provider": self.embedding_provider, + "embedding_model": self.embedding_model, + "embedding_dimensions": self.embedding_dimensions, + "embedding_endpoint": self.embedding_endpoint, + "embedding_api_key": self.embedding_api_key, + "embedding_api_version": self.embedding_api_version, + "embedding_max_tokens": self.embedding_max_tokens, + "huggingface_tokenizer": self.huggingface_tokenizer, + } + @lru_cache def get_embedding_config(): diff --git a/cognee/infrastructure/databases/vector/embeddings/get_embedding_engine.py b/cognee/infrastructure/databases/vector/embeddings/get_embedding_engine.py index d4ecb52d2..4adcaf13b 100644 --- a/cognee/infrastructure/databases/vector/embeddings/get_embedding_engine.py +++ b/cognee/infrastructure/databases/vector/embeddings/get_embedding_engine.py @@ -1,40 +1,66 @@ from cognee.infrastructure.databases.vector.embeddings.config import get_embedding_config from cognee.infrastructure.llm.config import get_llm_config from .EmbeddingEngine import EmbeddingEngine +from functools import lru_cache def get_embedding_engine() -> EmbeddingEngine: config = get_embedding_config() llm_config = get_llm_config() + # Embedding engine has to be a singleton based on configuration to ensure too many requests won't be sent to HuggingFace + return create_embedding_engine( + config.embedding_provider, + config.embedding_model, + config.embedding_dimensions, + config.embedding_max_tokens, + config.embedding_endpoint, + config.embedding_api_key, + config.embedding_api_version, + config.huggingface_tokenizer, + llm_config.llm_api_key, + ) - if config.embedding_provider == "fastembed": + +@lru_cache +def create_embedding_engine( + embedding_provider, + embedding_model, + embedding_dimensions, + embedding_max_tokens, + embedding_endpoint, + embedding_api_key, + embedding_api_version, + huggingface_tokenizer, + llm_api_key, +): + if embedding_provider == "fastembed": from .FastembedEmbeddingEngine import FastembedEmbeddingEngine return FastembedEmbeddingEngine( - model=config.embedding_model, - dimensions=config.embedding_dimensions, - max_tokens=config.embedding_max_tokens, + model=embedding_model, + dimensions=embedding_dimensions, + max_tokens=embedding_max_tokens, ) - if config.embedding_provider == "ollama": + if embedding_provider == "ollama": from .OllamaEmbeddingEngine import OllamaEmbeddingEngine return OllamaEmbeddingEngine( - model=config.embedding_model, - dimensions=config.embedding_dimensions, - max_tokens=config.embedding_max_tokens, - endpoint=config.embedding_endpoint, - huggingface_tokenizer=config.huggingface_tokenizer, + model=embedding_model, + dimensions=embedding_dimensions, + max_tokens=embedding_max_tokens, + endpoint=embedding_endpoint, + huggingface_tokenizer=huggingface_tokenizer, ) from .LiteLLMEmbeddingEngine import LiteLLMEmbeddingEngine return LiteLLMEmbeddingEngine( - provider=config.embedding_provider, - api_key=config.embedding_api_key or llm_config.llm_api_key, - endpoint=config.embedding_endpoint, - api_version=config.embedding_api_version, - model=config.embedding_model, - dimensions=config.embedding_dimensions, - max_tokens=config.embedding_max_tokens, + provider=embedding_provider, + api_key=embedding_api_key or llm_api_key, + endpoint=embedding_endpoint, + api_version=embedding_api_version, + model=embedding_model, + dimensions=embedding_dimensions, + max_tokens=embedding_max_tokens, )