Anthropic client (#361)

* update Anthropic client to use tool calling and add tests

* fix linting errors before creating pull request by making literal types for anthropic models
This commit is contained in:
Evan Schultz 2025-04-16 13:35:07 -06:00 committed by GitHub
parent aab53d6e73
commit 113179f674
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 770 additions and 82 deletions

View file

@ -9,104 +9,114 @@ We've restructured our contribution paths to solve this problem:
# Four Ways to Get Involved
### Pick Up Existing Issues
Our developers regularly tag issues with "help wanted" and "good first issue." These are pre-vetted tasks with clear scope and someone ready to help you if you get stuck.
### Create Your Own Tickets
See something that needs fixing? Have an idea for an improvement? You don't need permission to identify problems. The people closest to the pain are often best positioned to describe the solution.
For **feature requests**, tell us the story of what you're trying to accomplish. What are you working on? What's getting in your way? What would make your life easier? Submit these through our [GitHub issue tracker](https://github.com/getzep/graphiti/issues) with a "Feature Request" label.
For **bug reports**, we need enough context to reproduce the problem. Use the [GitHub issue tracker](https://github.com/getzep/graphiti/issues) and include:
- A clear title that summarizes the specific problem
- What you were trying to do when you encountered the bug
- What you expected to happen
- What actually happened
- A code sample or test case that demonstrates the issue
- A clear title that summarizes the specific problem
- What you were trying to do when you encountered the bug
- What you expected to happen
- What actually happened
- A code sample or test case that demonstrates the issue
### Share Your Use Cases
Sometimes the most valuable contribution isn't code. If you're using our project in an interesting way, add it to the [examples](https://github.com/getzep/graphiti/tree/main/examples) folder. This helps others discover new possibilities and counts as a meaningful contribution. We regularly feature compelling examples in our blog posts and videos - your work might be showcased to the broader community!
### Help Others in Discord
Join our [Discord server](https://discord.gg/2JbGZQZT) community and pitch in at the helpdesk. Answering questions and helping troubleshoot issues is an incredibly valuable contribution that benefits everyone. The knowledge you share today saves someone hours of frustration tomorrow.
Join our [Discord server](https://discord.gg/2JbGZQZT) community and pitch in at the helpdesk. Answering questions and helping troubleshoot issues is an incredibly valuable contribution that benefits everyone. The knowledge you share today saves someone hours of frustration tomorrow.
## What happens next?
Once you've found an issue tagged with "good first issue" or "help wanted," or prepared an example to share, here's how to turn that into a contribution:
1. Share your approach in the issue discussion or [Discord](https://discord.gg/2JbGZQZT) before diving deep into code. This helps ensure your solution adheres to the architecture of Graphiti from the start and saves you from potential rework.
1. Share your approach in the issue discussion or [Discord](https://discord.gg/2JbGZQZT) before diving deep into code. This helps ensure your solution adheres to the architecture of Graphiti from the start and saves you from potential rework.
2. Fork the repo, make your changes in a branch, and submit a PR. We've included more detailed technical instructions below; be open to feedback during review.
## Setup
1. Fork the repository on GitHub.
2. Clone your fork locally:
```
git clone https://github.com/getzep/graphiti
cd graphiti
```
```
git clone https://github.com/getzep/graphiti
cd graphiti
```
3. Set up your development environment:
- Ensure you have Python 3.10+ installed.
- Install Poetry: https://python-poetry.org/docs/#installation
- Install project dependencies:
```
make install
```
- To run integration tests, set the appropriate environment variables
```
export TEST_OPENAI_API_KEY=...
export TEST_OPENAI_MODEL=...
export NEO4J_URI=neo4j://...
export NEO4J_USER=...
export NEO4J_PASSWORD=...
```
- Ensure you have Python 3.10+ installed.
- Install Poetry: https://python-poetry.org/docs/#installation
- Install project dependencies:
```
make install
```
- To run integration tests, set the appropriate environment variables
```
export TEST_OPENAI_API_KEY=...
export TEST_OPENAI_MODEL=...
export TEST_ANTHROPIC_API_KEY=...
export NEO4J_URI=neo4j://...
export NEO4J_USER=...
export NEO4J_PASSWORD=...
```
## Making Changes
1. Create a new branch for your changes:
```
git checkout -b your-branch-name
```
```
git checkout -b your-branch-name
```
2. Make your changes in the codebase.
3. Write or update tests as necessary.
4. Run the tests to ensure they pass:
```
make test
```
```
make test
```
5. Format your code:
```
make format
```
```
make format
```
6. Run linting checks:
```
make lint
```
```
make lint
```
## Submitting Changes
1. Commit your changes:
```
git commit -m "Your detailed commit message"
```
```
git commit -m "Your detailed commit message"
```
2. Push to your fork:
```
git push origin your-branch-name
```
```
git push origin your-branch-name
```
3. Submit a pull request through the GitHub website to https://github.com/getzep/graphiti.
## Pull Request Guidelines
- Provide a clear title and description of your changes.
- Include any relevant issue numbers in the PR description.
- Ensure all tests pass and there are no linting errors.
- Update documentation if you're changing functionality.
- Provide a clear title and description of your changes.
- Include any relevant issue numbers in the PR description.
- Ensure all tests pass and there are no linting errors.
- Update documentation if you're changing functionality.
## Code Style and Quality
We use several tools to maintain code quality:
- Ruff for linting and formatting
- Mypy for static type checking
- Pytest for testing
- Ruff for linting and formatting
- Mypy for static type checking
- Pytest for testing
Before submitting a pull request, please run:
@ -117,6 +127,7 @@ make check
This command will format your code, run linting checks, and execute tests.
# Questions?
Stuck on a contribution or have a half-formed idea? Come say hello in our [Discord server](https://discord.gg/2JbGZQZT). Whether you're ready to contribute or just want to learn more, we're happy to have you! It's faster than GitHub issues and you'll find both maintainers and fellow contributors ready to help.
Thank you for contributing to Graphiti!

View file

@ -16,64 +16,317 @@ limitations under the License.
import json
import logging
import os
import typing
from json import JSONDecodeError
from typing import Literal
import anthropic
from anthropic import AsyncAnthropic
from pydantic import BaseModel
from anthropic.types import MessageParam, ToolChoiceParam, ToolUnionParam
from pydantic import BaseModel, ValidationError
from ..prompts.models import Message
from .client import LLMClient
from .config import LLMConfig
from .errors import RateLimitError
from .config import DEFAULT_MAX_TOKENS, LLMConfig
from .errors import RateLimitError, RefusalError
logger = logging.getLogger(__name__)
DEFAULT_MODEL = 'claude-3-7-sonnet-latest'
DEFAULT_MAX_TOKENS = 8192
AnthropicModel = Literal[
'claude-3-7-sonnet-latest',
'claude-3-7-sonnet-20250219',
'claude-3-5-haiku-latest',
'claude-3-5-haiku-20241022',
'claude-3-5-sonnet-latest',
'claude-3-5-sonnet-20241022',
'claude-3-5-sonnet-20240620',
'claude-3-opus-latest',
'claude-3-opus-20240229',
'claude-3-sonnet-20240229',
'claude-3-haiku-20240307',
'claude-2.1',
'claude-2.0',
]
DEFAULT_MODEL: AnthropicModel = 'claude-3-7-sonnet-latest'
class AnthropicClient(LLMClient):
def __init__(self, config: LLMConfig | None = None, cache: bool = False):
if config is None:
config = LLMConfig(max_tokens=DEFAULT_MAX_TOKENS)
elif config.max_tokens is None:
config.max_tokens = DEFAULT_MAX_TOKENS
super().__init__(config, cache)
"""
A client for the Anthropic LLM.
self.client = AsyncAnthropic(
api_key=config.api_key,
# we'll use tenacity to retry
max_retries=1,
)
Args:
config: A configuration object for the LLM.
cache: Whether to cache the LLM responses.
client: An optional client instance to use.
max_tokens: The maximum number of tokens to generate.
Methods:
generate_response: Generate a response from the LLM.
Notes:
- If a LLMConfig is not provided, api_key will be pulled from the ANTHROPIC_API_KEY environment
variable, and all default values will be used for the LLMConfig.
"""
model: AnthropicModel
def __init__(
self,
config: LLMConfig | None = None,
cache: bool = False,
client: AsyncAnthropic | None = None,
max_tokens: int = DEFAULT_MAX_TOKENS,
) -> None:
if config is None:
config = LLMConfig()
config.api_key = os.getenv('ANTHROPIC_API_KEY')
config.max_tokens = max_tokens
if config.model is None:
config.model = DEFAULT_MODEL
super().__init__(config, cache)
# Explicitly set the instance model to the config model to prevent type checking errors
self.model = typing.cast(AnthropicModel, config.model)
if not client:
self.client = AsyncAnthropic(
api_key=config.api_key,
max_retries=1,
)
else:
self.client = client
def _extract_json_from_text(self, text: str) -> dict[str, typing.Any]:
"""Extract JSON from text content.
A helper method to extract JSON from text content, used when tool use fails or
no response_model is provided.
Args:
text: The text to extract JSON from
Returns:
Extracted JSON as a dictionary
Raises:
ValueError: If JSON cannot be extracted or parsed
"""
try:
json_start = text.find('{')
json_end = text.rfind('}') + 1
if json_start >= 0 and json_end > json_start:
json_str = text[json_start:json_end]
return json.loads(json_str)
else:
raise ValueError(f'Could not extract JSON from model response: {text}')
except (JSONDecodeError, ValueError) as e:
raise ValueError(f'Could not extract JSON from model response: {text}') from e
def _create_tool(
self, response_model: type[BaseModel] | None = None
) -> tuple[list[ToolUnionParam], ToolChoiceParam]:
"""
Create a tool definition based on the response_model if provided, or a generic JSON tool if not.
Args:
response_model: Optional Pydantic model to use for structured output.
Returns:
A list containing a single tool definition for use with the Anthropic API.
"""
if response_model is not None:
# temporary debug log
logger.info(f'Creating tool for response_model: {response_model}')
# Use the response_model to define the tool
model_schema = response_model.model_json_schema()
tool_name = response_model.__name__
description = model_schema.get('description', f'Extract {tool_name} information')
else:
# temporary debug log
logger.info('Creating generic JSON output tool')
# Create a generic JSON output tool
tool_name = 'generic_json_output'
description = 'Output data in JSON format'
model_schema = {
'type': 'object',
'additionalProperties': True,
'description': 'Any JSON object containing the requested information',
}
tool = {
'name': tool_name,
'description': description,
'input_schema': model_schema,
}
tool_list = [tool]
tool_list_cast = typing.cast(list[ToolUnionParam], tool_list)
tool_choice = {'type': 'tool', 'name': tool_name}
tool_choice_cast = typing.cast(ToolChoiceParam, tool_choice)
return tool_list_cast, tool_choice_cast
async def _generate_response(
self,
messages: list[Message],
response_model: type[BaseModel] | None = None,
max_tokens: int = DEFAULT_MAX_TOKENS,
max_tokens: int | None = None,
) -> dict[str, typing.Any]:
system_message = messages[0]
user_messages = [{'role': m.role, 'content': m.content} for m in messages[1:]] + [
{'role': 'assistant', 'content': '{'}
]
"""
Generate a response from the Anthropic LLM using tool-based approach for all requests.
# Ensure max_tokens is not greater than config.max_tokens or DEFAULT_MAX_TOKENS
max_tokens = min(max_tokens, self.config.max_tokens, DEFAULT_MAX_TOKENS)
Args:
messages: List of message objects to send to the LLM.
response_model: Optional Pydantic model to use for structured output.
max_tokens: Maximum number of tokens to generate.
Returns:
Dictionary containing the structured response from the LLM.
Raises:
RateLimitError: If the rate limit is exceeded.
RefusalError: If the LLM refuses to respond.
Exception: If an error occurs during the generation process.
"""
system_message = messages[0]
user_messages = [{'role': m.role, 'content': m.content} for m in messages[1:]]
user_messages_cast = typing.cast(list[MessageParam], user_messages)
# TODO: Replace hacky min finding solution after fixing hardcoded EXTRACT_EDGES_MAX_TOKENS = 16384 in
# edge_operations.py. Throws errors with cheaper models that lower max_tokens.
max_creation_tokens: int = min(
max_tokens if max_tokens is not None else self.config.max_tokens,
DEFAULT_MAX_TOKENS,
)
try:
# Create the appropriate tool based on whether response_model is provided
tools, tool_choice = self._create_tool(response_model)
# temporary debug log
logger.info(f'using model: {self.model} with max_tokens: {self.max_tokens}')
result = await self.client.messages.create(
system='Only include JSON in the response. Do not include any additional text or explanation of the content.\n'
+ system_message.content,
max_tokens=max_tokens,
system=system_message.content,
max_tokens=max_creation_tokens,
temperature=self.temperature,
messages=user_messages, # type: ignore
model=self.model or DEFAULT_MODEL,
messages=user_messages_cast,
model=self.model,
tools=tools,
tool_choice=tool_choice,
)
# Extract the tool output from the response
for content_item in result.content:
if content_item.type == 'tool_use':
if isinstance(content_item.input, dict):
tool_args: dict[str, typing.Any] = content_item.input
else:
tool_args = json.loads(str(content_item.input))
return tool_args
# If we didn't get a proper tool_use response, try to extract from text
# logger.debug(
# f'Did not get a tool_use response, trying to extract json from text. Result: {result.content}'
# )
# temporary debug log
logger.info(
f'Did not get a tool_use response, trying to extract json from text. Result: {result.content}'
)
for content_item in result.content:
if content_item.type == 'text':
return self._extract_json_from_text(content_item.text)
else:
raise ValueError(
f'Could not extract structured data from model response: {result.content}'
)
# If we get here, we couldn't parse a structured response
raise ValueError(
f'Could not extract structured data from model response: {result.content}'
)
return json.loads('{' + result.content[0].text) # type: ignore
except anthropic.RateLimitError as e:
raise RateLimitError from e
raise RateLimitError(f'Rate limit exceeded. Please try again later. Error: {e}') from e
except anthropic.APIError as e:
# Special case for content policy violations. We convert these to RefusalError
# to bypass the retry mechanism, as retrying policy-violating content will always fail.
# This avoids wasting API calls and provides more specific error messaging to the user.
if 'refused to respond' in str(e).lower():
raise RefusalError(str(e)) from e
raise e
except Exception as e:
logger.error(f'Error in generating LLM response: {e}')
raise
raise e
async def generate_response(
self,
messages: list[Message],
response_model: type[BaseModel] | None = None,
max_tokens: int = DEFAULT_MAX_TOKENS,
) -> dict[str, typing.Any]:
"""
Generate a response from the LLM.
Args:
messages: List of message objects to send to the LLM.
response_model: Optional Pydantic model to use for structured output.
max_tokens: Maximum number of tokens to generate.
Returns:
Dictionary containing the structured response from the LLM.
Raises:
RateLimitError: If the rate limit is exceeded.
RefusalError: If the LLM refuses to respond.
Exception: If an error occurs during the generation process.
"""
retry_count = 0
max_retries = 2
last_error: Exception | None = None
while retry_count <= max_retries:
try:
response = await self._generate_response(messages, response_model, max_tokens)
# If we have a response_model, attempt to validate the response
if response_model is not None:
# Validate the response against the response_model
model_instance = response_model(**response)
return model_instance.model_dump()
# If no validation needed, return the response
return response
except (RateLimitError, RefusalError):
# These errors should not trigger retries
raise
except Exception as e:
last_error = e
if retry_count >= max_retries:
if isinstance(e, ValidationError):
logger.error(
f'Validation error after {retry_count}/{max_retries} attempts: {e}'
)
else:
logger.error(f'Max retries ({max_retries}) exceeded. Last error: {e}')
raise e
if isinstance(e, ValidationError):
response_model_cast = typing.cast(type[BaseModel], response_model)
error_context = f'The previous response was invalid. Please provide a valid {response_model_cast.__name__} object. Error: {e}'
else:
error_context = (
f'The previous response attempt was invalid. '
f'Error type: {e.__class__.__name__}. '
f'Error details: {str(e)}. '
f'Please try again with a valid response.'
)
# Common retry logic
retry_count += 1
messages.append(Message(role='user', content=error_context))
logger.warning(f'Retrying after error (attempt {retry_count}/{max_retries}): {e}')
# If we somehow get here, raise the last error
raise last_error or Exception('Max retries exceeded with no specific error')

View file

@ -29,3 +29,11 @@ class RefusalError(Exception):
def __init__(self, message: str):
self.message = message
super().__init__(self.message)
class EmptyResponseError(Exception):
"""Exception raised when the LLM returns an empty response."""
def __init__(self, message: str):
self.message = message
super().__init__(self.message)

View file

@ -0,0 +1,85 @@
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
# Running tests: pytest -xvs tests/integrations/test_anthropic_client_int.py
import os
import pytest
from pydantic import BaseModel, Field
from graphiti_core.llm_client.anthropic_client import AnthropicClient
from graphiti_core.prompts.models import Message
# Skip all tests if no API key is available
pytestmark = pytest.mark.skipif(
'TEST_ANTHROPIC_API_KEY' not in os.environ,
reason='Anthropic API key not available',
)
# Rename to avoid pytest collection as a test class
class SimpleResponseModel(BaseModel):
"""Test response model."""
message: str = Field(..., description='A message from the model')
@pytest.mark.asyncio
@pytest.mark.integration
async def test_generate_simple_response():
"""Test generating a simple response from the Anthropic API."""
if 'TEST_ANTHROPIC_API_KEY' not in os.environ:
pytest.skip('Anthropic API key not available')
client = AnthropicClient()
messages = [
Message(
role='user',
content="Respond with a JSON object containing a 'message' field with value 'Hello, world!'",
)
]
try:
response = await client.generate_response(messages, response_model=SimpleResponseModel)
assert isinstance(response, dict)
assert 'message' in response
assert response['message'] == 'Hello, world!'
except Exception as e:
pytest.skip(f'Test skipped due to Anthropic API error: {str(e)}')
@pytest.mark.asyncio
@pytest.mark.integration
async def test_extract_json_from_text():
"""Test the extract_json_from_text method with real data."""
# We don't need an actual API connection for this test,
# so we can create the client without worrying about the API key
with pytest.MonkeyPatch.context() as monkeypatch:
# Temporarily set an environment variable to avoid API key error
monkeypatch.setenv('ANTHROPIC_API_KEY', 'fake_key_for_testing')
client = AnthropicClient(cache=False)
# A string with embedded JSON
text = 'Some text before {"message": "Hello, world!"} and after'
result = client._extract_json_from_text(text)
assert isinstance(result, dict)
assert 'message' in result
assert result['message'] == 'Hello, world!'

View file

@ -0,0 +1,255 @@
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
# Running tests: pytest -xvs tests/llm_client/test_anthropic_client.py
import os
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from pydantic import BaseModel
from graphiti_core.llm_client.anthropic_client import AnthropicClient
from graphiti_core.llm_client.config import LLMConfig
from graphiti_core.llm_client.errors import RateLimitError, RefusalError
from graphiti_core.prompts.models import Message
# Rename class to avoid pytest collection as a test class
class ResponseModel(BaseModel):
"""Test model for response testing."""
test_field: str
optional_field: int = 0
@pytest.fixture
def mock_async_anthropic():
"""Fixture to mock the AsyncAnthropic client."""
with patch('anthropic.AsyncAnthropic') as mock_client:
# Setup mock instance and its create method
mock_instance = mock_client.return_value
mock_instance.messages.create = AsyncMock()
yield mock_instance
@pytest.fixture
def anthropic_client(mock_async_anthropic):
"""Fixture to create an AnthropicClient with a mocked AsyncAnthropic."""
# Use a context manager to patch the AsyncAnthropic constructor to avoid
# the client actually trying to create a real connection
with patch('anthropic.AsyncAnthropic', return_value=mock_async_anthropic):
config = LLMConfig(
api_key='test_api_key', model='test-model', temperature=0.5, max_tokens=1000
)
client = AnthropicClient(config=config, cache=False)
# Replace the client's client with our mock to ensure we're using the mock
client.client = mock_async_anthropic
return client
class TestAnthropicClientInitialization:
"""Tests for AnthropicClient initialization."""
def test_init_with_config(self):
"""Test initialization with a config object."""
config = LLMConfig(
api_key='test_api_key', model='test-model', temperature=0.5, max_tokens=1000
)
client = AnthropicClient(config=config, cache=False)
assert client.config == config
assert client.model == 'test-model'
assert client.temperature == 0.5
assert client.max_tokens == 1000
def test_init_with_default_model(self):
"""Test initialization with default model when none is provided."""
config = LLMConfig(api_key='test_api_key')
client = AnthropicClient(config=config, cache=False)
assert client.model == 'claude-3-7-sonnet-latest'
@patch.dict(os.environ, {'ANTHROPIC_API_KEY': 'env_api_key'})
def test_init_without_config(self):
"""Test initialization without a config, using environment variable."""
client = AnthropicClient(cache=False)
assert client.config.api_key == 'env_api_key'
assert client.model == 'claude-3-7-sonnet-latest'
def test_init_with_custom_client(self):
"""Test initialization with a custom AsyncAnthropic client."""
mock_client = MagicMock()
client = AnthropicClient(client=mock_client)
assert client.client == mock_client
class TestAnthropicClientGenerateResponse:
"""Tests for AnthropicClient generate_response method."""
@pytest.mark.asyncio
async def test_generate_response_with_tool_use(self, anthropic_client, mock_async_anthropic):
"""Test successful response generation with tool use."""
# Setup mock response
content_item = MagicMock()
content_item.type = 'tool_use'
content_item.input = {'test_field': 'test_value'}
mock_response = MagicMock()
mock_response.content = [content_item]
mock_async_anthropic.messages.create.return_value = mock_response
# Call method
messages = [
Message(role='system', content='System message'),
Message(role='user', content='User message'),
]
result = await anthropic_client.generate_response(
messages=messages, response_model=ResponseModel
)
# Assertions
assert isinstance(result, dict)
assert result['test_field'] == 'test_value'
mock_async_anthropic.messages.create.assert_called_once()
@pytest.mark.asyncio
async def test_generate_response_with_text_response(
self, anthropic_client, mock_async_anthropic
):
"""Test response generation when getting text response instead of tool use."""
# Setup mock response with text content
content_item = MagicMock()
content_item.type = 'text'
content_item.text = '{"test_field": "extracted_value"}'
mock_response = MagicMock()
mock_response.content = [content_item]
mock_async_anthropic.messages.create.return_value = mock_response
# Call method
messages = [
Message(role='system', content='System message'),
Message(role='user', content='User message'),
]
result = await anthropic_client.generate_response(
messages=messages, response_model=ResponseModel
)
# Assertions
assert isinstance(result, dict)
assert result['test_field'] == 'extracted_value'
@pytest.mark.asyncio
async def test_rate_limit_error(self, anthropic_client, mock_async_anthropic):
"""Test handling of rate limit errors."""
# Create a custom RateLimitError from Anthropic
class MockRateLimitError(Exception):
pass
# Patch the Anthropic error with our mock to avoid constructor issues
with patch('anthropic.RateLimitError', MockRateLimitError):
# Setup mock to raise our mocked RateLimitError
mock_async_anthropic.messages.create.side_effect = MockRateLimitError(
'Rate limit exceeded'
)
# Call method and check exception
messages = [Message(role='user', content='Test message')]
with pytest.raises(RateLimitError):
await anthropic_client.generate_response(messages)
@pytest.mark.asyncio
async def test_refusal_error(self, anthropic_client, mock_async_anthropic):
"""Test handling of content policy violations (refusal errors)."""
# Create a custom APIError that matches what we need
class MockAPIError(Exception):
def __init__(self, message):
self.message = message
super().__init__(message)
# Patch the Anthropic error with our mock
with patch('anthropic.APIError', MockAPIError):
# Setup mock to raise APIError with refusal message
mock_async_anthropic.messages.create.side_effect = MockAPIError('refused to respond')
# Call method and check exception
messages = [Message(role='user', content='Test message')]
with pytest.raises(RefusalError):
await anthropic_client.generate_response(messages)
@pytest.mark.asyncio
async def test_extract_json_from_text(self, anthropic_client):
"""Test the _extract_json_from_text method."""
# Valid JSON embedded in text
text = 'Some text before {"test_field": "value"} and after'
result = anthropic_client._extract_json_from_text(text)
assert result == {'test_field': 'value'}
# Invalid JSON
with pytest.raises(ValueError):
anthropic_client._extract_json_from_text('Not JSON at all')
@pytest.mark.asyncio
async def test_create_tool(self, anthropic_client):
"""Test the _create_tool method with and without response model."""
# With response model
tools, tool_choice = anthropic_client._create_tool(ResponseModel)
assert len(tools) == 1
assert tools[0]['name'] == 'ResponseModel'
assert tool_choice['name'] == 'ResponseModel'
# Without response model (generic JSON)
tools, tool_choice = anthropic_client._create_tool()
assert len(tools) == 1
assert tools[0]['name'] == 'generic_json_output'
@pytest.mark.asyncio
async def test_validation_error_retry(self, anthropic_client, mock_async_anthropic):
"""Test retry behavior on validation error."""
# First call returns invalid data, second call returns valid data
content_item1 = MagicMock()
content_item1.type = 'tool_use'
content_item1.input = {'wrong_field': 'wrong_value'}
content_item2 = MagicMock()
content_item2.type = 'tool_use'
content_item2.input = {'test_field': 'correct_value'}
# Setup mock to return different responses on consecutive calls
mock_response1 = MagicMock()
mock_response1.content = [content_item1]
mock_response2 = MagicMock()
mock_response2.content = [content_item2]
mock_async_anthropic.messages.create.side_effect = [mock_response1, mock_response2]
# Call method
messages = [Message(role='user', content='Test message')]
result = await anthropic_client.generate_response(messages, response_model=ResponseModel)
# Should have called create twice due to retry
assert mock_async_anthropic.messages.create.call_count == 2
assert result['test_field'] == 'correct_value'
if __name__ == '__main__':
pytest.main(['-v', 'test_anthropic_client.py'])

View file

@ -0,0 +1,76 @@
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
# Running tests: pytest -xvs tests/llm_client/test_errors.py
import pytest
from graphiti_core.llm_client.errors import EmptyResponseError, RateLimitError, RefusalError
class TestRateLimitError:
"""Tests for the RateLimitError class."""
def test_default_message(self):
"""Test that the default message is set correctly."""
error = RateLimitError()
assert error.message == 'Rate limit exceeded. Please try again later.'
assert str(error) == 'Rate limit exceeded. Please try again later.'
def test_custom_message(self):
"""Test that a custom message can be set."""
custom_message = 'Custom rate limit message'
error = RateLimitError(custom_message)
assert error.message == custom_message
assert str(error) == custom_message
class TestRefusalError:
"""Tests for the RefusalError class."""
def test_message_required(self):
"""Test that a message is required for RefusalError."""
with pytest.raises(TypeError):
# Intentionally not providing the required message parameter
RefusalError() # type: ignore
def test_message_assignment(self):
"""Test that the message is assigned correctly."""
message = 'The LLM refused to respond to this prompt.'
error = RefusalError(message=message) # Add explicit keyword argument
assert error.message == message
assert str(error) == message
class TestEmptyResponseError:
"""Tests for the EmptyResponseError class."""
def test_message_required(self):
"""Test that a message is required for EmptyResponseError."""
with pytest.raises(TypeError):
# Intentionally not providing the required message parameter
EmptyResponseError() # type: ignore
def test_message_assignment(self):
"""Test that the message is assigned correctly."""
message = 'The LLM returned an empty response.'
error = EmptyResponseError(message=message) # Add explicit keyword argument
assert error.message == message
assert str(error) == message
if __name__ == '__main__':
pytest.main(['-v', 'test_errors.py'])