feat: implements mock feature in LiteLLM engine

This commit is contained in:
hajdul88 2024-12-18 14:19:32 +01:00
parent 75b98e0dc6
commit 5eaeebd14e

View file

@ -3,6 +3,7 @@ import logging
import math import math
from typing import List, Optional from typing import List, Optional
import litellm import litellm
import os
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
litellm.set_verbose = False litellm.set_verbose = False
@ -14,6 +15,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
api_version: str api_version: str
model: str model: str
dimensions: int dimensions: int
mock:bool
def __init__( def __init__(
self, self,
@ -28,6 +30,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
self.api_version = api_version self.api_version = api_version
self.model = model self.model = model
self.dimensions = dimensions self.dimensions = dimensions
self.mock = os.getenv("MOCK_EMBEDDING", False).lower() in ("true", "1", "yes")
MAX_RETRIES = 5 MAX_RETRIES = 5
retry_count = 0 retry_count = 0
@ -38,17 +41,26 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
await asyncio.sleep(wait_time) await asyncio.sleep(wait_time)
try: try:
response = await litellm.aembedding( if self.mock:
self.model, response = {
input = text, "data": [{"embedding": [0.0] * self.dimensions} for _ in text]
api_key = self.api_key, }
api_base = self.endpoint,
api_version = self.api_version
)
self.retry_count = 0 self.retry_count = 0
return [data["embedding"] for data in response.data] return [data["embedding"] for data in response["data"]]
else:
response = await litellm.aembedding(
self.model,
input = text,
api_key = self.api_key,
api_base = self.endpoint,
api_version = self.api_version
)
self.retry_count = 0
return [data["embedding"] for data in response.data]
except litellm.exceptions.ContextWindowExceededError as error: except litellm.exceptions.ContextWindowExceededError as error:
if isinstance(text, list): if isinstance(text, list):