diff --git a/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py b/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py index 10992b22c..f81bc8515 100644 --- a/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +++ b/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py @@ -40,7 +40,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine): self.model = model self.dimensions = dimensions self.max_tokens = max_tokens - self.tokenizer = self.set_tokenizer() + self.tokenizer = self.get_tokenizer() enable_mocking = os.getenv("MOCK_EMBEDDING", "false") if isinstance(enable_mocking, bool): @@ -114,7 +114,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine): def get_vector_size(self) -> int: return self.dimensions - def set_tokenizer(self): + def get_tokenizer(self): logger.debug(f"Loading tokenizer for model {self.model}...") # If model also contains provider information, extract only model information model = self.model.split("/")[-1] diff --git a/cognee/infrastructure/llm/tokenizer/Gemini/adapter.py b/cognee/infrastructure/llm/tokenizer/Gemini/adapter.py index f3131ea08..e4cc4f145 100644 --- a/cognee/infrastructure/llm/tokenizer/Gemini/adapter.py +++ b/cognee/infrastructure/llm/tokenizer/Gemini/adapter.py @@ -26,7 +26,7 @@ class GeminiTokenizer(TokenizerInterface): def extract_tokens(self, text: str) -> List[Any]: raise NotImplementedError - def num_tokens_from_text(self, text: str) -> int: + def count_tokens(self, text: str) -> int: """ Returns the number of tokens in the given text. Args: diff --git a/cognee/infrastructure/llm/tokenizer/HuggingFace/adapter.py b/cognee/infrastructure/llm/tokenizer/HuggingFace/adapter.py index a8eac29d9..878458414 100644 --- a/cognee/infrastructure/llm/tokenizer/HuggingFace/adapter.py +++ b/cognee/infrastructure/llm/tokenizer/HuggingFace/adapter.py @@ -20,7 +20,7 @@ class HuggingFaceTokenizer(TokenizerInterface): tokens = self.tokenizer.tokenize(text) return tokens - def num_tokens_from_text(self, text: str) -> int: + def count_tokens(self, text: str) -> int: """ Returns the number of tokens in the given text. Args: diff --git a/cognee/infrastructure/llm/tokenizer/TikToken/adapter.py b/cognee/infrastructure/llm/tokenizer/TikToken/adapter.py index 862a79296..3d649ef38 100644 --- a/cognee/infrastructure/llm/tokenizer/TikToken/adapter.py +++ b/cognee/infrastructure/llm/tokenizer/TikToken/adapter.py @@ -30,7 +30,7 @@ class TikTokenTokenizer(TokenizerInterface): tokens.append(token) return tokens - def num_tokens_from_text(self, text: str) -> int: + def count_tokens(self, text: str) -> int: """ Returns the number of tokens in the given text. Args: @@ -54,7 +54,7 @@ class TikTokenTokenizer(TokenizerInterface): str: Trimmed version of text or original text if under the limit. """ # First check the number of tokens - num_tokens = self.num_tokens_from_string(text) + num_tokens = self.count_tokens(text) # If the number of tokens is within the limit, return the text as is if num_tokens <= self.max_tokens: diff --git a/cognee/infrastructure/llm/tokenizer/tokenizer_interface.py b/cognee/infrastructure/llm/tokenizer/tokenizer_interface.py index abd111f12..c533f0cf9 100644 --- a/cognee/infrastructure/llm/tokenizer/tokenizer_interface.py +++ b/cognee/infrastructure/llm/tokenizer/tokenizer_interface.py @@ -10,7 +10,7 @@ class TokenizerInterface(Protocol): raise NotImplementedError @abstractmethod - def num_tokens_from_text(self, text: str) -> int: + def count_tokens(self, text: str) -> int: raise NotImplementedError @abstractmethod diff --git a/cognee/tasks/chunks/chunk_by_paragraph.py b/cognee/tasks/chunks/chunk_by_paragraph.py index 077db1cd4..7d7221b87 100644 --- a/cognee/tasks/chunks/chunk_by_paragraph.py +++ b/cognee/tasks/chunks/chunk_by_paragraph.py @@ -39,7 +39,7 @@ def chunk_by_paragraph( data, maximum_length=paragraph_length ): # Check if this sentence would exceed length limit - token_count = embedding_engine.tokenizer.num_tokens_from_text(sentence) + token_count = embedding_engine.tokenizer.count_tokens(sentence) if current_word_count > 0 and ( current_word_count + word_count > paragraph_length