Fix llm client retry (#102)

* Fix llm client retry

* feat: Improve llm client retry error message
This commit is contained in:
Daniel Chalef 2024-09-10 08:15:27 -07:00 committed by GitHub
parent ad2962c6ba
commit 6851b1063a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 42 additions and 12 deletions

View file

@ -1,5 +1,6 @@
from .client import LLMClient
from .config import LLMConfig
from .errors import RateLimitError
from .openai_client import OpenAIClient
__all__ = ['LLMClient', 'OpenAIClient', 'LLMConfig']
__all__ = ['LLMClient', 'OpenAIClient', 'LLMConfig', 'RateLimitError']

View file

@ -18,12 +18,14 @@ import json
import logging
import typing
import anthropic
from anthropic import AsyncAnthropic
from openai import AsyncOpenAI
from ..prompts.models import Message
from .client import LLMClient
from .config import LLMConfig
from .errors import RateLimitError
logger = logging.getLogger(__name__)
@ -35,7 +37,11 @@ class AnthropicClient(LLMClient):
if config is None:
config = LLMConfig()
super().__init__(config, cache)
self.client = AsyncAnthropic(api_key=config.api_key)
self.client = AsyncAnthropic(
api_key=config.api_key,
# we'll use tenacity to retry
max_retries=1,
)
def get_embedder(self) -> typing.Any:
openai_client = AsyncOpenAI()
@ -58,6 +64,8 @@ class AnthropicClient(LLMClient):
)
return json.loads('{' + result.content[0].text) # type: ignore
except anthropic.RateLimitError as e:
raise RateLimitError from e
except Exception as e:
logger.error(f'Error in generating LLM response: {e}')
raise

View file

@ -22,10 +22,11 @@ from abc import ABC, abstractmethod
import httpx
from diskcache import Cache
from tenacity import retry, retry_if_exception, stop_after_attempt, wait_exponential
from tenacity import retry, retry_if_exception, stop_after_attempt, wait_random_exponential
from ..prompts.models import Message
from .config import LLMConfig
from .errors import RateLimitError
DEFAULT_TEMPERATURE = 0
DEFAULT_CACHE_DIR = './llm_cache'
@ -33,7 +34,10 @@ DEFAULT_CACHE_DIR = './llm_cache'
logger = logging.getLogger(__name__)
def is_server_error(exception):
def is_server_or_retry_error(exception):
if isinstance(exception, RateLimitError):
return True
return (
isinstance(exception, httpx.HTTPStatusError) and 500 <= exception.response.status_code < 600
)
@ -56,18 +60,21 @@ class LLMClient(ABC):
pass
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception(is_server_error),
stop=stop_after_attempt(4),
wait=wait_random_exponential(multiplier=10, min=5, max=120),
retry=retry_if_exception(is_server_or_retry_error),
after=lambda retry_state: logger.warning(
f'Retrying {retry_state.fn.__name__ if retry_state.fn else "function"} after {retry_state.attempt_number} attempts...'
)
if retry_state.attempt_number > 1
else None,
reraise=True,
)
async def _generate_response_with_retry(self, messages: list[Message]) -> dict[str, typing.Any]:
try:
return await self._generate_response(messages)
except httpx.HTTPStatusError as e:
if not is_server_error(e):
raise Exception(f'LLM request error: {e}') from e
else:
raise
except (httpx.HTTPStatusError, RateLimitError) as e:
raise e
@abstractmethod
async def _generate_response(self, messages: list[Message]) -> dict[str, typing.Any]:

View file

@ -0,0 +1,6 @@
class RateLimitError(Exception):
"""Exception raised when the rate limit is exceeded."""
def __init__(self, message='Rate limit exceeded. Please try again later.'):
self.message = message
super().__init__(self.message)

View file

@ -18,6 +18,7 @@ import json
import logging
import typing
import groq
from groq import AsyncGroq
from groq.types.chat import ChatCompletionMessageParam
from openai import AsyncOpenAI
@ -25,6 +26,7 @@ from openai import AsyncOpenAI
from ..prompts.models import Message
from .client import LLMClient
from .config import LLMConfig
from .errors import RateLimitError
logger = logging.getLogger(__name__)
@ -59,6 +61,8 @@ class GroqClient(LLMClient):
)
result = response.choices[0].message.content or ''
return json.loads(result)
except groq.RateLimitError as e:
raise RateLimitError from e
except Exception as e:
logger.error(f'Error in generating LLM response: {e}')
raise

View file

@ -18,12 +18,14 @@ import json
import logging
import typing
import openai
from openai import AsyncOpenAI
from openai.types.chat import ChatCompletionMessageParam
from ..prompts.models import Message
from .client import LLMClient
from .config import LLMConfig
from .errors import RateLimitError
logger = logging.getLogger(__name__)
@ -59,6 +61,8 @@ class OpenAIClient(LLMClient):
)
result = response.choices[0].message.content or ''
return json.loads(result)
except openai.RateLimitError as e:
raise RateLimitError from e
except Exception as e:
logger.error(f'Error in generating LLM response: {e}')
raise