added fixes to litellm

This commit is contained in:
vasilije 2025-12-28 21:48:01 +01:00
parent 310e9e97ae
commit 27f2aa03b3

View file

@ -14,6 +14,8 @@ from tenacity import (
)
import litellm
import os
from urllib.parse import urlparse
import httpx
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
from cognee.infrastructure.databases.exceptions import EmbeddingException
from cognee.infrastructure.llm.tokenizer.HuggingFace import (
@ -79,10 +81,26 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
enable_mocking = str(enable_mocking).lower()
self.mock = enable_mocking in ("true", "1", "yes")
# Validate provided custom embedding endpoint early to avoid long hangs later
if self.endpoint:
try:
parsed = urlparse(self.endpoint)
except Exception:
parsed = None
if not parsed or parsed.scheme not in ("http", "https") or not parsed.netloc:
logger.error(
"Invalid EMBEDDING_ENDPOINT configured: '%s'. Expected a URL starting with http:// or https://",
str(self.endpoint),
)
raise EmbeddingException(
"Invalid EMBEDDING_ENDPOINT. Please set a valid URL (e.g., https://host:port) "
"via environment variable EMBEDDING_ENDPOINT."
)
@retry(
stop=stop_after_delay(128),
stop=stop_after_delay(30),
wait=wait_exponential_jitter(2, 128),
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
retry=retry_if_not_exception_type((litellm.exceptions.NotFoundError, EmbeddingException)),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
@ -111,12 +129,16 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
return [data["embedding"] for data in response["data"]]
else:
async with embedding_rate_limiter_context_manager():
response = await litellm.aembedding(
model=self.model,
input=text,
api_key=self.api_key,
api_base=self.endpoint,
api_version=self.api_version,
# Ensure each attempt does not hang indefinitely
response = await asyncio.wait_for(
litellm.aembedding(
model=self.model,
input=text,
api_key=self.api_key,
api_base=self.endpoint,
api_version=self.api_version,
),
timeout=30.0,
)
return [data["embedding"] for data in response.data]
@ -154,6 +176,27 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
logger.error("Context window exceeded for embedding text: %s", str(error))
raise error
except asyncio.TimeoutError as e:
# Per-attempt timeout likely an unreachable endpoint
logger.error(
"Embedding endpoint timed out. EMBEDDING_ENDPOINT='%s'. "
"Verify that the endpoint is reachable and correct.",
str(self.endpoint),
)
raise EmbeddingException(
"Embedding request timed out. Check EMBEDDING_ENDPOINT connectivity."
) from e
except (httpx.ConnectError, httpx.ReadTimeout) as e:
logger.error(
"Failed to connect to embedding endpoint. EMBEDDING_ENDPOINT='%s'. "
"Ensure the URL is correct and the server is running.",
str(self.endpoint),
)
raise EmbeddingException(
"Cannot connect to embedding endpoint. Check EMBEDDING_ENDPOINT."
) from e
except (
litellm.exceptions.BadRequestError,
litellm.exceptions.NotFoundError,
@ -162,8 +205,15 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
raise EmbeddingException(f"Failed to index data points using model {self.model}") from e
except Exception as error:
logger.error("Error embedding text: %s", str(error))
raise error
# Fall back to a clear, actionable message for connectivity/misconfiguration issues
logger.error(
"Error embedding text: %s. EMBEDDING_ENDPOINT='%s'.",
str(error),
str(self.endpoint),
)
raise EmbeddingException(
"Embedding failed due to an unexpected error. Verify EMBEDDING_ENDPOINT and provider settings."
) from error
def get_vector_size(self) -> int:
"""