fix: HuggingFace tokenizer (#752)
<!-- .github/pull_request_template.md --> ## Description Resolve issue noticed by [RyabykinIlya](https://github.com/RyabykinIlya) where too many HuggingFace requests have been sent due to the embedding engine not working as a singleton per config ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin. --------- Signed-off-by: Ryabykin Ilya <ryabykinia@sibur.ru> Co-authored-by: greshish <ryabykinia@yandex.ru> Co-authored-by: Ryabykin Ilya <ryabykinia@sibur.ru>
This commit is contained in:
parent
9ba12b25ef
commit
ba2de9bb22
2 changed files with 58 additions and 17 deletions
|
|
@ -14,6 +14,21 @@ class EmbeddingConfig(BaseSettings):
|
|||
huggingface_tokenizer: Optional[str] = None
|
||||
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""
|
||||
Serialize all embedding configuration settings to a dictionary.
|
||||
"""
|
||||
return {
|
||||
"embedding_provider": self.embedding_provider,
|
||||
"embedding_model": self.embedding_model,
|
||||
"embedding_dimensions": self.embedding_dimensions,
|
||||
"embedding_endpoint": self.embedding_endpoint,
|
||||
"embedding_api_key": self.embedding_api_key,
|
||||
"embedding_api_version": self.embedding_api_version,
|
||||
"embedding_max_tokens": self.embedding_max_tokens,
|
||||
"huggingface_tokenizer": self.huggingface_tokenizer,
|
||||
}
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_embedding_config():
|
||||
|
|
|
|||
|
|
@ -1,40 +1,66 @@
|
|||
from cognee.infrastructure.databases.vector.embeddings.config import get_embedding_config
|
||||
from cognee.infrastructure.llm.config import get_llm_config
|
||||
from .EmbeddingEngine import EmbeddingEngine
|
||||
from functools import lru_cache
|
||||
|
||||
|
||||
def get_embedding_engine() -> EmbeddingEngine:
|
||||
config = get_embedding_config()
|
||||
llm_config = get_llm_config()
|
||||
# Embedding engine has to be a singleton based on configuration to ensure too many requests won't be sent to HuggingFace
|
||||
return create_embedding_engine(
|
||||
config.embedding_provider,
|
||||
config.embedding_model,
|
||||
config.embedding_dimensions,
|
||||
config.embedding_max_tokens,
|
||||
config.embedding_endpoint,
|
||||
config.embedding_api_key,
|
||||
config.embedding_api_version,
|
||||
config.huggingface_tokenizer,
|
||||
llm_config.llm_api_key,
|
||||
)
|
||||
|
||||
if config.embedding_provider == "fastembed":
|
||||
|
||||
@lru_cache
|
||||
def create_embedding_engine(
|
||||
embedding_provider,
|
||||
embedding_model,
|
||||
embedding_dimensions,
|
||||
embedding_max_tokens,
|
||||
embedding_endpoint,
|
||||
embedding_api_key,
|
||||
embedding_api_version,
|
||||
huggingface_tokenizer,
|
||||
llm_api_key,
|
||||
):
|
||||
if embedding_provider == "fastembed":
|
||||
from .FastembedEmbeddingEngine import FastembedEmbeddingEngine
|
||||
|
||||
return FastembedEmbeddingEngine(
|
||||
model=config.embedding_model,
|
||||
dimensions=config.embedding_dimensions,
|
||||
max_tokens=config.embedding_max_tokens,
|
||||
model=embedding_model,
|
||||
dimensions=embedding_dimensions,
|
||||
max_tokens=embedding_max_tokens,
|
||||
)
|
||||
|
||||
if config.embedding_provider == "ollama":
|
||||
if embedding_provider == "ollama":
|
||||
from .OllamaEmbeddingEngine import OllamaEmbeddingEngine
|
||||
|
||||
return OllamaEmbeddingEngine(
|
||||
model=config.embedding_model,
|
||||
dimensions=config.embedding_dimensions,
|
||||
max_tokens=config.embedding_max_tokens,
|
||||
endpoint=config.embedding_endpoint,
|
||||
huggingface_tokenizer=config.huggingface_tokenizer,
|
||||
model=embedding_model,
|
||||
dimensions=embedding_dimensions,
|
||||
max_tokens=embedding_max_tokens,
|
||||
endpoint=embedding_endpoint,
|
||||
huggingface_tokenizer=huggingface_tokenizer,
|
||||
)
|
||||
|
||||
from .LiteLLMEmbeddingEngine import LiteLLMEmbeddingEngine
|
||||
|
||||
return LiteLLMEmbeddingEngine(
|
||||
provider=config.embedding_provider,
|
||||
api_key=config.embedding_api_key or llm_config.llm_api_key,
|
||||
endpoint=config.embedding_endpoint,
|
||||
api_version=config.embedding_api_version,
|
||||
model=config.embedding_model,
|
||||
dimensions=config.embedding_dimensions,
|
||||
max_tokens=config.embedding_max_tokens,
|
||||
provider=embedding_provider,
|
||||
api_key=embedding_api_key or llm_api_key,
|
||||
endpoint=embedding_endpoint,
|
||||
api_version=embedding_api_version,
|
||||
model=embedding_model,
|
||||
dimensions=embedding_dimensions,
|
||||
max_tokens=embedding_max_tokens,
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue