feat: implements mock feature in LiteLLM engine

This commit is contained in:
hajdul88 2024-12-18 14:19:32 +01:00
parent 75b98e0dc6
commit 5eaeebd14e

View file

@ -3,6 +3,7 @@ import logging
import math
from typing import List, Optional
import litellm
import os
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
litellm.set_verbose = False
@ -14,6 +15,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
api_version: str
model: str
dimensions: int
mock:bool
def __init__(
self,
@ -28,6 +30,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
self.api_version = api_version
self.model = model
self.dimensions = dimensions
self.mock = os.getenv("MOCK_EMBEDDING", False).lower() in ("true", "1", "yes")
MAX_RETRIES = 5
retry_count = 0
@ -38,17 +41,26 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
await asyncio.sleep(wait_time)
try:
response = await litellm.aembedding(
self.model,
input = text,
api_key = self.api_key,
api_base = self.endpoint,
api_version = self.api_version
)
if self.mock:
response = {
"data": [{"embedding": [0.0] * self.dimensions} for _ in text]
}
self.retry_count = 0
self.retry_count = 0
return [data["embedding"] for data in response.data]
return [data["embedding"] for data in response["data"]]
else:
response = await litellm.aembedding(
self.model,
input = text,
api_key = self.api_key,
api_base = self.endpoint,
api_version = self.api_version
)
self.retry_count = 0
return [data["embedding"] for data in response.data]
except litellm.exceptions.ContextWindowExceededError as error:
if isinstance(text, list):