fix: HuggingFace tokenizer (#752)
<!-- .github/pull_request_template.md --> ## Description Resolve issue noticed by [RyabykinIlya](https://github.com/RyabykinIlya) where too many HuggingFace requests have been sent due to the embedding engine not working as a singleton per config ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin. --------- Signed-off-by: Ryabykin Ilya <ryabykinia@sibur.ru> Co-authored-by: greshish <ryabykinia@yandex.ru> Co-authored-by: Ryabykin Ilya <ryabykinia@sibur.ru>
This commit is contained in:
parent
9ba12b25ef
commit
ba2de9bb22
2 changed files with 58 additions and 17 deletions
|
|
@ -14,6 +14,21 @@ class EmbeddingConfig(BaseSettings):
|
||||||
huggingface_tokenizer: Optional[str] = None
|
huggingface_tokenizer: Optional[str] = None
|
||||||
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
"""
|
||||||
|
Serialize all embedding configuration settings to a dictionary.
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"embedding_provider": self.embedding_provider,
|
||||||
|
"embedding_model": self.embedding_model,
|
||||||
|
"embedding_dimensions": self.embedding_dimensions,
|
||||||
|
"embedding_endpoint": self.embedding_endpoint,
|
||||||
|
"embedding_api_key": self.embedding_api_key,
|
||||||
|
"embedding_api_version": self.embedding_api_version,
|
||||||
|
"embedding_max_tokens": self.embedding_max_tokens,
|
||||||
|
"huggingface_tokenizer": self.huggingface_tokenizer,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@lru_cache
|
@lru_cache
|
||||||
def get_embedding_config():
|
def get_embedding_config():
|
||||||
|
|
|
||||||
|
|
@ -1,40 +1,66 @@
|
||||||
from cognee.infrastructure.databases.vector.embeddings.config import get_embedding_config
|
from cognee.infrastructure.databases.vector.embeddings.config import get_embedding_config
|
||||||
from cognee.infrastructure.llm.config import get_llm_config
|
from cognee.infrastructure.llm.config import get_llm_config
|
||||||
from .EmbeddingEngine import EmbeddingEngine
|
from .EmbeddingEngine import EmbeddingEngine
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
|
|
||||||
def get_embedding_engine() -> EmbeddingEngine:
|
def get_embedding_engine() -> EmbeddingEngine:
|
||||||
config = get_embedding_config()
|
config = get_embedding_config()
|
||||||
llm_config = get_llm_config()
|
llm_config = get_llm_config()
|
||||||
|
# Embedding engine has to be a singleton based on configuration to ensure too many requests won't be sent to HuggingFace
|
||||||
|
return create_embedding_engine(
|
||||||
|
config.embedding_provider,
|
||||||
|
config.embedding_model,
|
||||||
|
config.embedding_dimensions,
|
||||||
|
config.embedding_max_tokens,
|
||||||
|
config.embedding_endpoint,
|
||||||
|
config.embedding_api_key,
|
||||||
|
config.embedding_api_version,
|
||||||
|
config.huggingface_tokenizer,
|
||||||
|
llm_config.llm_api_key,
|
||||||
|
)
|
||||||
|
|
||||||
if config.embedding_provider == "fastembed":
|
|
||||||
|
@lru_cache
|
||||||
|
def create_embedding_engine(
|
||||||
|
embedding_provider,
|
||||||
|
embedding_model,
|
||||||
|
embedding_dimensions,
|
||||||
|
embedding_max_tokens,
|
||||||
|
embedding_endpoint,
|
||||||
|
embedding_api_key,
|
||||||
|
embedding_api_version,
|
||||||
|
huggingface_tokenizer,
|
||||||
|
llm_api_key,
|
||||||
|
):
|
||||||
|
if embedding_provider == "fastembed":
|
||||||
from .FastembedEmbeddingEngine import FastembedEmbeddingEngine
|
from .FastembedEmbeddingEngine import FastembedEmbeddingEngine
|
||||||
|
|
||||||
return FastembedEmbeddingEngine(
|
return FastembedEmbeddingEngine(
|
||||||
model=config.embedding_model,
|
model=embedding_model,
|
||||||
dimensions=config.embedding_dimensions,
|
dimensions=embedding_dimensions,
|
||||||
max_tokens=config.embedding_max_tokens,
|
max_tokens=embedding_max_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.embedding_provider == "ollama":
|
if embedding_provider == "ollama":
|
||||||
from .OllamaEmbeddingEngine import OllamaEmbeddingEngine
|
from .OllamaEmbeddingEngine import OllamaEmbeddingEngine
|
||||||
|
|
||||||
return OllamaEmbeddingEngine(
|
return OllamaEmbeddingEngine(
|
||||||
model=config.embedding_model,
|
model=embedding_model,
|
||||||
dimensions=config.embedding_dimensions,
|
dimensions=embedding_dimensions,
|
||||||
max_tokens=config.embedding_max_tokens,
|
max_tokens=embedding_max_tokens,
|
||||||
endpoint=config.embedding_endpoint,
|
endpoint=embedding_endpoint,
|
||||||
huggingface_tokenizer=config.huggingface_tokenizer,
|
huggingface_tokenizer=huggingface_tokenizer,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .LiteLLMEmbeddingEngine import LiteLLMEmbeddingEngine
|
from .LiteLLMEmbeddingEngine import LiteLLMEmbeddingEngine
|
||||||
|
|
||||||
return LiteLLMEmbeddingEngine(
|
return LiteLLMEmbeddingEngine(
|
||||||
provider=config.embedding_provider,
|
provider=embedding_provider,
|
||||||
api_key=config.embedding_api_key or llm_config.llm_api_key,
|
api_key=embedding_api_key or llm_api_key,
|
||||||
endpoint=config.embedding_endpoint,
|
endpoint=embedding_endpoint,
|
||||||
api_version=config.embedding_api_version,
|
api_version=embedding_api_version,
|
||||||
model=config.embedding_model,
|
model=embedding_model,
|
||||||
dimensions=config.embedding_dimensions,
|
dimensions=embedding_dimensions,
|
||||||
max_tokens=config.embedding_max_tokens,
|
max_tokens=embedding_max_tokens,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue