diff --git a/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py b/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py index 7aeb39b59..d83a2bcee 100644 --- a/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +++ b/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py @@ -50,7 +50,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine): dimensions: int mock: bool - MAX_RETRIES = 5 + mock: bool def __init__( self, @@ -71,7 +71,6 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine): self.dimensions = dimensions self.max_completion_tokens = max_completion_tokens self.tokenizer = self.get_tokenizer() - self.retry_count = 0 self.batch_size = batch_size enable_mocking = os.getenv("MOCK_EMBEDDING", "false") @@ -195,40 +194,51 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine): The tokenizer instance compatible with the model. """ logger.debug(f"Loading tokenizer for model {self.model}...") - # If model also contains provider information, extract only model information - model = self.model.split("/")[-1] + + # Determine the effective model name. + # If the provider is OpenAI-compatible but not OpenAI itself, we might have a prefix like "openai/". + # We strip known prefixes to get the actual model ID for HuggingFace. + effective_model = self.model + if "hosted_vllm/" in effective_model: + effective_model = effective_model.replace("hosted_vllm/", "") + if "openai/" in effective_model: + effective_model = effective_model.replace("openai/", "") - if "openai" in self.provider.lower(): - tokenizer = TikTokenTokenizer( - model=model, max_completion_tokens=self.max_completion_tokens + provider_lower = self.provider.lower() + + if "openai" in provider_lower: + # TikToken expects the full model string usually, or at least how it was passed. + # But if it's "openai/gpt-4", split acts on self.model. + # Original code: model = self.model.split("/")[-1] + return TikTokenTokenizer( + model=self.model.split("/")[-1], + max_completion_tokens=self.max_completion_tokens ) - elif "gemini" in self.provider.lower(): - # Since Gemini tokenization needs to send an API request to get the token count we will use TikToken to - # count tokens as we calculate tokens word by word - tokenizer = TikTokenTokenizer( - model=None, max_completion_tokens=self.max_completion_tokens + + elif "gemini" in provider_lower: + return TikTokenTokenizer( + model=None, + max_completion_tokens=self.max_completion_tokens ) - # Note: Gemini Tokenizer expects an LLM model as input and not the embedding model - # tokenizer = GeminiTokenizer( - # llm_model=llm_model, max_completion_tokens=self.max_completion_tokens - # ) - elif "mistral" in self.provider.lower(): - tokenizer = MistralTokenizer( - model=model, max_completion_tokens=self.max_completion_tokens + + elif "mistral" in provider_lower: + return MistralTokenizer( + model=self.model.split("/")[-1], + max_completion_tokens=self.max_completion_tokens ) + else: try: - tokenizer = HuggingFaceTokenizer( - model=self.model.replace("hosted_vllm/", "").replace("openai/", ""), + # Use the effective model name (stripped of prefixes) for HuggingFace + return HuggingFaceTokenizer( + model=effective_model, max_completion_tokens=self.max_completion_tokens, ) except Exception as e: logger.warning(f"Could not get tokenizer from HuggingFace due to: {e}") logger.info("Switching to TikToken default tokenizer.") - tokenizer = TikTokenTokenizer( - model=None, max_completion_tokens=self.max_completion_tokens + return TikTokenTokenizer( + model=None, + max_completion_tokens=self.max_completion_tokens ) - logger.debug(f"Tokenizer loaded for model: {self.model}") - return tokenizer -