refactor: Change variable and function names based on PR comments
Change variable and function names based on PR comments
This commit is contained in:
parent
77a72851fc
commit
0a9f1349f2
6 changed files with 8 additions and 8 deletions
|
|
@ -40,7 +40,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
|||
self.model = model
|
||||
self.dimensions = dimensions
|
||||
self.max_tokens = max_tokens
|
||||
self.tokenizer = self.set_tokenizer()
|
||||
self.tokenizer = self.get_tokenizer()
|
||||
|
||||
enable_mocking = os.getenv("MOCK_EMBEDDING", "false")
|
||||
if isinstance(enable_mocking, bool):
|
||||
|
|
@ -114,7 +114,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
|||
def get_vector_size(self) -> int:
|
||||
return self.dimensions
|
||||
|
||||
def set_tokenizer(self):
|
||||
def get_tokenizer(self):
|
||||
logger.debug(f"Loading tokenizer for model {self.model}...")
|
||||
# If model also contains provider information, extract only model information
|
||||
model = self.model.split("/")[-1]
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ class GeminiTokenizer(TokenizerInterface):
|
|||
def extract_tokens(self, text: str) -> List[Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
def num_tokens_from_text(self, text: str) -> int:
|
||||
def count_tokens(self, text: str) -> int:
|
||||
"""
|
||||
Returns the number of tokens in the given text.
|
||||
Args:
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ class HuggingFaceTokenizer(TokenizerInterface):
|
|||
tokens = self.tokenizer.tokenize(text)
|
||||
return tokens
|
||||
|
||||
def num_tokens_from_text(self, text: str) -> int:
|
||||
def count_tokens(self, text: str) -> int:
|
||||
"""
|
||||
Returns the number of tokens in the given text.
|
||||
Args:
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ class TikTokenTokenizer(TokenizerInterface):
|
|||
tokens.append(token)
|
||||
return tokens
|
||||
|
||||
def num_tokens_from_text(self, text: str) -> int:
|
||||
def count_tokens(self, text: str) -> int:
|
||||
"""
|
||||
Returns the number of tokens in the given text.
|
||||
Args:
|
||||
|
|
@ -54,7 +54,7 @@ class TikTokenTokenizer(TokenizerInterface):
|
|||
str: Trimmed version of text or original text if under the limit.
|
||||
"""
|
||||
# First check the number of tokens
|
||||
num_tokens = self.num_tokens_from_string(text)
|
||||
num_tokens = self.count_tokens(text)
|
||||
|
||||
# If the number of tokens is within the limit, return the text as is
|
||||
if num_tokens <= self.max_tokens:
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ class TokenizerInterface(Protocol):
|
|||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def num_tokens_from_text(self, text: str) -> int:
|
||||
def count_tokens(self, text: str) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ def chunk_by_paragraph(
|
|||
data, maximum_length=paragraph_length
|
||||
):
|
||||
# Check if this sentence would exceed length limit
|
||||
token_count = embedding_engine.tokenizer.num_tokens_from_text(sentence)
|
||||
token_count = embedding_engine.tokenizer.count_tokens(sentence)
|
||||
|
||||
if current_word_count > 0 and (
|
||||
current_word_count + word_count > paragraph_length
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue