Override default max tokens for Anthropic and Groq clients (#143)
* Override default max tokens for Anthropic and Groq clients * Override default max tokens for Anthropic and Groq clients * Override default max tokens for Anthropic and Groq clients
This commit is contained in:
parent
d8c49c1c0a
commit
14d5ce0b36
2 changed files with 10 additions and 2 deletions
|
|
@ -30,13 +30,17 @@ from .errors import RateLimitError
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
DEFAULT_MODEL = 'claude-3-5-sonnet-20240620'
|
DEFAULT_MODEL = 'claude-3-5-sonnet-20240620'
|
||||||
|
DEFAULT_MAX_TOKENS = 8192
|
||||||
|
|
||||||
|
|
||||||
class AnthropicClient(LLMClient):
|
class AnthropicClient(LLMClient):
|
||||||
def __init__(self, config: LLMConfig | None = None, cache: bool = False):
|
def __init__(self, config: LLMConfig | None = None, cache: bool = False):
|
||||||
if config is None:
|
if config is None:
|
||||||
config = LLMConfig()
|
config = LLMConfig(max_tokens=DEFAULT_MAX_TOKENS)
|
||||||
|
elif config.max_tokens is None:
|
||||||
|
config.max_tokens = DEFAULT_MAX_TOKENS
|
||||||
super().__init__(config, cache)
|
super().__init__(config, cache)
|
||||||
|
|
||||||
self.client = AsyncAnthropic(
|
self.client = AsyncAnthropic(
|
||||||
api_key=config.api_key,
|
api_key=config.api_key,
|
||||||
# we'll use tenacity to retry
|
# we'll use tenacity to retry
|
||||||
|
|
|
||||||
|
|
@ -31,13 +31,17 @@ from .errors import RateLimitError
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
DEFAULT_MODEL = 'llama-3.1-70b-versatile'
|
DEFAULT_MODEL = 'llama-3.1-70b-versatile'
|
||||||
|
DEFAULT_MAX_TOKENS = 2048
|
||||||
|
|
||||||
|
|
||||||
class GroqClient(LLMClient):
|
class GroqClient(LLMClient):
|
||||||
def __init__(self, config: LLMConfig | None = None, cache: bool = False):
|
def __init__(self, config: LLMConfig | None = None, cache: bool = False):
|
||||||
if config is None:
|
if config is None:
|
||||||
config = LLMConfig()
|
config = LLMConfig(max_tokens=DEFAULT_MAX_TOKENS)
|
||||||
|
elif config.max_tokens is None:
|
||||||
|
config.max_tokens = DEFAULT_MAX_TOKENS
|
||||||
super().__init__(config, cache)
|
super().__init__(config, cache)
|
||||||
|
|
||||||
self.client = AsyncGroq(api_key=config.api_key)
|
self.client = AsyncGroq(api_key=config.api_key)
|
||||||
|
|
||||||
def get_embedder(self) -> typing.Any:
|
def get_embedder(self) -> typing.Any:
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue