Merge pull request #384 from topoteretes/feature/cog-919-implement-mock-embeddings-option
Feature/cog 919 implement mock embeddings option
This commit is contained in:
commit
b3b8d8aca2
2 changed files with 26 additions and 10 deletions
|
|
@ -3,6 +3,7 @@ import logging
|
|||
import math
|
||||
from typing import List, Optional
|
||||
import litellm
|
||||
import os
|
||||
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
|
||||
|
||||
litellm.set_verbose = False
|
||||
|
|
@ -14,6 +15,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
|||
api_version: str
|
||||
model: str
|
||||
dimensions: int
|
||||
mock:bool
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -29,6 +31,11 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
|||
self.model = model
|
||||
self.dimensions = dimensions
|
||||
|
||||
enable_mocking = os.getenv("MOCK_EMBEDDING", "false")
|
||||
if isinstance(enable_mocking, bool):
|
||||
enable_mocking= str(enable_mocking).lower()
|
||||
self.mock = enable_mocking in ("true", "1", "yes")
|
||||
|
||||
MAX_RETRIES = 5
|
||||
retry_count = 0
|
||||
|
||||
|
|
@ -38,17 +45,26 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
|||
await asyncio.sleep(wait_time)
|
||||
|
||||
try:
|
||||
response = await litellm.aembedding(
|
||||
self.model,
|
||||
input = text,
|
||||
api_key = self.api_key,
|
||||
api_base = self.endpoint,
|
||||
api_version = self.api_version
|
||||
)
|
||||
if self.mock:
|
||||
response = {
|
||||
"data": [{"embedding": [0.0] * self.dimensions} for _ in text]
|
||||
}
|
||||
|
||||
self.retry_count = 0
|
||||
self.retry_count = 0
|
||||
|
||||
return [data["embedding"] for data in response.data]
|
||||
return [data["embedding"] for data in response["data"]]
|
||||
else:
|
||||
response = await litellm.aembedding(
|
||||
self.model,
|
||||
input = text,
|
||||
api_key = self.api_key,
|
||||
api_base = self.endpoint,
|
||||
api_version = self.api_version
|
||||
)
|
||||
|
||||
self.retry_count = 0
|
||||
|
||||
return [data["embedding"] for data in response.data]
|
||||
|
||||
except litellm.exceptions.ContextWindowExceededError as error:
|
||||
if isinstance(text, list):
|
||||
|
|
|
|||
|
|
@ -73,7 +73,7 @@ async def get_repo_file_dependencies(repo_path: str) -> AsyncGenerator[list, Non
|
|||
|
||||
yield repo
|
||||
|
||||
with ProcessPoolExecutor(max_workers = 12) as executor:
|
||||
with ProcessPoolExecutor() as executor:
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
tasks = [
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue