From 14d5ce0b36258af525f028230f0a367a0ae1b9e7 Mon Sep 17 00:00:00 2001 From: Daniel Chalef <131175+danielchalef@users.noreply.github.com> Date: Sun, 22 Sep 2024 11:33:54 -0700 Subject: [PATCH] Override default max tokens for Anthropic and Groq clients (#143) * Override default max tokens for Anthropic and Groq clients * Override default max tokens for Anthropic and Groq clients * Override default max tokens for Anthropic and Groq clients --- graphiti_core/llm_client/anthropic_client.py | 6 +++++- graphiti_core/llm_client/groq_client.py | 6 +++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/graphiti_core/llm_client/anthropic_client.py b/graphiti_core/llm_client/anthropic_client.py index abc149b2..ee186f1e 100644 --- a/graphiti_core/llm_client/anthropic_client.py +++ b/graphiti_core/llm_client/anthropic_client.py @@ -30,13 +30,17 @@ from .errors import RateLimitError logger = logging.getLogger(__name__) DEFAULT_MODEL = 'claude-3-5-sonnet-20240620' +DEFAULT_MAX_TOKENS = 8192 class AnthropicClient(LLMClient): def __init__(self, config: LLMConfig | None = None, cache: bool = False): if config is None: - config = LLMConfig() + config = LLMConfig(max_tokens=DEFAULT_MAX_TOKENS) + elif config.max_tokens is None: + config.max_tokens = DEFAULT_MAX_TOKENS super().__init__(config, cache) + self.client = AsyncAnthropic( api_key=config.api_key, # we'll use tenacity to retry diff --git a/graphiti_core/llm_client/groq_client.py b/graphiti_core/llm_client/groq_client.py index bd5c4471..673b8db1 100644 --- a/graphiti_core/llm_client/groq_client.py +++ b/graphiti_core/llm_client/groq_client.py @@ -31,13 +31,17 @@ from .errors import RateLimitError logger = logging.getLogger(__name__) DEFAULT_MODEL = 'llama-3.1-70b-versatile' +DEFAULT_MAX_TOKENS = 2048 class GroqClient(LLMClient): def __init__(self, config: LLMConfig | None = None, cache: bool = False): if config is None: - config = LLMConfig() + config = LLMConfig(max_tokens=DEFAULT_MAX_TOKENS) + elif config.max_tokens is None: + config.max_tokens = DEFAULT_MAX_TOKENS super().__init__(config, cache) + self.client = AsyncGroq(api_key=config.api_key) def get_embedder(self) -> typing.Any: