* Update default Anthropic model to claude-haiku-4-5-latest - Add Claude 4.5 models to AnthropicModel type (claude-sonnet-4-5-latest, claude-sonnet-4-5-20250929, claude-haiku-4-5-latest) - Change DEFAULT_MODEL from claude-3-7-sonnet-latest to claude-haiku-4-5-latest - Update test assertions to reflect new default model 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * Add Claude 4.5 models to max_tokens mapping - Add claude-sonnet-4-5-latest, claude-sonnet-4-5-20250929, and claude-haiku-4-5-latest to ANTHROPIC_MODEL_MAX_TOKENS - All Claude 4.5 models support 64K (65536) max output tokens - Based on official Anthropic documentation 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * Update uv.lock dependencies --------- Co-authored-by: Claude <noreply@anthropic.com>
429 lines
16 KiB
Python
429 lines
16 KiB
Python
"""
|
|
Copyright 2024, Zep Software, Inc.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
import os
|
|
import typing
|
|
from json import JSONDecodeError
|
|
from typing import TYPE_CHECKING, Literal
|
|
|
|
from pydantic import BaseModel, ValidationError
|
|
|
|
from ..prompts.models import Message
|
|
from .client import LLMClient
|
|
from .config import DEFAULT_MAX_TOKENS, LLMConfig, ModelSize
|
|
from .errors import RateLimitError, RefusalError
|
|
|
|
if TYPE_CHECKING:
|
|
import anthropic
|
|
from anthropic import AsyncAnthropic
|
|
from anthropic.types import MessageParam, ToolChoiceParam, ToolUnionParam
|
|
else:
|
|
try:
|
|
import anthropic
|
|
from anthropic import AsyncAnthropic
|
|
from anthropic.types import MessageParam, ToolChoiceParam, ToolUnionParam
|
|
except ImportError:
|
|
raise ImportError(
|
|
'anthropic is required for AnthropicClient. '
|
|
'Install it with: pip install graphiti-core[anthropic]'
|
|
) from None
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
AnthropicModel = Literal[
|
|
'claude-sonnet-4-5-latest',
|
|
'claude-sonnet-4-5-20250929',
|
|
'claude-haiku-4-5-latest',
|
|
'claude-3-7-sonnet-latest',
|
|
'claude-3-7-sonnet-20250219',
|
|
'claude-3-5-haiku-latest',
|
|
'claude-3-5-haiku-20241022',
|
|
'claude-3-5-sonnet-latest',
|
|
'claude-3-5-sonnet-20241022',
|
|
'claude-3-5-sonnet-20240620',
|
|
'claude-3-opus-latest',
|
|
'claude-3-opus-20240229',
|
|
'claude-3-sonnet-20240229',
|
|
'claude-3-haiku-20240307',
|
|
'claude-2.1',
|
|
'claude-2.0',
|
|
]
|
|
|
|
DEFAULT_MODEL: AnthropicModel = 'claude-haiku-4-5-latest'
|
|
|
|
# Maximum output tokens for different Anthropic models
|
|
# Based on official Anthropic documentation (as of 2025)
|
|
# Note: These represent standard limits without beta headers.
|
|
# Some models support higher limits with additional configuration (e.g., Claude 3.7 supports
|
|
# 128K with 'anthropic-beta: output-128k-2025-02-19' header, but this is not currently implemented).
|
|
ANTHROPIC_MODEL_MAX_TOKENS = {
|
|
# Claude 4.5 models - 64K tokens
|
|
'claude-sonnet-4-5-latest': 65536,
|
|
'claude-sonnet-4-5-20250929': 65536,
|
|
'claude-haiku-4-5-latest': 65536,
|
|
# Claude 3.7 models - standard 64K tokens
|
|
'claude-3-7-sonnet-latest': 65536,
|
|
'claude-3-7-sonnet-20250219': 65536,
|
|
# Claude 3.5 models
|
|
'claude-3-5-haiku-latest': 8192,
|
|
'claude-3-5-haiku-20241022': 8192,
|
|
'claude-3-5-sonnet-latest': 8192,
|
|
'claude-3-5-sonnet-20241022': 8192,
|
|
'claude-3-5-sonnet-20240620': 8192,
|
|
# Claude 3 models - 4K tokens
|
|
'claude-3-opus-latest': 4096,
|
|
'claude-3-opus-20240229': 4096,
|
|
'claude-3-sonnet-20240229': 4096,
|
|
'claude-3-haiku-20240307': 4096,
|
|
# Claude 2 models - 4K tokens
|
|
'claude-2.1': 4096,
|
|
'claude-2.0': 4096,
|
|
}
|
|
|
|
# Default max tokens for models not in the mapping
|
|
DEFAULT_ANTHROPIC_MAX_TOKENS = 8192
|
|
|
|
|
|
class AnthropicClient(LLMClient):
|
|
"""
|
|
A client for the Anthropic LLM.
|
|
|
|
Args:
|
|
config: A configuration object for the LLM.
|
|
cache: Whether to cache the LLM responses.
|
|
client: An optional client instance to use.
|
|
max_tokens: The maximum number of tokens to generate.
|
|
|
|
Methods:
|
|
generate_response: Generate a response from the LLM.
|
|
|
|
Notes:
|
|
- If a LLMConfig is not provided, api_key will be pulled from the ANTHROPIC_API_KEY environment
|
|
variable, and all default values will be used for the LLMConfig.
|
|
|
|
"""
|
|
|
|
model: AnthropicModel
|
|
|
|
def __init__(
|
|
self,
|
|
config: LLMConfig | None = None,
|
|
cache: bool = False,
|
|
client: AsyncAnthropic | None = None,
|
|
max_tokens: int = DEFAULT_MAX_TOKENS,
|
|
) -> None:
|
|
if config is None:
|
|
config = LLMConfig()
|
|
config.api_key = os.getenv('ANTHROPIC_API_KEY')
|
|
config.max_tokens = max_tokens
|
|
|
|
if config.model is None:
|
|
config.model = DEFAULT_MODEL
|
|
|
|
super().__init__(config, cache)
|
|
# Explicitly set the instance model to the config model to prevent type checking errors
|
|
self.model = typing.cast(AnthropicModel, config.model)
|
|
|
|
if not client:
|
|
self.client = AsyncAnthropic(
|
|
api_key=config.api_key,
|
|
max_retries=1,
|
|
)
|
|
else:
|
|
self.client = client
|
|
|
|
def _extract_json_from_text(self, text: str) -> dict[str, typing.Any]:
|
|
"""Extract JSON from text content.
|
|
|
|
A helper method to extract JSON from text content, used when tool use fails or
|
|
no response_model is provided.
|
|
|
|
Args:
|
|
text: The text to extract JSON from
|
|
|
|
Returns:
|
|
Extracted JSON as a dictionary
|
|
|
|
Raises:
|
|
ValueError: If JSON cannot be extracted or parsed
|
|
"""
|
|
try:
|
|
json_start = text.find('{')
|
|
json_end = text.rfind('}') + 1
|
|
if json_start >= 0 and json_end > json_start:
|
|
json_str = text[json_start:json_end]
|
|
return json.loads(json_str)
|
|
else:
|
|
raise ValueError(f'Could not extract JSON from model response: {text}')
|
|
except (JSONDecodeError, ValueError) as e:
|
|
raise ValueError(f'Could not extract JSON from model response: {text}') from e
|
|
|
|
def _create_tool(
|
|
self, response_model: type[BaseModel] | None = None
|
|
) -> tuple[list[ToolUnionParam], ToolChoiceParam]:
|
|
"""
|
|
Create a tool definition based on the response_model if provided, or a generic JSON tool if not.
|
|
|
|
Args:
|
|
response_model: Optional Pydantic model to use for structured output.
|
|
|
|
Returns:
|
|
A list containing a single tool definition for use with the Anthropic API.
|
|
"""
|
|
if response_model is not None:
|
|
# Use the response_model to define the tool
|
|
model_schema = response_model.model_json_schema()
|
|
tool_name = response_model.__name__
|
|
description = model_schema.get('description', f'Extract {tool_name} information')
|
|
else:
|
|
# Create a generic JSON output tool
|
|
tool_name = 'generic_json_output'
|
|
description = 'Output data in JSON format'
|
|
model_schema = {
|
|
'type': 'object',
|
|
'additionalProperties': True,
|
|
'description': 'Any JSON object containing the requested information',
|
|
}
|
|
|
|
tool = {
|
|
'name': tool_name,
|
|
'description': description,
|
|
'input_schema': model_schema,
|
|
}
|
|
tool_list = [tool]
|
|
tool_list_cast = typing.cast(list[ToolUnionParam], tool_list)
|
|
tool_choice = {'type': 'tool', 'name': tool_name}
|
|
tool_choice_cast = typing.cast(ToolChoiceParam, tool_choice)
|
|
return tool_list_cast, tool_choice_cast
|
|
|
|
def _get_max_tokens_for_model(self, model: str) -> int:
|
|
"""Get the maximum output tokens for a specific Anthropic model.
|
|
|
|
Args:
|
|
model: The model name to look up
|
|
|
|
Returns:
|
|
int: The maximum output tokens for the model
|
|
"""
|
|
return ANTHROPIC_MODEL_MAX_TOKENS.get(model, DEFAULT_ANTHROPIC_MAX_TOKENS)
|
|
|
|
def _resolve_max_tokens(self, requested_max_tokens: int | None, model: str) -> int:
|
|
"""
|
|
Resolve the maximum output tokens to use based on precedence rules.
|
|
|
|
Precedence order (highest to lowest):
|
|
1. Explicit max_tokens parameter passed to generate_response()
|
|
2. Instance max_tokens set during client initialization
|
|
3. Model-specific maximum tokens from ANTHROPIC_MODEL_MAX_TOKENS mapping
|
|
4. DEFAULT_ANTHROPIC_MAX_TOKENS as final fallback
|
|
|
|
Args:
|
|
requested_max_tokens: The max_tokens parameter passed to generate_response()
|
|
model: The model name to look up model-specific limits
|
|
|
|
Returns:
|
|
int: The resolved maximum tokens to use
|
|
"""
|
|
# 1. Use explicit parameter if provided
|
|
if requested_max_tokens is not None:
|
|
return requested_max_tokens
|
|
|
|
# 2. Use instance max_tokens if set during initialization
|
|
if self.max_tokens is not None:
|
|
return self.max_tokens
|
|
|
|
# 3. Use model-specific maximum or return DEFAULT_ANTHROPIC_MAX_TOKENS
|
|
return self._get_max_tokens_for_model(model)
|
|
|
|
async def _generate_response(
|
|
self,
|
|
messages: list[Message],
|
|
response_model: type[BaseModel] | None = None,
|
|
max_tokens: int | None = None,
|
|
model_size: ModelSize = ModelSize.medium,
|
|
) -> dict[str, typing.Any]:
|
|
"""
|
|
Generate a response from the Anthropic LLM using tool-based approach for all requests.
|
|
|
|
Args:
|
|
messages: List of message objects to send to the LLM.
|
|
response_model: Optional Pydantic model to use for structured output.
|
|
max_tokens: Maximum number of tokens to generate.
|
|
|
|
Returns:
|
|
Dictionary containing the structured response from the LLM.
|
|
|
|
Raises:
|
|
RateLimitError: If the rate limit is exceeded.
|
|
RefusalError: If the LLM refuses to respond.
|
|
Exception: If an error occurs during the generation process.
|
|
"""
|
|
system_message = messages[0]
|
|
user_messages = [{'role': m.role, 'content': m.content} for m in messages[1:]]
|
|
user_messages_cast = typing.cast(list[MessageParam], user_messages)
|
|
|
|
# Resolve max_tokens dynamically based on the model's capabilities
|
|
# This allows different models to use their full output capacity
|
|
max_creation_tokens: int = self._resolve_max_tokens(max_tokens, self.model)
|
|
|
|
try:
|
|
# Create the appropriate tool based on whether response_model is provided
|
|
tools, tool_choice = self._create_tool(response_model)
|
|
result = await self.client.messages.create(
|
|
system=system_message.content,
|
|
max_tokens=max_creation_tokens,
|
|
temperature=self.temperature,
|
|
messages=user_messages_cast,
|
|
model=self.model,
|
|
tools=tools,
|
|
tool_choice=tool_choice,
|
|
)
|
|
|
|
# Extract the tool output from the response
|
|
for content_item in result.content:
|
|
if content_item.type == 'tool_use':
|
|
if isinstance(content_item.input, dict):
|
|
tool_args: dict[str, typing.Any] = content_item.input
|
|
else:
|
|
tool_args = json.loads(str(content_item.input))
|
|
return tool_args
|
|
|
|
# If we didn't get a proper tool_use response, try to extract from text
|
|
for content_item in result.content:
|
|
if content_item.type == 'text':
|
|
return self._extract_json_from_text(content_item.text)
|
|
else:
|
|
raise ValueError(
|
|
f'Could not extract structured data from model response: {result.content}'
|
|
)
|
|
|
|
# If we get here, we couldn't parse a structured response
|
|
raise ValueError(
|
|
f'Could not extract structured data from model response: {result.content}'
|
|
)
|
|
|
|
except anthropic.RateLimitError as e:
|
|
raise RateLimitError(f'Rate limit exceeded. Please try again later. Error: {e}') from e
|
|
except anthropic.APIError as e:
|
|
# Special case for content policy violations. We convert these to RefusalError
|
|
# to bypass the retry mechanism, as retrying policy-violating content will always fail.
|
|
# This avoids wasting API calls and provides more specific error messaging to the user.
|
|
if 'refused to respond' in str(e).lower():
|
|
raise RefusalError(str(e)) from e
|
|
raise e
|
|
except Exception as e:
|
|
raise e
|
|
|
|
async def generate_response(
|
|
self,
|
|
messages: list[Message],
|
|
response_model: type[BaseModel] | None = None,
|
|
max_tokens: int | None = None,
|
|
model_size: ModelSize = ModelSize.medium,
|
|
group_id: str | None = None,
|
|
prompt_name: str | None = None,
|
|
) -> dict[str, typing.Any]:
|
|
"""
|
|
Generate a response from the LLM.
|
|
|
|
Args:
|
|
messages: List of message objects to send to the LLM.
|
|
response_model: Optional Pydantic model to use for structured output.
|
|
max_tokens: Maximum number of tokens to generate.
|
|
|
|
Returns:
|
|
Dictionary containing the structured response from the LLM.
|
|
|
|
Raises:
|
|
RateLimitError: If the rate limit is exceeded.
|
|
RefusalError: If the LLM refuses to respond.
|
|
Exception: If an error occurs during the generation process.
|
|
"""
|
|
if max_tokens is None:
|
|
max_tokens = self.max_tokens
|
|
|
|
# Wrap entire operation in tracing span
|
|
with self.tracer.start_span('llm.generate') as span:
|
|
attributes = {
|
|
'llm.provider': 'anthropic',
|
|
'model.size': model_size.value,
|
|
'max_tokens': max_tokens,
|
|
}
|
|
if prompt_name:
|
|
attributes['prompt.name'] = prompt_name
|
|
span.add_attributes(attributes)
|
|
|
|
retry_count = 0
|
|
max_retries = 2
|
|
last_error: Exception | None = None
|
|
|
|
while retry_count <= max_retries:
|
|
try:
|
|
response = await self._generate_response(
|
|
messages, response_model, max_tokens, model_size
|
|
)
|
|
|
|
# If we have a response_model, attempt to validate the response
|
|
if response_model is not None:
|
|
# Validate the response against the response_model
|
|
model_instance = response_model(**response)
|
|
return model_instance.model_dump()
|
|
|
|
# If no validation needed, return the response
|
|
return response
|
|
|
|
except (RateLimitError, RefusalError):
|
|
# These errors should not trigger retries
|
|
span.set_status('error', str(last_error))
|
|
raise
|
|
except Exception as e:
|
|
last_error = e
|
|
|
|
if retry_count >= max_retries:
|
|
if isinstance(e, ValidationError):
|
|
logger.error(
|
|
f'Validation error after {retry_count}/{max_retries} attempts: {e}'
|
|
)
|
|
else:
|
|
logger.error(f'Max retries ({max_retries}) exceeded. Last error: {e}')
|
|
span.set_status('error', str(e))
|
|
span.record_exception(e)
|
|
raise e
|
|
|
|
if isinstance(e, ValidationError):
|
|
response_model_cast = typing.cast(type[BaseModel], response_model)
|
|
error_context = f'The previous response was invalid. Please provide a valid {response_model_cast.__name__} object. Error: {e}'
|
|
else:
|
|
error_context = (
|
|
f'The previous response attempt was invalid. '
|
|
f'Error type: {e.__class__.__name__}. '
|
|
f'Error details: {str(e)}. '
|
|
f'Please try again with a valid response.'
|
|
)
|
|
|
|
# Common retry logic
|
|
retry_count += 1
|
|
messages.append(Message(role='user', content=error_context))
|
|
logger.warning(
|
|
f'Retrying after error (attempt {retry_count}/{max_retries}): {e}'
|
|
)
|
|
|
|
# If we somehow get here, raise the last error
|
|
span.set_status('error', str(last_error))
|
|
raise last_error or Exception('Max retries exceeded with no specific error')
|