From 27f2aa03b3a0d1cb89dec8bbea1c4e39f219531b Mon Sep 17 00:00:00 2001 From: vasilije Date: Sun, 28 Dec 2025 21:48:01 +0100 Subject: [PATCH] added fixes to litellm --- .../embeddings/LiteLLMEmbeddingEngine.py | 70 ++++++++++++++++--- 1 file changed, 60 insertions(+), 10 deletions(-) diff --git a/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py b/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py index 12de57617..558b11538 100644 --- a/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +++ b/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py @@ -14,6 +14,8 @@ from tenacity import ( ) import litellm import os +from urllib.parse import urlparse +import httpx from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine from cognee.infrastructure.databases.exceptions import EmbeddingException from cognee.infrastructure.llm.tokenizer.HuggingFace import ( @@ -79,10 +81,26 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine): enable_mocking = str(enable_mocking).lower() self.mock = enable_mocking in ("true", "1", "yes") + # Validate provided custom embedding endpoint early to avoid long hangs later + if self.endpoint: + try: + parsed = urlparse(self.endpoint) + except Exception: + parsed = None + if not parsed or parsed.scheme not in ("http", "https") or not parsed.netloc: + logger.error( + "Invalid EMBEDDING_ENDPOINT configured: '%s'. Expected a URL starting with http:// or https://", + str(self.endpoint), + ) + raise EmbeddingException( + "Invalid EMBEDDING_ENDPOINT. Please set a valid URL (e.g., https://host:port) " + "via environment variable EMBEDDING_ENDPOINT." + ) + @retry( - stop=stop_after_delay(128), + stop=stop_after_delay(30), wait=wait_exponential_jitter(2, 128), - retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError), + retry=retry_if_not_exception_type((litellm.exceptions.NotFoundError, EmbeddingException)), before_sleep=before_sleep_log(logger, logging.DEBUG), reraise=True, ) @@ -111,12 +129,16 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine): return [data["embedding"] for data in response["data"]] else: async with embedding_rate_limiter_context_manager(): - response = await litellm.aembedding( - model=self.model, - input=text, - api_key=self.api_key, - api_base=self.endpoint, - api_version=self.api_version, + # Ensure each attempt does not hang indefinitely + response = await asyncio.wait_for( + litellm.aembedding( + model=self.model, + input=text, + api_key=self.api_key, + api_base=self.endpoint, + api_version=self.api_version, + ), + timeout=30.0, ) return [data["embedding"] for data in response.data] @@ -154,6 +176,27 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine): logger.error("Context window exceeded for embedding text: %s", str(error)) raise error + except asyncio.TimeoutError as e: + # Per-attempt timeout – likely an unreachable endpoint + logger.error( + "Embedding endpoint timed out. EMBEDDING_ENDPOINT='%s'. " + "Verify that the endpoint is reachable and correct.", + str(self.endpoint), + ) + raise EmbeddingException( + "Embedding request timed out. Check EMBEDDING_ENDPOINT connectivity." + ) from e + + except (httpx.ConnectError, httpx.ReadTimeout) as e: + logger.error( + "Failed to connect to embedding endpoint. EMBEDDING_ENDPOINT='%s'. " + "Ensure the URL is correct and the server is running.", + str(self.endpoint), + ) + raise EmbeddingException( + "Cannot connect to embedding endpoint. Check EMBEDDING_ENDPOINT." + ) from e + except ( litellm.exceptions.BadRequestError, litellm.exceptions.NotFoundError, @@ -162,8 +205,15 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine): raise EmbeddingException(f"Failed to index data points using model {self.model}") from e except Exception as error: - logger.error("Error embedding text: %s", str(error)) - raise error + # Fall back to a clear, actionable message for connectivity/misconfiguration issues + logger.error( + "Error embedding text: %s. EMBEDDING_ENDPOINT='%s'.", + str(error), + str(self.endpoint), + ) + raise EmbeddingException( + "Embedding failed due to an unexpected error. Verify EMBEDDING_ENDPOINT and provider settings." + ) from error def get_vector_size(self) -> int: """