From fc4bf3bde20d263c228edee26a44be2887849a52 Mon Sep 17 00:00:00 2001 From: Daniel Chalef <131175+danielchalef@users.noreply.github.com> Date: Mon, 26 Aug 2024 12:53:16 -0700 Subject: [PATCH] Implement retry for LLMClient (#44) * implement retry * chore: Refactor tenacity retry logic and improve LLMClient error handling * poetry * remove unnecessary try --- graphiti_core/llm_client/client.py | 24 +++++++++++++++++++++++- poetry.lock | 17 ++++++++++++++++- pyproject.toml | 1 + 3 files changed, 40 insertions(+), 2 deletions(-) diff --git a/graphiti_core/llm_client/client.py b/graphiti_core/llm_client/client.py index 02bd6f4f..5de06d76 100644 --- a/graphiti_core/llm_client/client.py +++ b/graphiti_core/llm_client/client.py @@ -20,7 +20,9 @@ import logging import typing from abc import ABC, abstractmethod +import httpx from diskcache import Cache +from tenacity import retry, retry_if_exception, stop_after_attempt, wait_exponential from ..prompts.models import Message from .config import LLMConfig @@ -31,6 +33,12 @@ DEFAULT_CACHE_DIR = './llm_cache' logger = logging.getLogger(__name__) +def is_server_error(exception): + return ( + isinstance(exception, httpx.HTTPStatusError) and 500 <= exception.response.status_code < 600 + ) + + class LLMClient(ABC): def __init__(self, config: LLMConfig | None, cache: bool = False): if config is None: @@ -47,6 +55,20 @@ class LLMClient(ABC): def get_embedder(self) -> typing.Any: pass + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception(is_server_error), + ) + async def _generate_response_with_retry(self, messages: list[Message]) -> dict[str, typing.Any]: + try: + return await self._generate_response(messages) + except httpx.HTTPStatusError as e: + if not is_server_error(e): + raise Exception(f'LLM request error: {e}') from e + else: + raise + @abstractmethod async def _generate_response(self, messages: list[Message]) -> dict[str, typing.Any]: pass @@ -66,7 +88,7 @@ class LLMClient(ABC): logger.debug(f'Cache hit for {cache_key}') return cached_response - response = await self._generate_response(messages) + response = await self._generate_response_with_retry(messages) if self.cache_enabled: self.cache_dir.set(cache_key, response) diff --git a/poetry.lock b/poetry.lock index 6d964a5f..22653da0 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3253,6 +3253,21 @@ mpmath = ">=1.1.0,<1.4" [package.extras] dev = ["hypothesis (>=6.70.0)", "pytest (>=7.1.0)"] +[[package]] +name = "tenacity" +version = "9.0.0" +description = "Retry code until it succeeds" +optional = false +python-versions = ">=3.8" +files = [ + {file = "tenacity-9.0.0-py3-none-any.whl", hash = "sha256:93de0c98785b27fcf659856aa9f54bfbd399e29969b0621bc7f762bd441b4539"}, + {file = "tenacity-9.0.0.tar.gz", hash = "sha256:807f37ca97d62aa361264d497b0e31e92b8027044942bfa756160d908320d73b"}, +] + +[package.extras] +doc = ["reno", "sphinx"] +test = ["pytest", "tornado (>=4.5)", "typeguard"] + [[package]] name = "terminado" version = "0.18.1" @@ -3743,4 +3758,4 @@ test = ["websockets"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "5b90bb6d58d36a2553f5410c418b179aa1c86b55078567c33aaa6fddf6a8c6c6" +content-hash = "001663dfc8078ad473675c994b15191db1f53a844e23f40ffa4a704379a61132" diff --git a/pyproject.toml b/pyproject.toml index 2456f13d..956e48d4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ diskcache = "^5.6.3" arrow = "^1.3.0" openai = "^1.38.0" anthropic = "^0.34.1" +tenacity = "^9.0.0" [tool.poetry.dev-dependencies] pytest = "^8.3.2"