added fixes to litellm

This commit is contained in:
vasilije 2025-12-28 21:48:01 +01:00
parent 310e9e97ae
commit 27f2aa03b3

View file

@ -14,6 +14,8 @@ from tenacity import (
) )
import litellm import litellm
import os import os
from urllib.parse import urlparse
import httpx
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
from cognee.infrastructure.databases.exceptions import EmbeddingException from cognee.infrastructure.databases.exceptions import EmbeddingException
from cognee.infrastructure.llm.tokenizer.HuggingFace import ( from cognee.infrastructure.llm.tokenizer.HuggingFace import (
@ -79,10 +81,26 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
enable_mocking = str(enable_mocking).lower() enable_mocking = str(enable_mocking).lower()
self.mock = enable_mocking in ("true", "1", "yes") self.mock = enable_mocking in ("true", "1", "yes")
# Validate provided custom embedding endpoint early to avoid long hangs later
if self.endpoint:
try:
parsed = urlparse(self.endpoint)
except Exception:
parsed = None
if not parsed or parsed.scheme not in ("http", "https") or not parsed.netloc:
logger.error(
"Invalid EMBEDDING_ENDPOINT configured: '%s'. Expected a URL starting with http:// or https://",
str(self.endpoint),
)
raise EmbeddingException(
"Invalid EMBEDDING_ENDPOINT. Please set a valid URL (e.g., https://host:port) "
"via environment variable EMBEDDING_ENDPOINT."
)
@retry( @retry(
stop=stop_after_delay(128), stop=stop_after_delay(30),
wait=wait_exponential_jitter(2, 128), wait=wait_exponential_jitter(2, 128),
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError), retry=retry_if_not_exception_type((litellm.exceptions.NotFoundError, EmbeddingException)),
before_sleep=before_sleep_log(logger, logging.DEBUG), before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True, reraise=True,
) )
@ -111,12 +129,16 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
return [data["embedding"] for data in response["data"]] return [data["embedding"] for data in response["data"]]
else: else:
async with embedding_rate_limiter_context_manager(): async with embedding_rate_limiter_context_manager():
response = await litellm.aembedding( # Ensure each attempt does not hang indefinitely
model=self.model, response = await asyncio.wait_for(
input=text, litellm.aembedding(
api_key=self.api_key, model=self.model,
api_base=self.endpoint, input=text,
api_version=self.api_version, api_key=self.api_key,
api_base=self.endpoint,
api_version=self.api_version,
),
timeout=30.0,
) )
return [data["embedding"] for data in response.data] return [data["embedding"] for data in response.data]
@ -154,6 +176,27 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
logger.error("Context window exceeded for embedding text: %s", str(error)) logger.error("Context window exceeded for embedding text: %s", str(error))
raise error raise error
except asyncio.TimeoutError as e:
# Per-attempt timeout likely an unreachable endpoint
logger.error(
"Embedding endpoint timed out. EMBEDDING_ENDPOINT='%s'. "
"Verify that the endpoint is reachable and correct.",
str(self.endpoint),
)
raise EmbeddingException(
"Embedding request timed out. Check EMBEDDING_ENDPOINT connectivity."
) from e
except (httpx.ConnectError, httpx.ReadTimeout) as e:
logger.error(
"Failed to connect to embedding endpoint. EMBEDDING_ENDPOINT='%s'. "
"Ensure the URL is correct and the server is running.",
str(self.endpoint),
)
raise EmbeddingException(
"Cannot connect to embedding endpoint. Check EMBEDDING_ENDPOINT."
) from e
except ( except (
litellm.exceptions.BadRequestError, litellm.exceptions.BadRequestError,
litellm.exceptions.NotFoundError, litellm.exceptions.NotFoundError,
@ -162,8 +205,15 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
raise EmbeddingException(f"Failed to index data points using model {self.model}") from e raise EmbeddingException(f"Failed to index data points using model {self.model}") from e
except Exception as error: except Exception as error:
logger.error("Error embedding text: %s", str(error)) # Fall back to a clear, actionable message for connectivity/misconfiguration issues
raise error logger.error(
"Error embedding text: %s. EMBEDDING_ENDPOINT='%s'.",
str(error),
str(self.endpoint),
)
raise EmbeddingException(
"Embedding failed due to an unexpected error. Verify EMBEDDING_ENDPOINT and provider settings."
) from error
def get_vector_size(self) -> int: def get_vector_size(self) -> int:
""" """