refactor: Add tenacity retry mechanism
This commit is contained in:
parent
84a23756f5
commit
98daadbb04
5 changed files with 38 additions and 8 deletions
|
|
@ -1,8 +1,17 @@
|
|||
import asyncio
|
||||
import logging
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from typing import List, Optional
|
||||
import numpy as np
|
||||
import math
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_delay,
|
||||
wait_exponential_jitter,
|
||||
retry_if_not_exception_type,
|
||||
before_sleep_log,
|
||||
)
|
||||
import litellm
|
||||
import os
|
||||
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
|
||||
|
|
@ -76,8 +85,13 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
|||
enable_mocking = str(enable_mocking).lower()
|
||||
self.mock = enable_mocking in ("true", "1", "yes")
|
||||
|
||||
@embedding_sleep_and_retry_async()
|
||||
@embedding_rate_limit_async
|
||||
@retry(
|
||||
stop=stop_after_delay(180),
|
||||
wait=wait_exponential_jitter(1, 180),
|
||||
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||
reraise=True,
|
||||
)
|
||||
async def embed_text(self, text: List[str]) -> List[List[float]]:
|
||||
"""
|
||||
Embed a list of text strings into vector representations.
|
||||
|
|
|
|||
|
|
@ -3,8 +3,16 @@ from cognee.shared.logging_utils import get_logger
|
|||
import aiohttp
|
||||
from typing import List, Optional
|
||||
import os
|
||||
|
||||
import litellm
|
||||
import logging
|
||||
import aiohttp.http_exceptions
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_delay,
|
||||
wait_exponential_jitter,
|
||||
retry_if_not_exception_type,
|
||||
before_sleep_log,
|
||||
)
|
||||
|
||||
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
|
||||
from cognee.infrastructure.llm.tokenizer.HuggingFace import (
|
||||
|
|
@ -69,7 +77,6 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
|
|||
enable_mocking = str(enable_mocking).lower()
|
||||
self.mock = enable_mocking in ("true", "1", "yes")
|
||||
|
||||
@embedding_rate_limit_async
|
||||
async def embed_text(self, text: List[str]) -> List[List[float]]:
|
||||
"""
|
||||
Generate embedding vectors for a list of text prompts.
|
||||
|
|
@ -92,7 +99,13 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
|
|||
embeddings = await asyncio.gather(*[self._get_embedding(prompt) for prompt in text])
|
||||
return embeddings
|
||||
|
||||
@embedding_sleep_and_retry_async()
|
||||
@retry(
|
||||
stop=stop_after_delay(180),
|
||||
wait=wait_exponential_jitter(1, 180),
|
||||
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||
reraise=True,
|
||||
)
|
||||
async def _get_embedding(self, prompt: str) -> List[float]:
|
||||
"""
|
||||
Internal method to call the Ollama embeddings endpoint for a single prompt.
|
||||
|
|
|
|||
2
poetry.lock
generated
2
poetry.lock
generated
|
|
@ -12738,4 +12738,4 @@ posthog = ["posthog"]
|
|||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = ">=3.10,<=3.13"
|
||||
content-hash = "38353807b06e5c06caaa107979529937b978204f0f405c6b38cee283f4a49d3c"
|
||||
content-hash = "d8cd8a8db46416e0c844ff90df5bd64551ebf9a0c338fbb2023a61008ff5941d"
|
||||
|
|
|
|||
|
|
@ -54,7 +54,8 @@ dependencies = [
|
|||
"networkx>=3.4.2,<4",
|
||||
"uvicorn>=0.34.0,<1.0.0",
|
||||
"gunicorn>=20.1.0,<24",
|
||||
"websockets>=15.0.1,<16.0.0"
|
||||
"websockets>=15.0.1,<16.0.0",
|
||||
"tenacity>=9.0.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
|
|
|||
4
uv.lock
generated
4
uv.lock
generated
|
|
@ -856,7 +856,7 @@ wheels = [
|
|||
|
||||
[[package]]
|
||||
name = "cognee"
|
||||
version = "0.3.4"
|
||||
version = "0.3.5"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "aiofiles" },
|
||||
|
|
@ -892,6 +892,7 @@ dependencies = [
|
|||
{ name = "rdflib" },
|
||||
{ name = "sqlalchemy" },
|
||||
{ name = "structlog" },
|
||||
{ name = "tenacity" },
|
||||
{ name = "tiktoken" },
|
||||
{ name = "typing-extensions" },
|
||||
{ name = "uvicorn" },
|
||||
|
|
@ -1086,6 +1087,7 @@ requires-dist = [
|
|||
{ name = "sentry-sdk", extras = ["fastapi"], marker = "extra == 'monitoring'", specifier = ">=2.9.0,<3" },
|
||||
{ name = "sqlalchemy", specifier = ">=2.0.39,<3.0.0" },
|
||||
{ name = "structlog", specifier = ">=25.2.0,<26" },
|
||||
{ name = "tenacity", specifier = ">=9.0.0" },
|
||||
{ name = "tiktoken", specifier = ">=0.8.0,<1.0.0" },
|
||||
{ name = "transformers", marker = "extra == 'codegraph'", specifier = ">=4.46.3,<5" },
|
||||
{ name = "transformers", marker = "extra == 'huggingface'", specifier = ">=4.46.3,<5" },
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue