Merge pull request #384 from topoteretes/feature/cog-919-implement-mock-embeddings-option

Feature/cog 919 implement mock embeddings option
This commit is contained in:
hajdul88 2024-12-18 15:00:40 +01:00 committed by GitHub
commit b3b8d8aca2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 26 additions and 10 deletions

View file

@ -3,6 +3,7 @@ import logging
import math
from typing import List, Optional
import litellm
import os
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
litellm.set_verbose = False
@ -14,6 +15,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
api_version: str
model: str
dimensions: int
mock:bool
def __init__(
self,
@ -29,6 +31,11 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
self.model = model
self.dimensions = dimensions
enable_mocking = os.getenv("MOCK_EMBEDDING", "false")
if isinstance(enable_mocking, bool):
enable_mocking= str(enable_mocking).lower()
self.mock = enable_mocking in ("true", "1", "yes")
MAX_RETRIES = 5
retry_count = 0
@ -38,17 +45,26 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
await asyncio.sleep(wait_time)
try:
response = await litellm.aembedding(
self.model,
input = text,
api_key = self.api_key,
api_base = self.endpoint,
api_version = self.api_version
)
if self.mock:
response = {
"data": [{"embedding": [0.0] * self.dimensions} for _ in text]
}
self.retry_count = 0
self.retry_count = 0
return [data["embedding"] for data in response.data]
return [data["embedding"] for data in response["data"]]
else:
response = await litellm.aembedding(
self.model,
input = text,
api_key = self.api_key,
api_base = self.endpoint,
api_version = self.api_version
)
self.retry_count = 0
return [data["embedding"] for data in response.data]
except litellm.exceptions.ContextWindowExceededError as error:
if isinstance(text, list):

View file

@ -73,7 +73,7 @@ async def get_repo_file_dependencies(repo_path: str) -> AsyncGenerator[list, Non
yield repo
with ProcessPoolExecutor(max_workers = 12) as executor:
with ProcessPoolExecutor() as executor:
loop = asyncio.get_event_loop()
tasks = [