refactor: Change variable and function names based on PR comments
Change variable and function names based on PR comments
This commit is contained in:
parent
77a72851fc
commit
0a9f1349f2
6 changed files with 8 additions and 8 deletions
|
|
@ -40,7 +40,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.dimensions = dimensions
|
self.dimensions = dimensions
|
||||||
self.max_tokens = max_tokens
|
self.max_tokens = max_tokens
|
||||||
self.tokenizer = self.set_tokenizer()
|
self.tokenizer = self.get_tokenizer()
|
||||||
|
|
||||||
enable_mocking = os.getenv("MOCK_EMBEDDING", "false")
|
enable_mocking = os.getenv("MOCK_EMBEDDING", "false")
|
||||||
if isinstance(enable_mocking, bool):
|
if isinstance(enable_mocking, bool):
|
||||||
|
|
@ -114,7 +114,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
||||||
def get_vector_size(self) -> int:
|
def get_vector_size(self) -> int:
|
||||||
return self.dimensions
|
return self.dimensions
|
||||||
|
|
||||||
def set_tokenizer(self):
|
def get_tokenizer(self):
|
||||||
logger.debug(f"Loading tokenizer for model {self.model}...")
|
logger.debug(f"Loading tokenizer for model {self.model}...")
|
||||||
# If model also contains provider information, extract only model information
|
# If model also contains provider information, extract only model information
|
||||||
model = self.model.split("/")[-1]
|
model = self.model.split("/")[-1]
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,7 @@ class GeminiTokenizer(TokenizerInterface):
|
||||||
def extract_tokens(self, text: str) -> List[Any]:
|
def extract_tokens(self, text: str) -> List[Any]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def num_tokens_from_text(self, text: str) -> int:
|
def count_tokens(self, text: str) -> int:
|
||||||
"""
|
"""
|
||||||
Returns the number of tokens in the given text.
|
Returns the number of tokens in the given text.
|
||||||
Args:
|
Args:
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,7 @@ class HuggingFaceTokenizer(TokenizerInterface):
|
||||||
tokens = self.tokenizer.tokenize(text)
|
tokens = self.tokenizer.tokenize(text)
|
||||||
return tokens
|
return tokens
|
||||||
|
|
||||||
def num_tokens_from_text(self, text: str) -> int:
|
def count_tokens(self, text: str) -> int:
|
||||||
"""
|
"""
|
||||||
Returns the number of tokens in the given text.
|
Returns the number of tokens in the given text.
|
||||||
Args:
|
Args:
|
||||||
|
|
|
||||||
|
|
@ -30,7 +30,7 @@ class TikTokenTokenizer(TokenizerInterface):
|
||||||
tokens.append(token)
|
tokens.append(token)
|
||||||
return tokens
|
return tokens
|
||||||
|
|
||||||
def num_tokens_from_text(self, text: str) -> int:
|
def count_tokens(self, text: str) -> int:
|
||||||
"""
|
"""
|
||||||
Returns the number of tokens in the given text.
|
Returns the number of tokens in the given text.
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -54,7 +54,7 @@ class TikTokenTokenizer(TokenizerInterface):
|
||||||
str: Trimmed version of text or original text if under the limit.
|
str: Trimmed version of text or original text if under the limit.
|
||||||
"""
|
"""
|
||||||
# First check the number of tokens
|
# First check the number of tokens
|
||||||
num_tokens = self.num_tokens_from_string(text)
|
num_tokens = self.count_tokens(text)
|
||||||
|
|
||||||
# If the number of tokens is within the limit, return the text as is
|
# If the number of tokens is within the limit, return the text as is
|
||||||
if num_tokens <= self.max_tokens:
|
if num_tokens <= self.max_tokens:
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ class TokenizerInterface(Protocol):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def num_tokens_from_text(self, text: str) -> int:
|
def count_tokens(self, text: str) -> int:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|
|
||||||
|
|
@ -39,7 +39,7 @@ def chunk_by_paragraph(
|
||||||
data, maximum_length=paragraph_length
|
data, maximum_length=paragraph_length
|
||||||
):
|
):
|
||||||
# Check if this sentence would exceed length limit
|
# Check if this sentence would exceed length limit
|
||||||
token_count = embedding_engine.tokenizer.num_tokens_from_text(sentence)
|
token_count = embedding_engine.tokenizer.count_tokens(sentence)
|
||||||
|
|
||||||
if current_word_count > 0 and (
|
if current_word_count > 0 and (
|
||||||
current_word_count + word_count > paragraph_length
|
current_word_count + word_count > paragraph_length
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue