feat: Add HuggingFace Tokenizer support

This commit is contained in:
Igor Ilic 2025-01-23 16:52:35 +01:00
parent 93249c72c5
commit 294ed1d960
3 changed files with 18 additions and 26 deletions

View file

@ -6,6 +6,7 @@ import litellm
import os
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
from cognee.infrastructure.databases.exceptions.EmbeddingException import EmbeddingException
from cognee.infrastructure.llm.tokenizer.HuggingFace import HuggingFaceTokenizer
from cognee.infrastructure.llm.tokenizer.TikToken import TikTokenTokenizer
from transformers import AutoTokenizer
import tiktoken # Assuming this is how you import TikToken
@ -123,7 +124,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
if "openai" in self.provider.lower() or "gpt" in self.model:
tokenizer = TikTokenTokenizer(model=model, max_tokens=self.max_tokens)
else:
tokenizer = AutoTokenizer.from_pretrained(self.model)
tokenizer = HuggingFaceTokenizer(model=self.model, max_tokens=self.max_tokens)
logger.debug(f"Tokenizer loaded for model: {self.model}")
return tokenizer

View file

@ -1,5 +1,7 @@
from typing import List, Any
from transformers import AutoTokenizer
from ..tokenizer_interface import TokenizerInterface
@ -12,11 +14,23 @@ class HuggingFaceTokenizer(TokenizerInterface):
self.model = model
self.max_tokens = max_tokens
self.tokenizer = AutoTokenizer.from_pretrained(model)
def extract_tokens(self, text: str) -> List[Any]:
raise NotImplementedError
tokens = self.tokenizer.tokenize(text)
return tokens
def num_tokens_from_text(self, text: str) -> int:
raise NotImplementedError
"""
Returns the number of tokens in the given text.
Args:
text: str
Returns:
number of tokens in the given text
"""
return len(self.tokenizer.tokenize(text))
def trim_text_to_max_tokens(self, text: str) -> str:
raise NotImplementedError

View file

@ -11,9 +11,7 @@ from cognee.shared.exceptions import IngestionError
from cognee.shared.utils import (
get_anonymous_id,
send_telemetry,
num_tokens_from_string,
get_file_content_hash,
trim_text_to_max_tokens,
prepare_edges,
prepare_nodes,
create_cognee_style_network_with_logo,
@ -45,15 +43,6 @@ def test_get_anonymous_id(mock_open_file, mock_makedirs, temp_dir):
# args, kwargs = mock_post.call_args
# assert kwargs["json"]["event_name"] == "test_event"
#
# @patch("tiktoken.encoding_for_model")
# def test_num_tokens_from_string(mock_encoding):
# mock_encoding.return_value.encode = lambda x: list(x)
#
# assert num_tokens_from_string("hello", "test_encoding") == 5
# assert num_tokens_from_string("world", "test_encoding") == 5
#
@patch("builtins.open", new_callable=mock_open, read_data=b"test_data")
def test_get_file_content_hash_file(mock_open_file):
@ -73,18 +62,6 @@ def test_get_file_content_hash_stream():
assert result == expected_hash
# def test_trim_text_to_max_tokens():
# text = "This is a test string with multiple words."
# encoding_name = "test_encoding"
#
# with patch("tiktoken.get_encoding") as mock_get_encoding:
# mock_get_encoding.return_value.encode = lambda x: list(x)
# mock_get_encoding.return_value.decode = lambda x: "".join(x)
#
# result = trim_text_to_max_tokens(text, 5, encoding_name)
# assert result == text[:5]
def test_prepare_edges():
graph = nx.MultiDiGraph()
graph.add_edge("A", "B", key="AB", weight=1)