Fix llm client retry (#102)
* Fix llm client retry * feat: Improve llm client retry error message
This commit is contained in:
parent
ad2962c6ba
commit
6851b1063a
6 changed files with 42 additions and 12 deletions
|
|
@ -1,5 +1,6 @@
|
|||
from .client import LLMClient
|
||||
from .config import LLMConfig
|
||||
from .errors import RateLimitError
|
||||
from .openai_client import OpenAIClient
|
||||
|
||||
__all__ = ['LLMClient', 'OpenAIClient', 'LLMConfig']
|
||||
__all__ = ['LLMClient', 'OpenAIClient', 'LLMConfig', 'RateLimitError']
|
||||
|
|
|
|||
|
|
@ -18,12 +18,14 @@ import json
|
|||
import logging
|
||||
import typing
|
||||
|
||||
import anthropic
|
||||
from anthropic import AsyncAnthropic
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from ..prompts.models import Message
|
||||
from .client import LLMClient
|
||||
from .config import LLMConfig
|
||||
from .errors import RateLimitError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -35,7 +37,11 @@ class AnthropicClient(LLMClient):
|
|||
if config is None:
|
||||
config = LLMConfig()
|
||||
super().__init__(config, cache)
|
||||
self.client = AsyncAnthropic(api_key=config.api_key)
|
||||
self.client = AsyncAnthropic(
|
||||
api_key=config.api_key,
|
||||
# we'll use tenacity to retry
|
||||
max_retries=1,
|
||||
)
|
||||
|
||||
def get_embedder(self) -> typing.Any:
|
||||
openai_client = AsyncOpenAI()
|
||||
|
|
@ -58,6 +64,8 @@ class AnthropicClient(LLMClient):
|
|||
)
|
||||
|
||||
return json.loads('{' + result.content[0].text) # type: ignore
|
||||
except anthropic.RateLimitError as e:
|
||||
raise RateLimitError from e
|
||||
except Exception as e:
|
||||
logger.error(f'Error in generating LLM response: {e}')
|
||||
raise
|
||||
|
|
|
|||
|
|
@ -22,10 +22,11 @@ from abc import ABC, abstractmethod
|
|||
|
||||
import httpx
|
||||
from diskcache import Cache
|
||||
from tenacity import retry, retry_if_exception, stop_after_attempt, wait_exponential
|
||||
from tenacity import retry, retry_if_exception, stop_after_attempt, wait_random_exponential
|
||||
|
||||
from ..prompts.models import Message
|
||||
from .config import LLMConfig
|
||||
from .errors import RateLimitError
|
||||
|
||||
DEFAULT_TEMPERATURE = 0
|
||||
DEFAULT_CACHE_DIR = './llm_cache'
|
||||
|
|
@ -33,7 +34,10 @@ DEFAULT_CACHE_DIR = './llm_cache'
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def is_server_error(exception):
|
||||
def is_server_or_retry_error(exception):
|
||||
if isinstance(exception, RateLimitError):
|
||||
return True
|
||||
|
||||
return (
|
||||
isinstance(exception, httpx.HTTPStatusError) and 500 <= exception.response.status_code < 600
|
||||
)
|
||||
|
|
@ -56,18 +60,21 @@ class LLMClient(ABC):
|
|||
pass
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
retry=retry_if_exception(is_server_error),
|
||||
stop=stop_after_attempt(4),
|
||||
wait=wait_random_exponential(multiplier=10, min=5, max=120),
|
||||
retry=retry_if_exception(is_server_or_retry_error),
|
||||
after=lambda retry_state: logger.warning(
|
||||
f'Retrying {retry_state.fn.__name__ if retry_state.fn else "function"} after {retry_state.attempt_number} attempts...'
|
||||
)
|
||||
if retry_state.attempt_number > 1
|
||||
else None,
|
||||
reraise=True,
|
||||
)
|
||||
async def _generate_response_with_retry(self, messages: list[Message]) -> dict[str, typing.Any]:
|
||||
try:
|
||||
return await self._generate_response(messages)
|
||||
except httpx.HTTPStatusError as e:
|
||||
if not is_server_error(e):
|
||||
raise Exception(f'LLM request error: {e}') from e
|
||||
else:
|
||||
raise
|
||||
except (httpx.HTTPStatusError, RateLimitError) as e:
|
||||
raise e
|
||||
|
||||
@abstractmethod
|
||||
async def _generate_response(self, messages: list[Message]) -> dict[str, typing.Any]:
|
||||
|
|
|
|||
6
graphiti_core/llm_client/errors.py
Normal file
6
graphiti_core/llm_client/errors.py
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
class RateLimitError(Exception):
|
||||
"""Exception raised when the rate limit is exceeded."""
|
||||
|
||||
def __init__(self, message='Rate limit exceeded. Please try again later.'):
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
|
|
@ -18,6 +18,7 @@ import json
|
|||
import logging
|
||||
import typing
|
||||
|
||||
import groq
|
||||
from groq import AsyncGroq
|
||||
from groq.types.chat import ChatCompletionMessageParam
|
||||
from openai import AsyncOpenAI
|
||||
|
|
@ -25,6 +26,7 @@ from openai import AsyncOpenAI
|
|||
from ..prompts.models import Message
|
||||
from .client import LLMClient
|
||||
from .config import LLMConfig
|
||||
from .errors import RateLimitError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -59,6 +61,8 @@ class GroqClient(LLMClient):
|
|||
)
|
||||
result = response.choices[0].message.content or ''
|
||||
return json.loads(result)
|
||||
except groq.RateLimitError as e:
|
||||
raise RateLimitError from e
|
||||
except Exception as e:
|
||||
logger.error(f'Error in generating LLM response: {e}')
|
||||
raise
|
||||
|
|
|
|||
|
|
@ -18,12 +18,14 @@ import json
|
|||
import logging
|
||||
import typing
|
||||
|
||||
import openai
|
||||
from openai import AsyncOpenAI
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
|
||||
from ..prompts.models import Message
|
||||
from .client import LLMClient
|
||||
from .config import LLMConfig
|
||||
from .errors import RateLimitError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -59,6 +61,8 @@ class OpenAIClient(LLMClient):
|
|||
)
|
||||
result = response.choices[0].message.content or ''
|
||||
return json.loads(result)
|
||||
except openai.RateLimitError as e:
|
||||
raise RateLimitError from e
|
||||
except Exception as e:
|
||||
logger.error(f'Error in generating LLM response: {e}')
|
||||
raise
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue