feat: implements mock feature in LiteLLM engine

This commit is contained in:
hajdul88 2024-12-18 14:19:32 +01:00
parent 75b98e0dc6
commit 5eaeebd14e

View file

@ -3,6 +3,7 @@ import logging
import math import math
from typing import List, Optional from typing import List, Optional
import litellm import litellm
import os
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
litellm.set_verbose = False litellm.set_verbose = False
@ -14,6 +15,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
api_version: str api_version: str
model: str model: str
dimensions: int dimensions: int
mock:bool
def __init__( def __init__(
self, self,
@ -28,6 +30,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
self.api_version = api_version self.api_version = api_version
self.model = model self.model = model
self.dimensions = dimensions self.dimensions = dimensions
self.mock = os.getenv("MOCK_EMBEDDING", False).lower() in ("true", "1", "yes")
MAX_RETRIES = 5 MAX_RETRIES = 5
retry_count = 0 retry_count = 0
@ -38,6 +41,15 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
await asyncio.sleep(wait_time) await asyncio.sleep(wait_time)
try: try:
if self.mock:
response = {
"data": [{"embedding": [0.0] * self.dimensions} for _ in text]
}
self.retry_count = 0
return [data["embedding"] for data in response["data"]]
else:
response = await litellm.aembedding( response = await litellm.aembedding(
self.model, self.model,
input = text, input = text,