feat: Add HuggingFace Tokenizer support
This commit is contained in:
parent
93249c72c5
commit
294ed1d960
3 changed files with 18 additions and 26 deletions
|
|
@ -6,6 +6,7 @@ import litellm
|
|||
import os
|
||||
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
|
||||
from cognee.infrastructure.databases.exceptions.EmbeddingException import EmbeddingException
|
||||
from cognee.infrastructure.llm.tokenizer.HuggingFace import HuggingFaceTokenizer
|
||||
from cognee.infrastructure.llm.tokenizer.TikToken import TikTokenTokenizer
|
||||
from transformers import AutoTokenizer
|
||||
import tiktoken # Assuming this is how you import TikToken
|
||||
|
|
@ -123,7 +124,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
|||
if "openai" in self.provider.lower() or "gpt" in self.model:
|
||||
tokenizer = TikTokenTokenizer(model=model, max_tokens=self.max_tokens)
|
||||
else:
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.model)
|
||||
tokenizer = HuggingFaceTokenizer(model=self.model, max_tokens=self.max_tokens)
|
||||
|
||||
logger.debug(f"Tokenizer loaded for model: {self.model}")
|
||||
return tokenizer
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
from typing import List, Any
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from ..tokenizer_interface import TokenizerInterface
|
||||
|
||||
|
||||
|
|
@ -12,11 +14,23 @@ class HuggingFaceTokenizer(TokenizerInterface):
|
|||
self.model = model
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model)
|
||||
|
||||
def extract_tokens(self, text: str) -> List[Any]:
|
||||
raise NotImplementedError
|
||||
tokens = self.tokenizer.tokenize(text)
|
||||
return tokens
|
||||
|
||||
def num_tokens_from_text(self, text: str) -> int:
|
||||
raise NotImplementedError
|
||||
"""
|
||||
Returns the number of tokens in the given text.
|
||||
Args:
|
||||
text: str
|
||||
|
||||
Returns:
|
||||
number of tokens in the given text
|
||||
|
||||
"""
|
||||
return len(self.tokenizer.tokenize(text))
|
||||
|
||||
def trim_text_to_max_tokens(self, text: str) -> str:
|
||||
raise NotImplementedError
|
||||
|
|
|
|||
|
|
@ -11,9 +11,7 @@ from cognee.shared.exceptions import IngestionError
|
|||
from cognee.shared.utils import (
|
||||
get_anonymous_id,
|
||||
send_telemetry,
|
||||
num_tokens_from_string,
|
||||
get_file_content_hash,
|
||||
trim_text_to_max_tokens,
|
||||
prepare_edges,
|
||||
prepare_nodes,
|
||||
create_cognee_style_network_with_logo,
|
||||
|
|
@ -45,15 +43,6 @@ def test_get_anonymous_id(mock_open_file, mock_makedirs, temp_dir):
|
|||
# args, kwargs = mock_post.call_args
|
||||
# assert kwargs["json"]["event_name"] == "test_event"
|
||||
|
||||
#
|
||||
# @patch("tiktoken.encoding_for_model")
|
||||
# def test_num_tokens_from_string(mock_encoding):
|
||||
# mock_encoding.return_value.encode = lambda x: list(x)
|
||||
#
|
||||
# assert num_tokens_from_string("hello", "test_encoding") == 5
|
||||
# assert num_tokens_from_string("world", "test_encoding") == 5
|
||||
#
|
||||
|
||||
|
||||
@patch("builtins.open", new_callable=mock_open, read_data=b"test_data")
|
||||
def test_get_file_content_hash_file(mock_open_file):
|
||||
|
|
@ -73,18 +62,6 @@ def test_get_file_content_hash_stream():
|
|||
assert result == expected_hash
|
||||
|
||||
|
||||
# def test_trim_text_to_max_tokens():
|
||||
# text = "This is a test string with multiple words."
|
||||
# encoding_name = "test_encoding"
|
||||
#
|
||||
# with patch("tiktoken.get_encoding") as mock_get_encoding:
|
||||
# mock_get_encoding.return_value.encode = lambda x: list(x)
|
||||
# mock_get_encoding.return_value.decode = lambda x: "".join(x)
|
||||
#
|
||||
# result = trim_text_to_max_tokens(text, 5, encoding_name)
|
||||
# assert result == text[:5]
|
||||
|
||||
|
||||
def test_prepare_edges():
|
||||
graph = nx.MultiDiGraph()
|
||||
graph.add_edge("A", "B", key="AB", weight=1)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue