fix: HuggingFace tokenizer (#752)

<!-- .github/pull_request_template.md -->

## Description
Resolve issue noticed by [RyabykinIlya](https://github.com/RyabykinIlya)
where too many HuggingFace requests have been sent due to the embedding
engine not working as a singleton per config

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.

---------

Signed-off-by: Ryabykin Ilya <ryabykinia@sibur.ru>
Co-authored-by: greshish <ryabykinia@yandex.ru>
Co-authored-by: Ryabykin Ilya <ryabykinia@sibur.ru>
This commit is contained in:
Igor Ilic 2025-04-17 17:07:36 +02:00 committed by GitHub
parent 9ba12b25ef
commit ba2de9bb22
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 58 additions and 17 deletions

View file

@ -14,6 +14,21 @@ class EmbeddingConfig(BaseSettings):
huggingface_tokenizer: Optional[str] = None
model_config = SettingsConfigDict(env_file=".env", extra="allow")
def to_dict(self) -> dict:
"""
Serialize all embedding configuration settings to a dictionary.
"""
return {
"embedding_provider": self.embedding_provider,
"embedding_model": self.embedding_model,
"embedding_dimensions": self.embedding_dimensions,
"embedding_endpoint": self.embedding_endpoint,
"embedding_api_key": self.embedding_api_key,
"embedding_api_version": self.embedding_api_version,
"embedding_max_tokens": self.embedding_max_tokens,
"huggingface_tokenizer": self.huggingface_tokenizer,
}
@lru_cache
def get_embedding_config():

View file

@ -1,40 +1,66 @@
from cognee.infrastructure.databases.vector.embeddings.config import get_embedding_config
from cognee.infrastructure.llm.config import get_llm_config
from .EmbeddingEngine import EmbeddingEngine
from functools import lru_cache
def get_embedding_engine() -> EmbeddingEngine:
config = get_embedding_config()
llm_config = get_llm_config()
# Embedding engine has to be a singleton based on configuration to ensure too many requests won't be sent to HuggingFace
return create_embedding_engine(
config.embedding_provider,
config.embedding_model,
config.embedding_dimensions,
config.embedding_max_tokens,
config.embedding_endpoint,
config.embedding_api_key,
config.embedding_api_version,
config.huggingface_tokenizer,
llm_config.llm_api_key,
)
if config.embedding_provider == "fastembed":
@lru_cache
def create_embedding_engine(
embedding_provider,
embedding_model,
embedding_dimensions,
embedding_max_tokens,
embedding_endpoint,
embedding_api_key,
embedding_api_version,
huggingface_tokenizer,
llm_api_key,
):
if embedding_provider == "fastembed":
from .FastembedEmbeddingEngine import FastembedEmbeddingEngine
return FastembedEmbeddingEngine(
model=config.embedding_model,
dimensions=config.embedding_dimensions,
max_tokens=config.embedding_max_tokens,
model=embedding_model,
dimensions=embedding_dimensions,
max_tokens=embedding_max_tokens,
)
if config.embedding_provider == "ollama":
if embedding_provider == "ollama":
from .OllamaEmbeddingEngine import OllamaEmbeddingEngine
return OllamaEmbeddingEngine(
model=config.embedding_model,
dimensions=config.embedding_dimensions,
max_tokens=config.embedding_max_tokens,
endpoint=config.embedding_endpoint,
huggingface_tokenizer=config.huggingface_tokenizer,
model=embedding_model,
dimensions=embedding_dimensions,
max_tokens=embedding_max_tokens,
endpoint=embedding_endpoint,
huggingface_tokenizer=huggingface_tokenizer,
)
from .LiteLLMEmbeddingEngine import LiteLLMEmbeddingEngine
return LiteLLMEmbeddingEngine(
provider=config.embedding_provider,
api_key=config.embedding_api_key or llm_config.llm_api_key,
endpoint=config.embedding_endpoint,
api_version=config.embedding_api_version,
model=config.embedding_model,
dimensions=config.embedding_dimensions,
max_tokens=config.embedding_max_tokens,
provider=embedding_provider,
api_key=embedding_api_key or llm_api_key,
endpoint=embedding_endpoint,
api_version=embedding_api_version,
model=embedding_model,
dimensions=embedding_dimensions,
max_tokens=embedding_max_tokens,
)