Implement retry for LLMClient (#44)
* implement retry * chore: Refactor tenacity retry logic and improve LLMClient error handling * poetry * remove unnecessary try
This commit is contained in:
parent
895afc7be1
commit
fc4bf3bde2
3 changed files with 40 additions and 2 deletions
|
|
@ -20,7 +20,9 @@ import logging
|
||||||
import typing
|
import typing
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
import httpx
|
||||||
from diskcache import Cache
|
from diskcache import Cache
|
||||||
|
from tenacity import retry, retry_if_exception, stop_after_attempt, wait_exponential
|
||||||
|
|
||||||
from ..prompts.models import Message
|
from ..prompts.models import Message
|
||||||
from .config import LLMConfig
|
from .config import LLMConfig
|
||||||
|
|
@ -31,6 +33,12 @@ DEFAULT_CACHE_DIR = './llm_cache'
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def is_server_error(exception):
|
||||||
|
return (
|
||||||
|
isinstance(exception, httpx.HTTPStatusError) and 500 <= exception.response.status_code < 600
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class LLMClient(ABC):
|
class LLMClient(ABC):
|
||||||
def __init__(self, config: LLMConfig | None, cache: bool = False):
|
def __init__(self, config: LLMConfig | None, cache: bool = False):
|
||||||
if config is None:
|
if config is None:
|
||||||
|
|
@ -47,6 +55,20 @@ class LLMClient(ABC):
|
||||||
def get_embedder(self) -> typing.Any:
|
def get_embedder(self) -> typing.Any:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@retry(
|
||||||
|
stop=stop_after_attempt(3),
|
||||||
|
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||||
|
retry=retry_if_exception(is_server_error),
|
||||||
|
)
|
||||||
|
async def _generate_response_with_retry(self, messages: list[Message]) -> dict[str, typing.Any]:
|
||||||
|
try:
|
||||||
|
return await self._generate_response(messages)
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
if not is_server_error(e):
|
||||||
|
raise Exception(f'LLM request error: {e}') from e
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def _generate_response(self, messages: list[Message]) -> dict[str, typing.Any]:
|
async def _generate_response(self, messages: list[Message]) -> dict[str, typing.Any]:
|
||||||
pass
|
pass
|
||||||
|
|
@ -66,7 +88,7 @@ class LLMClient(ABC):
|
||||||
logger.debug(f'Cache hit for {cache_key}')
|
logger.debug(f'Cache hit for {cache_key}')
|
||||||
return cached_response
|
return cached_response
|
||||||
|
|
||||||
response = await self._generate_response(messages)
|
response = await self._generate_response_with_retry(messages)
|
||||||
|
|
||||||
if self.cache_enabled:
|
if self.cache_enabled:
|
||||||
self.cache_dir.set(cache_key, response)
|
self.cache_dir.set(cache_key, response)
|
||||||
|
|
|
||||||
17
poetry.lock
generated
17
poetry.lock
generated
|
|
@ -3253,6 +3253,21 @@ mpmath = ">=1.1.0,<1.4"
|
||||||
[package.extras]
|
[package.extras]
|
||||||
dev = ["hypothesis (>=6.70.0)", "pytest (>=7.1.0)"]
|
dev = ["hypothesis (>=6.70.0)", "pytest (>=7.1.0)"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tenacity"
|
||||||
|
version = "9.0.0"
|
||||||
|
description = "Retry code until it succeeds"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.8"
|
||||||
|
files = [
|
||||||
|
{file = "tenacity-9.0.0-py3-none-any.whl", hash = "sha256:93de0c98785b27fcf659856aa9f54bfbd399e29969b0621bc7f762bd441b4539"},
|
||||||
|
{file = "tenacity-9.0.0.tar.gz", hash = "sha256:807f37ca97d62aa361264d497b0e31e92b8027044942bfa756160d908320d73b"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
doc = ["reno", "sphinx"]
|
||||||
|
test = ["pytest", "tornado (>=4.5)", "typeguard"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "terminado"
|
name = "terminado"
|
||||||
version = "0.18.1"
|
version = "0.18.1"
|
||||||
|
|
@ -3743,4 +3758,4 @@ test = ["websockets"]
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = "^3.10"
|
python-versions = "^3.10"
|
||||||
content-hash = "5b90bb6d58d36a2553f5410c418b179aa1c86b55078567c33aaa6fddf6a8c6c6"
|
content-hash = "001663dfc8078ad473675c994b15191db1f53a844e23f40ffa4a704379a61132"
|
||||||
|
|
|
||||||
|
|
@ -23,6 +23,7 @@ diskcache = "^5.6.3"
|
||||||
arrow = "^1.3.0"
|
arrow = "^1.3.0"
|
||||||
openai = "^1.38.0"
|
openai = "^1.38.0"
|
||||||
anthropic = "^0.34.1"
|
anthropic = "^0.34.1"
|
||||||
|
tenacity = "^9.0.0"
|
||||||
|
|
||||||
[tool.poetry.dev-dependencies]
|
[tool.poetry.dev-dependencies]
|
||||||
pytest = "^8.3.2"
|
pytest = "^8.3.2"
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue