cognee/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py
Igor Ilic 9faa47fc5a
feat: add default tokenizer in case hugging face is not available (#1177)
<!-- .github/pull_request_template.md -->

## Description
Add default tokenizer for custom models not available on HuggingFace

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.
2025-08-01 16:37:53 +02:00

188 lines
7 KiB
Python

import asyncio
from cognee.shared.logging_utils import get_logger
from typing import List, Optional
import numpy as np
import math
import litellm
import os
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
from cognee.infrastructure.databases.exceptions.EmbeddingException import EmbeddingException
from cognee.infrastructure.llm.tokenizer.Gemini import GeminiTokenizer
from cognee.infrastructure.llm.tokenizer.HuggingFace import HuggingFaceTokenizer
from cognee.infrastructure.llm.tokenizer.Mistral import MistralTokenizer
from cognee.infrastructure.llm.tokenizer.TikToken import TikTokenTokenizer
from cognee.infrastructure.llm.embedding_rate_limiter import (
embedding_rate_limit_async,
embedding_sleep_and_retry_async,
)
litellm.set_verbose = False
logger = get_logger("LiteLLMEmbeddingEngine")
class LiteLLMEmbeddingEngine(EmbeddingEngine):
"""
Engine for embedding text using a specific LLM model, supporting mock and actual
embedding calls.
Public methods:
- embed_text: Embed a list of strings into vector representations.
- get_vector_size: Retrieve the size of the embedding vectors.
- get_tokenizer: Load the appropriate tokenizer for the specified model.
"""
api_key: str
endpoint: str
api_version: str
provider: str
model: str
dimensions: int
mock: bool
MAX_RETRIES = 5
def __init__(
self,
model: Optional[str] = "openai/text-embedding-3-large",
provider: str = "openai",
dimensions: Optional[int] = 3072,
api_key: str = None,
endpoint: str = None,
api_version: str = None,
max_tokens: int = 512,
):
self.api_key = api_key
self.endpoint = endpoint
self.api_version = api_version
self.provider = provider
self.model = model
self.dimensions = dimensions
self.max_tokens = max_tokens
self.tokenizer = self.get_tokenizer()
self.retry_count = 0
enable_mocking = os.getenv("MOCK_EMBEDDING", "false")
if isinstance(enable_mocking, bool):
enable_mocking = str(enable_mocking).lower()
self.mock = enable_mocking in ("true", "1", "yes")
@embedding_sleep_and_retry_async()
@embedding_rate_limit_async
async def embed_text(self, text: List[str]) -> List[List[float]]:
"""
Embed a list of text strings into vector representations.
If the input exceeds the model's context window, the method will recursively split the
input and combine the results. It handles both mock and live embedding scenarios,
logging errors for any encountered exceptions, and raising specific exceptions for
context window issues and embedding failures.
Parameters:
-----------
- text (List[str]): A list of strings to be embedded.
Returns:
--------
- List[List[float]]: A list of vectors representing the embedded texts.
"""
try:
if self.mock:
response = {"data": [{"embedding": [0.0] * self.dimensions} for _ in text]}
return [data["embedding"] for data in response["data"]]
else:
response = await litellm.aembedding(
model=self.model,
input=text,
api_key=self.api_key,
api_base=self.endpoint,
api_version=self.api_version,
)
return [data["embedding"] for data in response.data]
except litellm.exceptions.ContextWindowExceededError as error:
if isinstance(text, list) and len(text) > 1:
mid = math.ceil(len(text) / 2)
left, right = text[:mid], text[mid:]
left_vecs, right_vecs = await asyncio.gather(
self.embed_text(left),
self.embed_text(right),
)
return left_vecs + right_vecs
# If caller passed ONE oversize string split the string itself into
# half so we can process it
if isinstance(text, list) and len(text) == 1:
logger.debug(f"Pooling embeddings of text string with size: {len(text[0])}")
s = text[0]
third = len(s) // 3
# We are using thirds to intentionally have overlap between split parts
# for better embedding calculation
left_part, right_part = s[: third * 2], s[third:]
# Recursively embed the split parts in parallel
(left_vec,), (right_vec,) = await asyncio.gather(
self.embed_text([left_part]),
self.embed_text([right_part]),
)
# POOL the two embeddings into one
pooled = (np.array(left_vec) + np.array(right_vec)) / 2
return [pooled.tolist()]
logger.error("Context window exceeded for embedding text: %s", str(error))
raise error
except (
litellm.exceptions.BadRequestError,
litellm.exceptions.NotFoundError,
) as e:
logger.error(f"Embedding error with model {self.model}: {str(e)}")
raise EmbeddingException(f"Failed to index data points using model {self.model}")
except Exception as error:
logger.error("Error embedding text: %s", str(error))
raise error
def get_vector_size(self) -> int:
"""
Retrieve the dimensionality of the embedding vectors.
Returns:
--------
- int: The size (dimensionality) of the embedding vectors.
"""
return self.dimensions
def get_tokenizer(self):
"""
Load and return the appropriate tokenizer for the specified model based on the provider.
Returns:
--------
The tokenizer instance compatible with the model.
"""
logger.debug(f"Loading tokenizer for model {self.model}...")
# If model also contains provider information, extract only model information
model = self.model.split("/")[-1]
if "openai" in self.provider.lower():
tokenizer = TikTokenTokenizer(model=model, max_tokens=self.max_tokens)
elif "gemini" in self.provider.lower():
tokenizer = GeminiTokenizer(model=model, max_tokens=self.max_tokens)
elif "mistral" in self.provider.lower():
tokenizer = MistralTokenizer(model=model, max_tokens=self.max_tokens)
else:
try:
tokenizer = HuggingFaceTokenizer(model=self.model, max_tokens=self.max_tokens)
except Exception as e:
logger.warning(f"Could not get tokenizer from HuggingFace due to: {e}")
logger.info("Switching to TikToken default tokenizer.")
tokenizer = TikTokenTokenizer(model=None, max_tokens=self.max_tokens)
logger.debug(f"Tokenizer loaded for model: {self.model}")
return tokenizer