Merge pull request #384 from topoteretes/feature/cog-919-implement-mock-embeddings-option

Feature/cog 919 implement mock embeddings option
This commit is contained in:
hajdul88 2024-12-18 15:00:40 +01:00 committed by GitHub
commit b3b8d8aca2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 26 additions and 10 deletions

View file

@ -3,6 +3,7 @@ import logging
import math import math
from typing import List, Optional from typing import List, Optional
import litellm import litellm
import os
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
litellm.set_verbose = False litellm.set_verbose = False
@ -14,6 +15,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
api_version: str api_version: str
model: str model: str
dimensions: int dimensions: int
mock:bool
def __init__( def __init__(
self, self,
@ -29,6 +31,11 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
self.model = model self.model = model
self.dimensions = dimensions self.dimensions = dimensions
enable_mocking = os.getenv("MOCK_EMBEDDING", "false")
if isinstance(enable_mocking, bool):
enable_mocking= str(enable_mocking).lower()
self.mock = enable_mocking in ("true", "1", "yes")
MAX_RETRIES = 5 MAX_RETRIES = 5
retry_count = 0 retry_count = 0
@ -38,6 +45,15 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
await asyncio.sleep(wait_time) await asyncio.sleep(wait_time)
try: try:
if self.mock:
response = {
"data": [{"embedding": [0.0] * self.dimensions} for _ in text]
}
self.retry_count = 0
return [data["embedding"] for data in response["data"]]
else:
response = await litellm.aembedding( response = await litellm.aembedding(
self.model, self.model,
input = text, input = text,

View file

@ -73,7 +73,7 @@ async def get_repo_file_dependencies(repo_path: str) -> AsyncGenerator[list, Non
yield repo yield repo
with ProcessPoolExecutor(max_workers = 12) as executor: with ProcessPoolExecutor() as executor:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
tasks = [ tasks = [