refactor: Change variable and function names based on PR comments

Change variable and function names based on PR comments
This commit is contained in:
Igor Ilic 2025-01-28 10:10:29 +01:00
parent 77a72851fc
commit 0a9f1349f2
6 changed files with 8 additions and 8 deletions

View file

@ -40,7 +40,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
self.model = model self.model = model
self.dimensions = dimensions self.dimensions = dimensions
self.max_tokens = max_tokens self.max_tokens = max_tokens
self.tokenizer = self.set_tokenizer() self.tokenizer = self.get_tokenizer()
enable_mocking = os.getenv("MOCK_EMBEDDING", "false") enable_mocking = os.getenv("MOCK_EMBEDDING", "false")
if isinstance(enable_mocking, bool): if isinstance(enable_mocking, bool):
@ -114,7 +114,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
def get_vector_size(self) -> int: def get_vector_size(self) -> int:
return self.dimensions return self.dimensions
def set_tokenizer(self): def get_tokenizer(self):
logger.debug(f"Loading tokenizer for model {self.model}...") logger.debug(f"Loading tokenizer for model {self.model}...")
# If model also contains provider information, extract only model information # If model also contains provider information, extract only model information
model = self.model.split("/")[-1] model = self.model.split("/")[-1]

View file

@ -26,7 +26,7 @@ class GeminiTokenizer(TokenizerInterface):
def extract_tokens(self, text: str) -> List[Any]: def extract_tokens(self, text: str) -> List[Any]:
raise NotImplementedError raise NotImplementedError
def num_tokens_from_text(self, text: str) -> int: def count_tokens(self, text: str) -> int:
""" """
Returns the number of tokens in the given text. Returns the number of tokens in the given text.
Args: Args:

View file

@ -20,7 +20,7 @@ class HuggingFaceTokenizer(TokenizerInterface):
tokens = self.tokenizer.tokenize(text) tokens = self.tokenizer.tokenize(text)
return tokens return tokens
def num_tokens_from_text(self, text: str) -> int: def count_tokens(self, text: str) -> int:
""" """
Returns the number of tokens in the given text. Returns the number of tokens in the given text.
Args: Args:

View file

@ -30,7 +30,7 @@ class TikTokenTokenizer(TokenizerInterface):
tokens.append(token) tokens.append(token)
return tokens return tokens
def num_tokens_from_text(self, text: str) -> int: def count_tokens(self, text: str) -> int:
""" """
Returns the number of tokens in the given text. Returns the number of tokens in the given text.
Args: Args:
@ -54,7 +54,7 @@ class TikTokenTokenizer(TokenizerInterface):
str: Trimmed version of text or original text if under the limit. str: Trimmed version of text or original text if under the limit.
""" """
# First check the number of tokens # First check the number of tokens
num_tokens = self.num_tokens_from_string(text) num_tokens = self.count_tokens(text)
# If the number of tokens is within the limit, return the text as is # If the number of tokens is within the limit, return the text as is
if num_tokens <= self.max_tokens: if num_tokens <= self.max_tokens:

View file

@ -10,7 +10,7 @@ class TokenizerInterface(Protocol):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def num_tokens_from_text(self, text: str) -> int: def count_tokens(self, text: str) -> int:
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod

View file

@ -39,7 +39,7 @@ def chunk_by_paragraph(
data, maximum_length=paragraph_length data, maximum_length=paragraph_length
): ):
# Check if this sentence would exceed length limit # Check if this sentence would exceed length limit
token_count = embedding_engine.tokenizer.num_tokens_from_text(sentence) token_count = embedding_engine.tokenizer.count_tokens(sentence)
if current_word_count > 0 and ( if current_word_count > 0 and (
current_word_count + word_count > paragraph_length current_word_count + word_count > paragraph_length