added fixes to litellm
This commit is contained in:
parent
310e9e97ae
commit
27f2aa03b3
1 changed files with 60 additions and 10 deletions
|
|
@ -14,6 +14,8 @@ from tenacity import (
|
||||||
)
|
)
|
||||||
import litellm
|
import litellm
|
||||||
import os
|
import os
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
import httpx
|
||||||
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
|
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
|
||||||
from cognee.infrastructure.databases.exceptions import EmbeddingException
|
from cognee.infrastructure.databases.exceptions import EmbeddingException
|
||||||
from cognee.infrastructure.llm.tokenizer.HuggingFace import (
|
from cognee.infrastructure.llm.tokenizer.HuggingFace import (
|
||||||
|
|
@ -79,10 +81,26 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
||||||
enable_mocking = str(enable_mocking).lower()
|
enable_mocking = str(enable_mocking).lower()
|
||||||
self.mock = enable_mocking in ("true", "1", "yes")
|
self.mock = enable_mocking in ("true", "1", "yes")
|
||||||
|
|
||||||
|
# Validate provided custom embedding endpoint early to avoid long hangs later
|
||||||
|
if self.endpoint:
|
||||||
|
try:
|
||||||
|
parsed = urlparse(self.endpoint)
|
||||||
|
except Exception:
|
||||||
|
parsed = None
|
||||||
|
if not parsed or parsed.scheme not in ("http", "https") or not parsed.netloc:
|
||||||
|
logger.error(
|
||||||
|
"Invalid EMBEDDING_ENDPOINT configured: '%s'. Expected a URL starting with http:// or https://",
|
||||||
|
str(self.endpoint),
|
||||||
|
)
|
||||||
|
raise EmbeddingException(
|
||||||
|
"Invalid EMBEDDING_ENDPOINT. Please set a valid URL (e.g., https://host:port) "
|
||||||
|
"via environment variable EMBEDDING_ENDPOINT."
|
||||||
|
)
|
||||||
|
|
||||||
@retry(
|
@retry(
|
||||||
stop=stop_after_delay(128),
|
stop=stop_after_delay(30),
|
||||||
wait=wait_exponential_jitter(2, 128),
|
wait=wait_exponential_jitter(2, 128),
|
||||||
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
retry=retry_if_not_exception_type((litellm.exceptions.NotFoundError, EmbeddingException)),
|
||||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||||
reraise=True,
|
reraise=True,
|
||||||
)
|
)
|
||||||
|
|
@ -111,12 +129,16 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
||||||
return [data["embedding"] for data in response["data"]]
|
return [data["embedding"] for data in response["data"]]
|
||||||
else:
|
else:
|
||||||
async with embedding_rate_limiter_context_manager():
|
async with embedding_rate_limiter_context_manager():
|
||||||
response = await litellm.aembedding(
|
# Ensure each attempt does not hang indefinitely
|
||||||
model=self.model,
|
response = await asyncio.wait_for(
|
||||||
input=text,
|
litellm.aembedding(
|
||||||
api_key=self.api_key,
|
model=self.model,
|
||||||
api_base=self.endpoint,
|
input=text,
|
||||||
api_version=self.api_version,
|
api_key=self.api_key,
|
||||||
|
api_base=self.endpoint,
|
||||||
|
api_version=self.api_version,
|
||||||
|
),
|
||||||
|
timeout=30.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
return [data["embedding"] for data in response.data]
|
return [data["embedding"] for data in response.data]
|
||||||
|
|
@ -154,6 +176,27 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
||||||
logger.error("Context window exceeded for embedding text: %s", str(error))
|
logger.error("Context window exceeded for embedding text: %s", str(error))
|
||||||
raise error
|
raise error
|
||||||
|
|
||||||
|
except asyncio.TimeoutError as e:
|
||||||
|
# Per-attempt timeout – likely an unreachable endpoint
|
||||||
|
logger.error(
|
||||||
|
"Embedding endpoint timed out. EMBEDDING_ENDPOINT='%s'. "
|
||||||
|
"Verify that the endpoint is reachable and correct.",
|
||||||
|
str(self.endpoint),
|
||||||
|
)
|
||||||
|
raise EmbeddingException(
|
||||||
|
"Embedding request timed out. Check EMBEDDING_ENDPOINT connectivity."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
except (httpx.ConnectError, httpx.ReadTimeout) as e:
|
||||||
|
logger.error(
|
||||||
|
"Failed to connect to embedding endpoint. EMBEDDING_ENDPOINT='%s'. "
|
||||||
|
"Ensure the URL is correct and the server is running.",
|
||||||
|
str(self.endpoint),
|
||||||
|
)
|
||||||
|
raise EmbeddingException(
|
||||||
|
"Cannot connect to embedding endpoint. Check EMBEDDING_ENDPOINT."
|
||||||
|
) from e
|
||||||
|
|
||||||
except (
|
except (
|
||||||
litellm.exceptions.BadRequestError,
|
litellm.exceptions.BadRequestError,
|
||||||
litellm.exceptions.NotFoundError,
|
litellm.exceptions.NotFoundError,
|
||||||
|
|
@ -162,8 +205,15 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
||||||
raise EmbeddingException(f"Failed to index data points using model {self.model}") from e
|
raise EmbeddingException(f"Failed to index data points using model {self.model}") from e
|
||||||
|
|
||||||
except Exception as error:
|
except Exception as error:
|
||||||
logger.error("Error embedding text: %s", str(error))
|
# Fall back to a clear, actionable message for connectivity/misconfiguration issues
|
||||||
raise error
|
logger.error(
|
||||||
|
"Error embedding text: %s. EMBEDDING_ENDPOINT='%s'.",
|
||||||
|
str(error),
|
||||||
|
str(self.endpoint),
|
||||||
|
)
|
||||||
|
raise EmbeddingException(
|
||||||
|
"Embedding failed due to an unexpected error. Verify EMBEDDING_ENDPOINT and provider settings."
|
||||||
|
) from error
|
||||||
|
|
||||||
def get_vector_size(self) -> int:
|
def get_vector_size(self) -> int:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue