Fix llm client retry (#102)
* Fix llm client retry * feat: Improve llm client retry error message
This commit is contained in:
parent
ad2962c6ba
commit
6851b1063a
6 changed files with 42 additions and 12 deletions
|
|
@ -1,5 +1,6 @@
|
||||||
from .client import LLMClient
|
from .client import LLMClient
|
||||||
from .config import LLMConfig
|
from .config import LLMConfig
|
||||||
|
from .errors import RateLimitError
|
||||||
from .openai_client import OpenAIClient
|
from .openai_client import OpenAIClient
|
||||||
|
|
||||||
__all__ = ['LLMClient', 'OpenAIClient', 'LLMConfig']
|
__all__ = ['LLMClient', 'OpenAIClient', 'LLMConfig', 'RateLimitError']
|
||||||
|
|
|
||||||
|
|
@ -18,12 +18,14 @@ import json
|
||||||
import logging
|
import logging
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
|
import anthropic
|
||||||
from anthropic import AsyncAnthropic
|
from anthropic import AsyncAnthropic
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
from ..prompts.models import Message
|
from ..prompts.models import Message
|
||||||
from .client import LLMClient
|
from .client import LLMClient
|
||||||
from .config import LLMConfig
|
from .config import LLMConfig
|
||||||
|
from .errors import RateLimitError
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -35,7 +37,11 @@ class AnthropicClient(LLMClient):
|
||||||
if config is None:
|
if config is None:
|
||||||
config = LLMConfig()
|
config = LLMConfig()
|
||||||
super().__init__(config, cache)
|
super().__init__(config, cache)
|
||||||
self.client = AsyncAnthropic(api_key=config.api_key)
|
self.client = AsyncAnthropic(
|
||||||
|
api_key=config.api_key,
|
||||||
|
# we'll use tenacity to retry
|
||||||
|
max_retries=1,
|
||||||
|
)
|
||||||
|
|
||||||
def get_embedder(self) -> typing.Any:
|
def get_embedder(self) -> typing.Any:
|
||||||
openai_client = AsyncOpenAI()
|
openai_client = AsyncOpenAI()
|
||||||
|
|
@ -58,6 +64,8 @@ class AnthropicClient(LLMClient):
|
||||||
)
|
)
|
||||||
|
|
||||||
return json.loads('{' + result.content[0].text) # type: ignore
|
return json.loads('{' + result.content[0].text) # type: ignore
|
||||||
|
except anthropic.RateLimitError as e:
|
||||||
|
raise RateLimitError from e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f'Error in generating LLM response: {e}')
|
logger.error(f'Error in generating LLM response: {e}')
|
||||||
raise
|
raise
|
||||||
|
|
|
||||||
|
|
@ -22,10 +22,11 @@ from abc import ABC, abstractmethod
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from diskcache import Cache
|
from diskcache import Cache
|
||||||
from tenacity import retry, retry_if_exception, stop_after_attempt, wait_exponential
|
from tenacity import retry, retry_if_exception, stop_after_attempt, wait_random_exponential
|
||||||
|
|
||||||
from ..prompts.models import Message
|
from ..prompts.models import Message
|
||||||
from .config import LLMConfig
|
from .config import LLMConfig
|
||||||
|
from .errors import RateLimitError
|
||||||
|
|
||||||
DEFAULT_TEMPERATURE = 0
|
DEFAULT_TEMPERATURE = 0
|
||||||
DEFAULT_CACHE_DIR = './llm_cache'
|
DEFAULT_CACHE_DIR = './llm_cache'
|
||||||
|
|
@ -33,7 +34,10 @@ DEFAULT_CACHE_DIR = './llm_cache'
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def is_server_error(exception):
|
def is_server_or_retry_error(exception):
|
||||||
|
if isinstance(exception, RateLimitError):
|
||||||
|
return True
|
||||||
|
|
||||||
return (
|
return (
|
||||||
isinstance(exception, httpx.HTTPStatusError) and 500 <= exception.response.status_code < 600
|
isinstance(exception, httpx.HTTPStatusError) and 500 <= exception.response.status_code < 600
|
||||||
)
|
)
|
||||||
|
|
@ -56,18 +60,21 @@ class LLMClient(ABC):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@retry(
|
@retry(
|
||||||
stop=stop_after_attempt(3),
|
stop=stop_after_attempt(4),
|
||||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
wait=wait_random_exponential(multiplier=10, min=5, max=120),
|
||||||
retry=retry_if_exception(is_server_error),
|
retry=retry_if_exception(is_server_or_retry_error),
|
||||||
|
after=lambda retry_state: logger.warning(
|
||||||
|
f'Retrying {retry_state.fn.__name__ if retry_state.fn else "function"} after {retry_state.attempt_number} attempts...'
|
||||||
|
)
|
||||||
|
if retry_state.attempt_number > 1
|
||||||
|
else None,
|
||||||
|
reraise=True,
|
||||||
)
|
)
|
||||||
async def _generate_response_with_retry(self, messages: list[Message]) -> dict[str, typing.Any]:
|
async def _generate_response_with_retry(self, messages: list[Message]) -> dict[str, typing.Any]:
|
||||||
try:
|
try:
|
||||||
return await self._generate_response(messages)
|
return await self._generate_response(messages)
|
||||||
except httpx.HTTPStatusError as e:
|
except (httpx.HTTPStatusError, RateLimitError) as e:
|
||||||
if not is_server_error(e):
|
raise e
|
||||||
raise Exception(f'LLM request error: {e}') from e
|
|
||||||
else:
|
|
||||||
raise
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def _generate_response(self, messages: list[Message]) -> dict[str, typing.Any]:
|
async def _generate_response(self, messages: list[Message]) -> dict[str, typing.Any]:
|
||||||
|
|
|
||||||
6
graphiti_core/llm_client/errors.py
Normal file
6
graphiti_core/llm_client/errors.py
Normal file
|
|
@ -0,0 +1,6 @@
|
||||||
|
class RateLimitError(Exception):
|
||||||
|
"""Exception raised when the rate limit is exceeded."""
|
||||||
|
|
||||||
|
def __init__(self, message='Rate limit exceeded. Please try again later.'):
|
||||||
|
self.message = message
|
||||||
|
super().__init__(self.message)
|
||||||
|
|
@ -18,6 +18,7 @@ import json
|
||||||
import logging
|
import logging
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
|
import groq
|
||||||
from groq import AsyncGroq
|
from groq import AsyncGroq
|
||||||
from groq.types.chat import ChatCompletionMessageParam
|
from groq.types.chat import ChatCompletionMessageParam
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
|
|
@ -25,6 +26,7 @@ from openai import AsyncOpenAI
|
||||||
from ..prompts.models import Message
|
from ..prompts.models import Message
|
||||||
from .client import LLMClient
|
from .client import LLMClient
|
||||||
from .config import LLMConfig
|
from .config import LLMConfig
|
||||||
|
from .errors import RateLimitError
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -59,6 +61,8 @@ class GroqClient(LLMClient):
|
||||||
)
|
)
|
||||||
result = response.choices[0].message.content or ''
|
result = response.choices[0].message.content or ''
|
||||||
return json.loads(result)
|
return json.loads(result)
|
||||||
|
except groq.RateLimitError as e:
|
||||||
|
raise RateLimitError from e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f'Error in generating LLM response: {e}')
|
logger.error(f'Error in generating LLM response: {e}')
|
||||||
raise
|
raise
|
||||||
|
|
|
||||||
|
|
@ -18,12 +18,14 @@ import json
|
||||||
import logging
|
import logging
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
|
import openai
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
from openai.types.chat import ChatCompletionMessageParam
|
from openai.types.chat import ChatCompletionMessageParam
|
||||||
|
|
||||||
from ..prompts.models import Message
|
from ..prompts.models import Message
|
||||||
from .client import LLMClient
|
from .client import LLMClient
|
||||||
from .config import LLMConfig
|
from .config import LLMConfig
|
||||||
|
from .errors import RateLimitError
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -59,6 +61,8 @@ class OpenAIClient(LLMClient):
|
||||||
)
|
)
|
||||||
result = response.choices[0].message.content or ''
|
result = response.choices[0].message.content or ''
|
||||||
return json.loads(result)
|
return json.loads(result)
|
||||||
|
except openai.RateLimitError as e:
|
||||||
|
raise RateLimitError from e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f'Error in generating LLM response: {e}')
|
logger.error(f'Error in generating LLM response: {e}')
|
||||||
raise
|
raise
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue