refactor: cleanup unused code and improve tokenizer loading logic

Signed-off-by: Faizan Shaikh <faizansk9292@gmail.com>
This commit is contained in:
Faizan Shaikh 2025-12-19 18:43:32 +05:30
parent f637f80d7a
commit 240d50e96f

View file

@ -50,7 +50,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
dimensions: int
mock: bool
MAX_RETRIES = 5
mock: bool
def __init__(
self,
@ -71,7 +71,6 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
self.dimensions = dimensions
self.max_completion_tokens = max_completion_tokens
self.tokenizer = self.get_tokenizer()
self.retry_count = 0
self.batch_size = batch_size
enable_mocking = os.getenv("MOCK_EMBEDDING", "false")
@ -195,40 +194,51 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
The tokenizer instance compatible with the model.
"""
logger.debug(f"Loading tokenizer for model {self.model}...")
# If model also contains provider information, extract only model information
model = self.model.split("/")[-1]
# Determine the effective model name.
# If the provider is OpenAI-compatible but not OpenAI itself, we might have a prefix like "openai/".
# We strip known prefixes to get the actual model ID for HuggingFace.
effective_model = self.model
if "hosted_vllm/" in effective_model:
effective_model = effective_model.replace("hosted_vllm/", "")
if "openai/" in effective_model:
effective_model = effective_model.replace("openai/", "")
if "openai" in self.provider.lower():
tokenizer = TikTokenTokenizer(
model=model, max_completion_tokens=self.max_completion_tokens
provider_lower = self.provider.lower()
if "openai" in provider_lower:
# TikToken expects the full model string usually, or at least how it was passed.
# But if it's "openai/gpt-4", split acts on self.model.
# Original code: model = self.model.split("/")[-1]
return TikTokenTokenizer(
model=self.model.split("/")[-1],
max_completion_tokens=self.max_completion_tokens
)
elif "gemini" in self.provider.lower():
# Since Gemini tokenization needs to send an API request to get the token count we will use TikToken to
# count tokens as we calculate tokens word by word
tokenizer = TikTokenTokenizer(
model=None, max_completion_tokens=self.max_completion_tokens
elif "gemini" in provider_lower:
return TikTokenTokenizer(
model=None,
max_completion_tokens=self.max_completion_tokens
)
# Note: Gemini Tokenizer expects an LLM model as input and not the embedding model
# tokenizer = GeminiTokenizer(
# llm_model=llm_model, max_completion_tokens=self.max_completion_tokens
# )
elif "mistral" in self.provider.lower():
tokenizer = MistralTokenizer(
model=model, max_completion_tokens=self.max_completion_tokens
elif "mistral" in provider_lower:
return MistralTokenizer(
model=self.model.split("/")[-1],
max_completion_tokens=self.max_completion_tokens
)
else:
try:
tokenizer = HuggingFaceTokenizer(
model=self.model.replace("hosted_vllm/", "").replace("openai/", ""),
# Use the effective model name (stripped of prefixes) for HuggingFace
return HuggingFaceTokenizer(
model=effective_model,
max_completion_tokens=self.max_completion_tokens,
)
except Exception as e:
logger.warning(f"Could not get tokenizer from HuggingFace due to: {e}")
logger.info("Switching to TikToken default tokenizer.")
tokenizer = TikTokenTokenizer(
model=None, max_completion_tokens=self.max_completion_tokens
return TikTokenTokenizer(
model=None,
max_completion_tokens=self.max_completion_tokens
)
logger.debug(f"Tokenizer loaded for model: {self.model}")
return tokenizer