diff --git a/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py b/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py index d68941d25..2a71d674d 100644 --- a/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +++ b/cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py @@ -1,8 +1,17 @@ import asyncio +import logging + from cognee.shared.logging_utils import get_logger from typing import List, Optional import numpy as np import math +from tenacity import ( + retry, + stop_after_delay, + wait_exponential_jitter, + retry_if_not_exception_type, + before_sleep_log, +) import litellm import os from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine @@ -76,8 +85,13 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine): enable_mocking = str(enable_mocking).lower() self.mock = enable_mocking in ("true", "1", "yes") - @embedding_sleep_and_retry_async() - @embedding_rate_limit_async + @retry( + stop=stop_after_delay(180), + wait=wait_exponential_jitter(1, 180), + retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError), + before_sleep=before_sleep_log(logger, logging.DEBUG), + reraise=True, + ) async def embed_text(self, text: List[str]) -> List[List[float]]: """ Embed a list of text strings into vector representations. diff --git a/cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py b/cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py index e79ba3f6a..b8ee9c7df 100644 --- a/cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py +++ b/cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py @@ -3,8 +3,16 @@ from cognee.shared.logging_utils import get_logger import aiohttp from typing import List, Optional import os - +import litellm +import logging import aiohttp.http_exceptions +from tenacity import ( + retry, + stop_after_delay, + wait_exponential_jitter, + retry_if_not_exception_type, + before_sleep_log, +) from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine from cognee.infrastructure.llm.tokenizer.HuggingFace import ( @@ -69,7 +77,6 @@ class OllamaEmbeddingEngine(EmbeddingEngine): enable_mocking = str(enable_mocking).lower() self.mock = enable_mocking in ("true", "1", "yes") - @embedding_rate_limit_async async def embed_text(self, text: List[str]) -> List[List[float]]: """ Generate embedding vectors for a list of text prompts. @@ -92,7 +99,13 @@ class OllamaEmbeddingEngine(EmbeddingEngine): embeddings = await asyncio.gather(*[self._get_embedding(prompt) for prompt in text]) return embeddings - @embedding_sleep_and_retry_async() + @retry( + stop=stop_after_delay(180), + wait=wait_exponential_jitter(1, 180), + retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError), + before_sleep=before_sleep_log(logger, logging.DEBUG), + reraise=True, + ) async def _get_embedding(self, prompt: str) -> List[float]: """ Internal method to call the Ollama embeddings endpoint for a single prompt. diff --git a/poetry.lock b/poetry.lock index 551295733..ffc5ec575 100644 --- a/poetry.lock +++ b/poetry.lock @@ -12738,4 +12738,4 @@ posthog = ["posthog"] [metadata] lock-version = "2.1" python-versions = ">=3.10,<=3.13" -content-hash = "38353807b06e5c06caaa107979529937b978204f0f405c6b38cee283f4a49d3c" +content-hash = "d8cd8a8db46416e0c844ff90df5bd64551ebf9a0c338fbb2023a61008ff5941d" diff --git a/pyproject.toml b/pyproject.toml index 3df57e1f5..7ac2915d4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,7 +54,8 @@ dependencies = [ "networkx>=3.4.2,<4", "uvicorn>=0.34.0,<1.0.0", "gunicorn>=20.1.0,<24", - "websockets>=15.0.1,<16.0.0" + "websockets>=15.0.1,<16.0.0", + "tenacity>=9.0.0", ] [project.optional-dependencies] diff --git a/uv.lock b/uv.lock index 570da9289..5c06b96be 100644 --- a/uv.lock +++ b/uv.lock @@ -856,7 +856,7 @@ wheels = [ [[package]] name = "cognee" -version = "0.3.4" +version = "0.3.5" source = { editable = "." } dependencies = [ { name = "aiofiles" }, @@ -892,6 +892,7 @@ dependencies = [ { name = "rdflib" }, { name = "sqlalchemy" }, { name = "structlog" }, + { name = "tenacity" }, { name = "tiktoken" }, { name = "typing-extensions" }, { name = "uvicorn" }, @@ -1086,6 +1087,7 @@ requires-dist = [ { name = "sentry-sdk", extras = ["fastapi"], marker = "extra == 'monitoring'", specifier = ">=2.9.0,<3" }, { name = "sqlalchemy", specifier = ">=2.0.39,<3.0.0" }, { name = "structlog", specifier = ">=25.2.0,<26" }, + { name = "tenacity", specifier = ">=9.0.0" }, { name = "tiktoken", specifier = ">=0.8.0,<1.0.0" }, { name = "transformers", marker = "extra == 'codegraph'", specifier = ">=4.46.3,<5" }, { name = "transformers", marker = "extra == 'huggingface'", specifier = ">=4.46.3,<5" },