cognee/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py
2026-01-08 12:45:03 +01:00

283 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import logging
from cognee.shared.logging_utils import get_logger
from typing import List, Optional
import numpy as np
import math
from tenacity import (
retry,
stop_after_delay,
wait_exponential_jitter,
retry_if_not_exception_type,
before_sleep_log,
)
import litellm
import os
from urllib.parse import urlparse
import httpx
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
from cognee.infrastructure.databases.exceptions import EmbeddingException
from cognee.infrastructure.llm.tokenizer.HuggingFace import (
HuggingFaceTokenizer,
)
from cognee.infrastructure.llm.tokenizer.Mistral import (
MistralTokenizer,
)
from cognee.infrastructure.llm.tokenizer.TikToken import (
TikTokenTokenizer,
)
from cognee.shared.rate_limiting import embedding_rate_limiter_context_manager
litellm.set_verbose = False
logger = get_logger("LiteLLMEmbeddingEngine")
class LiteLLMEmbeddingEngine(EmbeddingEngine):
"""
Engine for embedding text using a specific LLM model, supporting mock and actual
embedding calls.
Public methods:
- embed_text: Embed a list of strings into vector representations.
- get_vector_size: Retrieve the size of the embedding vectors.
- get_tokenizer: Load the appropriate tokenizer for the specified model.
"""
api_key: str
endpoint: str
api_version: str
provider: str
model: str
dimensions: int
mock: bool
MAX_RETRIES = 5
def __init__(
self,
model: Optional[str] = "openai/text-embedding-3-large",
provider: str = "openai",
dimensions: Optional[int] = 3072,
api_key: str = None,
endpoint: str = None,
api_version: str = None,
max_completion_tokens: int = 512,
batch_size: int = 100,
):
self.api_key = api_key
self.endpoint = endpoint
self.api_version = api_version
self.provider = provider
self.model = model
self.dimensions = dimensions
self.max_completion_tokens = max_completion_tokens
self.tokenizer = self.get_tokenizer()
self.retry_count = 0
self.batch_size = batch_size
enable_mocking = os.getenv("MOCK_EMBEDDING", "false")
if isinstance(enable_mocking, bool):
enable_mocking = str(enable_mocking).lower()
self.mock = enable_mocking in ("true", "1", "yes")
# Validate provided custom embedding endpoint early to avoid long hangs later
if self.endpoint:
try:
parsed = urlparse(self.endpoint)
except Exception:
parsed = None
if not parsed or parsed.scheme not in ("http", "https") or not parsed.netloc:
logger.error(
"Invalid EMBEDDING_ENDPOINT configured: '%s'. Expected a URL starting with http:// or https://",
str(self.endpoint),
)
raise EmbeddingException(
"Invalid EMBEDDING_ENDPOINT. Please set a valid URL (e.g., https://host:port) "
"via environment variable EMBEDDING_ENDPOINT."
)
@retry(
stop=stop_after_delay(30),
wait=wait_exponential_jitter(2, 128),
retry=retry_if_not_exception_type((litellm.exceptions.NotFoundError, EmbeddingException)),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
async def embed_text(self, text: List[str]) -> List[List[float]]:
"""
Embed a list of text strings into vector representations.
If the input exceeds the model's context window, the method will recursively split the
input and combine the results. It handles both mock and live embedding scenarios,
logging errors for any encountered exceptions, and raising specific exceptions for
context window issues and embedding failures.
Parameters:
-----------
- text (List[str]): A list of strings to be embedded.
Returns:
--------
- List[List[float]]: A list of vectors representing the embedded texts.
"""
try:
if self.mock:
response = {"data": [{"embedding": [0.0] * self.dimensions} for _ in text]}
return [data["embedding"] for data in response["data"]]
else:
async with embedding_rate_limiter_context_manager():
# Ensure each attempt does not hang indefinitely
response = await asyncio.wait_for(
litellm.aembedding(
model=self.model,
input=text,
api_key=self.api_key,
api_base=self.endpoint,
api_version=self.api_version,
),
timeout=30.0,
)
return [data["embedding"] for data in response.data]
except litellm.exceptions.ContextWindowExceededError as error:
if isinstance(text, list) and len(text) > 1:
mid = math.ceil(len(text) / 2)
left, right = text[:mid], text[mid:]
left_vecs, right_vecs = await asyncio.gather(
self.embed_text(left),
self.embed_text(right),
)
return left_vecs + right_vecs
# If caller passed ONE oversize string split the string itself into
# half so we can process it
if isinstance(text, list) and len(text) == 1:
logger.debug(f"Pooling embeddings of text string with size: {len(text[0])}")
s = text[0]
third = len(s) // 3
# We are using thirds to intentionally have overlap between split parts
# for better embedding calculation
left_part, right_part = s[: third * 2], s[third:]
# Recursively embed the split parts in parallel
(left_vec,), (right_vec,) = await asyncio.gather(
self.embed_text([left_part]),
self.embed_text([right_part]),
)
# POOL the two embeddings into one
pooled = (np.array(left_vec) + np.array(right_vec)) / 2
return [pooled.tolist()]
logger.error("Context window exceeded for embedding text: %s", str(error))
raise error
except asyncio.TimeoutError as e:
# Per-attempt timeout likely an unreachable endpoint
logger.error(
"Embedding endpoint timed out. EMBEDDING_ENDPOINT='%s'. "
"Verify that the endpoint is reachable and correct.",
str(self.endpoint),
)
raise EmbeddingException(
"Embedding request timed out. Check EMBEDDING_ENDPOINT connectivity."
) from e
except (httpx.ConnectError, httpx.ReadTimeout) as e:
logger.error(
"Failed to connect to embedding endpoint. EMBEDDING_ENDPOINT='%s'. "
"Ensure the URL is correct and the server is running.",
str(self.endpoint),
)
raise EmbeddingException(
"Cannot connect to embedding endpoint. Check EMBEDDING_ENDPOINT."
) from e
except (
litellm.exceptions.BadRequestError,
litellm.exceptions.NotFoundError,
) as e:
logger.error(f"Embedding error with model {self.model}: {str(e)}")
raise EmbeddingException(f"Failed to index data points using model {self.model}") from e
except Exception as error:
# Fall back to a clear, actionable message for connectivity/misconfiguration issues
logger.error(
"Error embedding text: %s. EMBEDDING_ENDPOINT='%s'.",
str(error),
str(self.endpoint),
)
raise EmbeddingException(
"Embedding failed due to an unexpected error. Verify EMBEDDING_ENDPOINT and provider settings."
) from error
def get_vector_size(self) -> int:
"""
Retrieve the dimensionality of the embedding vectors.
Returns:
--------
- int: The size (dimensionality) of the embedding vectors.
"""
return self.dimensions
def get_batch_size(self) -> int:
"""
Return the desired batch size for embedding calls
Returns:
"""
return self.batch_size
def get_tokenizer(self):
"""
Load and return the appropriate tokenizer for the specified model based on the provider.
Returns:
--------
The tokenizer instance compatible with the model.
"""
logger.debug(f"Loading tokenizer for model {self.model}...")
# If model also contains provider information, extract only model information
model = self.model.split("/")[-1]
if "openai" in self.provider.lower():
tokenizer = TikTokenTokenizer(
model=model, max_completion_tokens=self.max_completion_tokens
)
elif "gemini" in self.provider.lower():
# Since Gemini tokenization needs to send an API request to get the token count we will use TikToken to
# count tokens as we calculate tokens word by word
tokenizer = TikTokenTokenizer(
model=None, max_completion_tokens=self.max_completion_tokens
)
# Note: Gemini Tokenizer expects an LLM model as input and not the embedding model
# tokenizer = GeminiTokenizer(
# llm_model=llm_model, max_completion_tokens=self.max_completion_tokens
# )
elif "mistral" in self.provider.lower():
tokenizer = MistralTokenizer(
model=model, max_completion_tokens=self.max_completion_tokens
)
else:
try:
tokenizer = HuggingFaceTokenizer(
model=self.model.replace("hosted_vllm/", ""),
max_completion_tokens=self.max_completion_tokens,
)
except Exception as e:
logger.warning(f"Could not get tokenizer from HuggingFace due to: {e}")
logger.info("Switching to TikToken default tokenizer.")
tokenizer = TikTokenTokenizer(
model=None, max_completion_tokens=self.max_completion_tokens
)
logger.debug(f"Tokenizer loaded for model: {self.model}")
return tokenizer