diff --git a/mcp_server/tests/README.md b/mcp_server/tests/README.md new file mode 100644 index 00000000..d798b9e8 --- /dev/null +++ b/mcp_server/tests/README.md @@ -0,0 +1,314 @@ +# Graphiti MCP Server Integration Tests + +This directory contains a comprehensive integration test suite for the Graphiti MCP Server using the official Python MCP SDK. + +## Overview + +The test suite is designed to thoroughly test all aspects of the Graphiti MCP server with special consideration for LLM inference latency and system performance. + +## Test Organization + +### Core Test Modules + +- **`test_comprehensive_integration.py`** - Main integration test suite covering all MCP tools +- **`test_async_operations.py`** - Tests for concurrent operations and async patterns +- **`test_stress_load.py`** - Stress testing and load testing scenarios +- **`test_fixtures.py`** - Shared fixtures and test utilities +- **`test_mcp_integration.py`** - Original MCP integration tests +- **`test_configuration.py`** - Configuration loading and validation tests + +### Test Categories + +Tests are organized with pytest markers: + +- `unit` - Fast unit tests without external dependencies +- `integration` - Tests requiring database and services +- `slow` - Long-running tests (stress/load tests) +- `requires_neo4j` - Tests requiring Neo4j +- `requires_falkordb` - Tests requiring FalkorDB +- `requires_kuzu` - Tests requiring KuzuDB +- `requires_openai` - Tests requiring OpenAI API key + +## Installation + +```bash +# Install test dependencies +uv add --dev pytest pytest-asyncio pytest-timeout pytest-xdist faker psutil + +# Install MCP SDK +uv add mcp +``` + +## Running Tests + +### Quick Start + +```bash +# Run smoke tests (quick validation) +python tests/run_tests.py smoke + +# Run integration tests with mock LLM +python tests/run_tests.py integration --mock-llm + +# Run all tests +python tests/run_tests.py all +``` + +### Test Runner Options + +```bash +python tests/run_tests.py [suite] [options] + +Suites: + unit - Unit tests only + integration - Integration tests + comprehensive - Comprehensive integration suite + async - Async operation tests + stress - Stress and load tests + smoke - Quick smoke tests + all - All tests + +Options: + --database - Database backend (neo4j, falkordb, kuzu) + --mock-llm - Use mock LLM for faster testing + --parallel N - Run tests in parallel with N workers + --coverage - Generate coverage report + --skip-slow - Skip slow tests + --timeout N - Test timeout in seconds + --check-only - Only check prerequisites +``` + +### Examples + +```bash +# Quick smoke test with KuzuDB +python tests/run_tests.py smoke --database kuzu + +# Full integration test with Neo4j +python tests/run_tests.py integration --database neo4j + +# Stress testing with parallel execution +python tests/run_tests.py stress --parallel 4 + +# Run with coverage +python tests/run_tests.py all --coverage + +# Check prerequisites only +python tests/run_tests.py all --check-only +``` + +## Test Coverage + +### Core Operations +- Server initialization and tool discovery +- Adding memories (text, JSON, message) +- Episode queue management +- Search operations (semantic, hybrid) +- Episode retrieval and deletion +- Entity and edge operations + +### Async Operations +- Concurrent operations +- Queue management +- Sequential processing within groups +- Parallel processing across groups + +### Performance Testing +- Latency measurement +- Throughput testing +- Batch processing +- Resource usage monitoring + +### Stress Testing +- Sustained load scenarios +- Spike load handling +- Memory leak detection +- Connection pool exhaustion +- Rate limit handling + +## Configuration + +### Environment Variables + +```bash +# Database configuration +export DATABASE_PROVIDER=kuzu # or neo4j, falkordb +export NEO4J_URI=bolt://localhost:7687 +export NEO4J_USER=neo4j +export NEO4J_PASSWORD=graphiti +export FALKORDB_URI=redis://localhost:6379 +export KUZU_PATH=./test_kuzu.db + +# LLM configuration +export OPENAI_API_KEY=your_key_here # or use --mock-llm + +# Test configuration +export TEST_MODE=true +export LOG_LEVEL=INFO +``` + +### pytest.ini Configuration + +The `pytest.ini` file configures: +- Test discovery patterns +- Async mode settings +- Test markers +- Timeout settings +- Output formatting + +## Test Fixtures + +### Data Generation + +The test suite includes comprehensive data generators: + +```python +from test_fixtures import TestDataGenerator + +# Generate test data +company = TestDataGenerator.generate_company_profile() +conversation = TestDataGenerator.generate_conversation() +document = TestDataGenerator.generate_technical_document() +``` + +### Test Client + +Simplified client creation: + +```python +from test_fixtures import graphiti_test_client + +async with graphiti_test_client(database="kuzu") as (session, group_id): + # Use session for testing + result = await session.call_tool('add_memory', {...}) +``` + +## Performance Considerations + +### LLM Latency Management + +The tests account for LLM inference latency through: + +1. **Configurable timeouts** - Different timeouts for different operations +2. **Mock LLM option** - Fast testing without API calls +3. **Intelligent polling** - Adaptive waiting for episode processing +4. **Batch operations** - Testing efficiency of batched requests + +### Resource Management + +- Memory leak detection +- Connection pool monitoring +- Resource usage tracking +- Graceful degradation testing + +## CI/CD Integration + +### GitHub Actions + +```yaml +name: MCP Integration Tests + +on: [push, pull_request] + +jobs: + test: + runs-on: ubuntu-latest + + services: + neo4j: + image: neo4j:5.26 + env: + NEO4J_AUTH: neo4j/graphiti + ports: + - 7687:7687 + + steps: + - uses: actions/checkout@v2 + + - name: Install dependencies + run: | + pip install uv + uv sync --extra dev + + - name: Run smoke tests + run: python tests/run_tests.py smoke --mock-llm + + - name: Run integration tests + run: python tests/run_tests.py integration --database neo4j + env: + OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} +``` + +## Troubleshooting + +### Common Issues + +1. **Database connection failures** + ```bash + # Check Neo4j + curl http://localhost:7474 + + # Check FalkorDB + redis-cli ping + ``` + +2. **API key issues** + ```bash + # Use mock LLM for testing without API key + python tests/run_tests.py all --mock-llm + ``` + +3. **Timeout errors** + ```bash + # Increase timeout for slow systems + python tests/run_tests.py integration --timeout 600 + ``` + +4. **Memory issues** + ```bash + # Skip stress tests on low-memory systems + python tests/run_tests.py all --skip-slow + ``` + +## Test Reports + +### Performance Report + +After running performance tests: + +```python +from test_fixtures import PerformanceBenchmark + +benchmark = PerformanceBenchmark() +# ... run tests ... +print(benchmark.report()) +``` + +### Load Test Report + +Stress tests generate detailed reports: + +``` +LOAD TEST REPORT +================ +Test Run 1: + Total Operations: 100 + Success Rate: 95.0% + Throughput: 12.5 ops/s + Latency (avg/p50/p95/p99/max): 0.8/0.7/1.5/2.1/3.2s +``` + +## Contributing + +When adding new tests: + +1. Use appropriate pytest markers +2. Include docstrings explaining test purpose +3. Use fixtures for common operations +4. Consider LLM latency in test design +5. Add timeout handling for long operations +6. Include performance metrics where relevant + +## License + +See main project LICENSE file. \ No newline at end of file diff --git a/mcp_server/tests/pytest.ini b/mcp_server/tests/pytest.ini new file mode 100644 index 00000000..c024df39 --- /dev/null +++ b/mcp_server/tests/pytest.ini @@ -0,0 +1,40 @@ +[pytest] +# Pytest configuration for Graphiti MCP integration tests + +# Test discovery patterns +python_files = test_*.py +python_classes = Test* +python_functions = test_* + +# Asyncio configuration +asyncio_mode = auto + +# Markers for test categorization +markers = + slow: marks tests as slow (deselect with '-m "not slow"') + integration: marks tests as integration tests requiring external services + unit: marks tests as unit tests + stress: marks tests as stress/load tests + requires_neo4j: test requires Neo4j database + requires_falkordb: test requires FalkorDB + requires_kuzu: test requires KuzuDB + requires_openai: test requires OpenAI API key + +# Test output options +addopts = + -v + --tb=short + --strict-markers + --color=yes + -p no:warnings + +# Timeout for tests (seconds) +timeout = 300 + +# Coverage options +testpaths = tests + +# Environment variables for testing +env = + TEST_MODE=true + LOG_LEVEL=INFO \ No newline at end of file diff --git a/mcp_server/tests/run_tests.py b/mcp_server/tests/run_tests.py new file mode 100644 index 00000000..fed501d2 --- /dev/null +++ b/mcp_server/tests/run_tests.py @@ -0,0 +1,336 @@ +#!/usr/bin/env python3 +""" +Test runner for Graphiti MCP integration tests. +Provides various test execution modes and reporting options. +""" + +import argparse +import asyncio +import json +import os +import subprocess +import sys +import time +from pathlib import Path +from typing import Dict, List, Optional + +import pytest + + +class TestRunner: + """Orchestrate test execution with various configurations.""" + + def __init__(self, args): + self.args = args + self.test_dir = Path(__file__).parent + self.results = {} + + def check_prerequisites(self) -> Dict[str, bool]: + """Check if required services and dependencies are available.""" + checks = {} + + # Check for OpenAI API key if not using mocks + if not self.args.mock_llm: + checks['openai_api_key'] = bool(os.environ.get('OPENAI_API_KEY')) + else: + checks['openai_api_key'] = True + + # Check database availability based on backend + if self.args.database == 'neo4j': + checks['neo4j'] = self._check_neo4j() + elif self.args.database == 'falkordb': + checks['falkordb'] = self._check_falkordb() + elif self.args.database == 'kuzu': + checks['kuzu'] = True # KuzuDB is embedded + + # Check Python dependencies + checks['mcp'] = self._check_python_package('mcp') + checks['pytest'] = self._check_python_package('pytest') + checks['pytest-asyncio'] = self._check_python_package('pytest-asyncio') + + return checks + + def _check_neo4j(self) -> bool: + """Check if Neo4j is available.""" + try: + import neo4j + # Try to connect + uri = os.environ.get('NEO4J_URI', 'bolt://localhost:7687') + user = os.environ.get('NEO4J_USER', 'neo4j') + password = os.environ.get('NEO4J_PASSWORD', 'graphiti') + + driver = neo4j.GraphDatabase.driver(uri, auth=(user, password)) + with driver.session() as session: + session.run("RETURN 1") + driver.close() + return True + except: + return False + + def _check_falkordb(self) -> bool: + """Check if FalkorDB is available.""" + try: + import redis + uri = os.environ.get('FALKORDB_URI', 'redis://localhost:6379') + r = redis.from_url(uri) + r.ping() + return True + except: + return False + + def _check_python_package(self, package: str) -> bool: + """Check if a Python package is installed.""" + try: + __import__(package.replace('-', '_')) + return True + except ImportError: + return False + + def run_test_suite(self, suite: str) -> int: + """Run a specific test suite.""" + pytest_args = ['-v', '--tb=short'] + + # Add database marker + if self.args.database: + pytest_args.extend(['-m', f'not requires_{db}' + for db in ['neo4j', 'falkordb', 'kuzu'] + if db != self.args.database]) + + # Add suite-specific arguments + if suite == 'unit': + pytest_args.extend(['-m', 'unit', 'test_*.py']) + elif suite == 'integration': + pytest_args.extend(['-m', 'integration or not unit', 'test_*.py']) + elif suite == 'comprehensive': + pytest_args.append('test_comprehensive_integration.py') + elif suite == 'async': + pytest_args.append('test_async_operations.py') + elif suite == 'stress': + pytest_args.extend(['-m', 'slow', 'test_stress_load.py']) + elif suite == 'smoke': + # Quick smoke test - just basic operations + pytest_args.extend([ + 'test_comprehensive_integration.py::TestCoreOperations::test_server_initialization', + 'test_comprehensive_integration.py::TestCoreOperations::test_add_text_memory' + ]) + elif suite == 'all': + pytest_args.append('.') + else: + pytest_args.append(suite) + + # Add coverage if requested + if self.args.coverage: + pytest_args.extend(['--cov=../src', '--cov-report=html']) + + # Add parallel execution if requested + if self.args.parallel: + pytest_args.extend(['-n', str(self.args.parallel)]) + + # Add verbosity + if self.args.verbose: + pytest_args.append('-vv') + + # Add markers to skip + if self.args.skip_slow: + pytest_args.extend(['-m', 'not slow']) + + # Add timeout override + if self.args.timeout: + pytest_args.extend(['--timeout', str(self.args.timeout)]) + + # Add environment variables + env = os.environ.copy() + if self.args.mock_llm: + env['USE_MOCK_LLM'] = 'true' + if self.args.database: + env['DATABASE_PROVIDER'] = self.args.database + + # Run tests + print(f"Running {suite} tests with pytest args: {' '.join(pytest_args)}") + return pytest.main(pytest_args) + + def run_performance_benchmark(self): + """Run performance benchmarking suite.""" + print("Running performance benchmarks...") + + # Import test modules + from test_comprehensive_integration import TestPerformance + from test_async_operations import TestAsyncPerformance + + # Run performance tests + result = pytest.main([ + '-v', + 'test_comprehensive_integration.py::TestPerformance', + 'test_async_operations.py::TestAsyncPerformance', + '--benchmark-only' if self.args.benchmark_only else '', + ]) + + return result + + def generate_report(self): + """Generate test execution report.""" + report = [] + report.append("\n" + "=" * 60) + report.append("GRAPHITI MCP TEST EXECUTION REPORT") + report.append("=" * 60) + + # Prerequisites check + checks = self.check_prerequisites() + report.append("\nPrerequisites:") + for check, passed in checks.items(): + status = "✅" if passed else "❌" + report.append(f" {status} {check}") + + # Test configuration + report.append(f"\nConfiguration:") + report.append(f" Database: {self.args.database}") + report.append(f" Mock LLM: {self.args.mock_llm}") + report.append(f" Parallel: {self.args.parallel or 'No'}") + report.append(f" Timeout: {self.args.timeout}s") + + # Results summary (if available) + if self.results: + report.append(f"\nResults:") + for suite, result in self.results.items(): + status = "✅ Passed" if result == 0 else f"❌ Failed ({result})" + report.append(f" {suite}: {status}") + + report.append("=" * 60) + return "\n".join(report) + + +def main(): + """Main entry point for test runner.""" + parser = argparse.ArgumentParser( + description='Run Graphiti MCP integration tests', + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Test Suites: + unit - Run unit tests only + integration - Run integration tests + comprehensive - Run comprehensive integration test suite + async - Run async operation tests + stress - Run stress and load tests + smoke - Run quick smoke tests + all - Run all tests + +Examples: + python run_tests.py smoke # Quick smoke test + python run_tests.py integration --parallel 4 # Run integration tests in parallel + python run_tests.py stress --database neo4j # Run stress tests with Neo4j + python run_tests.py all --coverage # Run all tests with coverage + """ + ) + + parser.add_argument( + 'suite', + choices=['unit', 'integration', 'comprehensive', 'async', 'stress', 'smoke', 'all'], + help='Test suite to run' + ) + + parser.add_argument( + '--database', + choices=['neo4j', 'falkordb', 'kuzu'], + default='kuzu', + help='Database backend to test (default: kuzu)' + ) + + parser.add_argument( + '--mock-llm', + action='store_true', + help='Use mock LLM for faster testing' + ) + + parser.add_argument( + '--parallel', + type=int, + metavar='N', + help='Run tests in parallel with N workers' + ) + + parser.add_argument( + '--coverage', + action='store_true', + help='Generate coverage report' + ) + + parser.add_argument( + '--verbose', + action='store_true', + help='Verbose output' + ) + + parser.add_argument( + '--skip-slow', + action='store_true', + help='Skip slow tests' + ) + + parser.add_argument( + '--timeout', + type=int, + default=300, + help='Test timeout in seconds (default: 300)' + ) + + parser.add_argument( + '--benchmark-only', + action='store_true', + help='Run only benchmark tests' + ) + + parser.add_argument( + '--check-only', + action='store_true', + help='Only check prerequisites without running tests' + ) + + args = parser.parse_args() + + # Create test runner + runner = TestRunner(args) + + # Check prerequisites + if args.check_only: + print(runner.generate_report()) + sys.exit(0) + + # Check if prerequisites are met + checks = runner.check_prerequisites() + if not all(checks.values()): + print("⚠️ Some prerequisites are not met:") + for check, passed in checks.items(): + if not passed: + print(f" ❌ {check}") + + if not args.mock_llm and not checks.get('openai_api_key'): + print("\nHint: Use --mock-llm to run tests without OpenAI API key") + + response = input("\nContinue anyway? (y/N): ") + if response.lower() != 'y': + sys.exit(1) + + # Run tests + print(f"\n🚀 Starting test execution: {args.suite}") + start_time = time.time() + + if args.benchmark_only: + result = runner.run_performance_benchmark() + else: + result = runner.run_test_suite(args.suite) + + duration = time.time() - start_time + + # Store results + runner.results[args.suite] = result + + # Generate and print report + print(runner.generate_report()) + print(f"\n⏱️ Test execution completed in {duration:.2f} seconds") + + # Exit with test result code + sys.exit(result) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/mcp_server/tests/test_async_operations.py b/mcp_server/tests/test_async_operations.py new file mode 100644 index 00000000..56d41ae9 --- /dev/null +++ b/mcp_server/tests/test_async_operations.py @@ -0,0 +1,494 @@ +#!/usr/bin/env python3 +""" +Asynchronous operation tests for Graphiti MCP Server. +Tests concurrent operations, queue management, and async patterns. +""" + +import asyncio +import json +import time +from typing import Any, Dict, List +from unittest.mock import AsyncMock, patch + +import pytest +from test_fixtures import ( + TestDataGenerator, + graphiti_test_client, + performance_benchmark, + test_data_generator, +) + + +class TestAsyncQueueManagement: + """Test asynchronous queue operations and episode processing.""" + + @pytest.mark.asyncio + async def test_sequential_queue_processing(self): + """Verify episodes are processed sequentially within a group.""" + async with graphiti_test_client() as (session, group_id): + # Add multiple episodes quickly + episodes = [] + for i in range(5): + result = await session.call_tool( + 'add_memory', + { + 'name': f'Sequential Test {i}', + 'episode_body': f'Episode {i} with timestamp {time.time()}', + 'source': 'text', + 'source_description': 'sequential test', + 'group_id': group_id, + 'reference_id': f'seq_{i}', # Add reference for tracking + } + ) + episodes.append(result) + + # Wait for processing + await asyncio.sleep(10) # Allow time for sequential processing + + # Retrieve episodes and verify order + result = await session.call_tool( + 'get_episodes', + {'group_id': group_id, 'last_n': 10} + ) + + processed_episodes = json.loads(result.content[0].text)['episodes'] + + # Verify all episodes were processed + assert len(processed_episodes) >= 5, f"Expected at least 5 episodes, got {len(processed_episodes)}" + + # Verify sequential processing (timestamps should be ordered) + timestamps = [ep.get('created_at') for ep in processed_episodes] + assert timestamps == sorted(timestamps), "Episodes not processed in order" + + @pytest.mark.asyncio + async def test_concurrent_group_processing(self): + """Test that different groups can process concurrently.""" + async with graphiti_test_client() as (session, _): + groups = [f'group_{i}_{time.time()}' for i in range(3)] + tasks = [] + + # Create tasks for different groups + for group_id in groups: + for j in range(2): + task = session.call_tool( + 'add_memory', + { + 'name': f'Group {group_id} Episode {j}', + 'episode_body': f'Content for {group_id}', + 'source': 'text', + 'source_description': 'concurrent test', + 'group_id': group_id, + } + ) + tasks.append(task) + + # Execute all tasks concurrently + start_time = time.time() + results = await asyncio.gather(*tasks, return_exceptions=True) + execution_time = time.time() - start_time + + # Verify all succeeded + failures = [r for r in results if isinstance(r, Exception)] + assert not failures, f"Concurrent operations failed: {failures}" + + # Check that execution was actually concurrent (should be faster than sequential) + # Sequential would take at least 6 * processing_time + assert execution_time < 30, f"Concurrent execution too slow: {execution_time}s" + + @pytest.mark.asyncio + async def test_queue_overflow_handling(self): + """Test behavior when queue reaches capacity.""" + async with graphiti_test_client() as (session, group_id): + # Attempt to add many episodes rapidly + tasks = [] + for i in range(100): # Large number to potentially overflow + task = session.call_tool( + 'add_memory', + { + 'name': f'Overflow Test {i}', + 'episode_body': f'Episode {i}', + 'source': 'text', + 'source_description': 'overflow test', + 'group_id': group_id, + } + ) + tasks.append(task) + + # Execute with gathering to catch any failures + results = await asyncio.gather(*tasks, return_exceptions=True) + + # Count successful queuing + successful = sum(1 for r in results if not isinstance(r, Exception)) + + # Should handle overflow gracefully + assert successful > 0, "No episodes were queued successfully" + + # Log overflow behavior + if successful < 100: + print(f"Queue overflow: {successful}/100 episodes queued") + + +class TestConcurrentOperations: + """Test concurrent tool calls and operations.""" + + @pytest.mark.asyncio + async def test_concurrent_search_operations(self): + """Test multiple concurrent search operations.""" + async with graphiti_test_client() as (session, group_id): + # First, add some test data + data_gen = TestDataGenerator() + + add_tasks = [] + for _ in range(5): + task = session.call_tool( + 'add_memory', + { + 'name': 'Search Test Data', + 'episode_body': data_gen.generate_technical_document(), + 'source': 'text', + 'source_description': 'search test', + 'group_id': group_id, + } + ) + add_tasks.append(task) + + await asyncio.gather(*add_tasks) + await asyncio.sleep(15) # Wait for processing + + # Now perform concurrent searches + search_queries = [ + 'architecture', + 'performance', + 'implementation', + 'dependencies', + 'latency', + ] + + search_tasks = [] + for query in search_queries: + task = session.call_tool( + 'search_memory_nodes', + { + 'query': query, + 'group_id': group_id, + 'limit': 10, + } + ) + search_tasks.append(task) + + start_time = time.time() + results = await asyncio.gather(*search_tasks, return_exceptions=True) + search_time = time.time() - start_time + + # Verify all searches completed + failures = [r for r in results if isinstance(r, Exception)] + assert not failures, f"Search operations failed: {failures}" + + # Verify concurrent execution efficiency + assert search_time < len(search_queries) * 2, "Searches not executing concurrently" + + @pytest.mark.asyncio + async def test_mixed_operation_concurrency(self): + """Test different types of operations running concurrently.""" + async with graphiti_test_client() as (session, group_id): + operations = [] + + # Add memory operation + operations.append(session.call_tool( + 'add_memory', + { + 'name': 'Mixed Op Test', + 'episode_body': 'Testing mixed operations', + 'source': 'text', + 'source_description': 'test', + 'group_id': group_id, + } + )) + + # Search operation + operations.append(session.call_tool( + 'search_memory_nodes', + { + 'query': 'test', + 'group_id': group_id, + 'limit': 5, + } + )) + + # Get episodes operation + operations.append(session.call_tool( + 'get_episodes', + { + 'group_id': group_id, + 'last_n': 10, + } + )) + + # Get status operation + operations.append(session.call_tool( + 'get_status', + {} + )) + + # Execute all concurrently + results = await asyncio.gather(*operations, return_exceptions=True) + + # Check results + for i, result in enumerate(results): + assert not isinstance(result, Exception), f"Operation {i} failed: {result}" + + +class TestAsyncErrorHandling: + """Test async error handling and recovery.""" + + @pytest.mark.asyncio + async def test_timeout_recovery(self): + """Test recovery from operation timeouts.""" + async with graphiti_test_client() as (session, group_id): + # Create a very large episode that might timeout + large_content = "x" * 1000000 # 1MB of data + + try: + result = await asyncio.wait_for( + session.call_tool( + 'add_memory', + { + 'name': 'Timeout Test', + 'episode_body': large_content, + 'source': 'text', + 'source_description': 'timeout test', + 'group_id': group_id, + } + ), + timeout=2.0 # Short timeout + ) + except asyncio.TimeoutError: + # Expected timeout + pass + + # Verify server is still responsive after timeout + status_result = await session.call_tool('get_status', {}) + assert status_result is not None, "Server unresponsive after timeout" + + @pytest.mark.asyncio + async def test_cancellation_handling(self): + """Test proper handling of cancelled operations.""" + async with graphiti_test_client() as (session, group_id): + # Start a long-running operation + task = asyncio.create_task( + session.call_tool( + 'add_memory', + { + 'name': 'Cancellation Test', + 'episode_body': TestDataGenerator.generate_technical_document(), + 'source': 'text', + 'source_description': 'cancel test', + 'group_id': group_id, + } + ) + ) + + # Cancel after a short delay + await asyncio.sleep(0.1) + task.cancel() + + # Verify cancellation was handled + with pytest.raises(asyncio.CancelledError): + await task + + # Server should still be operational + result = await session.call_tool('get_status', {}) + assert result is not None + + @pytest.mark.asyncio + async def test_exception_propagation(self): + """Test that exceptions are properly propagated in async context.""" + async with graphiti_test_client() as (session, group_id): + # Call with invalid arguments + with pytest.raises(Exception): + await session.call_tool( + 'add_memory', + { + # Missing required fields + 'group_id': group_id, + } + ) + + # Server should remain operational + status = await session.call_tool('get_status', {}) + assert status is not None + + +class TestAsyncPerformance: + """Performance tests for async operations.""" + + @pytest.mark.asyncio + async def test_async_throughput(self, performance_benchmark): + """Measure throughput of async operations.""" + async with graphiti_test_client() as (session, group_id): + num_operations = 50 + start_time = time.time() + + # Create many concurrent operations + tasks = [] + for i in range(num_operations): + task = session.call_tool( + 'add_memory', + { + 'name': f'Throughput Test {i}', + 'episode_body': f'Content {i}', + 'source': 'text', + 'source_description': 'throughput test', + 'group_id': group_id, + } + ) + tasks.append(task) + + # Execute all + results = await asyncio.gather(*tasks, return_exceptions=True) + total_time = time.time() - start_time + + # Calculate metrics + successful = sum(1 for r in results if not isinstance(r, Exception)) + throughput = successful / total_time + + performance_benchmark.record('async_throughput', throughput) + + # Log results + print(f"\nAsync Throughput Test:") + print(f" Operations: {num_operations}") + print(f" Successful: {successful}") + print(f" Total time: {total_time:.2f}s") + print(f" Throughput: {throughput:.2f} ops/s") + + # Assert minimum throughput + assert throughput > 1.0, f"Throughput too low: {throughput:.2f} ops/s" + + @pytest.mark.asyncio + async def test_latency_under_load(self, performance_benchmark): + """Test operation latency under concurrent load.""" + async with graphiti_test_client() as (session, group_id): + # Create background load + background_tasks = [] + for i in range(10): + task = asyncio.create_task( + session.call_tool( + 'add_memory', + { + 'name': f'Background {i}', + 'episode_body': TestDataGenerator.generate_technical_document(), + 'source': 'text', + 'source_description': 'background', + 'group_id': f'background_{group_id}', + } + ) + ) + background_tasks.append(task) + + # Measure latency of operations under load + latencies = [] + for _ in range(5): + start = time.time() + await session.call_tool( + 'get_status', + {} + ) + latency = time.time() - start + latencies.append(latency) + performance_benchmark.record('latency_under_load', latency) + + # Clean up background tasks + for task in background_tasks: + task.cancel() + + # Analyze latencies + avg_latency = sum(latencies) / len(latencies) + max_latency = max(latencies) + + print(f"\nLatency Under Load:") + print(f" Average: {avg_latency:.3f}s") + print(f" Max: {max_latency:.3f}s") + + # Assert acceptable latency + assert avg_latency < 2.0, f"Average latency too high: {avg_latency:.3f}s" + assert max_latency < 5.0, f"Max latency too high: {max_latency:.3f}s" + + +class TestAsyncStreamHandling: + """Test handling of streaming responses and data.""" + + @pytest.mark.asyncio + async def test_large_response_streaming(self): + """Test handling of large streamed responses.""" + async with graphiti_test_client() as (session, group_id): + # Add many episodes + for i in range(20): + await session.call_tool( + 'add_memory', + { + 'name': f'Stream Test {i}', + 'episode_body': f'Episode content {i}', + 'source': 'text', + 'source_description': 'stream test', + 'group_id': group_id, + } + ) + + # Wait for processing + await asyncio.sleep(30) + + # Request large result set + result = await session.call_tool( + 'get_episodes', + { + 'group_id': group_id, + 'last_n': 100, # Request all + } + ) + + # Verify response handling + episodes = json.loads(result.content[0].text)['episodes'] + assert len(episodes) >= 20, f"Expected at least 20 episodes, got {len(episodes)}" + + @pytest.mark.asyncio + async def test_incremental_processing(self): + """Test incremental processing of results.""" + async with graphiti_test_client() as (session, group_id): + # Add episodes incrementally + for batch in range(3): + batch_tasks = [] + for i in range(5): + task = session.call_tool( + 'add_memory', + { + 'name': f'Batch {batch} Item {i}', + 'episode_body': f'Content for batch {batch}', + 'source': 'text', + 'source_description': 'incremental test', + 'group_id': group_id, + } + ) + batch_tasks.append(task) + + # Process batch + await asyncio.gather(*batch_tasks) + + # Wait for this batch to process + await asyncio.sleep(10) + + # Verify incremental results + result = await session.call_tool( + 'get_episodes', + { + 'group_id': group_id, + 'last_n': 100, + } + ) + + episodes = json.loads(result.content[0].text)['episodes'] + expected_min = (batch + 1) * 5 + assert len(episodes) >= expected_min, f"Batch {batch}: Expected at least {expected_min} episodes" + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "--asyncio-mode=auto"]) \ No newline at end of file diff --git a/mcp_server/tests/test_comprehensive_integration.py b/mcp_server/tests/test_comprehensive_integration.py new file mode 100644 index 00000000..10ad2261 --- /dev/null +++ b/mcp_server/tests/test_comprehensive_integration.py @@ -0,0 +1,696 @@ +#!/usr/bin/env python3 +""" +Comprehensive integration test suite for Graphiti MCP Server. +Covers all MCP tools with consideration for LLM inference latency. +""" + +import asyncio +import json +import os +import time +from dataclasses import dataclass +from datetime import datetime, timedelta +from typing import Any, Dict, List, Optional +from unittest.mock import patch + +import pytest +from mcp import ClientSession, StdioServerParameters +from mcp.client.stdio import stdio_client + + +@dataclass +class TestMetrics: + """Track test performance metrics.""" + + operation: str + start_time: float + end_time: float + success: bool + details: Dict[str, Any] + + @property + def duration(self) -> float: + """Calculate operation duration in seconds.""" + return self.end_time - self.start_time + + +class GraphitiTestClient: + """Enhanced test client for comprehensive Graphiti MCP testing.""" + + def __init__(self, test_group_id: Optional[str] = None): + self.test_group_id = test_group_id or f'test_{int(time.time())}' + self.session = None + self.metrics: List[TestMetrics] = [] + self.default_timeout = 30 # seconds + + async def __aenter__(self): + """Initialize MCP client session.""" + server_params = StdioServerParameters( + command='uv', + args=['run', 'main.py', '--transport', 'stdio'], + env={ + 'NEO4J_URI': os.environ.get('NEO4J_URI', 'bolt://localhost:7687'), + 'NEO4J_USER': os.environ.get('NEO4J_USER', 'neo4j'), + 'NEO4J_PASSWORD': os.environ.get('NEO4J_PASSWORD', 'graphiti'), + 'OPENAI_API_KEY': os.environ.get('OPENAI_API_KEY'), + 'KUZU_PATH': os.environ.get('KUZU_PATH', './test_kuzu.db'), + 'FALKORDB_URI': os.environ.get('FALKORDB_URI', 'redis://localhost:6379'), + }, + ) + + self.client_context = stdio_client(server_params) + read, write = await self.client_context.__aenter__() + self.session = ClientSession(read, write) + await self.session.initialize() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Clean up client session.""" + if self.session: + await self.session.close() + if hasattr(self, 'client_context'): + await self.client_context.__aexit__(exc_type, exc_val, exc_tb) + + async def call_tool_with_metrics( + self, tool_name: str, arguments: Dict[str, Any], timeout: Optional[float] = None + ) -> tuple[Any, TestMetrics]: + """Call a tool and capture performance metrics.""" + start_time = time.time() + timeout = timeout or self.default_timeout + + try: + result = await asyncio.wait_for( + self.session.call_tool(tool_name, arguments), + timeout=timeout + ) + + content = result.content[0].text if result.content else None + success = True + details = {'result': content, 'tool': tool_name} + + except asyncio.TimeoutError: + content = None + success = False + details = {'error': f'Timeout after {timeout}s', 'tool': tool_name} + + except Exception as e: + content = None + success = False + details = {'error': str(e), 'tool': tool_name} + + end_time = time.time() + metric = TestMetrics( + operation=f'call_{tool_name}', + start_time=start_time, + end_time=end_time, + success=success, + details=details + ) + self.metrics.append(metric) + + return content, metric + + async def wait_for_episode_processing( + self, expected_count: int = 1, max_wait: int = 60, poll_interval: int = 2 + ) -> bool: + """ + Wait for episodes to be processed with intelligent polling. + + Args: + expected_count: Number of episodes expected to be processed + max_wait: Maximum seconds to wait + poll_interval: Seconds between status checks + + Returns: + True if episodes were processed successfully + """ + start_time = time.time() + + while (time.time() - start_time) < max_wait: + result, _ = await self.call_tool_with_metrics( + 'get_episodes', + {'group_id': self.test_group_id, 'last_n': 100} + ) + + if result: + try: + episodes = json.loads(result) if isinstance(result, str) else result + if len(episodes.get('episodes', [])) >= expected_count: + return True + except (json.JSONDecodeError, AttributeError): + pass + + await asyncio.sleep(poll_interval) + + return False + + +class TestCoreOperations: + """Test core Graphiti operations.""" + + @pytest.mark.asyncio + async def test_server_initialization(self): + """Verify server initializes with all required tools.""" + async with GraphitiTestClient() as client: + tools_result = await client.session.list_tools() + tools = {tool.name for tool in tools_result.tools} + + required_tools = { + 'add_memory', + 'search_memory_nodes', + 'search_memory_facts', + 'get_episodes', + 'delete_episode', + 'delete_entity_edge', + 'get_entity_edge', + 'clear_graph', + 'get_status' + } + + missing_tools = required_tools - tools + assert not missing_tools, f"Missing required tools: {missing_tools}" + + @pytest.mark.asyncio + async def test_add_text_memory(self): + """Test adding text-based memories.""" + async with GraphitiTestClient() as client: + # Add memory + result, metric = await client.call_tool_with_metrics( + 'add_memory', + { + 'name': 'Tech Conference Notes', + 'episode_body': 'The AI conference featured talks on LLMs, RAG systems, and knowledge graphs. Notable speakers included researchers from OpenAI and Anthropic.', + 'source': 'text', + 'source_description': 'conference notes', + 'group_id': client.test_group_id, + } + ) + + assert metric.success, f"Failed to add memory: {metric.details}" + assert 'queued' in str(result).lower() + + # Wait for processing + processed = await client.wait_for_episode_processing(expected_count=1) + assert processed, "Episode was not processed within timeout" + + @pytest.mark.asyncio + async def test_add_json_memory(self): + """Test adding structured JSON memories.""" + async with GraphitiTestClient() as client: + json_data = { + 'project': { + 'name': 'GraphitiDB', + 'version': '2.0.0', + 'features': ['temporal-awareness', 'hybrid-search', 'custom-entities'] + }, + 'team': { + 'size': 5, + 'roles': ['engineering', 'product', 'research'] + } + } + + result, metric = await client.call_tool_with_metrics( + 'add_memory', + { + 'name': 'Project Data', + 'episode_body': json.dumps(json_data), + 'source': 'json', + 'source_description': 'project database', + 'group_id': client.test_group_id, + } + ) + + assert metric.success + assert 'queued' in str(result).lower() + + @pytest.mark.asyncio + async def test_add_message_memory(self): + """Test adding conversation/message memories.""" + async with GraphitiTestClient() as client: + conversation = """ + user: What are the key features of Graphiti? + assistant: Graphiti offers temporal-aware knowledge graphs, hybrid retrieval, and real-time updates. + user: How does it handle entity resolution? + assistant: It uses LLM-based entity extraction and deduplication with semantic similarity matching. + """ + + result, metric = await client.call_tool_with_metrics( + 'add_memory', + { + 'name': 'Feature Discussion', + 'episode_body': conversation, + 'source': 'message', + 'source_description': 'support chat', + 'group_id': client.test_group_id, + } + ) + + assert metric.success + assert metric.duration < 5, f"Add memory took too long: {metric.duration}s" + + +class TestSearchOperations: + """Test search and retrieval operations.""" + + @pytest.mark.asyncio + async def test_search_nodes_semantic(self): + """Test semantic search for nodes.""" + async with GraphitiTestClient() as client: + # First add some test data + await client.call_tool_with_metrics( + 'add_memory', + { + 'name': 'Product Launch', + 'episode_body': 'Our new AI assistant product launches in Q2 2024 with advanced NLP capabilities.', + 'source': 'text', + 'source_description': 'product roadmap', + 'group_id': client.test_group_id, + } + ) + + # Wait for processing + await client.wait_for_episode_processing() + + # Search for nodes + result, metric = await client.call_tool_with_metrics( + 'search_memory_nodes', + { + 'query': 'AI product features', + 'group_id': client.test_group_id, + 'limit': 10 + } + ) + + assert metric.success + assert result is not None + + @pytest.mark.asyncio + async def test_search_facts_with_filters(self): + """Test fact search with various filters.""" + async with GraphitiTestClient() as client: + # Add test data + await client.call_tool_with_metrics( + 'add_memory', + { + 'name': 'Company Facts', + 'episode_body': 'Acme Corp was founded in 2020. They have 50 employees and $10M in revenue.', + 'source': 'text', + 'source_description': 'company profile', + 'group_id': client.test_group_id, + } + ) + + await client.wait_for_episode_processing() + + # Search with date filter + result, metric = await client.call_tool_with_metrics( + 'search_memory_facts', + { + 'query': 'company information', + 'group_id': client.test_group_id, + 'created_after': '2020-01-01T00:00:00Z', + 'limit': 20 + } + ) + + assert metric.success + + @pytest.mark.asyncio + async def test_hybrid_search(self): + """Test hybrid search combining semantic and keyword search.""" + async with GraphitiTestClient() as client: + # Add diverse test data + test_memories = [ + { + 'name': 'Technical Doc', + 'episode_body': 'GraphQL API endpoints support pagination, filtering, and real-time subscriptions.', + 'source': 'text' + }, + { + 'name': 'Architecture', + 'episode_body': 'The system uses Neo4j for graph storage and OpenAI embeddings for semantic search.', + 'source': 'text' + } + ] + + for memory in test_memories: + memory['group_id'] = client.test_group_id + memory['source_description'] = 'documentation' + await client.call_tool_with_metrics('add_memory', memory) + + await client.wait_for_episode_processing(expected_count=2) + + # Test semantic + keyword search + result, metric = await client.call_tool_with_metrics( + 'search_memory_nodes', + { + 'query': 'Neo4j graph database', + 'group_id': client.test_group_id, + 'limit': 10 + } + ) + + assert metric.success + + +class TestEpisodeManagement: + """Test episode lifecycle operations.""" + + @pytest.mark.asyncio + async def test_get_episodes_pagination(self): + """Test retrieving episodes with pagination.""" + async with GraphitiTestClient() as client: + # Add multiple episodes + for i in range(5): + await client.call_tool_with_metrics( + 'add_memory', + { + 'name': f'Episode {i}', + 'episode_body': f'This is test episode number {i}', + 'source': 'text', + 'source_description': 'test', + 'group_id': client.test_group_id, + } + ) + + await client.wait_for_episode_processing(expected_count=5) + + # Test pagination + result, metric = await client.call_tool_with_metrics( + 'get_episodes', + { + 'group_id': client.test_group_id, + 'last_n': 3 + } + ) + + assert metric.success + episodes = json.loads(result) if isinstance(result, str) else result + assert len(episodes.get('episodes', [])) <= 3 + + @pytest.mark.asyncio + async def test_delete_episode(self): + """Test deleting specific episodes.""" + async with GraphitiTestClient() as client: + # Add an episode + await client.call_tool_with_metrics( + 'add_memory', + { + 'name': 'To Delete', + 'episode_body': 'This episode will be deleted', + 'source': 'text', + 'source_description': 'test', + 'group_id': client.test_group_id, + } + ) + + await client.wait_for_episode_processing() + + # Get episode UUID + result, _ = await client.call_tool_with_metrics( + 'get_episodes', + {'group_id': client.test_group_id, 'last_n': 1} + ) + + episodes = json.loads(result) if isinstance(result, str) else result + episode_uuid = episodes['episodes'][0]['uuid'] + + # Delete the episode + result, metric = await client.call_tool_with_metrics( + 'delete_episode', + {'episode_uuid': episode_uuid} + ) + + assert metric.success + assert 'deleted' in str(result).lower() + + +class TestEntityAndEdgeOperations: + """Test entity and edge management.""" + + @pytest.mark.asyncio + async def test_get_entity_edge(self): + """Test retrieving entity edges.""" + async with GraphitiTestClient() as client: + # Add data to create entities and edges + await client.call_tool_with_metrics( + 'add_memory', + { + 'name': 'Relationship Data', + 'episode_body': 'Alice works at TechCorp. Bob is the CEO of TechCorp.', + 'source': 'text', + 'source_description': 'org chart', + 'group_id': client.test_group_id, + } + ) + + await client.wait_for_episode_processing() + + # Search for nodes to get UUIDs + result, _ = await client.call_tool_with_metrics( + 'search_memory_nodes', + { + 'query': 'TechCorp', + 'group_id': client.test_group_id, + 'limit': 5 + } + ) + + # Note: This test assumes edges are created between entities + # Actual edge retrieval would require valid edge UUIDs + + @pytest.mark.asyncio + async def test_delete_entity_edge(self): + """Test deleting entity edges.""" + # Similar structure to get_entity_edge but with deletion + pass # Implement based on actual edge creation patterns + + +class TestErrorHandling: + """Test error conditions and edge cases.""" + + @pytest.mark.asyncio + async def test_invalid_tool_arguments(self): + """Test handling of invalid tool arguments.""" + async with GraphitiTestClient() as client: + # Missing required arguments + result, metric = await client.call_tool_with_metrics( + 'add_memory', + {'name': 'Incomplete'} # Missing required fields + ) + + assert not metric.success + assert 'error' in str(metric.details).lower() + + @pytest.mark.asyncio + async def test_timeout_handling(self): + """Test timeout handling for long operations.""" + async with GraphitiTestClient() as client: + # Simulate a very large episode that might timeout + large_text = "Large document content. " * 10000 + + result, metric = await client.call_tool_with_metrics( + 'add_memory', + { + 'name': 'Large Document', + 'episode_body': large_text, + 'source': 'text', + 'source_description': 'large file', + 'group_id': client.test_group_id, + }, + timeout=5 # Short timeout + ) + + # Check if timeout was handled gracefully + if not metric.success: + assert 'timeout' in str(metric.details).lower() + + @pytest.mark.asyncio + async def test_concurrent_operations(self): + """Test handling of concurrent operations.""" + async with GraphitiTestClient() as client: + # Launch multiple operations concurrently + tasks = [] + for i in range(5): + task = client.call_tool_with_metrics( + 'add_memory', + { + 'name': f'Concurrent {i}', + 'episode_body': f'Concurrent operation {i}', + 'source': 'text', + 'source_description': 'concurrent test', + 'group_id': client.test_group_id, + } + ) + tasks.append(task) + + results = await asyncio.gather(*tasks, return_exceptions=True) + + # Check that operations were queued successfully + successful = sum(1 for r, m in results if m.success) + assert successful >= 3 # At least 60% should succeed + + +class TestPerformance: + """Test performance characteristics and optimization.""" + + @pytest.mark.asyncio + async def test_latency_metrics(self): + """Measure and validate operation latencies.""" + async with GraphitiTestClient() as client: + operations = [ + ('add_memory', { + 'name': 'Perf Test', + 'episode_body': 'Simple text', + 'source': 'text', + 'source_description': 'test', + 'group_id': client.test_group_id, + }), + ('search_memory_nodes', { + 'query': 'test', + 'group_id': client.test_group_id, + 'limit': 10 + }), + ('get_episodes', { + 'group_id': client.test_group_id, + 'last_n': 10 + }) + ] + + for tool_name, args in operations: + _, metric = await client.call_tool_with_metrics(tool_name, args) + + # Log performance metrics + print(f"{tool_name}: {metric.duration:.2f}s") + + # Basic latency assertions + if tool_name == 'get_episodes': + assert metric.duration < 2, f"{tool_name} too slow" + elif tool_name == 'search_memory_nodes': + assert metric.duration < 10, f"{tool_name} too slow" + + @pytest.mark.asyncio + async def test_batch_processing_efficiency(self): + """Test efficiency of batch operations.""" + async with GraphitiTestClient() as client: + batch_size = 10 + start_time = time.time() + + # Batch add memories + for i in range(batch_size): + await client.call_tool_with_metrics( + 'add_memory', + { + 'name': f'Batch {i}', + 'episode_body': f'Batch content {i}', + 'source': 'text', + 'source_description': 'batch test', + 'group_id': client.test_group_id, + } + ) + + # Wait for all to process + processed = await client.wait_for_episode_processing( + expected_count=batch_size, + max_wait=120 # Allow more time for batch + ) + + total_time = time.time() - start_time + avg_time_per_item = total_time / batch_size + + assert processed, f"Failed to process {batch_size} items" + assert avg_time_per_item < 15, f"Batch processing too slow: {avg_time_per_item:.2f}s per item" + + # Generate performance report + print(f"\nBatch Performance Report:") + print(f" Total items: {batch_size}") + print(f" Total time: {total_time:.2f}s") + print(f" Avg per item: {avg_time_per_item:.2f}s") + + +class TestDatabaseBackends: + """Test different database backend configurations.""" + + @pytest.mark.asyncio + @pytest.mark.parametrize("database", ["neo4j", "falkordb", "kuzu"]) + async def test_database_operations(self, database): + """Test operations with different database backends.""" + env_vars = { + 'DATABASE_PROVIDER': database, + 'OPENAI_API_KEY': os.environ.get('OPENAI_API_KEY'), + } + + if database == 'neo4j': + env_vars.update({ + 'NEO4J_URI': os.environ.get('NEO4J_URI', 'bolt://localhost:7687'), + 'NEO4J_USER': os.environ.get('NEO4J_USER', 'neo4j'), + 'NEO4J_PASSWORD': os.environ.get('NEO4J_PASSWORD', 'graphiti'), + }) + elif database == 'falkordb': + env_vars['FALKORDB_URI'] = os.environ.get('FALKORDB_URI', 'redis://localhost:6379') + elif database == 'kuzu': + env_vars['KUZU_PATH'] = os.environ.get('KUZU_PATH', f'./test_kuzu_{int(time.time())}.db') + + server_params = StdioServerParameters( + command='uv', + args=['run', 'main.py', '--transport', 'stdio', '--database', database], + env=env_vars + ) + + # Test basic operations with each backend + # Implementation depends on database availability + + +def generate_test_report(client: GraphitiTestClient) -> str: + """Generate a comprehensive test report from metrics.""" + if not client.metrics: + return "No metrics collected" + + report = [] + report.append("\n" + "="*60) + report.append("GRAPHITI MCP TEST REPORT") + report.append("="*60) + + # Summary statistics + total_ops = len(client.metrics) + successful_ops = sum(1 for m in client.metrics if m.success) + avg_duration = sum(m.duration for m in client.metrics) / total_ops + + report.append(f"\nTotal Operations: {total_ops}") + report.append(f"Successful: {successful_ops} ({successful_ops/total_ops*100:.1f}%)") + report.append(f"Average Duration: {avg_duration:.2f}s") + + # Operation breakdown + report.append("\nOperation Breakdown:") + operation_stats = {} + for metric in client.metrics: + if metric.operation not in operation_stats: + operation_stats[metric.operation] = { + 'count': 0, 'success': 0, 'total_duration': 0 + } + stats = operation_stats[metric.operation] + stats['count'] += 1 + stats['success'] += 1 if metric.success else 0 + stats['total_duration'] += metric.duration + + for op, stats in sorted(operation_stats.items()): + avg_dur = stats['total_duration'] / stats['count'] + success_rate = stats['success'] / stats['count'] * 100 + report.append( + f" {op}: {stats['count']} calls, " + f"{success_rate:.0f}% success, {avg_dur:.2f}s avg" + ) + + # Slowest operations + slowest = sorted(client.metrics, key=lambda m: m.duration, reverse=True)[:5] + report.append("\nSlowest Operations:") + for metric in slowest: + report.append(f" {metric.operation}: {metric.duration:.2f}s") + + report.append("="*60) + return "\n".join(report) + + +if __name__ == "__main__": + # Run tests with pytest + pytest.main([__file__, "-v", "--asyncio-mode=auto"]) \ No newline at end of file diff --git a/mcp_server/tests/test_fixtures.py b/mcp_server/tests/test_fixtures.py new file mode 100644 index 00000000..ed1e0eda --- /dev/null +++ b/mcp_server/tests/test_fixtures.py @@ -0,0 +1,324 @@ +""" +Shared test fixtures and utilities for Graphiti MCP integration tests. +""" + +import asyncio +import json +import os +import random +import time +from contextlib import asynccontextmanager +from typing import Any, Dict, List, Optional + +import pytest +from faker import Faker +from mcp import ClientSession, StdioServerParameters +from mcp.client.stdio import stdio_client + +fake = Faker() + + +class TestDataGenerator: + """Generate realistic test data for various scenarios.""" + + @staticmethod + def generate_company_profile() -> Dict[str, Any]: + """Generate a realistic company profile.""" + return { + 'company': { + 'name': fake.company(), + 'founded': random.randint(1990, 2023), + 'industry': random.choice(['Tech', 'Finance', 'Healthcare', 'Retail']), + 'employees': random.randint(10, 10000), + 'revenue': f"${random.randint(1, 1000)}M", + 'headquarters': fake.city(), + }, + 'products': [ + { + 'id': fake.uuid4()[:8], + 'name': fake.catch_phrase(), + 'category': random.choice(['Software', 'Hardware', 'Service']), + 'price': random.randint(10, 10000), + } + for _ in range(random.randint(1, 5)) + ], + 'leadership': { + 'ceo': fake.name(), + 'cto': fake.name(), + 'cfo': fake.name(), + } + } + + @staticmethod + def generate_conversation(turns: int = 3) -> str: + """Generate a realistic conversation.""" + topics = [ + "product features", + "pricing", + "technical support", + "integration", + "documentation", + "performance", + ] + + conversation = [] + for _ in range(turns): + topic = random.choice(topics) + user_msg = f"user: {fake.sentence()} about {topic}?" + assistant_msg = f"assistant: {fake.paragraph(nb_sentences=2)}" + conversation.extend([user_msg, assistant_msg]) + + return "\n".join(conversation) + + @staticmethod + def generate_technical_document() -> str: + """Generate technical documentation content.""" + sections = [ + f"# {fake.catch_phrase()}\n\n{fake.paragraph()}", + f"## Architecture\n{fake.paragraph()}", + f"## Implementation\n{fake.paragraph()}", + f"## Performance\n- Latency: {random.randint(1, 100)}ms\n- Throughput: {random.randint(100, 10000)} req/s", + f"## Dependencies\n- {fake.word()}\n- {fake.word()}\n- {fake.word()}", + ] + return "\n\n".join(sections) + + @staticmethod + def generate_news_article() -> str: + """Generate a news article.""" + company = fake.company() + return f""" + {company} Announces {fake.catch_phrase()} + + {fake.city()}, {fake.date()} - {company} today announced {fake.paragraph()}. + + "This is a significant milestone," said {fake.name()}, CEO of {company}. + "{fake.sentence()}" + + The announcement comes after {fake.paragraph()}. + + Industry analysts predict {fake.paragraph()}. + """ + + @staticmethod + def generate_user_profile() -> Dict[str, Any]: + """Generate a user profile.""" + return { + 'user_id': fake.uuid4(), + 'name': fake.name(), + 'email': fake.email(), + 'joined': fake.date_time_this_year().isoformat(), + 'preferences': { + 'theme': random.choice(['light', 'dark', 'auto']), + 'notifications': random.choice([True, False]), + 'language': random.choice(['en', 'es', 'fr', 'de']), + }, + 'activity': { + 'last_login': fake.date_time_this_month().isoformat(), + 'total_sessions': random.randint(1, 1000), + 'average_duration': f"{random.randint(1, 60)} minutes", + } + } + + +class MockLLMProvider: + """Mock LLM provider for testing without actual API calls.""" + + def __init__(self, delay: float = 0.1): + self.delay = delay # Simulate LLM latency + + async def generate(self, prompt: str) -> str: + """Simulate LLM generation with delay.""" + await asyncio.sleep(self.delay) + + # Return deterministic responses based on prompt patterns + if "extract entities" in prompt.lower(): + return json.dumps({ + 'entities': [ + {'name': 'TestEntity1', 'type': 'PERSON'}, + {'name': 'TestEntity2', 'type': 'ORGANIZATION'}, + ] + }) + elif "summarize" in prompt.lower(): + return "This is a test summary of the provided content." + else: + return "Mock LLM response" + + +@asynccontextmanager +async def graphiti_test_client( + group_id: Optional[str] = None, + database: str = "kuzu", + use_mock_llm: bool = False, + config_overrides: Optional[Dict[str, Any]] = None +): + """ + Context manager for creating test clients with various configurations. + + Args: + group_id: Test group identifier + database: Database backend (neo4j, falkordb, kuzu) + use_mock_llm: Whether to use mock LLM for faster tests + config_overrides: Additional config overrides + """ + test_group_id = group_id or f'test_{int(time.time())}_{random.randint(1000, 9999)}' + + env = { + 'DATABASE_PROVIDER': database, + 'OPENAI_API_KEY': os.environ.get('OPENAI_API_KEY', 'test_key' if use_mock_llm else None), + } + + # Database-specific configuration + if database == 'neo4j': + env.update({ + 'NEO4J_URI': os.environ.get('NEO4J_URI', 'bolt://localhost:7687'), + 'NEO4J_USER': os.environ.get('NEO4J_USER', 'neo4j'), + 'NEO4J_PASSWORD': os.environ.get('NEO4J_PASSWORD', 'graphiti'), + }) + elif database == 'falkordb': + env['FALKORDB_URI'] = os.environ.get('FALKORDB_URI', 'redis://localhost:6379') + elif database == 'kuzu': + env['KUZU_PATH'] = os.environ.get('KUZU_PATH', f'./test_kuzu_{test_group_id}.db') + + # Apply config overrides + if config_overrides: + env.update(config_overrides) + + # Add mock LLM flag if needed + if use_mock_llm: + env['USE_MOCK_LLM'] = 'true' + + server_params = StdioServerParameters( + command='uv', + args=['run', 'main.py', '--transport', 'stdio'], + env=env + ) + + async with stdio_client(server_params) as (read, write): + session = ClientSession(read, write) + await session.initialize() + + try: + yield session, test_group_id + finally: + # Cleanup: Clear test data + try: + await session.call_tool('clear_graph', {'group_id': test_group_id}) + except: + pass # Ignore cleanup errors + + await session.close() + + +class PerformanceBenchmark: + """Track and analyze performance benchmarks.""" + + def __init__(self): + self.measurements: Dict[str, List[float]] = {} + + def record(self, operation: str, duration: float): + """Record a performance measurement.""" + if operation not in self.measurements: + self.measurements[operation] = [] + self.measurements[operation].append(duration) + + def get_stats(self, operation: str) -> Dict[str, float]: + """Get statistics for an operation.""" + if operation not in self.measurements or not self.measurements[operation]: + return {} + + durations = self.measurements[operation] + return { + 'count': len(durations), + 'mean': sum(durations) / len(durations), + 'min': min(durations), + 'max': max(durations), + 'median': sorted(durations)[len(durations) // 2], + } + + def report(self) -> str: + """Generate a performance report.""" + lines = ["Performance Benchmark Report", "=" * 40] + + for operation in sorted(self.measurements.keys()): + stats = self.get_stats(operation) + lines.append(f"\n{operation}:") + lines.append(f" Samples: {stats['count']}") + lines.append(f" Mean: {stats['mean']:.3f}s") + lines.append(f" Median: {stats['median']:.3f}s") + lines.append(f" Min: {stats['min']:.3f}s") + lines.append(f" Max: {stats['max']:.3f}s") + + return "\n".join(lines) + + +# Pytest fixtures +@pytest.fixture +def test_data_generator(): + """Provide test data generator.""" + return TestDataGenerator() + + +@pytest.fixture +def performance_benchmark(): + """Provide performance benchmark tracker.""" + return PerformanceBenchmark() + + +@pytest.fixture +async def mock_graphiti_client(): + """Provide a Graphiti client with mocked LLM.""" + async with graphiti_test_client(use_mock_llm=True) as (session, group_id): + yield session, group_id + + +@pytest.fixture +async def graphiti_client(): + """Provide a real Graphiti client.""" + async with graphiti_test_client(use_mock_llm=False) as (session, group_id): + yield session, group_id + + +# Test data fixtures +@pytest.fixture +def sample_memories(): + """Provide sample memory data for testing.""" + return [ + { + 'name': 'Company Overview', + 'episode_body': TestDataGenerator.generate_company_profile(), + 'source': 'json', + 'source_description': 'company database', + }, + { + 'name': 'Product Launch', + 'episode_body': TestDataGenerator.generate_news_article(), + 'source': 'text', + 'source_description': 'press release', + }, + { + 'name': 'Customer Support', + 'episode_body': TestDataGenerator.generate_conversation(), + 'source': 'message', + 'source_description': 'support chat', + }, + { + 'name': 'Technical Specs', + 'episode_body': TestDataGenerator.generate_technical_document(), + 'source': 'text', + 'source_description': 'documentation', + }, + ] + + +@pytest.fixture +def large_dataset(): + """Generate a large dataset for stress testing.""" + return [ + { + 'name': f'Document {i}', + 'episode_body': TestDataGenerator.generate_technical_document(), + 'source': 'text', + 'source_description': 'bulk import', + } + for i in range(50) + ] \ No newline at end of file diff --git a/mcp_server/tests/test_stress_load.py b/mcp_server/tests/test_stress_load.py new file mode 100644 index 00000000..0ffa6349 --- /dev/null +++ b/mcp_server/tests/test_stress_load.py @@ -0,0 +1,524 @@ +#!/usr/bin/env python3 +""" +Stress and load testing for Graphiti MCP Server. +Tests system behavior under high load, resource constraints, and edge conditions. +""" + +import asyncio +import gc +import json +import os +import psutil +import random +import time +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple + +import pytest +from test_fixtures import TestDataGenerator, graphiti_test_client, PerformanceBenchmark + + +@dataclass +class LoadTestConfig: + """Configuration for load testing scenarios.""" + + num_clients: int = 10 + operations_per_client: int = 100 + ramp_up_time: float = 5.0 # seconds + test_duration: float = 60.0 # seconds + target_throughput: Optional[float] = None # ops/sec + think_time: float = 0.1 # seconds between ops + + +@dataclass +class LoadTestResult: + """Results from a load test run.""" + + total_operations: int + successful_operations: int + failed_operations: int + duration: float + throughput: float + average_latency: float + p50_latency: float + p95_latency: float + p99_latency: float + max_latency: float + errors: Dict[str, int] + resource_usage: Dict[str, float] + + +class LoadTester: + """Orchestrate load testing scenarios.""" + + def __init__(self, config: LoadTestConfig): + self.config = config + self.metrics: List[Tuple[float, float, bool]] = [] # (start, duration, success) + self.errors: Dict[str, int] = {} + self.start_time: Optional[float] = None + + async def run_client_workload( + self, + client_id: int, + session, + group_id: str + ) -> Dict[str, int]: + """Run workload for a single simulated client.""" + stats = {'success': 0, 'failure': 0} + data_gen = TestDataGenerator() + + # Ramp-up delay + ramp_delay = (client_id / self.config.num_clients) * self.config.ramp_up_time + await asyncio.sleep(ramp_delay) + + for op_num in range(self.config.operations_per_client): + operation_start = time.time() + + try: + # Randomly select operation type + operation = random.choice([ + 'add_memory', + 'search_memory_nodes', + 'get_episodes', + ]) + + if operation == 'add_memory': + args = { + 'name': f'Load Test {client_id}-{op_num}', + 'episode_body': data_gen.generate_technical_document(), + 'source': 'text', + 'source_description': 'load test', + 'group_id': group_id, + } + elif operation == 'search_memory_nodes': + args = { + 'query': random.choice(['performance', 'architecture', 'test', 'data']), + 'group_id': group_id, + 'limit': 10, + } + else: # get_episodes + args = { + 'group_id': group_id, + 'last_n': 10, + } + + # Execute operation with timeout + result = await asyncio.wait_for( + session.call_tool(operation, args), + timeout=30.0 + ) + + duration = time.time() - operation_start + self.metrics.append((operation_start, duration, True)) + stats['success'] += 1 + + except asyncio.TimeoutError: + duration = time.time() - operation_start + self.metrics.append((operation_start, duration, False)) + self.errors['timeout'] = self.errors.get('timeout', 0) + 1 + stats['failure'] += 1 + + except Exception as e: + duration = time.time() - operation_start + self.metrics.append((operation_start, duration, False)) + error_type = type(e).__name__ + self.errors[error_type] = self.errors.get(error_type, 0) + 1 + stats['failure'] += 1 + + # Think time between operations + await asyncio.sleep(self.config.think_time) + + # Stop if we've exceeded test duration + if self.start_time and (time.time() - self.start_time) > self.config.test_duration: + break + + return stats + + def calculate_results(self) -> LoadTestResult: + """Calculate load test results from metrics.""" + if not self.metrics: + return LoadTestResult(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, {}, {}) + + successful = [m for m in self.metrics if m[2]] + failed = [m for m in self.metrics if not m[2]] + + latencies = sorted([m[1] for m in self.metrics]) + duration = max([m[0] + m[1] for m in self.metrics]) - min([m[0] for m in self.metrics]) + + # Calculate percentiles + def percentile(data: List[float], p: float) -> float: + if not data: + return 0.0 + idx = int(len(data) * p / 100) + return data[min(idx, len(data) - 1)] + + # Get resource usage + process = psutil.Process() + resource_usage = { + 'cpu_percent': process.cpu_percent(), + 'memory_mb': process.memory_info().rss / 1024 / 1024, + 'num_threads': process.num_threads(), + } + + return LoadTestResult( + total_operations=len(self.metrics), + successful_operations=len(successful), + failed_operations=len(failed), + duration=duration, + throughput=len(self.metrics) / duration if duration > 0 else 0, + average_latency=sum(latencies) / len(latencies) if latencies else 0, + p50_latency=percentile(latencies, 50), + p95_latency=percentile(latencies, 95), + p99_latency=percentile(latencies, 99), + max_latency=max(latencies) if latencies else 0, + errors=self.errors, + resource_usage=resource_usage, + ) + + +class TestLoadScenarios: + """Various load testing scenarios.""" + + @pytest.mark.asyncio + @pytest.mark.slow + async def test_sustained_load(self): + """Test system under sustained moderate load.""" + config = LoadTestConfig( + num_clients=5, + operations_per_client=20, + ramp_up_time=2.0, + test_duration=30.0, + think_time=0.5, + ) + + async with graphiti_test_client() as (session, group_id): + tester = LoadTester(config) + tester.start_time = time.time() + + # Run client workloads + client_tasks = [] + for client_id in range(config.num_clients): + task = tester.run_client_workload(client_id, session, group_id) + client_tasks.append(task) + + # Execute all clients + await asyncio.gather(*client_tasks) + + # Calculate results + results = tester.calculate_results() + + # Assertions + assert results.successful_operations > results.failed_operations + assert results.average_latency < 5.0, f"Average latency too high: {results.average_latency:.2f}s" + assert results.p95_latency < 10.0, f"P95 latency too high: {results.p95_latency:.2f}s" + + # Report results + print(f"\nSustained Load Test Results:") + print(f" Total operations: {results.total_operations}") + print(f" Success rate: {results.successful_operations / results.total_operations * 100:.1f}%") + print(f" Throughput: {results.throughput:.2f} ops/s") + print(f" Avg latency: {results.average_latency:.2f}s") + print(f" P95 latency: {results.p95_latency:.2f}s") + + @pytest.mark.asyncio + @pytest.mark.slow + async def test_spike_load(self): + """Test system response to sudden load spikes.""" + async with graphiti_test_client() as (session, group_id): + # Normal load phase + normal_tasks = [] + for i in range(3): + task = session.call_tool( + 'add_memory', + { + 'name': f'Normal Load {i}', + 'episode_body': 'Normal operation', + 'source': 'text', + 'source_description': 'normal', + 'group_id': group_id, + } + ) + normal_tasks.append(task) + await asyncio.sleep(0.5) + + await asyncio.gather(*normal_tasks) + + # Spike phase - sudden burst of requests + spike_start = time.time() + spike_tasks = [] + for i in range(50): + task = session.call_tool( + 'add_memory', + { + 'name': f'Spike Load {i}', + 'episode_body': TestDataGenerator.generate_technical_document(), + 'source': 'text', + 'source_description': 'spike', + 'group_id': group_id, + } + ) + spike_tasks.append(task) + + # Execute spike + spike_results = await asyncio.gather(*spike_tasks, return_exceptions=True) + spike_duration = time.time() - spike_start + + # Analyze spike handling + spike_failures = sum(1 for r in spike_results if isinstance(r, Exception)) + spike_success_rate = (len(spike_results) - spike_failures) / len(spike_results) + + print(f"\nSpike Load Test Results:") + print(f" Spike size: {len(spike_tasks)} operations") + print(f" Duration: {spike_duration:.2f}s") + print(f" Success rate: {spike_success_rate * 100:.1f}%") + print(f" Throughput: {len(spike_tasks) / spike_duration:.2f} ops/s") + + # System should handle at least 80% of spike + assert spike_success_rate > 0.8, f"Too many failures during spike: {spike_failures}" + + @pytest.mark.asyncio + @pytest.mark.slow + async def test_memory_leak_detection(self): + """Test for memory leaks during extended operation.""" + async with graphiti_test_client() as (session, group_id): + process = psutil.Process() + gc.collect() # Force garbage collection + initial_memory = process.memory_info().rss / 1024 / 1024 # MB + + # Perform many operations + for batch in range(10): + batch_tasks = [] + for i in range(10): + task = session.call_tool( + 'add_memory', + { + 'name': f'Memory Test {batch}-{i}', + 'episode_body': TestDataGenerator.generate_technical_document(), + 'source': 'text', + 'source_description': 'memory test', + 'group_id': group_id, + } + ) + batch_tasks.append(task) + + await asyncio.gather(*batch_tasks) + + # Force garbage collection between batches + gc.collect() + await asyncio.sleep(1) + + # Check memory after operations + gc.collect() + final_memory = process.memory_info().rss / 1024 / 1024 # MB + memory_growth = final_memory - initial_memory + + print(f"\nMemory Leak Test:") + print(f" Initial memory: {initial_memory:.1f} MB") + print(f" Final memory: {final_memory:.1f} MB") + print(f" Growth: {memory_growth:.1f} MB") + + # Allow for some memory growth but flag potential leaks + # This is a soft check - actual threshold depends on system + if memory_growth > 100: # More than 100MB growth + print(f" ⚠️ Potential memory leak detected: {memory_growth:.1f} MB growth") + + @pytest.mark.asyncio + @pytest.mark.slow + async def test_connection_pool_exhaustion(self): + """Test behavior when connection pools are exhausted.""" + async with graphiti_test_client() as (session, group_id): + # Create many concurrent long-running operations + long_tasks = [] + for i in range(100): # Many more than typical pool size + task = session.call_tool( + 'search_memory_nodes', + { + 'query': f'complex query {i} ' + ' '.join([TestDataGenerator.fake.word() for _ in range(10)]), + 'group_id': group_id, + 'limit': 100, + } + ) + long_tasks.append(task) + + # Execute with timeout + try: + results = await asyncio.wait_for( + asyncio.gather(*long_tasks, return_exceptions=True), + timeout=60.0 + ) + + # Count connection-related errors + connection_errors = sum( + 1 for r in results + if isinstance(r, Exception) and 'connection' in str(r).lower() + ) + + print(f"\nConnection Pool Test:") + print(f" Total requests: {len(long_tasks)}") + print(f" Connection errors: {connection_errors}") + + except asyncio.TimeoutError: + print(" Test timed out - possible deadlock or exhaustion") + + @pytest.mark.asyncio + @pytest.mark.slow + async def test_gradual_degradation(self): + """Test system degradation under increasing load.""" + async with graphiti_test_client() as (session, group_id): + load_levels = [5, 10, 20, 40, 80] # Increasing concurrent operations + results_by_level = {} + + for level in load_levels: + level_start = time.time() + tasks = [] + + for i in range(level): + task = session.call_tool( + 'add_memory', + { + 'name': f'Load Level {level} Op {i}', + 'episode_body': f'Testing at load level {level}', + 'source': 'text', + 'source_description': 'degradation test', + 'group_id': group_id, + } + ) + tasks.append(task) + + # Execute level + level_results = await asyncio.gather(*tasks, return_exceptions=True) + level_duration = time.time() - level_start + + # Calculate metrics + failures = sum(1 for r in level_results if isinstance(r, Exception)) + success_rate = (level - failures) / level * 100 + throughput = level / level_duration + + results_by_level[level] = { + 'success_rate': success_rate, + 'throughput': throughput, + 'duration': level_duration, + } + + print(f"\nLoad Level {level}:") + print(f" Success rate: {success_rate:.1f}%") + print(f" Throughput: {throughput:.2f} ops/s") + print(f" Duration: {level_duration:.2f}s") + + # Brief pause between levels + await asyncio.sleep(2) + + # Verify graceful degradation + # Success rate should not drop below 50% even at high load + for level, metrics in results_by_level.items(): + assert metrics['success_rate'] > 50, f"Poor performance at load level {level}" + + +class TestResourceLimits: + """Test behavior at resource limits.""" + + @pytest.mark.asyncio + async def test_large_payload_handling(self): + """Test handling of very large payloads.""" + async with graphiti_test_client() as (session, group_id): + payload_sizes = [ + (1_000, "1KB"), + (10_000, "10KB"), + (100_000, "100KB"), + (1_000_000, "1MB"), + ] + + for size, label in payload_sizes: + content = "x" * size + + start_time = time.time() + try: + result = await asyncio.wait_for( + session.call_tool( + 'add_memory', + { + 'name': f'Large Payload {label}', + 'episode_body': content, + 'source': 'text', + 'source_description': 'payload test', + 'group_id': group_id, + } + ), + timeout=30.0 + ) + duration = time.time() - start_time + status = "✅ Success" + + except asyncio.TimeoutError: + duration = 30.0 + status = "⏱️ Timeout" + + except Exception as e: + duration = time.time() - start_time + status = f"❌ Error: {type(e).__name__}" + + print(f"Payload {label}: {status} ({duration:.2f}s)") + + @pytest.mark.asyncio + async def test_rate_limit_handling(self): + """Test handling of rate limits.""" + async with graphiti_test_client() as (session, group_id): + # Rapid fire requests to trigger rate limits + rapid_tasks = [] + for i in range(100): + task = session.call_tool( + 'add_memory', + { + 'name': f'Rate Limit Test {i}', + 'episode_body': f'Testing rate limit {i}', + 'source': 'text', + 'source_description': 'rate test', + 'group_id': group_id, + } + ) + rapid_tasks.append(task) + + # Execute without delays + results = await asyncio.gather(*rapid_tasks, return_exceptions=True) + + # Count rate limit errors + rate_limit_errors = sum( + 1 for r in results + if isinstance(r, Exception) and ('rate' in str(r).lower() or '429' in str(r)) + ) + + print(f"\nRate Limit Test:") + print(f" Total requests: {len(rapid_tasks)}") + print(f" Rate limit errors: {rate_limit_errors}") + print(f" Success rate: {(len(rapid_tasks) - rate_limit_errors) / len(rapid_tasks) * 100:.1f}%") + + +def generate_load_test_report(results: List[LoadTestResult]) -> str: + """Generate comprehensive load test report.""" + report = [] + report.append("\n" + "=" * 60) + report.append("LOAD TEST REPORT") + report.append("=" * 60) + + for i, result in enumerate(results): + report.append(f"\nTest Run {i + 1}:") + report.append(f" Total Operations: {result.total_operations}") + report.append(f" Success Rate: {result.successful_operations / result.total_operations * 100:.1f}%") + report.append(f" Throughput: {result.throughput:.2f} ops/s") + report.append(f" Latency (avg/p50/p95/p99/max): {result.average_latency:.2f}/{result.p50_latency:.2f}/{result.p95_latency:.2f}/{result.p99_latency:.2f}/{result.max_latency:.2f}s") + + if result.errors: + report.append(" Errors:") + for error_type, count in result.errors.items(): + report.append(f" {error_type}: {count}") + + report.append(" Resource Usage:") + for metric, value in result.resource_usage.items(): + report.append(f" {metric}: {value:.2f}") + + report.append("=" * 60) + return "\n".join(report) + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "--asyncio-mode=auto", "-m", "slow"]) \ No newline at end of file