Merge pull request #384 from topoteretes/feature/cog-919-implement-mock-embeddings-option
Feature/cog 919 implement mock embeddings option
This commit is contained in:
commit
b3b8d8aca2
2 changed files with 26 additions and 10 deletions
|
|
@ -3,6 +3,7 @@ import logging
|
||||||
import math
|
import math
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
import litellm
|
import litellm
|
||||||
|
import os
|
||||||
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
|
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
|
||||||
|
|
||||||
litellm.set_verbose = False
|
litellm.set_verbose = False
|
||||||
|
|
@ -14,6 +15,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
||||||
api_version: str
|
api_version: str
|
||||||
model: str
|
model: str
|
||||||
dimensions: int
|
dimensions: int
|
||||||
|
mock:bool
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -29,6 +31,11 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.dimensions = dimensions
|
self.dimensions = dimensions
|
||||||
|
|
||||||
|
enable_mocking = os.getenv("MOCK_EMBEDDING", "false")
|
||||||
|
if isinstance(enable_mocking, bool):
|
||||||
|
enable_mocking= str(enable_mocking).lower()
|
||||||
|
self.mock = enable_mocking in ("true", "1", "yes")
|
||||||
|
|
||||||
MAX_RETRIES = 5
|
MAX_RETRIES = 5
|
||||||
retry_count = 0
|
retry_count = 0
|
||||||
|
|
||||||
|
|
@ -38,17 +45,26 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
||||||
await asyncio.sleep(wait_time)
|
await asyncio.sleep(wait_time)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await litellm.aembedding(
|
if self.mock:
|
||||||
self.model,
|
response = {
|
||||||
input = text,
|
"data": [{"embedding": [0.0] * self.dimensions} for _ in text]
|
||||||
api_key = self.api_key,
|
}
|
||||||
api_base = self.endpoint,
|
|
||||||
api_version = self.api_version
|
|
||||||
)
|
|
||||||
|
|
||||||
self.retry_count = 0
|
self.retry_count = 0
|
||||||
|
|
||||||
return [data["embedding"] for data in response.data]
|
return [data["embedding"] for data in response["data"]]
|
||||||
|
else:
|
||||||
|
response = await litellm.aembedding(
|
||||||
|
self.model,
|
||||||
|
input = text,
|
||||||
|
api_key = self.api_key,
|
||||||
|
api_base = self.endpoint,
|
||||||
|
api_version = self.api_version
|
||||||
|
)
|
||||||
|
|
||||||
|
self.retry_count = 0
|
||||||
|
|
||||||
|
return [data["embedding"] for data in response.data]
|
||||||
|
|
||||||
except litellm.exceptions.ContextWindowExceededError as error:
|
except litellm.exceptions.ContextWindowExceededError as error:
|
||||||
if isinstance(text, list):
|
if isinstance(text, list):
|
||||||
|
|
|
||||||
|
|
@ -73,7 +73,7 @@ async def get_repo_file_dependencies(repo_path: str) -> AsyncGenerator[list, Non
|
||||||
|
|
||||||
yield repo
|
yield repo
|
||||||
|
|
||||||
with ProcessPoolExecutor(max_workers = 12) as executor:
|
with ProcessPoolExecutor() as executor:
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
tasks = [
|
tasks = [
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue