Implement retry for LLMClient (#44)

* implement retry

* chore: Refactor tenacity retry logic and improve LLMClient error handling

* poetry

* remove unnecessary try
This commit is contained in:
Daniel Chalef 2024-08-26 12:53:16 -07:00 committed by GitHub
parent 895afc7be1
commit fc4bf3bde2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 40 additions and 2 deletions

View file

@ -20,7 +20,9 @@ import logging
import typing import typing
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import httpx
from diskcache import Cache from diskcache import Cache
from tenacity import retry, retry_if_exception, stop_after_attempt, wait_exponential
from ..prompts.models import Message from ..prompts.models import Message
from .config import LLMConfig from .config import LLMConfig
@ -31,6 +33,12 @@ DEFAULT_CACHE_DIR = './llm_cache'
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def is_server_error(exception):
return (
isinstance(exception, httpx.HTTPStatusError) and 500 <= exception.response.status_code < 600
)
class LLMClient(ABC): class LLMClient(ABC):
def __init__(self, config: LLMConfig | None, cache: bool = False): def __init__(self, config: LLMConfig | None, cache: bool = False):
if config is None: if config is None:
@ -47,6 +55,20 @@ class LLMClient(ABC):
def get_embedder(self) -> typing.Any: def get_embedder(self) -> typing.Any:
pass pass
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception(is_server_error),
)
async def _generate_response_with_retry(self, messages: list[Message]) -> dict[str, typing.Any]:
try:
return await self._generate_response(messages)
except httpx.HTTPStatusError as e:
if not is_server_error(e):
raise Exception(f'LLM request error: {e}') from e
else:
raise
@abstractmethod @abstractmethod
async def _generate_response(self, messages: list[Message]) -> dict[str, typing.Any]: async def _generate_response(self, messages: list[Message]) -> dict[str, typing.Any]:
pass pass
@ -66,7 +88,7 @@ class LLMClient(ABC):
logger.debug(f'Cache hit for {cache_key}') logger.debug(f'Cache hit for {cache_key}')
return cached_response return cached_response
response = await self._generate_response(messages) response = await self._generate_response_with_retry(messages)
if self.cache_enabled: if self.cache_enabled:
self.cache_dir.set(cache_key, response) self.cache_dir.set(cache_key, response)

17
poetry.lock generated
View file

@ -3253,6 +3253,21 @@ mpmath = ">=1.1.0,<1.4"
[package.extras] [package.extras]
dev = ["hypothesis (>=6.70.0)", "pytest (>=7.1.0)"] dev = ["hypothesis (>=6.70.0)", "pytest (>=7.1.0)"]
[[package]]
name = "tenacity"
version = "9.0.0"
description = "Retry code until it succeeds"
optional = false
python-versions = ">=3.8"
files = [
{file = "tenacity-9.0.0-py3-none-any.whl", hash = "sha256:93de0c98785b27fcf659856aa9f54bfbd399e29969b0621bc7f762bd441b4539"},
{file = "tenacity-9.0.0.tar.gz", hash = "sha256:807f37ca97d62aa361264d497b0e31e92b8027044942bfa756160d908320d73b"},
]
[package.extras]
doc = ["reno", "sphinx"]
test = ["pytest", "tornado (>=4.5)", "typeguard"]
[[package]] [[package]]
name = "terminado" name = "terminado"
version = "0.18.1" version = "0.18.1"
@ -3743,4 +3758,4 @@ test = ["websockets"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.10" python-versions = "^3.10"
content-hash = "5b90bb6d58d36a2553f5410c418b179aa1c86b55078567c33aaa6fddf6a8c6c6" content-hash = "001663dfc8078ad473675c994b15191db1f53a844e23f40ffa4a704379a61132"

View file

@ -23,6 +23,7 @@ diskcache = "^5.6.3"
arrow = "^1.3.0" arrow = "^1.3.0"
openai = "^1.38.0" openai = "^1.38.0"
anthropic = "^0.34.1" anthropic = "^0.34.1"
tenacity = "^9.0.0"
[tool.poetry.dev-dependencies] [tool.poetry.dev-dependencies]
pytest = "^8.3.2" pytest = "^8.3.2"