feat: implements mock feature in LiteLLM engine
This commit is contained in:
parent
75b98e0dc6
commit
5eaeebd14e
1 changed files with 21 additions and 9 deletions
|
|
@ -3,6 +3,7 @@ import logging
|
||||||
import math
|
import math
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
import litellm
|
import litellm
|
||||||
|
import os
|
||||||
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
|
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
|
||||||
|
|
||||||
litellm.set_verbose = False
|
litellm.set_verbose = False
|
||||||
|
|
@ -14,6 +15,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
||||||
api_version: str
|
api_version: str
|
||||||
model: str
|
model: str
|
||||||
dimensions: int
|
dimensions: int
|
||||||
|
mock:bool
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -28,6 +30,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
||||||
self.api_version = api_version
|
self.api_version = api_version
|
||||||
self.model = model
|
self.model = model
|
||||||
self.dimensions = dimensions
|
self.dimensions = dimensions
|
||||||
|
self.mock = os.getenv("MOCK_EMBEDDING", False).lower() in ("true", "1", "yes")
|
||||||
|
|
||||||
MAX_RETRIES = 5
|
MAX_RETRIES = 5
|
||||||
retry_count = 0
|
retry_count = 0
|
||||||
|
|
@ -38,17 +41,26 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
||||||
await asyncio.sleep(wait_time)
|
await asyncio.sleep(wait_time)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await litellm.aembedding(
|
if self.mock:
|
||||||
self.model,
|
response = {
|
||||||
input = text,
|
"data": [{"embedding": [0.0] * self.dimensions} for _ in text]
|
||||||
api_key = self.api_key,
|
}
|
||||||
api_base = self.endpoint,
|
|
||||||
api_version = self.api_version
|
|
||||||
)
|
|
||||||
|
|
||||||
self.retry_count = 0
|
self.retry_count = 0
|
||||||
|
|
||||||
return [data["embedding"] for data in response.data]
|
return [data["embedding"] for data in response["data"]]
|
||||||
|
else:
|
||||||
|
response = await litellm.aembedding(
|
||||||
|
self.model,
|
||||||
|
input = text,
|
||||||
|
api_key = self.api_key,
|
||||||
|
api_base = self.endpoint,
|
||||||
|
api_version = self.api_version
|
||||||
|
)
|
||||||
|
|
||||||
|
self.retry_count = 0
|
||||||
|
|
||||||
|
return [data["embedding"] for data in response.data]
|
||||||
|
|
||||||
except litellm.exceptions.ContextWindowExceededError as error:
|
except litellm.exceptions.ContextWindowExceededError as error:
|
||||||
if isinstance(text, list):
|
if isinstance(text, list):
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue