From 6851b1063a6fc3cbb791e9dd52b55540c8216513 Mon Sep 17 00:00:00 2001 From: Daniel Chalef <131175+danielchalef@users.noreply.github.com> Date: Tue, 10 Sep 2024 08:15:27 -0700 Subject: [PATCH] Fix llm client retry (#102) * Fix llm client retry * feat: Improve llm client retry error message --- graphiti_core/llm_client/__init__.py | 3 ++- graphiti_core/llm_client/anthropic_client.py | 10 +++++++- graphiti_core/llm_client/client.py | 27 ++++++++++++-------- graphiti_core/llm_client/errors.py | 6 +++++ graphiti_core/llm_client/groq_client.py | 4 +++ graphiti_core/llm_client/openai_client.py | 4 +++ 6 files changed, 42 insertions(+), 12 deletions(-) create mode 100644 graphiti_core/llm_client/errors.py diff --git a/graphiti_core/llm_client/__init__.py b/graphiti_core/llm_client/__init__.py index 24f9a46f..1472aa6b 100644 --- a/graphiti_core/llm_client/__init__.py +++ b/graphiti_core/llm_client/__init__.py @@ -1,5 +1,6 @@ from .client import LLMClient from .config import LLMConfig +from .errors import RateLimitError from .openai_client import OpenAIClient -__all__ = ['LLMClient', 'OpenAIClient', 'LLMConfig'] +__all__ = ['LLMClient', 'OpenAIClient', 'LLMConfig', 'RateLimitError'] diff --git a/graphiti_core/llm_client/anthropic_client.py b/graphiti_core/llm_client/anthropic_client.py index 33f069ed..abc149b2 100644 --- a/graphiti_core/llm_client/anthropic_client.py +++ b/graphiti_core/llm_client/anthropic_client.py @@ -18,12 +18,14 @@ import json import logging import typing +import anthropic from anthropic import AsyncAnthropic from openai import AsyncOpenAI from ..prompts.models import Message from .client import LLMClient from .config import LLMConfig +from .errors import RateLimitError logger = logging.getLogger(__name__) @@ -35,7 +37,11 @@ class AnthropicClient(LLMClient): if config is None: config = LLMConfig() super().__init__(config, cache) - self.client = AsyncAnthropic(api_key=config.api_key) + self.client = AsyncAnthropic( + api_key=config.api_key, + # we'll use tenacity to retry + max_retries=1, + ) def get_embedder(self) -> typing.Any: openai_client = AsyncOpenAI() @@ -58,6 +64,8 @@ class AnthropicClient(LLMClient): ) return json.loads('{' + result.content[0].text) # type: ignore + except anthropic.RateLimitError as e: + raise RateLimitError from e except Exception as e: logger.error(f'Error in generating LLM response: {e}') raise diff --git a/graphiti_core/llm_client/client.py b/graphiti_core/llm_client/client.py index 5de06d76..4c53dd90 100644 --- a/graphiti_core/llm_client/client.py +++ b/graphiti_core/llm_client/client.py @@ -22,10 +22,11 @@ from abc import ABC, abstractmethod import httpx from diskcache import Cache -from tenacity import retry, retry_if_exception, stop_after_attempt, wait_exponential +from tenacity import retry, retry_if_exception, stop_after_attempt, wait_random_exponential from ..prompts.models import Message from .config import LLMConfig +from .errors import RateLimitError DEFAULT_TEMPERATURE = 0 DEFAULT_CACHE_DIR = './llm_cache' @@ -33,7 +34,10 @@ DEFAULT_CACHE_DIR = './llm_cache' logger = logging.getLogger(__name__) -def is_server_error(exception): +def is_server_or_retry_error(exception): + if isinstance(exception, RateLimitError): + return True + return ( isinstance(exception, httpx.HTTPStatusError) and 500 <= exception.response.status_code < 600 ) @@ -56,18 +60,21 @@ class LLMClient(ABC): pass @retry( - stop=stop_after_attempt(3), - wait=wait_exponential(multiplier=1, min=4, max=10), - retry=retry_if_exception(is_server_error), + stop=stop_after_attempt(4), + wait=wait_random_exponential(multiplier=10, min=5, max=120), + retry=retry_if_exception(is_server_or_retry_error), + after=lambda retry_state: logger.warning( + f'Retrying {retry_state.fn.__name__ if retry_state.fn else "function"} after {retry_state.attempt_number} attempts...' + ) + if retry_state.attempt_number > 1 + else None, + reraise=True, ) async def _generate_response_with_retry(self, messages: list[Message]) -> dict[str, typing.Any]: try: return await self._generate_response(messages) - except httpx.HTTPStatusError as e: - if not is_server_error(e): - raise Exception(f'LLM request error: {e}') from e - else: - raise + except (httpx.HTTPStatusError, RateLimitError) as e: + raise e @abstractmethod async def _generate_response(self, messages: list[Message]) -> dict[str, typing.Any]: diff --git a/graphiti_core/llm_client/errors.py b/graphiti_core/llm_client/errors.py new file mode 100644 index 00000000..13f9b479 --- /dev/null +++ b/graphiti_core/llm_client/errors.py @@ -0,0 +1,6 @@ +class RateLimitError(Exception): + """Exception raised when the rate limit is exceeded.""" + + def __init__(self, message='Rate limit exceeded. Please try again later.'): + self.message = message + super().__init__(self.message) diff --git a/graphiti_core/llm_client/groq_client.py b/graphiti_core/llm_client/groq_client.py index b9a5f190..bd5c4471 100644 --- a/graphiti_core/llm_client/groq_client.py +++ b/graphiti_core/llm_client/groq_client.py @@ -18,6 +18,7 @@ import json import logging import typing +import groq from groq import AsyncGroq from groq.types.chat import ChatCompletionMessageParam from openai import AsyncOpenAI @@ -25,6 +26,7 @@ from openai import AsyncOpenAI from ..prompts.models import Message from .client import LLMClient from .config import LLMConfig +from .errors import RateLimitError logger = logging.getLogger(__name__) @@ -59,6 +61,8 @@ class GroqClient(LLMClient): ) result = response.choices[0].message.content or '' return json.loads(result) + except groq.RateLimitError as e: + raise RateLimitError from e except Exception as e: logger.error(f'Error in generating LLM response: {e}') raise diff --git a/graphiti_core/llm_client/openai_client.py b/graphiti_core/llm_client/openai_client.py index dd727576..9ac71a56 100644 --- a/graphiti_core/llm_client/openai_client.py +++ b/graphiti_core/llm_client/openai_client.py @@ -18,12 +18,14 @@ import json import logging import typing +import openai from openai import AsyncOpenAI from openai.types.chat import ChatCompletionMessageParam from ..prompts.models import Message from .client import LLMClient from .config import LLMConfig +from .errors import RateLimitError logger = logging.getLogger(__name__) @@ -59,6 +61,8 @@ class OpenAIClient(LLMClient): ) result = response.choices[0].message.content or '' return json.loads(result) + except openai.RateLimitError as e: + raise RateLimitError from e except Exception as e: logger.error(f'Error in generating LLM response: {e}') raise