From 294ed1d960cd4285046b5bfe0168b31b67ce8812 Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Thu, 23 Jan 2025 16:52:35 +0100 Subject: [PATCH] feat: Add HuggingFace Tokenizer support --- .../embeddings/LiteLLMEmbeddingEngine.py | 3 ++- .../llm/tokenizer/HuggingFace/adapter.py | 18 +++++++++++++-- .../tests/unit/processing/utils/utils_test.py | 23 ------------------- 3 files changed, 18 insertions(+), 26 deletions(-) diff --git a/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py b/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py index 842256659..c037b45e0 100644 --- a/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +++ b/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py @@ -6,6 +6,7 @@ import litellm import os from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine from cognee.infrastructure.databases.exceptions.EmbeddingException import EmbeddingException +from cognee.infrastructure.llm.tokenizer.HuggingFace import HuggingFaceTokenizer from cognee.infrastructure.llm.tokenizer.TikToken import TikTokenTokenizer from transformers import AutoTokenizer import tiktoken # Assuming this is how you import TikToken @@ -123,7 +124,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine): if "openai" in self.provider.lower() or "gpt" in self.model: tokenizer = TikTokenTokenizer(model=model, max_tokens=self.max_tokens) else: - tokenizer = AutoTokenizer.from_pretrained(self.model) + tokenizer = HuggingFaceTokenizer(model=self.model, max_tokens=self.max_tokens) logger.debug(f"Tokenizer loaded for model: {self.model}") return tokenizer diff --git a/cognee/infrastructure/llm/tokenizer/HuggingFace/adapter.py b/cognee/infrastructure/llm/tokenizer/HuggingFace/adapter.py index 19238b62e..7b92fb76b 100644 --- a/cognee/infrastructure/llm/tokenizer/HuggingFace/adapter.py +++ b/cognee/infrastructure/llm/tokenizer/HuggingFace/adapter.py @@ -1,5 +1,7 @@ from typing import List, Any +from transformers import AutoTokenizer + from ..tokenizer_interface import TokenizerInterface @@ -12,11 +14,23 @@ class HuggingFaceTokenizer(TokenizerInterface): self.model = model self.max_tokens = max_tokens + self.tokenizer = AutoTokenizer.from_pretrained(model) + def extract_tokens(self, text: str) -> List[Any]: - raise NotImplementedError + tokens = self.tokenizer.tokenize(text) + return tokens def num_tokens_from_text(self, text: str) -> int: - raise NotImplementedError + """ + Returns the number of tokens in the given text. + Args: + text: str + + Returns: + number of tokens in the given text + + """ + return len(self.tokenizer.tokenize(text)) def trim_text_to_max_tokens(self, text: str) -> str: raise NotImplementedError diff --git a/cognee/tests/unit/processing/utils/utils_test.py b/cognee/tests/unit/processing/utils/utils_test.py index f8c325100..067ab6ea7 100644 --- a/cognee/tests/unit/processing/utils/utils_test.py +++ b/cognee/tests/unit/processing/utils/utils_test.py @@ -11,9 +11,7 @@ from cognee.shared.exceptions import IngestionError from cognee.shared.utils import ( get_anonymous_id, send_telemetry, - num_tokens_from_string, get_file_content_hash, - trim_text_to_max_tokens, prepare_edges, prepare_nodes, create_cognee_style_network_with_logo, @@ -45,15 +43,6 @@ def test_get_anonymous_id(mock_open_file, mock_makedirs, temp_dir): # args, kwargs = mock_post.call_args # assert kwargs["json"]["event_name"] == "test_event" -# -# @patch("tiktoken.encoding_for_model") -# def test_num_tokens_from_string(mock_encoding): -# mock_encoding.return_value.encode = lambda x: list(x) -# -# assert num_tokens_from_string("hello", "test_encoding") == 5 -# assert num_tokens_from_string("world", "test_encoding") == 5 -# - @patch("builtins.open", new_callable=mock_open, read_data=b"test_data") def test_get_file_content_hash_file(mock_open_file): @@ -73,18 +62,6 @@ def test_get_file_content_hash_stream(): assert result == expected_hash -# def test_trim_text_to_max_tokens(): -# text = "This is a test string with multiple words." -# encoding_name = "test_encoding" -# -# with patch("tiktoken.get_encoding") as mock_get_encoding: -# mock_get_encoding.return_value.encode = lambda x: list(x) -# mock_get_encoding.return_value.decode = lambda x: "".join(x) -# -# result = trim_text_to_max_tokens(text, 5, encoding_name) -# assert result == text[:5] - - def test_prepare_edges(): graph = nx.MultiDiGraph() graph.add_edge("A", "B", key="AB", weight=1)