feat: Add HuggingFace Tokenizer support
This commit is contained in:
parent
93249c72c5
commit
294ed1d960
3 changed files with 18 additions and 26 deletions
|
|
@ -6,6 +6,7 @@ import litellm
|
||||||
import os
|
import os
|
||||||
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
|
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
|
||||||
from cognee.infrastructure.databases.exceptions.EmbeddingException import EmbeddingException
|
from cognee.infrastructure.databases.exceptions.EmbeddingException import EmbeddingException
|
||||||
|
from cognee.infrastructure.llm.tokenizer.HuggingFace import HuggingFaceTokenizer
|
||||||
from cognee.infrastructure.llm.tokenizer.TikToken import TikTokenTokenizer
|
from cognee.infrastructure.llm.tokenizer.TikToken import TikTokenTokenizer
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
import tiktoken # Assuming this is how you import TikToken
|
import tiktoken # Assuming this is how you import TikToken
|
||||||
|
|
@ -123,7 +124,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
||||||
if "openai" in self.provider.lower() or "gpt" in self.model:
|
if "openai" in self.provider.lower() or "gpt" in self.model:
|
||||||
tokenizer = TikTokenTokenizer(model=model, max_tokens=self.max_tokens)
|
tokenizer = TikTokenTokenizer(model=model, max_tokens=self.max_tokens)
|
||||||
else:
|
else:
|
||||||
tokenizer = AutoTokenizer.from_pretrained(self.model)
|
tokenizer = HuggingFaceTokenizer(model=self.model, max_tokens=self.max_tokens)
|
||||||
|
|
||||||
logger.debug(f"Tokenizer loaded for model: {self.model}")
|
logger.debug(f"Tokenizer loaded for model: {self.model}")
|
||||||
return tokenizer
|
return tokenizer
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,7 @@
|
||||||
from typing import List, Any
|
from typing import List, Any
|
||||||
|
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from ..tokenizer_interface import TokenizerInterface
|
from ..tokenizer_interface import TokenizerInterface
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -12,11 +14,23 @@ class HuggingFaceTokenizer(TokenizerInterface):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.max_tokens = max_tokens
|
self.max_tokens = max_tokens
|
||||||
|
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(model)
|
||||||
|
|
||||||
def extract_tokens(self, text: str) -> List[Any]:
|
def extract_tokens(self, text: str) -> List[Any]:
|
||||||
raise NotImplementedError
|
tokens = self.tokenizer.tokenize(text)
|
||||||
|
return tokens
|
||||||
|
|
||||||
def num_tokens_from_text(self, text: str) -> int:
|
def num_tokens_from_text(self, text: str) -> int:
|
||||||
raise NotImplementedError
|
"""
|
||||||
|
Returns the number of tokens in the given text.
|
||||||
|
Args:
|
||||||
|
text: str
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
number of tokens in the given text
|
||||||
|
|
||||||
|
"""
|
||||||
|
return len(self.tokenizer.tokenize(text))
|
||||||
|
|
||||||
def trim_text_to_max_tokens(self, text: str) -> str:
|
def trim_text_to_max_tokens(self, text: str) -> str:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
|
||||||
|
|
@ -11,9 +11,7 @@ from cognee.shared.exceptions import IngestionError
|
||||||
from cognee.shared.utils import (
|
from cognee.shared.utils import (
|
||||||
get_anonymous_id,
|
get_anonymous_id,
|
||||||
send_telemetry,
|
send_telemetry,
|
||||||
num_tokens_from_string,
|
|
||||||
get_file_content_hash,
|
get_file_content_hash,
|
||||||
trim_text_to_max_tokens,
|
|
||||||
prepare_edges,
|
prepare_edges,
|
||||||
prepare_nodes,
|
prepare_nodes,
|
||||||
create_cognee_style_network_with_logo,
|
create_cognee_style_network_with_logo,
|
||||||
|
|
@ -45,15 +43,6 @@ def test_get_anonymous_id(mock_open_file, mock_makedirs, temp_dir):
|
||||||
# args, kwargs = mock_post.call_args
|
# args, kwargs = mock_post.call_args
|
||||||
# assert kwargs["json"]["event_name"] == "test_event"
|
# assert kwargs["json"]["event_name"] == "test_event"
|
||||||
|
|
||||||
#
|
|
||||||
# @patch("tiktoken.encoding_for_model")
|
|
||||||
# def test_num_tokens_from_string(mock_encoding):
|
|
||||||
# mock_encoding.return_value.encode = lambda x: list(x)
|
|
||||||
#
|
|
||||||
# assert num_tokens_from_string("hello", "test_encoding") == 5
|
|
||||||
# assert num_tokens_from_string("world", "test_encoding") == 5
|
|
||||||
#
|
|
||||||
|
|
||||||
|
|
||||||
@patch("builtins.open", new_callable=mock_open, read_data=b"test_data")
|
@patch("builtins.open", new_callable=mock_open, read_data=b"test_data")
|
||||||
def test_get_file_content_hash_file(mock_open_file):
|
def test_get_file_content_hash_file(mock_open_file):
|
||||||
|
|
@ -73,18 +62,6 @@ def test_get_file_content_hash_stream():
|
||||||
assert result == expected_hash
|
assert result == expected_hash
|
||||||
|
|
||||||
|
|
||||||
# def test_trim_text_to_max_tokens():
|
|
||||||
# text = "This is a test string with multiple words."
|
|
||||||
# encoding_name = "test_encoding"
|
|
||||||
#
|
|
||||||
# with patch("tiktoken.get_encoding") as mock_get_encoding:
|
|
||||||
# mock_get_encoding.return_value.encode = lambda x: list(x)
|
|
||||||
# mock_get_encoding.return_value.decode = lambda x: "".join(x)
|
|
||||||
#
|
|
||||||
# result = trim_text_to_max_tokens(text, 5, encoding_name)
|
|
||||||
# assert result == text[:5]
|
|
||||||
|
|
||||||
|
|
||||||
def test_prepare_edges():
|
def test_prepare_edges():
|
||||||
graph = nx.MultiDiGraph()
|
graph = nx.MultiDiGraph()
|
||||||
graph.add_edge("A", "B", key="AB", weight=1)
|
graph.add_edge("A", "B", key="AB", weight=1)
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue