cognee/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py

131 lines
4.5 KiB
Python

import asyncio
import logging
import math
from typing import List, Optional
import litellm
import os
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
from cognee.infrastructure.databases.exceptions.EmbeddingException import EmbeddingException
from cognee.infrastructure.llm.tokenizer.Gemini import GeminiTokenizer
from cognee.infrastructure.llm.tokenizer.HuggingFace import HuggingFaceTokenizer
from cognee.infrastructure.llm.tokenizer.TikToken import TikTokenTokenizer
litellm.set_verbose = False
logger = logging.getLogger("LiteLLMEmbeddingEngine")
class LiteLLMEmbeddingEngine(EmbeddingEngine):
api_key: str
endpoint: str
api_version: str
provider: str
model: str
dimensions: int
mock: bool
def __init__(
self,
provider: str = "openai",
model: Optional[str] = "text-embedding-3-large",
dimensions: Optional[int] = 3072,
api_key: str = None,
endpoint: str = None,
api_version: str = None,
max_tokens: int = float("inf"),
):
self.api_key = api_key
self.endpoint = endpoint
self.api_version = api_version
# TODO: Add or remove provider info
self.provider = provider
self.model = model
self.dimensions = dimensions
self.max_tokens = max_tokens
self.tokenizer = self.set_tokenizer()
enable_mocking = os.getenv("MOCK_EMBEDDING", "false")
if isinstance(enable_mocking, bool):
enable_mocking = str(enable_mocking).lower()
self.mock = enable_mocking in ("true", "1", "yes")
MAX_RETRIES = 5
retry_count = 0
async def embed_text(self, text: List[str]) -> List[List[float]]:
async def exponential_backoff(attempt):
wait_time = min(10 * (2**attempt), 60) # Max 60 seconds
await asyncio.sleep(wait_time)
try:
if self.mock:
response = {"data": [{"embedding": [0.0] * self.dimensions} for _ in text]}
self.retry_count = 0
return [data["embedding"] for data in response["data"]]
else:
response = await litellm.aembedding(
self.model,
input=text,
api_key=self.api_key,
api_base=self.endpoint,
api_version=self.api_version,
)
self.retry_count = 0
return [data["embedding"] for data in response.data]
except litellm.exceptions.ContextWindowExceededError as error:
if isinstance(text, list):
if len(text) == 1:
parts = [text]
else:
parts = [text[0 : math.ceil(len(text) / 2)], text[math.ceil(len(text) / 2) :]]
parts_futures = [self.embed_text(part) for part in parts]
embeddings = await asyncio.gather(*parts_futures)
all_embeddings = []
for embeddings_part in embeddings:
all_embeddings.extend(embeddings_part)
return all_embeddings
logger.error("Context window exceeded for embedding text: %s", str(error))
raise error
except litellm.exceptions.RateLimitError:
if self.retry_count >= self.MAX_RETRIES:
raise Exception("Rate limit exceeded and no more retries left.")
await exponential_backoff(self.retry_count)
self.retry_count += 1
return await self.embed_text(text)
except litellm.exceptions.BadRequestError:
raise EmbeddingException("Failed to index data points.")
except Exception as error:
logger.error("Error embedding text: %s", str(error))
raise error
def get_vector_size(self) -> int:
return self.dimensions
def set_tokenizer(self):
logger.debug(f"Loading tokenizer for model {self.model}...")
# If model also contains provider information, extract only model information
model = self.model.split("/")[-1]
if "openai" in self.provider.lower():
tokenizer = TikTokenTokenizer(model=model, max_tokens=self.max_tokens)
elif "gemini" in self.provider.lower():
tokenizer = GeminiTokenizer(model=model, max_tokens=self.max_tokens)
else:
tokenizer = HuggingFaceTokenizer(model=self.model, max_tokens=self.max_tokens)
logger.debug(f"Tokenizer loaded for model: {self.model}")
return tokenizer