refactor: cleanup unused code and improve tokenizer loading logic
Signed-off-by: Faizan Shaikh <faizansk9292@gmail.com>
This commit is contained in:
parent
f637f80d7a
commit
240d50e96f
1 changed files with 36 additions and 26 deletions
|
|
@ -50,7 +50,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
|||
dimensions: int
|
||||
mock: bool
|
||||
|
||||
MAX_RETRIES = 5
|
||||
mock: bool
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -71,7 +71,6 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
|||
self.dimensions = dimensions
|
||||
self.max_completion_tokens = max_completion_tokens
|
||||
self.tokenizer = self.get_tokenizer()
|
||||
self.retry_count = 0
|
||||
self.batch_size = batch_size
|
||||
|
||||
enable_mocking = os.getenv("MOCK_EMBEDDING", "false")
|
||||
|
|
@ -195,40 +194,51 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
|||
The tokenizer instance compatible with the model.
|
||||
"""
|
||||
logger.debug(f"Loading tokenizer for model {self.model}...")
|
||||
# If model also contains provider information, extract only model information
|
||||
model = self.model.split("/")[-1]
|
||||
|
||||
# Determine the effective model name.
|
||||
# If the provider is OpenAI-compatible but not OpenAI itself, we might have a prefix like "openai/".
|
||||
# We strip known prefixes to get the actual model ID for HuggingFace.
|
||||
effective_model = self.model
|
||||
if "hosted_vllm/" in effective_model:
|
||||
effective_model = effective_model.replace("hosted_vllm/", "")
|
||||
if "openai/" in effective_model:
|
||||
effective_model = effective_model.replace("openai/", "")
|
||||
|
||||
if "openai" in self.provider.lower():
|
||||
tokenizer = TikTokenTokenizer(
|
||||
model=model, max_completion_tokens=self.max_completion_tokens
|
||||
provider_lower = self.provider.lower()
|
||||
|
||||
if "openai" in provider_lower:
|
||||
# TikToken expects the full model string usually, or at least how it was passed.
|
||||
# But if it's "openai/gpt-4", split acts on self.model.
|
||||
# Original code: model = self.model.split("/")[-1]
|
||||
return TikTokenTokenizer(
|
||||
model=self.model.split("/")[-1],
|
||||
max_completion_tokens=self.max_completion_tokens
|
||||
)
|
||||
elif "gemini" in self.provider.lower():
|
||||
# Since Gemini tokenization needs to send an API request to get the token count we will use TikToken to
|
||||
# count tokens as we calculate tokens word by word
|
||||
tokenizer = TikTokenTokenizer(
|
||||
model=None, max_completion_tokens=self.max_completion_tokens
|
||||
|
||||
elif "gemini" in provider_lower:
|
||||
return TikTokenTokenizer(
|
||||
model=None,
|
||||
max_completion_tokens=self.max_completion_tokens
|
||||
)
|
||||
# Note: Gemini Tokenizer expects an LLM model as input and not the embedding model
|
||||
# tokenizer = GeminiTokenizer(
|
||||
# llm_model=llm_model, max_completion_tokens=self.max_completion_tokens
|
||||
# )
|
||||
elif "mistral" in self.provider.lower():
|
||||
tokenizer = MistralTokenizer(
|
||||
model=model, max_completion_tokens=self.max_completion_tokens
|
||||
|
||||
elif "mistral" in provider_lower:
|
||||
return MistralTokenizer(
|
||||
model=self.model.split("/")[-1],
|
||||
max_completion_tokens=self.max_completion_tokens
|
||||
)
|
||||
|
||||
else:
|
||||
try:
|
||||
tokenizer = HuggingFaceTokenizer(
|
||||
model=self.model.replace("hosted_vllm/", "").replace("openai/", ""),
|
||||
# Use the effective model name (stripped of prefixes) for HuggingFace
|
||||
return HuggingFaceTokenizer(
|
||||
model=effective_model,
|
||||
max_completion_tokens=self.max_completion_tokens,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not get tokenizer from HuggingFace due to: {e}")
|
||||
logger.info("Switching to TikToken default tokenizer.")
|
||||
tokenizer = TikTokenTokenizer(
|
||||
model=None, max_completion_tokens=self.max_completion_tokens
|
||||
return TikTokenTokenizer(
|
||||
model=None,
|
||||
max_completion_tokens=self.max_completion_tokens
|
||||
)
|
||||
|
||||
logger.debug(f"Tokenizer loaded for model: {self.model}")
|
||||
return tokenizer
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue