refactor: Add tenacity retry mechanism
This commit is contained in:
parent
84a23756f5
commit
98daadbb04
5 changed files with 38 additions and 8 deletions
|
|
@ -1,8 +1,17 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import logging
|
||||||
|
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import math
|
import math
|
||||||
|
from tenacity import (
|
||||||
|
retry,
|
||||||
|
stop_after_delay,
|
||||||
|
wait_exponential_jitter,
|
||||||
|
retry_if_not_exception_type,
|
||||||
|
before_sleep_log,
|
||||||
|
)
|
||||||
import litellm
|
import litellm
|
||||||
import os
|
import os
|
||||||
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
|
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
|
||||||
|
|
@ -76,8 +85,13 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
||||||
enable_mocking = str(enable_mocking).lower()
|
enable_mocking = str(enable_mocking).lower()
|
||||||
self.mock = enable_mocking in ("true", "1", "yes")
|
self.mock = enable_mocking in ("true", "1", "yes")
|
||||||
|
|
||||||
@embedding_sleep_and_retry_async()
|
@retry(
|
||||||
@embedding_rate_limit_async
|
stop=stop_after_delay(180),
|
||||||
|
wait=wait_exponential_jitter(1, 180),
|
||||||
|
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
||||||
|
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||||
|
reraise=True,
|
||||||
|
)
|
||||||
async def embed_text(self, text: List[str]) -> List[List[float]]:
|
async def embed_text(self, text: List[str]) -> List[List[float]]:
|
||||||
"""
|
"""
|
||||||
Embed a list of text strings into vector representations.
|
Embed a list of text strings into vector representations.
|
||||||
|
|
|
||||||
|
|
@ -3,8 +3,16 @@ from cognee.shared.logging_utils import get_logger
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
import os
|
import os
|
||||||
|
import litellm
|
||||||
|
import logging
|
||||||
import aiohttp.http_exceptions
|
import aiohttp.http_exceptions
|
||||||
|
from tenacity import (
|
||||||
|
retry,
|
||||||
|
stop_after_delay,
|
||||||
|
wait_exponential_jitter,
|
||||||
|
retry_if_not_exception_type,
|
||||||
|
before_sleep_log,
|
||||||
|
)
|
||||||
|
|
||||||
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
|
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
|
||||||
from cognee.infrastructure.llm.tokenizer.HuggingFace import (
|
from cognee.infrastructure.llm.tokenizer.HuggingFace import (
|
||||||
|
|
@ -69,7 +77,6 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
|
||||||
enable_mocking = str(enable_mocking).lower()
|
enable_mocking = str(enable_mocking).lower()
|
||||||
self.mock = enable_mocking in ("true", "1", "yes")
|
self.mock = enable_mocking in ("true", "1", "yes")
|
||||||
|
|
||||||
@embedding_rate_limit_async
|
|
||||||
async def embed_text(self, text: List[str]) -> List[List[float]]:
|
async def embed_text(self, text: List[str]) -> List[List[float]]:
|
||||||
"""
|
"""
|
||||||
Generate embedding vectors for a list of text prompts.
|
Generate embedding vectors for a list of text prompts.
|
||||||
|
|
@ -92,7 +99,13 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
|
||||||
embeddings = await asyncio.gather(*[self._get_embedding(prompt) for prompt in text])
|
embeddings = await asyncio.gather(*[self._get_embedding(prompt) for prompt in text])
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
@embedding_sleep_and_retry_async()
|
@retry(
|
||||||
|
stop=stop_after_delay(180),
|
||||||
|
wait=wait_exponential_jitter(1, 180),
|
||||||
|
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
||||||
|
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||||
|
reraise=True,
|
||||||
|
)
|
||||||
async def _get_embedding(self, prompt: str) -> List[float]:
|
async def _get_embedding(self, prompt: str) -> List[float]:
|
||||||
"""
|
"""
|
||||||
Internal method to call the Ollama embeddings endpoint for a single prompt.
|
Internal method to call the Ollama embeddings endpoint for a single prompt.
|
||||||
|
|
|
||||||
2
poetry.lock
generated
2
poetry.lock
generated
|
|
@ -12738,4 +12738,4 @@ posthog = ["posthog"]
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.1"
|
lock-version = "2.1"
|
||||||
python-versions = ">=3.10,<=3.13"
|
python-versions = ">=3.10,<=3.13"
|
||||||
content-hash = "38353807b06e5c06caaa107979529937b978204f0f405c6b38cee283f4a49d3c"
|
content-hash = "d8cd8a8db46416e0c844ff90df5bd64551ebf9a0c338fbb2023a61008ff5941d"
|
||||||
|
|
|
||||||
|
|
@ -54,7 +54,8 @@ dependencies = [
|
||||||
"networkx>=3.4.2,<4",
|
"networkx>=3.4.2,<4",
|
||||||
"uvicorn>=0.34.0,<1.0.0",
|
"uvicorn>=0.34.0,<1.0.0",
|
||||||
"gunicorn>=20.1.0,<24",
|
"gunicorn>=20.1.0,<24",
|
||||||
"websockets>=15.0.1,<16.0.0"
|
"websockets>=15.0.1,<16.0.0",
|
||||||
|
"tenacity>=9.0.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
|
|
|
||||||
4
uv.lock
generated
4
uv.lock
generated
|
|
@ -856,7 +856,7 @@ wheels = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "cognee"
|
name = "cognee"
|
||||||
version = "0.3.4"
|
version = "0.3.5"
|
||||||
source = { editable = "." }
|
source = { editable = "." }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "aiofiles" },
|
{ name = "aiofiles" },
|
||||||
|
|
@ -892,6 +892,7 @@ dependencies = [
|
||||||
{ name = "rdflib" },
|
{ name = "rdflib" },
|
||||||
{ name = "sqlalchemy" },
|
{ name = "sqlalchemy" },
|
||||||
{ name = "structlog" },
|
{ name = "structlog" },
|
||||||
|
{ name = "tenacity" },
|
||||||
{ name = "tiktoken" },
|
{ name = "tiktoken" },
|
||||||
{ name = "typing-extensions" },
|
{ name = "typing-extensions" },
|
||||||
{ name = "uvicorn" },
|
{ name = "uvicorn" },
|
||||||
|
|
@ -1086,6 +1087,7 @@ requires-dist = [
|
||||||
{ name = "sentry-sdk", extras = ["fastapi"], marker = "extra == 'monitoring'", specifier = ">=2.9.0,<3" },
|
{ name = "sentry-sdk", extras = ["fastapi"], marker = "extra == 'monitoring'", specifier = ">=2.9.0,<3" },
|
||||||
{ name = "sqlalchemy", specifier = ">=2.0.39,<3.0.0" },
|
{ name = "sqlalchemy", specifier = ">=2.0.39,<3.0.0" },
|
||||||
{ name = "structlog", specifier = ">=25.2.0,<26" },
|
{ name = "structlog", specifier = ">=25.2.0,<26" },
|
||||||
|
{ name = "tenacity", specifier = ">=9.0.0" },
|
||||||
{ name = "tiktoken", specifier = ">=0.8.0,<1.0.0" },
|
{ name = "tiktoken", specifier = ">=0.8.0,<1.0.0" },
|
||||||
{ name = "transformers", marker = "extra == 'codegraph'", specifier = ">=4.46.3,<5" },
|
{ name = "transformers", marker = "extra == 'codegraph'", specifier = ">=4.46.3,<5" },
|
||||||
{ name = "transformers", marker = "extra == 'huggingface'", specifier = ">=4.46.3,<5" },
|
{ name = "transformers", marker = "extra == 'huggingface'", specifier = ">=4.46.3,<5" },
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue