From 113179f6741b4d441b9b764d3cc40461e542a947 Mon Sep 17 00:00:00 2001 From: Evan Schultz Date: Wed, 16 Apr 2025 13:35:07 -0600 Subject: [PATCH] Anthropic client (#361) * update Anthropic client to use tool calling and add tests * fix linting errors before creating pull request by making literal types for anthropic models --- CONTRIBUTING.md | 111 +++--- graphiti_core/llm_client/anthropic_client.py | 317 ++++++++++++++++-- graphiti_core/llm_client/errors.py | 8 + .../integrations/test_anthropic_client_int.py | 85 +++++ tests/llm_client/test_anthropic_client.py | 255 ++++++++++++++ tests/llm_client/test_errors.py | 76 +++++ 6 files changed, 770 insertions(+), 82 deletions(-) create mode 100644 tests/integrations/test_anthropic_client_int.py create mode 100644 tests/llm_client/test_anthropic_client.py create mode 100644 tests/llm_client/test_errors.py diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 2a64daa1..53ba8c36 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -9,104 +9,114 @@ We've restructured our contribution paths to solve this problem: # Four Ways to Get Involved ### Pick Up Existing Issues + Our developers regularly tag issues with "help wanted" and "good first issue." These are pre-vetted tasks with clear scope and someone ready to help you if you get stuck. ### Create Your Own Tickets + See something that needs fixing? Have an idea for an improvement? You don't need permission to identify problems. The people closest to the pain are often best positioned to describe the solution. For **feature requests**, tell us the story of what you're trying to accomplish. What are you working on? What's getting in your way? What would make your life easier? Submit these through our [GitHub issue tracker](https://github.com/getzep/graphiti/issues) with a "Feature Request" label. For **bug reports**, we need enough context to reproduce the problem. Use the [GitHub issue tracker](https://github.com/getzep/graphiti/issues) and include: -- A clear title that summarizes the specific problem -- What you were trying to do when you encountered the bug -- What you expected to happen -- What actually happened -- A code sample or test case that demonstrates the issue + +- A clear title that summarizes the specific problem +- What you were trying to do when you encountered the bug +- What you expected to happen +- What actually happened +- A code sample or test case that demonstrates the issue ### Share Your Use Cases + Sometimes the most valuable contribution isn't code. If you're using our project in an interesting way, add it to the [examples](https://github.com/getzep/graphiti/tree/main/examples) folder. This helps others discover new possibilities and counts as a meaningful contribution. We regularly feature compelling examples in our blog posts and videos - your work might be showcased to the broader community! ### Help Others in Discord -Join our [Discord server](https://discord.gg/2JbGZQZT) community and pitch in at the helpdesk. Answering questions and helping troubleshoot issues is an incredibly valuable contribution that benefits everyone. The knowledge you share today saves someone hours of frustration tomorrow. + +Join our [Discord server](https://discord.gg/2JbGZQZT) community and pitch in at the helpdesk. Answering questions and helping troubleshoot issues is an incredibly valuable contribution that benefits everyone. The knowledge you share today saves someone hours of frustration tomorrow. ## What happens next? + Once you've found an issue tagged with "good first issue" or "help wanted," or prepared an example to share, here's how to turn that into a contribution: -1. Share your approach in the issue discussion or [Discord](https://discord.gg/2JbGZQZT) before diving deep into code. This helps ensure your solution adheres to the architecture of Graphiti from the start and saves you from potential rework. +1. Share your approach in the issue discussion or [Discord](https://discord.gg/2JbGZQZT) before diving deep into code. This helps ensure your solution adheres to the architecture of Graphiti from the start and saves you from potential rework. 2. Fork the repo, make your changes in a branch, and submit a PR. We've included more detailed technical instructions below; be open to feedback during review. ## Setup + 1. Fork the repository on GitHub. 2. Clone your fork locally: - ``` - git clone https://github.com/getzep/graphiti - cd graphiti - ``` + ``` + git clone https://github.com/getzep/graphiti + cd graphiti + ``` 3. Set up your development environment: - - Ensure you have Python 3.10+ installed. - - Install Poetry: https://python-poetry.org/docs/#installation - - Install project dependencies: - ``` - make install - ``` - - To run integration tests, set the appropriate environment variables - ``` - export TEST_OPENAI_API_KEY=... - export TEST_OPENAI_MODEL=... - export NEO4J_URI=neo4j://... - export NEO4J_USER=... - export NEO4J_PASSWORD=... - ``` + - Ensure you have Python 3.10+ installed. + - Install Poetry: https://python-poetry.org/docs/#installation + - Install project dependencies: + ``` + make install + ``` + - To run integration tests, set the appropriate environment variables + + ``` + export TEST_OPENAI_API_KEY=... + export TEST_OPENAI_MODEL=... + export TEST_ANTHROPIC_API_KEY=... + + export NEO4J_URI=neo4j://... + export NEO4J_USER=... + export NEO4J_PASSWORD=... + ``` ## Making Changes 1. Create a new branch for your changes: - ``` - git checkout -b your-branch-name - ``` + ``` + git checkout -b your-branch-name + ``` 2. Make your changes in the codebase. 3. Write or update tests as necessary. 4. Run the tests to ensure they pass: - ``` - make test - ``` + ``` + make test + ``` 5. Format your code: - ``` - make format - ``` + ``` + make format + ``` 6. Run linting checks: - ``` - make lint - ``` + ``` + make lint + ``` ## Submitting Changes 1. Commit your changes: - ``` - git commit -m "Your detailed commit message" - ``` + ``` + git commit -m "Your detailed commit message" + ``` 2. Push to your fork: - ``` - git push origin your-branch-name - ``` + ``` + git push origin your-branch-name + ``` 3. Submit a pull request through the GitHub website to https://github.com/getzep/graphiti. ## Pull Request Guidelines -- Provide a clear title and description of your changes. -- Include any relevant issue numbers in the PR description. -- Ensure all tests pass and there are no linting errors. -- Update documentation if you're changing functionality. +- Provide a clear title and description of your changes. +- Include any relevant issue numbers in the PR description. +- Ensure all tests pass and there are no linting errors. +- Update documentation if you're changing functionality. ## Code Style and Quality We use several tools to maintain code quality: -- Ruff for linting and formatting -- Mypy for static type checking -- Pytest for testing +- Ruff for linting and formatting +- Mypy for static type checking +- Pytest for testing Before submitting a pull request, please run: @@ -117,6 +127,7 @@ make check This command will format your code, run linting checks, and execute tests. # Questions? + Stuck on a contribution or have a half-formed idea? Come say hello in our [Discord server](https://discord.gg/2JbGZQZT). Whether you're ready to contribute or just want to learn more, we're happy to have you! It's faster than GitHub issues and you'll find both maintainers and fellow contributors ready to help. Thank you for contributing to Graphiti! diff --git a/graphiti_core/llm_client/anthropic_client.py b/graphiti_core/llm_client/anthropic_client.py index c6f32c5d..efa837c8 100644 --- a/graphiti_core/llm_client/anthropic_client.py +++ b/graphiti_core/llm_client/anthropic_client.py @@ -16,64 +16,317 @@ limitations under the License. import json import logging +import os import typing +from json import JSONDecodeError +from typing import Literal import anthropic from anthropic import AsyncAnthropic -from pydantic import BaseModel +from anthropic.types import MessageParam, ToolChoiceParam, ToolUnionParam +from pydantic import BaseModel, ValidationError from ..prompts.models import Message from .client import LLMClient -from .config import LLMConfig -from .errors import RateLimitError +from .config import DEFAULT_MAX_TOKENS, LLMConfig +from .errors import RateLimitError, RefusalError logger = logging.getLogger(__name__) -DEFAULT_MODEL = 'claude-3-7-sonnet-latest' -DEFAULT_MAX_TOKENS = 8192 +AnthropicModel = Literal[ + 'claude-3-7-sonnet-latest', + 'claude-3-7-sonnet-20250219', + 'claude-3-5-haiku-latest', + 'claude-3-5-haiku-20241022', + 'claude-3-5-sonnet-latest', + 'claude-3-5-sonnet-20241022', + 'claude-3-5-sonnet-20240620', + 'claude-3-opus-latest', + 'claude-3-opus-20240229', + 'claude-3-sonnet-20240229', + 'claude-3-haiku-20240307', + 'claude-2.1', + 'claude-2.0', +] + +DEFAULT_MODEL: AnthropicModel = 'claude-3-7-sonnet-latest' class AnthropicClient(LLMClient): - def __init__(self, config: LLMConfig | None = None, cache: bool = False): - if config is None: - config = LLMConfig(max_tokens=DEFAULT_MAX_TOKENS) - elif config.max_tokens is None: - config.max_tokens = DEFAULT_MAX_TOKENS - super().__init__(config, cache) + """ + A client for the Anthropic LLM. - self.client = AsyncAnthropic( - api_key=config.api_key, - # we'll use tenacity to retry - max_retries=1, - ) + Args: + config: A configuration object for the LLM. + cache: Whether to cache the LLM responses. + client: An optional client instance to use. + max_tokens: The maximum number of tokens to generate. + + Methods: + generate_response: Generate a response from the LLM. + + Notes: + - If a LLMConfig is not provided, api_key will be pulled from the ANTHROPIC_API_KEY environment + variable, and all default values will be used for the LLMConfig. + + """ + + model: AnthropicModel + + def __init__( + self, + config: LLMConfig | None = None, + cache: bool = False, + client: AsyncAnthropic | None = None, + max_tokens: int = DEFAULT_MAX_TOKENS, + ) -> None: + if config is None: + config = LLMConfig() + config.api_key = os.getenv('ANTHROPIC_API_KEY') + config.max_tokens = max_tokens + + if config.model is None: + config.model = DEFAULT_MODEL + + super().__init__(config, cache) + # Explicitly set the instance model to the config model to prevent type checking errors + self.model = typing.cast(AnthropicModel, config.model) + + if not client: + self.client = AsyncAnthropic( + api_key=config.api_key, + max_retries=1, + ) + else: + self.client = client + + def _extract_json_from_text(self, text: str) -> dict[str, typing.Any]: + """Extract JSON from text content. + + A helper method to extract JSON from text content, used when tool use fails or + no response_model is provided. + + Args: + text: The text to extract JSON from + + Returns: + Extracted JSON as a dictionary + + Raises: + ValueError: If JSON cannot be extracted or parsed + """ + try: + json_start = text.find('{') + json_end = text.rfind('}') + 1 + if json_start >= 0 and json_end > json_start: + json_str = text[json_start:json_end] + return json.loads(json_str) + else: + raise ValueError(f'Could not extract JSON from model response: {text}') + except (JSONDecodeError, ValueError) as e: + raise ValueError(f'Could not extract JSON from model response: {text}') from e + + def _create_tool( + self, response_model: type[BaseModel] | None = None + ) -> tuple[list[ToolUnionParam], ToolChoiceParam]: + """ + Create a tool definition based on the response_model if provided, or a generic JSON tool if not. + + Args: + response_model: Optional Pydantic model to use for structured output. + + Returns: + A list containing a single tool definition for use with the Anthropic API. + """ + if response_model is not None: + # temporary debug log + logger.info(f'Creating tool for response_model: {response_model}') + # Use the response_model to define the tool + model_schema = response_model.model_json_schema() + tool_name = response_model.__name__ + description = model_schema.get('description', f'Extract {tool_name} information') + else: + # temporary debug log + logger.info('Creating generic JSON output tool') + # Create a generic JSON output tool + tool_name = 'generic_json_output' + description = 'Output data in JSON format' + model_schema = { + 'type': 'object', + 'additionalProperties': True, + 'description': 'Any JSON object containing the requested information', + } + + tool = { + 'name': tool_name, + 'description': description, + 'input_schema': model_schema, + } + tool_list = [tool] + tool_list_cast = typing.cast(list[ToolUnionParam], tool_list) + tool_choice = {'type': 'tool', 'name': tool_name} + tool_choice_cast = typing.cast(ToolChoiceParam, tool_choice) + return tool_list_cast, tool_choice_cast async def _generate_response( self, messages: list[Message], response_model: type[BaseModel] | None = None, - max_tokens: int = DEFAULT_MAX_TOKENS, + max_tokens: int | None = None, ) -> dict[str, typing.Any]: - system_message = messages[0] - user_messages = [{'role': m.role, 'content': m.content} for m in messages[1:]] + [ - {'role': 'assistant', 'content': '{'} - ] + """ + Generate a response from the Anthropic LLM using tool-based approach for all requests. - # Ensure max_tokens is not greater than config.max_tokens or DEFAULT_MAX_TOKENS - max_tokens = min(max_tokens, self.config.max_tokens, DEFAULT_MAX_TOKENS) + Args: + messages: List of message objects to send to the LLM. + response_model: Optional Pydantic model to use for structured output. + max_tokens: Maximum number of tokens to generate. + + Returns: + Dictionary containing the structured response from the LLM. + + Raises: + RateLimitError: If the rate limit is exceeded. + RefusalError: If the LLM refuses to respond. + Exception: If an error occurs during the generation process. + """ + system_message = messages[0] + user_messages = [{'role': m.role, 'content': m.content} for m in messages[1:]] + user_messages_cast = typing.cast(list[MessageParam], user_messages) + + # TODO: Replace hacky min finding solution after fixing hardcoded EXTRACT_EDGES_MAX_TOKENS = 16384 in + # edge_operations.py. Throws errors with cheaper models that lower max_tokens. + max_creation_tokens: int = min( + max_tokens if max_tokens is not None else self.config.max_tokens, + DEFAULT_MAX_TOKENS, + ) try: + # Create the appropriate tool based on whether response_model is provided + tools, tool_choice = self._create_tool(response_model) + # temporary debug log + logger.info(f'using model: {self.model} with max_tokens: {self.max_tokens}') result = await self.client.messages.create( - system='Only include JSON in the response. Do not include any additional text or explanation of the content.\n' - + system_message.content, - max_tokens=max_tokens, + system=system_message.content, + max_tokens=max_creation_tokens, temperature=self.temperature, - messages=user_messages, # type: ignore - model=self.model or DEFAULT_MODEL, + messages=user_messages_cast, + model=self.model, + tools=tools, + tool_choice=tool_choice, + ) + + # Extract the tool output from the response + for content_item in result.content: + if content_item.type == 'tool_use': + if isinstance(content_item.input, dict): + tool_args: dict[str, typing.Any] = content_item.input + else: + tool_args = json.loads(str(content_item.input)) + return tool_args + + # If we didn't get a proper tool_use response, try to extract from text + # logger.debug( + # f'Did not get a tool_use response, trying to extract json from text. Result: {result.content}' + # ) + # temporary debug log + logger.info( + f'Did not get a tool_use response, trying to extract json from text. Result: {result.content}' + ) + for content_item in result.content: + if content_item.type == 'text': + return self._extract_json_from_text(content_item.text) + else: + raise ValueError( + f'Could not extract structured data from model response: {result.content}' + ) + + # If we get here, we couldn't parse a structured response + raise ValueError( + f'Could not extract structured data from model response: {result.content}' ) - return json.loads('{' + result.content[0].text) # type: ignore except anthropic.RateLimitError as e: - raise RateLimitError from e + raise RateLimitError(f'Rate limit exceeded. Please try again later. Error: {e}') from e + except anthropic.APIError as e: + # Special case for content policy violations. We convert these to RefusalError + # to bypass the retry mechanism, as retrying policy-violating content will always fail. + # This avoids wasting API calls and provides more specific error messaging to the user. + if 'refused to respond' in str(e).lower(): + raise RefusalError(str(e)) from e + raise e except Exception as e: - logger.error(f'Error in generating LLM response: {e}') - raise + raise e + + async def generate_response( + self, + messages: list[Message], + response_model: type[BaseModel] | None = None, + max_tokens: int = DEFAULT_MAX_TOKENS, + ) -> dict[str, typing.Any]: + """ + Generate a response from the LLM. + + Args: + messages: List of message objects to send to the LLM. + response_model: Optional Pydantic model to use for structured output. + max_tokens: Maximum number of tokens to generate. + + Returns: + Dictionary containing the structured response from the LLM. + + Raises: + RateLimitError: If the rate limit is exceeded. + RefusalError: If the LLM refuses to respond. + Exception: If an error occurs during the generation process. + """ + retry_count = 0 + max_retries = 2 + last_error: Exception | None = None + + while retry_count <= max_retries: + try: + response = await self._generate_response(messages, response_model, max_tokens) + + # If we have a response_model, attempt to validate the response + if response_model is not None: + # Validate the response against the response_model + model_instance = response_model(**response) + return model_instance.model_dump() + + # If no validation needed, return the response + return response + + except (RateLimitError, RefusalError): + # These errors should not trigger retries + raise + except Exception as e: + last_error = e + + if retry_count >= max_retries: + if isinstance(e, ValidationError): + logger.error( + f'Validation error after {retry_count}/{max_retries} attempts: {e}' + ) + else: + logger.error(f'Max retries ({max_retries}) exceeded. Last error: {e}') + raise e + + if isinstance(e, ValidationError): + response_model_cast = typing.cast(type[BaseModel], response_model) + error_context = f'The previous response was invalid. Please provide a valid {response_model_cast.__name__} object. Error: {e}' + else: + error_context = ( + f'The previous response attempt was invalid. ' + f'Error type: {e.__class__.__name__}. ' + f'Error details: {str(e)}. ' + f'Please try again with a valid response.' + ) + + # Common retry logic + retry_count += 1 + messages.append(Message(role='user', content=error_context)) + logger.warning(f'Retrying after error (attempt {retry_count}/{max_retries}): {e}') + + # If we somehow get here, raise the last error + raise last_error or Exception('Max retries exceeded with no specific error') diff --git a/graphiti_core/llm_client/errors.py b/graphiti_core/llm_client/errors.py index cd8c22a1..362f62ad 100644 --- a/graphiti_core/llm_client/errors.py +++ b/graphiti_core/llm_client/errors.py @@ -29,3 +29,11 @@ class RefusalError(Exception): def __init__(self, message: str): self.message = message super().__init__(self.message) + + +class EmptyResponseError(Exception): + """Exception raised when the LLM returns an empty response.""" + + def __init__(self, message: str): + self.message = message + super().__init__(self.message) diff --git a/tests/integrations/test_anthropic_client_int.py b/tests/integrations/test_anthropic_client_int.py new file mode 100644 index 00000000..6b9dcdc5 --- /dev/null +++ b/tests/integrations/test_anthropic_client_int.py @@ -0,0 +1,85 @@ +""" +Copyright 2024, Zep Software, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +# Running tests: pytest -xvs tests/integrations/test_anthropic_client_int.py + +import os + +import pytest +from pydantic import BaseModel, Field + +from graphiti_core.llm_client.anthropic_client import AnthropicClient +from graphiti_core.prompts.models import Message + +# Skip all tests if no API key is available +pytestmark = pytest.mark.skipif( + 'TEST_ANTHROPIC_API_KEY' not in os.environ, + reason='Anthropic API key not available', +) + + +# Rename to avoid pytest collection as a test class +class SimpleResponseModel(BaseModel): + """Test response model.""" + + message: str = Field(..., description='A message from the model') + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_generate_simple_response(): + """Test generating a simple response from the Anthropic API.""" + if 'TEST_ANTHROPIC_API_KEY' not in os.environ: + pytest.skip('Anthropic API key not available') + + client = AnthropicClient() + + messages = [ + Message( + role='user', + content="Respond with a JSON object containing a 'message' field with value 'Hello, world!'", + ) + ] + + try: + response = await client.generate_response(messages, response_model=SimpleResponseModel) + + assert isinstance(response, dict) + assert 'message' in response + assert response['message'] == 'Hello, world!' + except Exception as e: + pytest.skip(f'Test skipped due to Anthropic API error: {str(e)}') + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_extract_json_from_text(): + """Test the extract_json_from_text method with real data.""" + # We don't need an actual API connection for this test, + # so we can create the client without worrying about the API key + with pytest.MonkeyPatch.context() as monkeypatch: + # Temporarily set an environment variable to avoid API key error + monkeypatch.setenv('ANTHROPIC_API_KEY', 'fake_key_for_testing') + client = AnthropicClient(cache=False) + + # A string with embedded JSON + text = 'Some text before {"message": "Hello, world!"} and after' + + result = client._extract_json_from_text(text) + + assert isinstance(result, dict) + assert 'message' in result + assert result['message'] == 'Hello, world!' diff --git a/tests/llm_client/test_anthropic_client.py b/tests/llm_client/test_anthropic_client.py new file mode 100644 index 00000000..e8ba82cd --- /dev/null +++ b/tests/llm_client/test_anthropic_client.py @@ -0,0 +1,255 @@ +""" +Copyright 2024, Zep Software, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +# Running tests: pytest -xvs tests/llm_client/test_anthropic_client.py + +import os +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from pydantic import BaseModel + +from graphiti_core.llm_client.anthropic_client import AnthropicClient +from graphiti_core.llm_client.config import LLMConfig +from graphiti_core.llm_client.errors import RateLimitError, RefusalError +from graphiti_core.prompts.models import Message + + +# Rename class to avoid pytest collection as a test class +class ResponseModel(BaseModel): + """Test model for response testing.""" + + test_field: str + optional_field: int = 0 + + +@pytest.fixture +def mock_async_anthropic(): + """Fixture to mock the AsyncAnthropic client.""" + with patch('anthropic.AsyncAnthropic') as mock_client: + # Setup mock instance and its create method + mock_instance = mock_client.return_value + mock_instance.messages.create = AsyncMock() + yield mock_instance + + +@pytest.fixture +def anthropic_client(mock_async_anthropic): + """Fixture to create an AnthropicClient with a mocked AsyncAnthropic.""" + # Use a context manager to patch the AsyncAnthropic constructor to avoid + # the client actually trying to create a real connection + with patch('anthropic.AsyncAnthropic', return_value=mock_async_anthropic): + config = LLMConfig( + api_key='test_api_key', model='test-model', temperature=0.5, max_tokens=1000 + ) + client = AnthropicClient(config=config, cache=False) + # Replace the client's client with our mock to ensure we're using the mock + client.client = mock_async_anthropic + return client + + +class TestAnthropicClientInitialization: + """Tests for AnthropicClient initialization.""" + + def test_init_with_config(self): + """Test initialization with a config object.""" + config = LLMConfig( + api_key='test_api_key', model='test-model', temperature=0.5, max_tokens=1000 + ) + client = AnthropicClient(config=config, cache=False) + + assert client.config == config + assert client.model == 'test-model' + assert client.temperature == 0.5 + assert client.max_tokens == 1000 + + def test_init_with_default_model(self): + """Test initialization with default model when none is provided.""" + config = LLMConfig(api_key='test_api_key') + client = AnthropicClient(config=config, cache=False) + + assert client.model == 'claude-3-7-sonnet-latest' + + @patch.dict(os.environ, {'ANTHROPIC_API_KEY': 'env_api_key'}) + def test_init_without_config(self): + """Test initialization without a config, using environment variable.""" + client = AnthropicClient(cache=False) + + assert client.config.api_key == 'env_api_key' + assert client.model == 'claude-3-7-sonnet-latest' + + def test_init_with_custom_client(self): + """Test initialization with a custom AsyncAnthropic client.""" + mock_client = MagicMock() + client = AnthropicClient(client=mock_client) + + assert client.client == mock_client + + +class TestAnthropicClientGenerateResponse: + """Tests for AnthropicClient generate_response method.""" + + @pytest.mark.asyncio + async def test_generate_response_with_tool_use(self, anthropic_client, mock_async_anthropic): + """Test successful response generation with tool use.""" + # Setup mock response + content_item = MagicMock() + content_item.type = 'tool_use' + content_item.input = {'test_field': 'test_value'} + + mock_response = MagicMock() + mock_response.content = [content_item] + mock_async_anthropic.messages.create.return_value = mock_response + + # Call method + messages = [ + Message(role='system', content='System message'), + Message(role='user', content='User message'), + ] + result = await anthropic_client.generate_response( + messages=messages, response_model=ResponseModel + ) + + # Assertions + assert isinstance(result, dict) + assert result['test_field'] == 'test_value' + mock_async_anthropic.messages.create.assert_called_once() + + @pytest.mark.asyncio + async def test_generate_response_with_text_response( + self, anthropic_client, mock_async_anthropic + ): + """Test response generation when getting text response instead of tool use.""" + # Setup mock response with text content + content_item = MagicMock() + content_item.type = 'text' + content_item.text = '{"test_field": "extracted_value"}' + + mock_response = MagicMock() + mock_response.content = [content_item] + mock_async_anthropic.messages.create.return_value = mock_response + + # Call method + messages = [ + Message(role='system', content='System message'), + Message(role='user', content='User message'), + ] + result = await anthropic_client.generate_response( + messages=messages, response_model=ResponseModel + ) + + # Assertions + assert isinstance(result, dict) + assert result['test_field'] == 'extracted_value' + + @pytest.mark.asyncio + async def test_rate_limit_error(self, anthropic_client, mock_async_anthropic): + """Test handling of rate limit errors.""" + + # Create a custom RateLimitError from Anthropic + class MockRateLimitError(Exception): + pass + + # Patch the Anthropic error with our mock to avoid constructor issues + with patch('anthropic.RateLimitError', MockRateLimitError): + # Setup mock to raise our mocked RateLimitError + mock_async_anthropic.messages.create.side_effect = MockRateLimitError( + 'Rate limit exceeded' + ) + + # Call method and check exception + messages = [Message(role='user', content='Test message')] + with pytest.raises(RateLimitError): + await anthropic_client.generate_response(messages) + + @pytest.mark.asyncio + async def test_refusal_error(self, anthropic_client, mock_async_anthropic): + """Test handling of content policy violations (refusal errors).""" + + # Create a custom APIError that matches what we need + class MockAPIError(Exception): + def __init__(self, message): + self.message = message + super().__init__(message) + + # Patch the Anthropic error with our mock + with patch('anthropic.APIError', MockAPIError): + # Setup mock to raise APIError with refusal message + mock_async_anthropic.messages.create.side_effect = MockAPIError('refused to respond') + + # Call method and check exception + messages = [Message(role='user', content='Test message')] + with pytest.raises(RefusalError): + await anthropic_client.generate_response(messages) + + @pytest.mark.asyncio + async def test_extract_json_from_text(self, anthropic_client): + """Test the _extract_json_from_text method.""" + # Valid JSON embedded in text + text = 'Some text before {"test_field": "value"} and after' + result = anthropic_client._extract_json_from_text(text) + assert result == {'test_field': 'value'} + + # Invalid JSON + with pytest.raises(ValueError): + anthropic_client._extract_json_from_text('Not JSON at all') + + @pytest.mark.asyncio + async def test_create_tool(self, anthropic_client): + """Test the _create_tool method with and without response model.""" + # With response model + tools, tool_choice = anthropic_client._create_tool(ResponseModel) + assert len(tools) == 1 + assert tools[0]['name'] == 'ResponseModel' + assert tool_choice['name'] == 'ResponseModel' + + # Without response model (generic JSON) + tools, tool_choice = anthropic_client._create_tool() + assert len(tools) == 1 + assert tools[0]['name'] == 'generic_json_output' + + @pytest.mark.asyncio + async def test_validation_error_retry(self, anthropic_client, mock_async_anthropic): + """Test retry behavior on validation error.""" + # First call returns invalid data, second call returns valid data + content_item1 = MagicMock() + content_item1.type = 'tool_use' + content_item1.input = {'wrong_field': 'wrong_value'} + + content_item2 = MagicMock() + content_item2.type = 'tool_use' + content_item2.input = {'test_field': 'correct_value'} + + # Setup mock to return different responses on consecutive calls + mock_response1 = MagicMock() + mock_response1.content = [content_item1] + + mock_response2 = MagicMock() + mock_response2.content = [content_item2] + + mock_async_anthropic.messages.create.side_effect = [mock_response1, mock_response2] + + # Call method + messages = [Message(role='user', content='Test message')] + result = await anthropic_client.generate_response(messages, response_model=ResponseModel) + + # Should have called create twice due to retry + assert mock_async_anthropic.messages.create.call_count == 2 + assert result['test_field'] == 'correct_value' + + +if __name__ == '__main__': + pytest.main(['-v', 'test_anthropic_client.py']) diff --git a/tests/llm_client/test_errors.py b/tests/llm_client/test_errors.py new file mode 100644 index 00000000..0cd12963 --- /dev/null +++ b/tests/llm_client/test_errors.py @@ -0,0 +1,76 @@ +""" +Copyright 2024, Zep Software, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +# Running tests: pytest -xvs tests/llm_client/test_errors.py + +import pytest + +from graphiti_core.llm_client.errors import EmptyResponseError, RateLimitError, RefusalError + + +class TestRateLimitError: + """Tests for the RateLimitError class.""" + + def test_default_message(self): + """Test that the default message is set correctly.""" + error = RateLimitError() + assert error.message == 'Rate limit exceeded. Please try again later.' + assert str(error) == 'Rate limit exceeded. Please try again later.' + + def test_custom_message(self): + """Test that a custom message can be set.""" + custom_message = 'Custom rate limit message' + error = RateLimitError(custom_message) + assert error.message == custom_message + assert str(error) == custom_message + + +class TestRefusalError: + """Tests for the RefusalError class.""" + + def test_message_required(self): + """Test that a message is required for RefusalError.""" + with pytest.raises(TypeError): + # Intentionally not providing the required message parameter + RefusalError() # type: ignore + + def test_message_assignment(self): + """Test that the message is assigned correctly.""" + message = 'The LLM refused to respond to this prompt.' + error = RefusalError(message=message) # Add explicit keyword argument + assert error.message == message + assert str(error) == message + + +class TestEmptyResponseError: + """Tests for the EmptyResponseError class.""" + + def test_message_required(self): + """Test that a message is required for EmptyResponseError.""" + with pytest.raises(TypeError): + # Intentionally not providing the required message parameter + EmptyResponseError() # type: ignore + + def test_message_assignment(self): + """Test that the message is assigned correctly.""" + message = 'The LLM returned an empty response.' + error = EmptyResponseError(message=message) # Add explicit keyword argument + assert error.message == message + assert str(error) == message + + +if __name__ == '__main__': + pytest.main(['-v', 'test_errors.py'])