refactor: cleanup unused code and improve tokenizer loading logic
Signed-off-by: Faizan Shaikh <faizansk9292@gmail.com>
This commit is contained in:
parent
f637f80d7a
commit
240d50e96f
1 changed files with 36 additions and 26 deletions
|
|
@ -50,7 +50,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
||||||
dimensions: int
|
dimensions: int
|
||||||
mock: bool
|
mock: bool
|
||||||
|
|
||||||
MAX_RETRIES = 5
|
mock: bool
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -71,7 +71,6 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
||||||
self.dimensions = dimensions
|
self.dimensions = dimensions
|
||||||
self.max_completion_tokens = max_completion_tokens
|
self.max_completion_tokens = max_completion_tokens
|
||||||
self.tokenizer = self.get_tokenizer()
|
self.tokenizer = self.get_tokenizer()
|
||||||
self.retry_count = 0
|
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
|
|
||||||
enable_mocking = os.getenv("MOCK_EMBEDDING", "false")
|
enable_mocking = os.getenv("MOCK_EMBEDDING", "false")
|
||||||
|
|
@ -195,40 +194,51 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
||||||
The tokenizer instance compatible with the model.
|
The tokenizer instance compatible with the model.
|
||||||
"""
|
"""
|
||||||
logger.debug(f"Loading tokenizer for model {self.model}...")
|
logger.debug(f"Loading tokenizer for model {self.model}...")
|
||||||
# If model also contains provider information, extract only model information
|
|
||||||
model = self.model.split("/")[-1]
|
# Determine the effective model name.
|
||||||
|
# If the provider is OpenAI-compatible but not OpenAI itself, we might have a prefix like "openai/".
|
||||||
|
# We strip known prefixes to get the actual model ID for HuggingFace.
|
||||||
|
effective_model = self.model
|
||||||
|
if "hosted_vllm/" in effective_model:
|
||||||
|
effective_model = effective_model.replace("hosted_vllm/", "")
|
||||||
|
if "openai/" in effective_model:
|
||||||
|
effective_model = effective_model.replace("openai/", "")
|
||||||
|
|
||||||
if "openai" in self.provider.lower():
|
provider_lower = self.provider.lower()
|
||||||
tokenizer = TikTokenTokenizer(
|
|
||||||
model=model, max_completion_tokens=self.max_completion_tokens
|
if "openai" in provider_lower:
|
||||||
|
# TikToken expects the full model string usually, or at least how it was passed.
|
||||||
|
# But if it's "openai/gpt-4", split acts on self.model.
|
||||||
|
# Original code: model = self.model.split("/")[-1]
|
||||||
|
return TikTokenTokenizer(
|
||||||
|
model=self.model.split("/")[-1],
|
||||||
|
max_completion_tokens=self.max_completion_tokens
|
||||||
)
|
)
|
||||||
elif "gemini" in self.provider.lower():
|
|
||||||
# Since Gemini tokenization needs to send an API request to get the token count we will use TikToken to
|
elif "gemini" in provider_lower:
|
||||||
# count tokens as we calculate tokens word by word
|
return TikTokenTokenizer(
|
||||||
tokenizer = TikTokenTokenizer(
|
model=None,
|
||||||
model=None, max_completion_tokens=self.max_completion_tokens
|
max_completion_tokens=self.max_completion_tokens
|
||||||
)
|
)
|
||||||
# Note: Gemini Tokenizer expects an LLM model as input and not the embedding model
|
|
||||||
# tokenizer = GeminiTokenizer(
|
elif "mistral" in provider_lower:
|
||||||
# llm_model=llm_model, max_completion_tokens=self.max_completion_tokens
|
return MistralTokenizer(
|
||||||
# )
|
model=self.model.split("/")[-1],
|
||||||
elif "mistral" in self.provider.lower():
|
max_completion_tokens=self.max_completion_tokens
|
||||||
tokenizer = MistralTokenizer(
|
|
||||||
model=model, max_completion_tokens=self.max_completion_tokens
|
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
tokenizer = HuggingFaceTokenizer(
|
# Use the effective model name (stripped of prefixes) for HuggingFace
|
||||||
model=self.model.replace("hosted_vllm/", "").replace("openai/", ""),
|
return HuggingFaceTokenizer(
|
||||||
|
model=effective_model,
|
||||||
max_completion_tokens=self.max_completion_tokens,
|
max_completion_tokens=self.max_completion_tokens,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Could not get tokenizer from HuggingFace due to: {e}")
|
logger.warning(f"Could not get tokenizer from HuggingFace due to: {e}")
|
||||||
logger.info("Switching to TikToken default tokenizer.")
|
logger.info("Switching to TikToken default tokenizer.")
|
||||||
tokenizer = TikTokenTokenizer(
|
return TikTokenTokenizer(
|
||||||
model=None, max_completion_tokens=self.max_completion_tokens
|
model=None,
|
||||||
|
max_completion_tokens=self.max_completion_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(f"Tokenizer loaded for model: {self.model}")
|
|
||||||
return tokenizer
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue