conductor-checkpoint-msg_01TscHXmijzkqcTJX5sGTYP8
This commit is contained in:
parent
fefcd1a2de
commit
968c36c2d1
7 changed files with 2728 additions and 0 deletions
314
mcp_server/tests/README.md
Normal file
314
mcp_server/tests/README.md
Normal file
|
|
@ -0,0 +1,314 @@
|
||||||
|
# Graphiti MCP Server Integration Tests
|
||||||
|
|
||||||
|
This directory contains a comprehensive integration test suite for the Graphiti MCP Server using the official Python MCP SDK.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The test suite is designed to thoroughly test all aspects of the Graphiti MCP server with special consideration for LLM inference latency and system performance.
|
||||||
|
|
||||||
|
## Test Organization
|
||||||
|
|
||||||
|
### Core Test Modules
|
||||||
|
|
||||||
|
- **`test_comprehensive_integration.py`** - Main integration test suite covering all MCP tools
|
||||||
|
- **`test_async_operations.py`** - Tests for concurrent operations and async patterns
|
||||||
|
- **`test_stress_load.py`** - Stress testing and load testing scenarios
|
||||||
|
- **`test_fixtures.py`** - Shared fixtures and test utilities
|
||||||
|
- **`test_mcp_integration.py`** - Original MCP integration tests
|
||||||
|
- **`test_configuration.py`** - Configuration loading and validation tests
|
||||||
|
|
||||||
|
### Test Categories
|
||||||
|
|
||||||
|
Tests are organized with pytest markers:
|
||||||
|
|
||||||
|
- `unit` - Fast unit tests without external dependencies
|
||||||
|
- `integration` - Tests requiring database and services
|
||||||
|
- `slow` - Long-running tests (stress/load tests)
|
||||||
|
- `requires_neo4j` - Tests requiring Neo4j
|
||||||
|
- `requires_falkordb` - Tests requiring FalkorDB
|
||||||
|
- `requires_kuzu` - Tests requiring KuzuDB
|
||||||
|
- `requires_openai` - Tests requiring OpenAI API key
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Install test dependencies
|
||||||
|
uv add --dev pytest pytest-asyncio pytest-timeout pytest-xdist faker psutil
|
||||||
|
|
||||||
|
# Install MCP SDK
|
||||||
|
uv add mcp
|
||||||
|
```
|
||||||
|
|
||||||
|
## Running Tests
|
||||||
|
|
||||||
|
### Quick Start
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Run smoke tests (quick validation)
|
||||||
|
python tests/run_tests.py smoke
|
||||||
|
|
||||||
|
# Run integration tests with mock LLM
|
||||||
|
python tests/run_tests.py integration --mock-llm
|
||||||
|
|
||||||
|
# Run all tests
|
||||||
|
python tests/run_tests.py all
|
||||||
|
```
|
||||||
|
|
||||||
|
### Test Runner Options
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python tests/run_tests.py [suite] [options]
|
||||||
|
|
||||||
|
Suites:
|
||||||
|
unit - Unit tests only
|
||||||
|
integration - Integration tests
|
||||||
|
comprehensive - Comprehensive integration suite
|
||||||
|
async - Async operation tests
|
||||||
|
stress - Stress and load tests
|
||||||
|
smoke - Quick smoke tests
|
||||||
|
all - All tests
|
||||||
|
|
||||||
|
Options:
|
||||||
|
--database - Database backend (neo4j, falkordb, kuzu)
|
||||||
|
--mock-llm - Use mock LLM for faster testing
|
||||||
|
--parallel N - Run tests in parallel with N workers
|
||||||
|
--coverage - Generate coverage report
|
||||||
|
--skip-slow - Skip slow tests
|
||||||
|
--timeout N - Test timeout in seconds
|
||||||
|
--check-only - Only check prerequisites
|
||||||
|
```
|
||||||
|
|
||||||
|
### Examples
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Quick smoke test with KuzuDB
|
||||||
|
python tests/run_tests.py smoke --database kuzu
|
||||||
|
|
||||||
|
# Full integration test with Neo4j
|
||||||
|
python tests/run_tests.py integration --database neo4j
|
||||||
|
|
||||||
|
# Stress testing with parallel execution
|
||||||
|
python tests/run_tests.py stress --parallel 4
|
||||||
|
|
||||||
|
# Run with coverage
|
||||||
|
python tests/run_tests.py all --coverage
|
||||||
|
|
||||||
|
# Check prerequisites only
|
||||||
|
python tests/run_tests.py all --check-only
|
||||||
|
```
|
||||||
|
|
||||||
|
## Test Coverage
|
||||||
|
|
||||||
|
### Core Operations
|
||||||
|
- Server initialization and tool discovery
|
||||||
|
- Adding memories (text, JSON, message)
|
||||||
|
- Episode queue management
|
||||||
|
- Search operations (semantic, hybrid)
|
||||||
|
- Episode retrieval and deletion
|
||||||
|
- Entity and edge operations
|
||||||
|
|
||||||
|
### Async Operations
|
||||||
|
- Concurrent operations
|
||||||
|
- Queue management
|
||||||
|
- Sequential processing within groups
|
||||||
|
- Parallel processing across groups
|
||||||
|
|
||||||
|
### Performance Testing
|
||||||
|
- Latency measurement
|
||||||
|
- Throughput testing
|
||||||
|
- Batch processing
|
||||||
|
- Resource usage monitoring
|
||||||
|
|
||||||
|
### Stress Testing
|
||||||
|
- Sustained load scenarios
|
||||||
|
- Spike load handling
|
||||||
|
- Memory leak detection
|
||||||
|
- Connection pool exhaustion
|
||||||
|
- Rate limit handling
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
### Environment Variables
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Database configuration
|
||||||
|
export DATABASE_PROVIDER=kuzu # or neo4j, falkordb
|
||||||
|
export NEO4J_URI=bolt://localhost:7687
|
||||||
|
export NEO4J_USER=neo4j
|
||||||
|
export NEO4J_PASSWORD=graphiti
|
||||||
|
export FALKORDB_URI=redis://localhost:6379
|
||||||
|
export KUZU_PATH=./test_kuzu.db
|
||||||
|
|
||||||
|
# LLM configuration
|
||||||
|
export OPENAI_API_KEY=your_key_here # or use --mock-llm
|
||||||
|
|
||||||
|
# Test configuration
|
||||||
|
export TEST_MODE=true
|
||||||
|
export LOG_LEVEL=INFO
|
||||||
|
```
|
||||||
|
|
||||||
|
### pytest.ini Configuration
|
||||||
|
|
||||||
|
The `pytest.ini` file configures:
|
||||||
|
- Test discovery patterns
|
||||||
|
- Async mode settings
|
||||||
|
- Test markers
|
||||||
|
- Timeout settings
|
||||||
|
- Output formatting
|
||||||
|
|
||||||
|
## Test Fixtures
|
||||||
|
|
||||||
|
### Data Generation
|
||||||
|
|
||||||
|
The test suite includes comprehensive data generators:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from test_fixtures import TestDataGenerator
|
||||||
|
|
||||||
|
# Generate test data
|
||||||
|
company = TestDataGenerator.generate_company_profile()
|
||||||
|
conversation = TestDataGenerator.generate_conversation()
|
||||||
|
document = TestDataGenerator.generate_technical_document()
|
||||||
|
```
|
||||||
|
|
||||||
|
### Test Client
|
||||||
|
|
||||||
|
Simplified client creation:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from test_fixtures import graphiti_test_client
|
||||||
|
|
||||||
|
async with graphiti_test_client(database="kuzu") as (session, group_id):
|
||||||
|
# Use session for testing
|
||||||
|
result = await session.call_tool('add_memory', {...})
|
||||||
|
```
|
||||||
|
|
||||||
|
## Performance Considerations
|
||||||
|
|
||||||
|
### LLM Latency Management
|
||||||
|
|
||||||
|
The tests account for LLM inference latency through:
|
||||||
|
|
||||||
|
1. **Configurable timeouts** - Different timeouts for different operations
|
||||||
|
2. **Mock LLM option** - Fast testing without API calls
|
||||||
|
3. **Intelligent polling** - Adaptive waiting for episode processing
|
||||||
|
4. **Batch operations** - Testing efficiency of batched requests
|
||||||
|
|
||||||
|
### Resource Management
|
||||||
|
|
||||||
|
- Memory leak detection
|
||||||
|
- Connection pool monitoring
|
||||||
|
- Resource usage tracking
|
||||||
|
- Graceful degradation testing
|
||||||
|
|
||||||
|
## CI/CD Integration
|
||||||
|
|
||||||
|
### GitHub Actions
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
name: MCP Integration Tests
|
||||||
|
|
||||||
|
on: [push, pull_request]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
test:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
services:
|
||||||
|
neo4j:
|
||||||
|
image: neo4j:5.26
|
||||||
|
env:
|
||||||
|
NEO4J_AUTH: neo4j/graphiti
|
||||||
|
ports:
|
||||||
|
- 7687:7687
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v2
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
pip install uv
|
||||||
|
uv sync --extra dev
|
||||||
|
|
||||||
|
- name: Run smoke tests
|
||||||
|
run: python tests/run_tests.py smoke --mock-llm
|
||||||
|
|
||||||
|
- name: Run integration tests
|
||||||
|
run: python tests/run_tests.py integration --database neo4j
|
||||||
|
env:
|
||||||
|
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
### Common Issues
|
||||||
|
|
||||||
|
1. **Database connection failures**
|
||||||
|
```bash
|
||||||
|
# Check Neo4j
|
||||||
|
curl http://localhost:7474
|
||||||
|
|
||||||
|
# Check FalkorDB
|
||||||
|
redis-cli ping
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **API key issues**
|
||||||
|
```bash
|
||||||
|
# Use mock LLM for testing without API key
|
||||||
|
python tests/run_tests.py all --mock-llm
|
||||||
|
```
|
||||||
|
|
||||||
|
3. **Timeout errors**
|
||||||
|
```bash
|
||||||
|
# Increase timeout for slow systems
|
||||||
|
python tests/run_tests.py integration --timeout 600
|
||||||
|
```
|
||||||
|
|
||||||
|
4. **Memory issues**
|
||||||
|
```bash
|
||||||
|
# Skip stress tests on low-memory systems
|
||||||
|
python tests/run_tests.py all --skip-slow
|
||||||
|
```
|
||||||
|
|
||||||
|
## Test Reports
|
||||||
|
|
||||||
|
### Performance Report
|
||||||
|
|
||||||
|
After running performance tests:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from test_fixtures import PerformanceBenchmark
|
||||||
|
|
||||||
|
benchmark = PerformanceBenchmark()
|
||||||
|
# ... run tests ...
|
||||||
|
print(benchmark.report())
|
||||||
|
```
|
||||||
|
|
||||||
|
### Load Test Report
|
||||||
|
|
||||||
|
Stress tests generate detailed reports:
|
||||||
|
|
||||||
|
```
|
||||||
|
LOAD TEST REPORT
|
||||||
|
================
|
||||||
|
Test Run 1:
|
||||||
|
Total Operations: 100
|
||||||
|
Success Rate: 95.0%
|
||||||
|
Throughput: 12.5 ops/s
|
||||||
|
Latency (avg/p50/p95/p99/max): 0.8/0.7/1.5/2.1/3.2s
|
||||||
|
```
|
||||||
|
|
||||||
|
## Contributing
|
||||||
|
|
||||||
|
When adding new tests:
|
||||||
|
|
||||||
|
1. Use appropriate pytest markers
|
||||||
|
2. Include docstrings explaining test purpose
|
||||||
|
3. Use fixtures for common operations
|
||||||
|
4. Consider LLM latency in test design
|
||||||
|
5. Add timeout handling for long operations
|
||||||
|
6. Include performance metrics where relevant
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
See main project LICENSE file.
|
||||||
40
mcp_server/tests/pytest.ini
Normal file
40
mcp_server/tests/pytest.ini
Normal file
|
|
@ -0,0 +1,40 @@
|
||||||
|
[pytest]
|
||||||
|
# Pytest configuration for Graphiti MCP integration tests
|
||||||
|
|
||||||
|
# Test discovery patterns
|
||||||
|
python_files = test_*.py
|
||||||
|
python_classes = Test*
|
||||||
|
python_functions = test_*
|
||||||
|
|
||||||
|
# Asyncio configuration
|
||||||
|
asyncio_mode = auto
|
||||||
|
|
||||||
|
# Markers for test categorization
|
||||||
|
markers =
|
||||||
|
slow: marks tests as slow (deselect with '-m "not slow"')
|
||||||
|
integration: marks tests as integration tests requiring external services
|
||||||
|
unit: marks tests as unit tests
|
||||||
|
stress: marks tests as stress/load tests
|
||||||
|
requires_neo4j: test requires Neo4j database
|
||||||
|
requires_falkordb: test requires FalkorDB
|
||||||
|
requires_kuzu: test requires KuzuDB
|
||||||
|
requires_openai: test requires OpenAI API key
|
||||||
|
|
||||||
|
# Test output options
|
||||||
|
addopts =
|
||||||
|
-v
|
||||||
|
--tb=short
|
||||||
|
--strict-markers
|
||||||
|
--color=yes
|
||||||
|
-p no:warnings
|
||||||
|
|
||||||
|
# Timeout for tests (seconds)
|
||||||
|
timeout = 300
|
||||||
|
|
||||||
|
# Coverage options
|
||||||
|
testpaths = tests
|
||||||
|
|
||||||
|
# Environment variables for testing
|
||||||
|
env =
|
||||||
|
TEST_MODE=true
|
||||||
|
LOG_LEVEL=INFO
|
||||||
336
mcp_server/tests/run_tests.py
Normal file
336
mcp_server/tests/run_tests.py
Normal file
|
|
@ -0,0 +1,336 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Test runner for Graphiti MCP integration tests.
|
||||||
|
Provides various test execution modes and reporting options.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
class TestRunner:
|
||||||
|
"""Orchestrate test execution with various configurations."""
|
||||||
|
|
||||||
|
def __init__(self, args):
|
||||||
|
self.args = args
|
||||||
|
self.test_dir = Path(__file__).parent
|
||||||
|
self.results = {}
|
||||||
|
|
||||||
|
def check_prerequisites(self) -> Dict[str, bool]:
|
||||||
|
"""Check if required services and dependencies are available."""
|
||||||
|
checks = {}
|
||||||
|
|
||||||
|
# Check for OpenAI API key if not using mocks
|
||||||
|
if not self.args.mock_llm:
|
||||||
|
checks['openai_api_key'] = bool(os.environ.get('OPENAI_API_KEY'))
|
||||||
|
else:
|
||||||
|
checks['openai_api_key'] = True
|
||||||
|
|
||||||
|
# Check database availability based on backend
|
||||||
|
if self.args.database == 'neo4j':
|
||||||
|
checks['neo4j'] = self._check_neo4j()
|
||||||
|
elif self.args.database == 'falkordb':
|
||||||
|
checks['falkordb'] = self._check_falkordb()
|
||||||
|
elif self.args.database == 'kuzu':
|
||||||
|
checks['kuzu'] = True # KuzuDB is embedded
|
||||||
|
|
||||||
|
# Check Python dependencies
|
||||||
|
checks['mcp'] = self._check_python_package('mcp')
|
||||||
|
checks['pytest'] = self._check_python_package('pytest')
|
||||||
|
checks['pytest-asyncio'] = self._check_python_package('pytest-asyncio')
|
||||||
|
|
||||||
|
return checks
|
||||||
|
|
||||||
|
def _check_neo4j(self) -> bool:
|
||||||
|
"""Check if Neo4j is available."""
|
||||||
|
try:
|
||||||
|
import neo4j
|
||||||
|
# Try to connect
|
||||||
|
uri = os.environ.get('NEO4J_URI', 'bolt://localhost:7687')
|
||||||
|
user = os.environ.get('NEO4J_USER', 'neo4j')
|
||||||
|
password = os.environ.get('NEO4J_PASSWORD', 'graphiti')
|
||||||
|
|
||||||
|
driver = neo4j.GraphDatabase.driver(uri, auth=(user, password))
|
||||||
|
with driver.session() as session:
|
||||||
|
session.run("RETURN 1")
|
||||||
|
driver.close()
|
||||||
|
return True
|
||||||
|
except:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _check_falkordb(self) -> bool:
|
||||||
|
"""Check if FalkorDB is available."""
|
||||||
|
try:
|
||||||
|
import redis
|
||||||
|
uri = os.environ.get('FALKORDB_URI', 'redis://localhost:6379')
|
||||||
|
r = redis.from_url(uri)
|
||||||
|
r.ping()
|
||||||
|
return True
|
||||||
|
except:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _check_python_package(self, package: str) -> bool:
|
||||||
|
"""Check if a Python package is installed."""
|
||||||
|
try:
|
||||||
|
__import__(package.replace('-', '_'))
|
||||||
|
return True
|
||||||
|
except ImportError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def run_test_suite(self, suite: str) -> int:
|
||||||
|
"""Run a specific test suite."""
|
||||||
|
pytest_args = ['-v', '--tb=short']
|
||||||
|
|
||||||
|
# Add database marker
|
||||||
|
if self.args.database:
|
||||||
|
pytest_args.extend(['-m', f'not requires_{db}'
|
||||||
|
for db in ['neo4j', 'falkordb', 'kuzu']
|
||||||
|
if db != self.args.database])
|
||||||
|
|
||||||
|
# Add suite-specific arguments
|
||||||
|
if suite == 'unit':
|
||||||
|
pytest_args.extend(['-m', 'unit', 'test_*.py'])
|
||||||
|
elif suite == 'integration':
|
||||||
|
pytest_args.extend(['-m', 'integration or not unit', 'test_*.py'])
|
||||||
|
elif suite == 'comprehensive':
|
||||||
|
pytest_args.append('test_comprehensive_integration.py')
|
||||||
|
elif suite == 'async':
|
||||||
|
pytest_args.append('test_async_operations.py')
|
||||||
|
elif suite == 'stress':
|
||||||
|
pytest_args.extend(['-m', 'slow', 'test_stress_load.py'])
|
||||||
|
elif suite == 'smoke':
|
||||||
|
# Quick smoke test - just basic operations
|
||||||
|
pytest_args.extend([
|
||||||
|
'test_comprehensive_integration.py::TestCoreOperations::test_server_initialization',
|
||||||
|
'test_comprehensive_integration.py::TestCoreOperations::test_add_text_memory'
|
||||||
|
])
|
||||||
|
elif suite == 'all':
|
||||||
|
pytest_args.append('.')
|
||||||
|
else:
|
||||||
|
pytest_args.append(suite)
|
||||||
|
|
||||||
|
# Add coverage if requested
|
||||||
|
if self.args.coverage:
|
||||||
|
pytest_args.extend(['--cov=../src', '--cov-report=html'])
|
||||||
|
|
||||||
|
# Add parallel execution if requested
|
||||||
|
if self.args.parallel:
|
||||||
|
pytest_args.extend(['-n', str(self.args.parallel)])
|
||||||
|
|
||||||
|
# Add verbosity
|
||||||
|
if self.args.verbose:
|
||||||
|
pytest_args.append('-vv')
|
||||||
|
|
||||||
|
# Add markers to skip
|
||||||
|
if self.args.skip_slow:
|
||||||
|
pytest_args.extend(['-m', 'not slow'])
|
||||||
|
|
||||||
|
# Add timeout override
|
||||||
|
if self.args.timeout:
|
||||||
|
pytest_args.extend(['--timeout', str(self.args.timeout)])
|
||||||
|
|
||||||
|
# Add environment variables
|
||||||
|
env = os.environ.copy()
|
||||||
|
if self.args.mock_llm:
|
||||||
|
env['USE_MOCK_LLM'] = 'true'
|
||||||
|
if self.args.database:
|
||||||
|
env['DATABASE_PROVIDER'] = self.args.database
|
||||||
|
|
||||||
|
# Run tests
|
||||||
|
print(f"Running {suite} tests with pytest args: {' '.join(pytest_args)}")
|
||||||
|
return pytest.main(pytest_args)
|
||||||
|
|
||||||
|
def run_performance_benchmark(self):
|
||||||
|
"""Run performance benchmarking suite."""
|
||||||
|
print("Running performance benchmarks...")
|
||||||
|
|
||||||
|
# Import test modules
|
||||||
|
from test_comprehensive_integration import TestPerformance
|
||||||
|
from test_async_operations import TestAsyncPerformance
|
||||||
|
|
||||||
|
# Run performance tests
|
||||||
|
result = pytest.main([
|
||||||
|
'-v',
|
||||||
|
'test_comprehensive_integration.py::TestPerformance',
|
||||||
|
'test_async_operations.py::TestAsyncPerformance',
|
||||||
|
'--benchmark-only' if self.args.benchmark_only else '',
|
||||||
|
])
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def generate_report(self):
|
||||||
|
"""Generate test execution report."""
|
||||||
|
report = []
|
||||||
|
report.append("\n" + "=" * 60)
|
||||||
|
report.append("GRAPHITI MCP TEST EXECUTION REPORT")
|
||||||
|
report.append("=" * 60)
|
||||||
|
|
||||||
|
# Prerequisites check
|
||||||
|
checks = self.check_prerequisites()
|
||||||
|
report.append("\nPrerequisites:")
|
||||||
|
for check, passed in checks.items():
|
||||||
|
status = "✅" if passed else "❌"
|
||||||
|
report.append(f" {status} {check}")
|
||||||
|
|
||||||
|
# Test configuration
|
||||||
|
report.append(f"\nConfiguration:")
|
||||||
|
report.append(f" Database: {self.args.database}")
|
||||||
|
report.append(f" Mock LLM: {self.args.mock_llm}")
|
||||||
|
report.append(f" Parallel: {self.args.parallel or 'No'}")
|
||||||
|
report.append(f" Timeout: {self.args.timeout}s")
|
||||||
|
|
||||||
|
# Results summary (if available)
|
||||||
|
if self.results:
|
||||||
|
report.append(f"\nResults:")
|
||||||
|
for suite, result in self.results.items():
|
||||||
|
status = "✅ Passed" if result == 0 else f"❌ Failed ({result})"
|
||||||
|
report.append(f" {suite}: {status}")
|
||||||
|
|
||||||
|
report.append("=" * 60)
|
||||||
|
return "\n".join(report)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Main entry point for test runner."""
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description='Run Graphiti MCP integration tests',
|
||||||
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||||
|
epilog="""
|
||||||
|
Test Suites:
|
||||||
|
unit - Run unit tests only
|
||||||
|
integration - Run integration tests
|
||||||
|
comprehensive - Run comprehensive integration test suite
|
||||||
|
async - Run async operation tests
|
||||||
|
stress - Run stress and load tests
|
||||||
|
smoke - Run quick smoke tests
|
||||||
|
all - Run all tests
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
python run_tests.py smoke # Quick smoke test
|
||||||
|
python run_tests.py integration --parallel 4 # Run integration tests in parallel
|
||||||
|
python run_tests.py stress --database neo4j # Run stress tests with Neo4j
|
||||||
|
python run_tests.py all --coverage # Run all tests with coverage
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'suite',
|
||||||
|
choices=['unit', 'integration', 'comprehensive', 'async', 'stress', 'smoke', 'all'],
|
||||||
|
help='Test suite to run'
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--database',
|
||||||
|
choices=['neo4j', 'falkordb', 'kuzu'],
|
||||||
|
default='kuzu',
|
||||||
|
help='Database backend to test (default: kuzu)'
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--mock-llm',
|
||||||
|
action='store_true',
|
||||||
|
help='Use mock LLM for faster testing'
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--parallel',
|
||||||
|
type=int,
|
||||||
|
metavar='N',
|
||||||
|
help='Run tests in parallel with N workers'
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--coverage',
|
||||||
|
action='store_true',
|
||||||
|
help='Generate coverage report'
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--verbose',
|
||||||
|
action='store_true',
|
||||||
|
help='Verbose output'
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--skip-slow',
|
||||||
|
action='store_true',
|
||||||
|
help='Skip slow tests'
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--timeout',
|
||||||
|
type=int,
|
||||||
|
default=300,
|
||||||
|
help='Test timeout in seconds (default: 300)'
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--benchmark-only',
|
||||||
|
action='store_true',
|
||||||
|
help='Run only benchmark tests'
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--check-only',
|
||||||
|
action='store_true',
|
||||||
|
help='Only check prerequisites without running tests'
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Create test runner
|
||||||
|
runner = TestRunner(args)
|
||||||
|
|
||||||
|
# Check prerequisites
|
||||||
|
if args.check_only:
|
||||||
|
print(runner.generate_report())
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
# Check if prerequisites are met
|
||||||
|
checks = runner.check_prerequisites()
|
||||||
|
if not all(checks.values()):
|
||||||
|
print("⚠️ Some prerequisites are not met:")
|
||||||
|
for check, passed in checks.items():
|
||||||
|
if not passed:
|
||||||
|
print(f" ❌ {check}")
|
||||||
|
|
||||||
|
if not args.mock_llm and not checks.get('openai_api_key'):
|
||||||
|
print("\nHint: Use --mock-llm to run tests without OpenAI API key")
|
||||||
|
|
||||||
|
response = input("\nContinue anyway? (y/N): ")
|
||||||
|
if response.lower() != 'y':
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Run tests
|
||||||
|
print(f"\n🚀 Starting test execution: {args.suite}")
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
if args.benchmark_only:
|
||||||
|
result = runner.run_performance_benchmark()
|
||||||
|
else:
|
||||||
|
result = runner.run_test_suite(args.suite)
|
||||||
|
|
||||||
|
duration = time.time() - start_time
|
||||||
|
|
||||||
|
# Store results
|
||||||
|
runner.results[args.suite] = result
|
||||||
|
|
||||||
|
# Generate and print report
|
||||||
|
print(runner.generate_report())
|
||||||
|
print(f"\n⏱️ Test execution completed in {duration:.2f} seconds")
|
||||||
|
|
||||||
|
# Exit with test result code
|
||||||
|
sys.exit(result)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
494
mcp_server/tests/test_async_operations.py
Normal file
494
mcp_server/tests/test_async_operations.py
Normal file
|
|
@ -0,0 +1,494 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Asynchronous operation tests for Graphiti MCP Server.
|
||||||
|
Tests concurrent operations, queue management, and async patterns.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from test_fixtures import (
|
||||||
|
TestDataGenerator,
|
||||||
|
graphiti_test_client,
|
||||||
|
performance_benchmark,
|
||||||
|
test_data_generator,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsyncQueueManagement:
|
||||||
|
"""Test asynchronous queue operations and episode processing."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_sequential_queue_processing(self):
|
||||||
|
"""Verify episodes are processed sequentially within a group."""
|
||||||
|
async with graphiti_test_client() as (session, group_id):
|
||||||
|
# Add multiple episodes quickly
|
||||||
|
episodes = []
|
||||||
|
for i in range(5):
|
||||||
|
result = await session.call_tool(
|
||||||
|
'add_memory',
|
||||||
|
{
|
||||||
|
'name': f'Sequential Test {i}',
|
||||||
|
'episode_body': f'Episode {i} with timestamp {time.time()}',
|
||||||
|
'source': 'text',
|
||||||
|
'source_description': 'sequential test',
|
||||||
|
'group_id': group_id,
|
||||||
|
'reference_id': f'seq_{i}', # Add reference for tracking
|
||||||
|
}
|
||||||
|
)
|
||||||
|
episodes.append(result)
|
||||||
|
|
||||||
|
# Wait for processing
|
||||||
|
await asyncio.sleep(10) # Allow time for sequential processing
|
||||||
|
|
||||||
|
# Retrieve episodes and verify order
|
||||||
|
result = await session.call_tool(
|
||||||
|
'get_episodes',
|
||||||
|
{'group_id': group_id, 'last_n': 10}
|
||||||
|
)
|
||||||
|
|
||||||
|
processed_episodes = json.loads(result.content[0].text)['episodes']
|
||||||
|
|
||||||
|
# Verify all episodes were processed
|
||||||
|
assert len(processed_episodes) >= 5, f"Expected at least 5 episodes, got {len(processed_episodes)}"
|
||||||
|
|
||||||
|
# Verify sequential processing (timestamps should be ordered)
|
||||||
|
timestamps = [ep.get('created_at') for ep in processed_episodes]
|
||||||
|
assert timestamps == sorted(timestamps), "Episodes not processed in order"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_concurrent_group_processing(self):
|
||||||
|
"""Test that different groups can process concurrently."""
|
||||||
|
async with graphiti_test_client() as (session, _):
|
||||||
|
groups = [f'group_{i}_{time.time()}' for i in range(3)]
|
||||||
|
tasks = []
|
||||||
|
|
||||||
|
# Create tasks for different groups
|
||||||
|
for group_id in groups:
|
||||||
|
for j in range(2):
|
||||||
|
task = session.call_tool(
|
||||||
|
'add_memory',
|
||||||
|
{
|
||||||
|
'name': f'Group {group_id} Episode {j}',
|
||||||
|
'episode_body': f'Content for {group_id}',
|
||||||
|
'source': 'text',
|
||||||
|
'source_description': 'concurrent test',
|
||||||
|
'group_id': group_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
tasks.append(task)
|
||||||
|
|
||||||
|
# Execute all tasks concurrently
|
||||||
|
start_time = time.time()
|
||||||
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
execution_time = time.time() - start_time
|
||||||
|
|
||||||
|
# Verify all succeeded
|
||||||
|
failures = [r for r in results if isinstance(r, Exception)]
|
||||||
|
assert not failures, f"Concurrent operations failed: {failures}"
|
||||||
|
|
||||||
|
# Check that execution was actually concurrent (should be faster than sequential)
|
||||||
|
# Sequential would take at least 6 * processing_time
|
||||||
|
assert execution_time < 30, f"Concurrent execution too slow: {execution_time}s"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_queue_overflow_handling(self):
|
||||||
|
"""Test behavior when queue reaches capacity."""
|
||||||
|
async with graphiti_test_client() as (session, group_id):
|
||||||
|
# Attempt to add many episodes rapidly
|
||||||
|
tasks = []
|
||||||
|
for i in range(100): # Large number to potentially overflow
|
||||||
|
task = session.call_tool(
|
||||||
|
'add_memory',
|
||||||
|
{
|
||||||
|
'name': f'Overflow Test {i}',
|
||||||
|
'episode_body': f'Episode {i}',
|
||||||
|
'source': 'text',
|
||||||
|
'source_description': 'overflow test',
|
||||||
|
'group_id': group_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
tasks.append(task)
|
||||||
|
|
||||||
|
# Execute with gathering to catch any failures
|
||||||
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
# Count successful queuing
|
||||||
|
successful = sum(1 for r in results if not isinstance(r, Exception))
|
||||||
|
|
||||||
|
# Should handle overflow gracefully
|
||||||
|
assert successful > 0, "No episodes were queued successfully"
|
||||||
|
|
||||||
|
# Log overflow behavior
|
||||||
|
if successful < 100:
|
||||||
|
print(f"Queue overflow: {successful}/100 episodes queued")
|
||||||
|
|
||||||
|
|
||||||
|
class TestConcurrentOperations:
|
||||||
|
"""Test concurrent tool calls and operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_concurrent_search_operations(self):
|
||||||
|
"""Test multiple concurrent search operations."""
|
||||||
|
async with graphiti_test_client() as (session, group_id):
|
||||||
|
# First, add some test data
|
||||||
|
data_gen = TestDataGenerator()
|
||||||
|
|
||||||
|
add_tasks = []
|
||||||
|
for _ in range(5):
|
||||||
|
task = session.call_tool(
|
||||||
|
'add_memory',
|
||||||
|
{
|
||||||
|
'name': 'Search Test Data',
|
||||||
|
'episode_body': data_gen.generate_technical_document(),
|
||||||
|
'source': 'text',
|
||||||
|
'source_description': 'search test',
|
||||||
|
'group_id': group_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
add_tasks.append(task)
|
||||||
|
|
||||||
|
await asyncio.gather(*add_tasks)
|
||||||
|
await asyncio.sleep(15) # Wait for processing
|
||||||
|
|
||||||
|
# Now perform concurrent searches
|
||||||
|
search_queries = [
|
||||||
|
'architecture',
|
||||||
|
'performance',
|
||||||
|
'implementation',
|
||||||
|
'dependencies',
|
||||||
|
'latency',
|
||||||
|
]
|
||||||
|
|
||||||
|
search_tasks = []
|
||||||
|
for query in search_queries:
|
||||||
|
task = session.call_tool(
|
||||||
|
'search_memory_nodes',
|
||||||
|
{
|
||||||
|
'query': query,
|
||||||
|
'group_id': group_id,
|
||||||
|
'limit': 10,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
search_tasks.append(task)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
results = await asyncio.gather(*search_tasks, return_exceptions=True)
|
||||||
|
search_time = time.time() - start_time
|
||||||
|
|
||||||
|
# Verify all searches completed
|
||||||
|
failures = [r for r in results if isinstance(r, Exception)]
|
||||||
|
assert not failures, f"Search operations failed: {failures}"
|
||||||
|
|
||||||
|
# Verify concurrent execution efficiency
|
||||||
|
assert search_time < len(search_queries) * 2, "Searches not executing concurrently"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_mixed_operation_concurrency(self):
|
||||||
|
"""Test different types of operations running concurrently."""
|
||||||
|
async with graphiti_test_client() as (session, group_id):
|
||||||
|
operations = []
|
||||||
|
|
||||||
|
# Add memory operation
|
||||||
|
operations.append(session.call_tool(
|
||||||
|
'add_memory',
|
||||||
|
{
|
||||||
|
'name': 'Mixed Op Test',
|
||||||
|
'episode_body': 'Testing mixed operations',
|
||||||
|
'source': 'text',
|
||||||
|
'source_description': 'test',
|
||||||
|
'group_id': group_id,
|
||||||
|
}
|
||||||
|
))
|
||||||
|
|
||||||
|
# Search operation
|
||||||
|
operations.append(session.call_tool(
|
||||||
|
'search_memory_nodes',
|
||||||
|
{
|
||||||
|
'query': 'test',
|
||||||
|
'group_id': group_id,
|
||||||
|
'limit': 5,
|
||||||
|
}
|
||||||
|
))
|
||||||
|
|
||||||
|
# Get episodes operation
|
||||||
|
operations.append(session.call_tool(
|
||||||
|
'get_episodes',
|
||||||
|
{
|
||||||
|
'group_id': group_id,
|
||||||
|
'last_n': 10,
|
||||||
|
}
|
||||||
|
))
|
||||||
|
|
||||||
|
# Get status operation
|
||||||
|
operations.append(session.call_tool(
|
||||||
|
'get_status',
|
||||||
|
{}
|
||||||
|
))
|
||||||
|
|
||||||
|
# Execute all concurrently
|
||||||
|
results = await asyncio.gather(*operations, return_exceptions=True)
|
||||||
|
|
||||||
|
# Check results
|
||||||
|
for i, result in enumerate(results):
|
||||||
|
assert not isinstance(result, Exception), f"Operation {i} failed: {result}"
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsyncErrorHandling:
|
||||||
|
"""Test async error handling and recovery."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_timeout_recovery(self):
|
||||||
|
"""Test recovery from operation timeouts."""
|
||||||
|
async with graphiti_test_client() as (session, group_id):
|
||||||
|
# Create a very large episode that might timeout
|
||||||
|
large_content = "x" * 1000000 # 1MB of data
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await asyncio.wait_for(
|
||||||
|
session.call_tool(
|
||||||
|
'add_memory',
|
||||||
|
{
|
||||||
|
'name': 'Timeout Test',
|
||||||
|
'episode_body': large_content,
|
||||||
|
'source': 'text',
|
||||||
|
'source_description': 'timeout test',
|
||||||
|
'group_id': group_id,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
timeout=2.0 # Short timeout
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
# Expected timeout
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Verify server is still responsive after timeout
|
||||||
|
status_result = await session.call_tool('get_status', {})
|
||||||
|
assert status_result is not None, "Server unresponsive after timeout"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cancellation_handling(self):
|
||||||
|
"""Test proper handling of cancelled operations."""
|
||||||
|
async with graphiti_test_client() as (session, group_id):
|
||||||
|
# Start a long-running operation
|
||||||
|
task = asyncio.create_task(
|
||||||
|
session.call_tool(
|
||||||
|
'add_memory',
|
||||||
|
{
|
||||||
|
'name': 'Cancellation Test',
|
||||||
|
'episode_body': TestDataGenerator.generate_technical_document(),
|
||||||
|
'source': 'text',
|
||||||
|
'source_description': 'cancel test',
|
||||||
|
'group_id': group_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Cancel after a short delay
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
task.cancel()
|
||||||
|
|
||||||
|
# Verify cancellation was handled
|
||||||
|
with pytest.raises(asyncio.CancelledError):
|
||||||
|
await task
|
||||||
|
|
||||||
|
# Server should still be operational
|
||||||
|
result = await session.call_tool('get_status', {})
|
||||||
|
assert result is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_exception_propagation(self):
|
||||||
|
"""Test that exceptions are properly propagated in async context."""
|
||||||
|
async with graphiti_test_client() as (session, group_id):
|
||||||
|
# Call with invalid arguments
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
await session.call_tool(
|
||||||
|
'add_memory',
|
||||||
|
{
|
||||||
|
# Missing required fields
|
||||||
|
'group_id': group_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Server should remain operational
|
||||||
|
status = await session.call_tool('get_status', {})
|
||||||
|
assert status is not None
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsyncPerformance:
|
||||||
|
"""Performance tests for async operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_throughput(self, performance_benchmark):
|
||||||
|
"""Measure throughput of async operations."""
|
||||||
|
async with graphiti_test_client() as (session, group_id):
|
||||||
|
num_operations = 50
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# Create many concurrent operations
|
||||||
|
tasks = []
|
||||||
|
for i in range(num_operations):
|
||||||
|
task = session.call_tool(
|
||||||
|
'add_memory',
|
||||||
|
{
|
||||||
|
'name': f'Throughput Test {i}',
|
||||||
|
'episode_body': f'Content {i}',
|
||||||
|
'source': 'text',
|
||||||
|
'source_description': 'throughput test',
|
||||||
|
'group_id': group_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
tasks.append(task)
|
||||||
|
|
||||||
|
# Execute all
|
||||||
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
total_time = time.time() - start_time
|
||||||
|
|
||||||
|
# Calculate metrics
|
||||||
|
successful = sum(1 for r in results if not isinstance(r, Exception))
|
||||||
|
throughput = successful / total_time
|
||||||
|
|
||||||
|
performance_benchmark.record('async_throughput', throughput)
|
||||||
|
|
||||||
|
# Log results
|
||||||
|
print(f"\nAsync Throughput Test:")
|
||||||
|
print(f" Operations: {num_operations}")
|
||||||
|
print(f" Successful: {successful}")
|
||||||
|
print(f" Total time: {total_time:.2f}s")
|
||||||
|
print(f" Throughput: {throughput:.2f} ops/s")
|
||||||
|
|
||||||
|
# Assert minimum throughput
|
||||||
|
assert throughput > 1.0, f"Throughput too low: {throughput:.2f} ops/s"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_latency_under_load(self, performance_benchmark):
|
||||||
|
"""Test operation latency under concurrent load."""
|
||||||
|
async with graphiti_test_client() as (session, group_id):
|
||||||
|
# Create background load
|
||||||
|
background_tasks = []
|
||||||
|
for i in range(10):
|
||||||
|
task = asyncio.create_task(
|
||||||
|
session.call_tool(
|
||||||
|
'add_memory',
|
||||||
|
{
|
||||||
|
'name': f'Background {i}',
|
||||||
|
'episode_body': TestDataGenerator.generate_technical_document(),
|
||||||
|
'source': 'text',
|
||||||
|
'source_description': 'background',
|
||||||
|
'group_id': f'background_{group_id}',
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
background_tasks.append(task)
|
||||||
|
|
||||||
|
# Measure latency of operations under load
|
||||||
|
latencies = []
|
||||||
|
for _ in range(5):
|
||||||
|
start = time.time()
|
||||||
|
await session.call_tool(
|
||||||
|
'get_status',
|
||||||
|
{}
|
||||||
|
)
|
||||||
|
latency = time.time() - start
|
||||||
|
latencies.append(latency)
|
||||||
|
performance_benchmark.record('latency_under_load', latency)
|
||||||
|
|
||||||
|
# Clean up background tasks
|
||||||
|
for task in background_tasks:
|
||||||
|
task.cancel()
|
||||||
|
|
||||||
|
# Analyze latencies
|
||||||
|
avg_latency = sum(latencies) / len(latencies)
|
||||||
|
max_latency = max(latencies)
|
||||||
|
|
||||||
|
print(f"\nLatency Under Load:")
|
||||||
|
print(f" Average: {avg_latency:.3f}s")
|
||||||
|
print(f" Max: {max_latency:.3f}s")
|
||||||
|
|
||||||
|
# Assert acceptable latency
|
||||||
|
assert avg_latency < 2.0, f"Average latency too high: {avg_latency:.3f}s"
|
||||||
|
assert max_latency < 5.0, f"Max latency too high: {max_latency:.3f}s"
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsyncStreamHandling:
|
||||||
|
"""Test handling of streaming responses and data."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_large_response_streaming(self):
|
||||||
|
"""Test handling of large streamed responses."""
|
||||||
|
async with graphiti_test_client() as (session, group_id):
|
||||||
|
# Add many episodes
|
||||||
|
for i in range(20):
|
||||||
|
await session.call_tool(
|
||||||
|
'add_memory',
|
||||||
|
{
|
||||||
|
'name': f'Stream Test {i}',
|
||||||
|
'episode_body': f'Episode content {i}',
|
||||||
|
'source': 'text',
|
||||||
|
'source_description': 'stream test',
|
||||||
|
'group_id': group_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Wait for processing
|
||||||
|
await asyncio.sleep(30)
|
||||||
|
|
||||||
|
# Request large result set
|
||||||
|
result = await session.call_tool(
|
||||||
|
'get_episodes',
|
||||||
|
{
|
||||||
|
'group_id': group_id,
|
||||||
|
'last_n': 100, # Request all
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify response handling
|
||||||
|
episodes = json.loads(result.content[0].text)['episodes']
|
||||||
|
assert len(episodes) >= 20, f"Expected at least 20 episodes, got {len(episodes)}"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_incremental_processing(self):
|
||||||
|
"""Test incremental processing of results."""
|
||||||
|
async with graphiti_test_client() as (session, group_id):
|
||||||
|
# Add episodes incrementally
|
||||||
|
for batch in range(3):
|
||||||
|
batch_tasks = []
|
||||||
|
for i in range(5):
|
||||||
|
task = session.call_tool(
|
||||||
|
'add_memory',
|
||||||
|
{
|
||||||
|
'name': f'Batch {batch} Item {i}',
|
||||||
|
'episode_body': f'Content for batch {batch}',
|
||||||
|
'source': 'text',
|
||||||
|
'source_description': 'incremental test',
|
||||||
|
'group_id': group_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
batch_tasks.append(task)
|
||||||
|
|
||||||
|
# Process batch
|
||||||
|
await asyncio.gather(*batch_tasks)
|
||||||
|
|
||||||
|
# Wait for this batch to process
|
||||||
|
await asyncio.sleep(10)
|
||||||
|
|
||||||
|
# Verify incremental results
|
||||||
|
result = await session.call_tool(
|
||||||
|
'get_episodes',
|
||||||
|
{
|
||||||
|
'group_id': group_id,
|
||||||
|
'last_n': 100,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
episodes = json.loads(result.content[0].text)['episodes']
|
||||||
|
expected_min = (batch + 1) * 5
|
||||||
|
assert len(episodes) >= expected_min, f"Batch {batch}: Expected at least {expected_min} episodes"
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__, "-v", "--asyncio-mode=auto"])
|
||||||
696
mcp_server/tests/test_comprehensive_integration.py
Normal file
696
mcp_server/tests/test_comprehensive_integration.py
Normal file
|
|
@ -0,0 +1,696 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Comprehensive integration test suite for Graphiti MCP Server.
|
||||||
|
Covers all MCP tools with consideration for LLM inference latency.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from mcp import ClientSession, StdioServerParameters
|
||||||
|
from mcp.client.stdio import stdio_client
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TestMetrics:
|
||||||
|
"""Track test performance metrics."""
|
||||||
|
|
||||||
|
operation: str
|
||||||
|
start_time: float
|
||||||
|
end_time: float
|
||||||
|
success: bool
|
||||||
|
details: Dict[str, Any]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def duration(self) -> float:
|
||||||
|
"""Calculate operation duration in seconds."""
|
||||||
|
return self.end_time - self.start_time
|
||||||
|
|
||||||
|
|
||||||
|
class GraphitiTestClient:
|
||||||
|
"""Enhanced test client for comprehensive Graphiti MCP testing."""
|
||||||
|
|
||||||
|
def __init__(self, test_group_id: Optional[str] = None):
|
||||||
|
self.test_group_id = test_group_id or f'test_{int(time.time())}'
|
||||||
|
self.session = None
|
||||||
|
self.metrics: List[TestMetrics] = []
|
||||||
|
self.default_timeout = 30 # seconds
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
"""Initialize MCP client session."""
|
||||||
|
server_params = StdioServerParameters(
|
||||||
|
command='uv',
|
||||||
|
args=['run', 'main.py', '--transport', 'stdio'],
|
||||||
|
env={
|
||||||
|
'NEO4J_URI': os.environ.get('NEO4J_URI', 'bolt://localhost:7687'),
|
||||||
|
'NEO4J_USER': os.environ.get('NEO4J_USER', 'neo4j'),
|
||||||
|
'NEO4J_PASSWORD': os.environ.get('NEO4J_PASSWORD', 'graphiti'),
|
||||||
|
'OPENAI_API_KEY': os.environ.get('OPENAI_API_KEY'),
|
||||||
|
'KUZU_PATH': os.environ.get('KUZU_PATH', './test_kuzu.db'),
|
||||||
|
'FALKORDB_URI': os.environ.get('FALKORDB_URI', 'redis://localhost:6379'),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
self.client_context = stdio_client(server_params)
|
||||||
|
read, write = await self.client_context.__aenter__()
|
||||||
|
self.session = ClientSession(read, write)
|
||||||
|
await self.session.initialize()
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
"""Clean up client session."""
|
||||||
|
if self.session:
|
||||||
|
await self.session.close()
|
||||||
|
if hasattr(self, 'client_context'):
|
||||||
|
await self.client_context.__aexit__(exc_type, exc_val, exc_tb)
|
||||||
|
|
||||||
|
async def call_tool_with_metrics(
|
||||||
|
self, tool_name: str, arguments: Dict[str, Any], timeout: Optional[float] = None
|
||||||
|
) -> tuple[Any, TestMetrics]:
|
||||||
|
"""Call a tool and capture performance metrics."""
|
||||||
|
start_time = time.time()
|
||||||
|
timeout = timeout or self.default_timeout
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await asyncio.wait_for(
|
||||||
|
self.session.call_tool(tool_name, arguments),
|
||||||
|
timeout=timeout
|
||||||
|
)
|
||||||
|
|
||||||
|
content = result.content[0].text if result.content else None
|
||||||
|
success = True
|
||||||
|
details = {'result': content, 'tool': tool_name}
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
content = None
|
||||||
|
success = False
|
||||||
|
details = {'error': f'Timeout after {timeout}s', 'tool': tool_name}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
content = None
|
||||||
|
success = False
|
||||||
|
details = {'error': str(e), 'tool': tool_name}
|
||||||
|
|
||||||
|
end_time = time.time()
|
||||||
|
metric = TestMetrics(
|
||||||
|
operation=f'call_{tool_name}',
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
success=success,
|
||||||
|
details=details
|
||||||
|
)
|
||||||
|
self.metrics.append(metric)
|
||||||
|
|
||||||
|
return content, metric
|
||||||
|
|
||||||
|
async def wait_for_episode_processing(
|
||||||
|
self, expected_count: int = 1, max_wait: int = 60, poll_interval: int = 2
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Wait for episodes to be processed with intelligent polling.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
expected_count: Number of episodes expected to be processed
|
||||||
|
max_wait: Maximum seconds to wait
|
||||||
|
poll_interval: Seconds between status checks
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if episodes were processed successfully
|
||||||
|
"""
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
while (time.time() - start_time) < max_wait:
|
||||||
|
result, _ = await self.call_tool_with_metrics(
|
||||||
|
'get_episodes',
|
||||||
|
{'group_id': self.test_group_id, 'last_n': 100}
|
||||||
|
)
|
||||||
|
|
||||||
|
if result:
|
||||||
|
try:
|
||||||
|
episodes = json.loads(result) if isinstance(result, str) else result
|
||||||
|
if len(episodes.get('episodes', [])) >= expected_count:
|
||||||
|
return True
|
||||||
|
except (json.JSONDecodeError, AttributeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
await asyncio.sleep(poll_interval)
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class TestCoreOperations:
|
||||||
|
"""Test core Graphiti operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_server_initialization(self):
|
||||||
|
"""Verify server initializes with all required tools."""
|
||||||
|
async with GraphitiTestClient() as client:
|
||||||
|
tools_result = await client.session.list_tools()
|
||||||
|
tools = {tool.name for tool in tools_result.tools}
|
||||||
|
|
||||||
|
required_tools = {
|
||||||
|
'add_memory',
|
||||||
|
'search_memory_nodes',
|
||||||
|
'search_memory_facts',
|
||||||
|
'get_episodes',
|
||||||
|
'delete_episode',
|
||||||
|
'delete_entity_edge',
|
||||||
|
'get_entity_edge',
|
||||||
|
'clear_graph',
|
||||||
|
'get_status'
|
||||||
|
}
|
||||||
|
|
||||||
|
missing_tools = required_tools - tools
|
||||||
|
assert not missing_tools, f"Missing required tools: {missing_tools}"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_add_text_memory(self):
|
||||||
|
"""Test adding text-based memories."""
|
||||||
|
async with GraphitiTestClient() as client:
|
||||||
|
# Add memory
|
||||||
|
result, metric = await client.call_tool_with_metrics(
|
||||||
|
'add_memory',
|
||||||
|
{
|
||||||
|
'name': 'Tech Conference Notes',
|
||||||
|
'episode_body': 'The AI conference featured talks on LLMs, RAG systems, and knowledge graphs. Notable speakers included researchers from OpenAI and Anthropic.',
|
||||||
|
'source': 'text',
|
||||||
|
'source_description': 'conference notes',
|
||||||
|
'group_id': client.test_group_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert metric.success, f"Failed to add memory: {metric.details}"
|
||||||
|
assert 'queued' in str(result).lower()
|
||||||
|
|
||||||
|
# Wait for processing
|
||||||
|
processed = await client.wait_for_episode_processing(expected_count=1)
|
||||||
|
assert processed, "Episode was not processed within timeout"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_add_json_memory(self):
|
||||||
|
"""Test adding structured JSON memories."""
|
||||||
|
async with GraphitiTestClient() as client:
|
||||||
|
json_data = {
|
||||||
|
'project': {
|
||||||
|
'name': 'GraphitiDB',
|
||||||
|
'version': '2.0.0',
|
||||||
|
'features': ['temporal-awareness', 'hybrid-search', 'custom-entities']
|
||||||
|
},
|
||||||
|
'team': {
|
||||||
|
'size': 5,
|
||||||
|
'roles': ['engineering', 'product', 'research']
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result, metric = await client.call_tool_with_metrics(
|
||||||
|
'add_memory',
|
||||||
|
{
|
||||||
|
'name': 'Project Data',
|
||||||
|
'episode_body': json.dumps(json_data),
|
||||||
|
'source': 'json',
|
||||||
|
'source_description': 'project database',
|
||||||
|
'group_id': client.test_group_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert metric.success
|
||||||
|
assert 'queued' in str(result).lower()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_add_message_memory(self):
|
||||||
|
"""Test adding conversation/message memories."""
|
||||||
|
async with GraphitiTestClient() as client:
|
||||||
|
conversation = """
|
||||||
|
user: What are the key features of Graphiti?
|
||||||
|
assistant: Graphiti offers temporal-aware knowledge graphs, hybrid retrieval, and real-time updates.
|
||||||
|
user: How does it handle entity resolution?
|
||||||
|
assistant: It uses LLM-based entity extraction and deduplication with semantic similarity matching.
|
||||||
|
"""
|
||||||
|
|
||||||
|
result, metric = await client.call_tool_with_metrics(
|
||||||
|
'add_memory',
|
||||||
|
{
|
||||||
|
'name': 'Feature Discussion',
|
||||||
|
'episode_body': conversation,
|
||||||
|
'source': 'message',
|
||||||
|
'source_description': 'support chat',
|
||||||
|
'group_id': client.test_group_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert metric.success
|
||||||
|
assert metric.duration < 5, f"Add memory took too long: {metric.duration}s"
|
||||||
|
|
||||||
|
|
||||||
|
class TestSearchOperations:
|
||||||
|
"""Test search and retrieval operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_nodes_semantic(self):
|
||||||
|
"""Test semantic search for nodes."""
|
||||||
|
async with GraphitiTestClient() as client:
|
||||||
|
# First add some test data
|
||||||
|
await client.call_tool_with_metrics(
|
||||||
|
'add_memory',
|
||||||
|
{
|
||||||
|
'name': 'Product Launch',
|
||||||
|
'episode_body': 'Our new AI assistant product launches in Q2 2024 with advanced NLP capabilities.',
|
||||||
|
'source': 'text',
|
||||||
|
'source_description': 'product roadmap',
|
||||||
|
'group_id': client.test_group_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Wait for processing
|
||||||
|
await client.wait_for_episode_processing()
|
||||||
|
|
||||||
|
# Search for nodes
|
||||||
|
result, metric = await client.call_tool_with_metrics(
|
||||||
|
'search_memory_nodes',
|
||||||
|
{
|
||||||
|
'query': 'AI product features',
|
||||||
|
'group_id': client.test_group_id,
|
||||||
|
'limit': 10
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert metric.success
|
||||||
|
assert result is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_facts_with_filters(self):
|
||||||
|
"""Test fact search with various filters."""
|
||||||
|
async with GraphitiTestClient() as client:
|
||||||
|
# Add test data
|
||||||
|
await client.call_tool_with_metrics(
|
||||||
|
'add_memory',
|
||||||
|
{
|
||||||
|
'name': 'Company Facts',
|
||||||
|
'episode_body': 'Acme Corp was founded in 2020. They have 50 employees and $10M in revenue.',
|
||||||
|
'source': 'text',
|
||||||
|
'source_description': 'company profile',
|
||||||
|
'group_id': client.test_group_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
await client.wait_for_episode_processing()
|
||||||
|
|
||||||
|
# Search with date filter
|
||||||
|
result, metric = await client.call_tool_with_metrics(
|
||||||
|
'search_memory_facts',
|
||||||
|
{
|
||||||
|
'query': 'company information',
|
||||||
|
'group_id': client.test_group_id,
|
||||||
|
'created_after': '2020-01-01T00:00:00Z',
|
||||||
|
'limit': 20
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert metric.success
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_hybrid_search(self):
|
||||||
|
"""Test hybrid search combining semantic and keyword search."""
|
||||||
|
async with GraphitiTestClient() as client:
|
||||||
|
# Add diverse test data
|
||||||
|
test_memories = [
|
||||||
|
{
|
||||||
|
'name': 'Technical Doc',
|
||||||
|
'episode_body': 'GraphQL API endpoints support pagination, filtering, and real-time subscriptions.',
|
||||||
|
'source': 'text'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'name': 'Architecture',
|
||||||
|
'episode_body': 'The system uses Neo4j for graph storage and OpenAI embeddings for semantic search.',
|
||||||
|
'source': 'text'
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
for memory in test_memories:
|
||||||
|
memory['group_id'] = client.test_group_id
|
||||||
|
memory['source_description'] = 'documentation'
|
||||||
|
await client.call_tool_with_metrics('add_memory', memory)
|
||||||
|
|
||||||
|
await client.wait_for_episode_processing(expected_count=2)
|
||||||
|
|
||||||
|
# Test semantic + keyword search
|
||||||
|
result, metric = await client.call_tool_with_metrics(
|
||||||
|
'search_memory_nodes',
|
||||||
|
{
|
||||||
|
'query': 'Neo4j graph database',
|
||||||
|
'group_id': client.test_group_id,
|
||||||
|
'limit': 10
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert metric.success
|
||||||
|
|
||||||
|
|
||||||
|
class TestEpisodeManagement:
|
||||||
|
"""Test episode lifecycle operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_episodes_pagination(self):
|
||||||
|
"""Test retrieving episodes with pagination."""
|
||||||
|
async with GraphitiTestClient() as client:
|
||||||
|
# Add multiple episodes
|
||||||
|
for i in range(5):
|
||||||
|
await client.call_tool_with_metrics(
|
||||||
|
'add_memory',
|
||||||
|
{
|
||||||
|
'name': f'Episode {i}',
|
||||||
|
'episode_body': f'This is test episode number {i}',
|
||||||
|
'source': 'text',
|
||||||
|
'source_description': 'test',
|
||||||
|
'group_id': client.test_group_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
await client.wait_for_episode_processing(expected_count=5)
|
||||||
|
|
||||||
|
# Test pagination
|
||||||
|
result, metric = await client.call_tool_with_metrics(
|
||||||
|
'get_episodes',
|
||||||
|
{
|
||||||
|
'group_id': client.test_group_id,
|
||||||
|
'last_n': 3
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert metric.success
|
||||||
|
episodes = json.loads(result) if isinstance(result, str) else result
|
||||||
|
assert len(episodes.get('episodes', [])) <= 3
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_episode(self):
|
||||||
|
"""Test deleting specific episodes."""
|
||||||
|
async with GraphitiTestClient() as client:
|
||||||
|
# Add an episode
|
||||||
|
await client.call_tool_with_metrics(
|
||||||
|
'add_memory',
|
||||||
|
{
|
||||||
|
'name': 'To Delete',
|
||||||
|
'episode_body': 'This episode will be deleted',
|
||||||
|
'source': 'text',
|
||||||
|
'source_description': 'test',
|
||||||
|
'group_id': client.test_group_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
await client.wait_for_episode_processing()
|
||||||
|
|
||||||
|
# Get episode UUID
|
||||||
|
result, _ = await client.call_tool_with_metrics(
|
||||||
|
'get_episodes',
|
||||||
|
{'group_id': client.test_group_id, 'last_n': 1}
|
||||||
|
)
|
||||||
|
|
||||||
|
episodes = json.loads(result) if isinstance(result, str) else result
|
||||||
|
episode_uuid = episodes['episodes'][0]['uuid']
|
||||||
|
|
||||||
|
# Delete the episode
|
||||||
|
result, metric = await client.call_tool_with_metrics(
|
||||||
|
'delete_episode',
|
||||||
|
{'episode_uuid': episode_uuid}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert metric.success
|
||||||
|
assert 'deleted' in str(result).lower()
|
||||||
|
|
||||||
|
|
||||||
|
class TestEntityAndEdgeOperations:
|
||||||
|
"""Test entity and edge management."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_entity_edge(self):
|
||||||
|
"""Test retrieving entity edges."""
|
||||||
|
async with GraphitiTestClient() as client:
|
||||||
|
# Add data to create entities and edges
|
||||||
|
await client.call_tool_with_metrics(
|
||||||
|
'add_memory',
|
||||||
|
{
|
||||||
|
'name': 'Relationship Data',
|
||||||
|
'episode_body': 'Alice works at TechCorp. Bob is the CEO of TechCorp.',
|
||||||
|
'source': 'text',
|
||||||
|
'source_description': 'org chart',
|
||||||
|
'group_id': client.test_group_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
await client.wait_for_episode_processing()
|
||||||
|
|
||||||
|
# Search for nodes to get UUIDs
|
||||||
|
result, _ = await client.call_tool_with_metrics(
|
||||||
|
'search_memory_nodes',
|
||||||
|
{
|
||||||
|
'query': 'TechCorp',
|
||||||
|
'group_id': client.test_group_id,
|
||||||
|
'limit': 5
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Note: This test assumes edges are created between entities
|
||||||
|
# Actual edge retrieval would require valid edge UUIDs
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_entity_edge(self):
|
||||||
|
"""Test deleting entity edges."""
|
||||||
|
# Similar structure to get_entity_edge but with deletion
|
||||||
|
pass # Implement based on actual edge creation patterns
|
||||||
|
|
||||||
|
|
||||||
|
class TestErrorHandling:
|
||||||
|
"""Test error conditions and edge cases."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invalid_tool_arguments(self):
|
||||||
|
"""Test handling of invalid tool arguments."""
|
||||||
|
async with GraphitiTestClient() as client:
|
||||||
|
# Missing required arguments
|
||||||
|
result, metric = await client.call_tool_with_metrics(
|
||||||
|
'add_memory',
|
||||||
|
{'name': 'Incomplete'} # Missing required fields
|
||||||
|
)
|
||||||
|
|
||||||
|
assert not metric.success
|
||||||
|
assert 'error' in str(metric.details).lower()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_timeout_handling(self):
|
||||||
|
"""Test timeout handling for long operations."""
|
||||||
|
async with GraphitiTestClient() as client:
|
||||||
|
# Simulate a very large episode that might timeout
|
||||||
|
large_text = "Large document content. " * 10000
|
||||||
|
|
||||||
|
result, metric = await client.call_tool_with_metrics(
|
||||||
|
'add_memory',
|
||||||
|
{
|
||||||
|
'name': 'Large Document',
|
||||||
|
'episode_body': large_text,
|
||||||
|
'source': 'text',
|
||||||
|
'source_description': 'large file',
|
||||||
|
'group_id': client.test_group_id,
|
||||||
|
},
|
||||||
|
timeout=5 # Short timeout
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if timeout was handled gracefully
|
||||||
|
if not metric.success:
|
||||||
|
assert 'timeout' in str(metric.details).lower()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_concurrent_operations(self):
|
||||||
|
"""Test handling of concurrent operations."""
|
||||||
|
async with GraphitiTestClient() as client:
|
||||||
|
# Launch multiple operations concurrently
|
||||||
|
tasks = []
|
||||||
|
for i in range(5):
|
||||||
|
task = client.call_tool_with_metrics(
|
||||||
|
'add_memory',
|
||||||
|
{
|
||||||
|
'name': f'Concurrent {i}',
|
||||||
|
'episode_body': f'Concurrent operation {i}',
|
||||||
|
'source': 'text',
|
||||||
|
'source_description': 'concurrent test',
|
||||||
|
'group_id': client.test_group_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
tasks.append(task)
|
||||||
|
|
||||||
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
# Check that operations were queued successfully
|
||||||
|
successful = sum(1 for r, m in results if m.success)
|
||||||
|
assert successful >= 3 # At least 60% should succeed
|
||||||
|
|
||||||
|
|
||||||
|
class TestPerformance:
|
||||||
|
"""Test performance characteristics and optimization."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_latency_metrics(self):
|
||||||
|
"""Measure and validate operation latencies."""
|
||||||
|
async with GraphitiTestClient() as client:
|
||||||
|
operations = [
|
||||||
|
('add_memory', {
|
||||||
|
'name': 'Perf Test',
|
||||||
|
'episode_body': 'Simple text',
|
||||||
|
'source': 'text',
|
||||||
|
'source_description': 'test',
|
||||||
|
'group_id': client.test_group_id,
|
||||||
|
}),
|
||||||
|
('search_memory_nodes', {
|
||||||
|
'query': 'test',
|
||||||
|
'group_id': client.test_group_id,
|
||||||
|
'limit': 10
|
||||||
|
}),
|
||||||
|
('get_episodes', {
|
||||||
|
'group_id': client.test_group_id,
|
||||||
|
'last_n': 10
|
||||||
|
})
|
||||||
|
]
|
||||||
|
|
||||||
|
for tool_name, args in operations:
|
||||||
|
_, metric = await client.call_tool_with_metrics(tool_name, args)
|
||||||
|
|
||||||
|
# Log performance metrics
|
||||||
|
print(f"{tool_name}: {metric.duration:.2f}s")
|
||||||
|
|
||||||
|
# Basic latency assertions
|
||||||
|
if tool_name == 'get_episodes':
|
||||||
|
assert metric.duration < 2, f"{tool_name} too slow"
|
||||||
|
elif tool_name == 'search_memory_nodes':
|
||||||
|
assert metric.duration < 10, f"{tool_name} too slow"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_batch_processing_efficiency(self):
|
||||||
|
"""Test efficiency of batch operations."""
|
||||||
|
async with GraphitiTestClient() as client:
|
||||||
|
batch_size = 10
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# Batch add memories
|
||||||
|
for i in range(batch_size):
|
||||||
|
await client.call_tool_with_metrics(
|
||||||
|
'add_memory',
|
||||||
|
{
|
||||||
|
'name': f'Batch {i}',
|
||||||
|
'episode_body': f'Batch content {i}',
|
||||||
|
'source': 'text',
|
||||||
|
'source_description': 'batch test',
|
||||||
|
'group_id': client.test_group_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Wait for all to process
|
||||||
|
processed = await client.wait_for_episode_processing(
|
||||||
|
expected_count=batch_size,
|
||||||
|
max_wait=120 # Allow more time for batch
|
||||||
|
)
|
||||||
|
|
||||||
|
total_time = time.time() - start_time
|
||||||
|
avg_time_per_item = total_time / batch_size
|
||||||
|
|
||||||
|
assert processed, f"Failed to process {batch_size} items"
|
||||||
|
assert avg_time_per_item < 15, f"Batch processing too slow: {avg_time_per_item:.2f}s per item"
|
||||||
|
|
||||||
|
# Generate performance report
|
||||||
|
print(f"\nBatch Performance Report:")
|
||||||
|
print(f" Total items: {batch_size}")
|
||||||
|
print(f" Total time: {total_time:.2f}s")
|
||||||
|
print(f" Avg per item: {avg_time_per_item:.2f}s")
|
||||||
|
|
||||||
|
|
||||||
|
class TestDatabaseBackends:
|
||||||
|
"""Test different database backend configurations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("database", ["neo4j", "falkordb", "kuzu"])
|
||||||
|
async def test_database_operations(self, database):
|
||||||
|
"""Test operations with different database backends."""
|
||||||
|
env_vars = {
|
||||||
|
'DATABASE_PROVIDER': database,
|
||||||
|
'OPENAI_API_KEY': os.environ.get('OPENAI_API_KEY'),
|
||||||
|
}
|
||||||
|
|
||||||
|
if database == 'neo4j':
|
||||||
|
env_vars.update({
|
||||||
|
'NEO4J_URI': os.environ.get('NEO4J_URI', 'bolt://localhost:7687'),
|
||||||
|
'NEO4J_USER': os.environ.get('NEO4J_USER', 'neo4j'),
|
||||||
|
'NEO4J_PASSWORD': os.environ.get('NEO4J_PASSWORD', 'graphiti'),
|
||||||
|
})
|
||||||
|
elif database == 'falkordb':
|
||||||
|
env_vars['FALKORDB_URI'] = os.environ.get('FALKORDB_URI', 'redis://localhost:6379')
|
||||||
|
elif database == 'kuzu':
|
||||||
|
env_vars['KUZU_PATH'] = os.environ.get('KUZU_PATH', f'./test_kuzu_{int(time.time())}.db')
|
||||||
|
|
||||||
|
server_params = StdioServerParameters(
|
||||||
|
command='uv',
|
||||||
|
args=['run', 'main.py', '--transport', 'stdio', '--database', database],
|
||||||
|
env=env_vars
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test basic operations with each backend
|
||||||
|
# Implementation depends on database availability
|
||||||
|
|
||||||
|
|
||||||
|
def generate_test_report(client: GraphitiTestClient) -> str:
|
||||||
|
"""Generate a comprehensive test report from metrics."""
|
||||||
|
if not client.metrics:
|
||||||
|
return "No metrics collected"
|
||||||
|
|
||||||
|
report = []
|
||||||
|
report.append("\n" + "="*60)
|
||||||
|
report.append("GRAPHITI MCP TEST REPORT")
|
||||||
|
report.append("="*60)
|
||||||
|
|
||||||
|
# Summary statistics
|
||||||
|
total_ops = len(client.metrics)
|
||||||
|
successful_ops = sum(1 for m in client.metrics if m.success)
|
||||||
|
avg_duration = sum(m.duration for m in client.metrics) / total_ops
|
||||||
|
|
||||||
|
report.append(f"\nTotal Operations: {total_ops}")
|
||||||
|
report.append(f"Successful: {successful_ops} ({successful_ops/total_ops*100:.1f}%)")
|
||||||
|
report.append(f"Average Duration: {avg_duration:.2f}s")
|
||||||
|
|
||||||
|
# Operation breakdown
|
||||||
|
report.append("\nOperation Breakdown:")
|
||||||
|
operation_stats = {}
|
||||||
|
for metric in client.metrics:
|
||||||
|
if metric.operation not in operation_stats:
|
||||||
|
operation_stats[metric.operation] = {
|
||||||
|
'count': 0, 'success': 0, 'total_duration': 0
|
||||||
|
}
|
||||||
|
stats = operation_stats[metric.operation]
|
||||||
|
stats['count'] += 1
|
||||||
|
stats['success'] += 1 if metric.success else 0
|
||||||
|
stats['total_duration'] += metric.duration
|
||||||
|
|
||||||
|
for op, stats in sorted(operation_stats.items()):
|
||||||
|
avg_dur = stats['total_duration'] / stats['count']
|
||||||
|
success_rate = stats['success'] / stats['count'] * 100
|
||||||
|
report.append(
|
||||||
|
f" {op}: {stats['count']} calls, "
|
||||||
|
f"{success_rate:.0f}% success, {avg_dur:.2f}s avg"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Slowest operations
|
||||||
|
slowest = sorted(client.metrics, key=lambda m: m.duration, reverse=True)[:5]
|
||||||
|
report.append("\nSlowest Operations:")
|
||||||
|
for metric in slowest:
|
||||||
|
report.append(f" {metric.operation}: {metric.duration:.2f}s")
|
||||||
|
|
||||||
|
report.append("="*60)
|
||||||
|
return "\n".join(report)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Run tests with pytest
|
||||||
|
pytest.main([__file__, "-v", "--asyncio-mode=auto"])
|
||||||
324
mcp_server/tests/test_fixtures.py
Normal file
324
mcp_server/tests/test_fixtures.py
Normal file
|
|
@ -0,0 +1,324 @@
|
||||||
|
"""
|
||||||
|
Shared test fixtures and utilities for Graphiti MCP integration tests.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from faker import Faker
|
||||||
|
from mcp import ClientSession, StdioServerParameters
|
||||||
|
from mcp.client.stdio import stdio_client
|
||||||
|
|
||||||
|
fake = Faker()
|
||||||
|
|
||||||
|
|
||||||
|
class TestDataGenerator:
|
||||||
|
"""Generate realistic test data for various scenarios."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def generate_company_profile() -> Dict[str, Any]:
|
||||||
|
"""Generate a realistic company profile."""
|
||||||
|
return {
|
||||||
|
'company': {
|
||||||
|
'name': fake.company(),
|
||||||
|
'founded': random.randint(1990, 2023),
|
||||||
|
'industry': random.choice(['Tech', 'Finance', 'Healthcare', 'Retail']),
|
||||||
|
'employees': random.randint(10, 10000),
|
||||||
|
'revenue': f"${random.randint(1, 1000)}M",
|
||||||
|
'headquarters': fake.city(),
|
||||||
|
},
|
||||||
|
'products': [
|
||||||
|
{
|
||||||
|
'id': fake.uuid4()[:8],
|
||||||
|
'name': fake.catch_phrase(),
|
||||||
|
'category': random.choice(['Software', 'Hardware', 'Service']),
|
||||||
|
'price': random.randint(10, 10000),
|
||||||
|
}
|
||||||
|
for _ in range(random.randint(1, 5))
|
||||||
|
],
|
||||||
|
'leadership': {
|
||||||
|
'ceo': fake.name(),
|
||||||
|
'cto': fake.name(),
|
||||||
|
'cfo': fake.name(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def generate_conversation(turns: int = 3) -> str:
|
||||||
|
"""Generate a realistic conversation."""
|
||||||
|
topics = [
|
||||||
|
"product features",
|
||||||
|
"pricing",
|
||||||
|
"technical support",
|
||||||
|
"integration",
|
||||||
|
"documentation",
|
||||||
|
"performance",
|
||||||
|
]
|
||||||
|
|
||||||
|
conversation = []
|
||||||
|
for _ in range(turns):
|
||||||
|
topic = random.choice(topics)
|
||||||
|
user_msg = f"user: {fake.sentence()} about {topic}?"
|
||||||
|
assistant_msg = f"assistant: {fake.paragraph(nb_sentences=2)}"
|
||||||
|
conversation.extend([user_msg, assistant_msg])
|
||||||
|
|
||||||
|
return "\n".join(conversation)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def generate_technical_document() -> str:
|
||||||
|
"""Generate technical documentation content."""
|
||||||
|
sections = [
|
||||||
|
f"# {fake.catch_phrase()}\n\n{fake.paragraph()}",
|
||||||
|
f"## Architecture\n{fake.paragraph()}",
|
||||||
|
f"## Implementation\n{fake.paragraph()}",
|
||||||
|
f"## Performance\n- Latency: {random.randint(1, 100)}ms\n- Throughput: {random.randint(100, 10000)} req/s",
|
||||||
|
f"## Dependencies\n- {fake.word()}\n- {fake.word()}\n- {fake.word()}",
|
||||||
|
]
|
||||||
|
return "\n\n".join(sections)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def generate_news_article() -> str:
|
||||||
|
"""Generate a news article."""
|
||||||
|
company = fake.company()
|
||||||
|
return f"""
|
||||||
|
{company} Announces {fake.catch_phrase()}
|
||||||
|
|
||||||
|
{fake.city()}, {fake.date()} - {company} today announced {fake.paragraph()}.
|
||||||
|
|
||||||
|
"This is a significant milestone," said {fake.name()}, CEO of {company}.
|
||||||
|
"{fake.sentence()}"
|
||||||
|
|
||||||
|
The announcement comes after {fake.paragraph()}.
|
||||||
|
|
||||||
|
Industry analysts predict {fake.paragraph()}.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def generate_user_profile() -> Dict[str, Any]:
|
||||||
|
"""Generate a user profile."""
|
||||||
|
return {
|
||||||
|
'user_id': fake.uuid4(),
|
||||||
|
'name': fake.name(),
|
||||||
|
'email': fake.email(),
|
||||||
|
'joined': fake.date_time_this_year().isoformat(),
|
||||||
|
'preferences': {
|
||||||
|
'theme': random.choice(['light', 'dark', 'auto']),
|
||||||
|
'notifications': random.choice([True, False]),
|
||||||
|
'language': random.choice(['en', 'es', 'fr', 'de']),
|
||||||
|
},
|
||||||
|
'activity': {
|
||||||
|
'last_login': fake.date_time_this_month().isoformat(),
|
||||||
|
'total_sessions': random.randint(1, 1000),
|
||||||
|
'average_duration': f"{random.randint(1, 60)} minutes",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class MockLLMProvider:
|
||||||
|
"""Mock LLM provider for testing without actual API calls."""
|
||||||
|
|
||||||
|
def __init__(self, delay: float = 0.1):
|
||||||
|
self.delay = delay # Simulate LLM latency
|
||||||
|
|
||||||
|
async def generate(self, prompt: str) -> str:
|
||||||
|
"""Simulate LLM generation with delay."""
|
||||||
|
await asyncio.sleep(self.delay)
|
||||||
|
|
||||||
|
# Return deterministic responses based on prompt patterns
|
||||||
|
if "extract entities" in prompt.lower():
|
||||||
|
return json.dumps({
|
||||||
|
'entities': [
|
||||||
|
{'name': 'TestEntity1', 'type': 'PERSON'},
|
||||||
|
{'name': 'TestEntity2', 'type': 'ORGANIZATION'},
|
||||||
|
]
|
||||||
|
})
|
||||||
|
elif "summarize" in prompt.lower():
|
||||||
|
return "This is a test summary of the provided content."
|
||||||
|
else:
|
||||||
|
return "Mock LLM response"
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def graphiti_test_client(
|
||||||
|
group_id: Optional[str] = None,
|
||||||
|
database: str = "kuzu",
|
||||||
|
use_mock_llm: bool = False,
|
||||||
|
config_overrides: Optional[Dict[str, Any]] = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Context manager for creating test clients with various configurations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
group_id: Test group identifier
|
||||||
|
database: Database backend (neo4j, falkordb, kuzu)
|
||||||
|
use_mock_llm: Whether to use mock LLM for faster tests
|
||||||
|
config_overrides: Additional config overrides
|
||||||
|
"""
|
||||||
|
test_group_id = group_id or f'test_{int(time.time())}_{random.randint(1000, 9999)}'
|
||||||
|
|
||||||
|
env = {
|
||||||
|
'DATABASE_PROVIDER': database,
|
||||||
|
'OPENAI_API_KEY': os.environ.get('OPENAI_API_KEY', 'test_key' if use_mock_llm else None),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Database-specific configuration
|
||||||
|
if database == 'neo4j':
|
||||||
|
env.update({
|
||||||
|
'NEO4J_URI': os.environ.get('NEO4J_URI', 'bolt://localhost:7687'),
|
||||||
|
'NEO4J_USER': os.environ.get('NEO4J_USER', 'neo4j'),
|
||||||
|
'NEO4J_PASSWORD': os.environ.get('NEO4J_PASSWORD', 'graphiti'),
|
||||||
|
})
|
||||||
|
elif database == 'falkordb':
|
||||||
|
env['FALKORDB_URI'] = os.environ.get('FALKORDB_URI', 'redis://localhost:6379')
|
||||||
|
elif database == 'kuzu':
|
||||||
|
env['KUZU_PATH'] = os.environ.get('KUZU_PATH', f'./test_kuzu_{test_group_id}.db')
|
||||||
|
|
||||||
|
# Apply config overrides
|
||||||
|
if config_overrides:
|
||||||
|
env.update(config_overrides)
|
||||||
|
|
||||||
|
# Add mock LLM flag if needed
|
||||||
|
if use_mock_llm:
|
||||||
|
env['USE_MOCK_LLM'] = 'true'
|
||||||
|
|
||||||
|
server_params = StdioServerParameters(
|
||||||
|
command='uv',
|
||||||
|
args=['run', 'main.py', '--transport', 'stdio'],
|
||||||
|
env=env
|
||||||
|
)
|
||||||
|
|
||||||
|
async with stdio_client(server_params) as (read, write):
|
||||||
|
session = ClientSession(read, write)
|
||||||
|
await session.initialize()
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield session, test_group_id
|
||||||
|
finally:
|
||||||
|
# Cleanup: Clear test data
|
||||||
|
try:
|
||||||
|
await session.call_tool('clear_graph', {'group_id': test_group_id})
|
||||||
|
except:
|
||||||
|
pass # Ignore cleanup errors
|
||||||
|
|
||||||
|
await session.close()
|
||||||
|
|
||||||
|
|
||||||
|
class PerformanceBenchmark:
|
||||||
|
"""Track and analyze performance benchmarks."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.measurements: Dict[str, List[float]] = {}
|
||||||
|
|
||||||
|
def record(self, operation: str, duration: float):
|
||||||
|
"""Record a performance measurement."""
|
||||||
|
if operation not in self.measurements:
|
||||||
|
self.measurements[operation] = []
|
||||||
|
self.measurements[operation].append(duration)
|
||||||
|
|
||||||
|
def get_stats(self, operation: str) -> Dict[str, float]:
|
||||||
|
"""Get statistics for an operation."""
|
||||||
|
if operation not in self.measurements or not self.measurements[operation]:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
durations = self.measurements[operation]
|
||||||
|
return {
|
||||||
|
'count': len(durations),
|
||||||
|
'mean': sum(durations) / len(durations),
|
||||||
|
'min': min(durations),
|
||||||
|
'max': max(durations),
|
||||||
|
'median': sorted(durations)[len(durations) // 2],
|
||||||
|
}
|
||||||
|
|
||||||
|
def report(self) -> str:
|
||||||
|
"""Generate a performance report."""
|
||||||
|
lines = ["Performance Benchmark Report", "=" * 40]
|
||||||
|
|
||||||
|
for operation in sorted(self.measurements.keys()):
|
||||||
|
stats = self.get_stats(operation)
|
||||||
|
lines.append(f"\n{operation}:")
|
||||||
|
lines.append(f" Samples: {stats['count']}")
|
||||||
|
lines.append(f" Mean: {stats['mean']:.3f}s")
|
||||||
|
lines.append(f" Median: {stats['median']:.3f}s")
|
||||||
|
lines.append(f" Min: {stats['min']:.3f}s")
|
||||||
|
lines.append(f" Max: {stats['max']:.3f}s")
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
# Pytest fixtures
|
||||||
|
@pytest.fixture
|
||||||
|
def test_data_generator():
|
||||||
|
"""Provide test data generator."""
|
||||||
|
return TestDataGenerator()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def performance_benchmark():
|
||||||
|
"""Provide performance benchmark tracker."""
|
||||||
|
return PerformanceBenchmark()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def mock_graphiti_client():
|
||||||
|
"""Provide a Graphiti client with mocked LLM."""
|
||||||
|
async with graphiti_test_client(use_mock_llm=True) as (session, group_id):
|
||||||
|
yield session, group_id
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def graphiti_client():
|
||||||
|
"""Provide a real Graphiti client."""
|
||||||
|
async with graphiti_test_client(use_mock_llm=False) as (session, group_id):
|
||||||
|
yield session, group_id
|
||||||
|
|
||||||
|
|
||||||
|
# Test data fixtures
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_memories():
|
||||||
|
"""Provide sample memory data for testing."""
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
'name': 'Company Overview',
|
||||||
|
'episode_body': TestDataGenerator.generate_company_profile(),
|
||||||
|
'source': 'json',
|
||||||
|
'source_description': 'company database',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'name': 'Product Launch',
|
||||||
|
'episode_body': TestDataGenerator.generate_news_article(),
|
||||||
|
'source': 'text',
|
||||||
|
'source_description': 'press release',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'name': 'Customer Support',
|
||||||
|
'episode_body': TestDataGenerator.generate_conversation(),
|
||||||
|
'source': 'message',
|
||||||
|
'source_description': 'support chat',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'name': 'Technical Specs',
|
||||||
|
'episode_body': TestDataGenerator.generate_technical_document(),
|
||||||
|
'source': 'text',
|
||||||
|
'source_description': 'documentation',
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def large_dataset():
|
||||||
|
"""Generate a large dataset for stress testing."""
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
'name': f'Document {i}',
|
||||||
|
'episode_body': TestDataGenerator.generate_technical_document(),
|
||||||
|
'source': 'text',
|
||||||
|
'source_description': 'bulk import',
|
||||||
|
}
|
||||||
|
for i in range(50)
|
||||||
|
]
|
||||||
524
mcp_server/tests/test_stress_load.py
Normal file
524
mcp_server/tests/test_stress_load.py
Normal file
|
|
@ -0,0 +1,524 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Stress and load testing for Graphiti MCP Server.
|
||||||
|
Tests system behavior under high load, resource constraints, and edge conditions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import gc
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import psutil
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from test_fixtures import TestDataGenerator, graphiti_test_client, PerformanceBenchmark
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LoadTestConfig:
|
||||||
|
"""Configuration for load testing scenarios."""
|
||||||
|
|
||||||
|
num_clients: int = 10
|
||||||
|
operations_per_client: int = 100
|
||||||
|
ramp_up_time: float = 5.0 # seconds
|
||||||
|
test_duration: float = 60.0 # seconds
|
||||||
|
target_throughput: Optional[float] = None # ops/sec
|
||||||
|
think_time: float = 0.1 # seconds between ops
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LoadTestResult:
|
||||||
|
"""Results from a load test run."""
|
||||||
|
|
||||||
|
total_operations: int
|
||||||
|
successful_operations: int
|
||||||
|
failed_operations: int
|
||||||
|
duration: float
|
||||||
|
throughput: float
|
||||||
|
average_latency: float
|
||||||
|
p50_latency: float
|
||||||
|
p95_latency: float
|
||||||
|
p99_latency: float
|
||||||
|
max_latency: float
|
||||||
|
errors: Dict[str, int]
|
||||||
|
resource_usage: Dict[str, float]
|
||||||
|
|
||||||
|
|
||||||
|
class LoadTester:
|
||||||
|
"""Orchestrate load testing scenarios."""
|
||||||
|
|
||||||
|
def __init__(self, config: LoadTestConfig):
|
||||||
|
self.config = config
|
||||||
|
self.metrics: List[Tuple[float, float, bool]] = [] # (start, duration, success)
|
||||||
|
self.errors: Dict[str, int] = {}
|
||||||
|
self.start_time: Optional[float] = None
|
||||||
|
|
||||||
|
async def run_client_workload(
|
||||||
|
self,
|
||||||
|
client_id: int,
|
||||||
|
session,
|
||||||
|
group_id: str
|
||||||
|
) -> Dict[str, int]:
|
||||||
|
"""Run workload for a single simulated client."""
|
||||||
|
stats = {'success': 0, 'failure': 0}
|
||||||
|
data_gen = TestDataGenerator()
|
||||||
|
|
||||||
|
# Ramp-up delay
|
||||||
|
ramp_delay = (client_id / self.config.num_clients) * self.config.ramp_up_time
|
||||||
|
await asyncio.sleep(ramp_delay)
|
||||||
|
|
||||||
|
for op_num in range(self.config.operations_per_client):
|
||||||
|
operation_start = time.time()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Randomly select operation type
|
||||||
|
operation = random.choice([
|
||||||
|
'add_memory',
|
||||||
|
'search_memory_nodes',
|
||||||
|
'get_episodes',
|
||||||
|
])
|
||||||
|
|
||||||
|
if operation == 'add_memory':
|
||||||
|
args = {
|
||||||
|
'name': f'Load Test {client_id}-{op_num}',
|
||||||
|
'episode_body': data_gen.generate_technical_document(),
|
||||||
|
'source': 'text',
|
||||||
|
'source_description': 'load test',
|
||||||
|
'group_id': group_id,
|
||||||
|
}
|
||||||
|
elif operation == 'search_memory_nodes':
|
||||||
|
args = {
|
||||||
|
'query': random.choice(['performance', 'architecture', 'test', 'data']),
|
||||||
|
'group_id': group_id,
|
||||||
|
'limit': 10,
|
||||||
|
}
|
||||||
|
else: # get_episodes
|
||||||
|
args = {
|
||||||
|
'group_id': group_id,
|
||||||
|
'last_n': 10,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Execute operation with timeout
|
||||||
|
result = await asyncio.wait_for(
|
||||||
|
session.call_tool(operation, args),
|
||||||
|
timeout=30.0
|
||||||
|
)
|
||||||
|
|
||||||
|
duration = time.time() - operation_start
|
||||||
|
self.metrics.append((operation_start, duration, True))
|
||||||
|
stats['success'] += 1
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
duration = time.time() - operation_start
|
||||||
|
self.metrics.append((operation_start, duration, False))
|
||||||
|
self.errors['timeout'] = self.errors.get('timeout', 0) + 1
|
||||||
|
stats['failure'] += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
duration = time.time() - operation_start
|
||||||
|
self.metrics.append((operation_start, duration, False))
|
||||||
|
error_type = type(e).__name__
|
||||||
|
self.errors[error_type] = self.errors.get(error_type, 0) + 1
|
||||||
|
stats['failure'] += 1
|
||||||
|
|
||||||
|
# Think time between operations
|
||||||
|
await asyncio.sleep(self.config.think_time)
|
||||||
|
|
||||||
|
# Stop if we've exceeded test duration
|
||||||
|
if self.start_time and (time.time() - self.start_time) > self.config.test_duration:
|
||||||
|
break
|
||||||
|
|
||||||
|
return stats
|
||||||
|
|
||||||
|
def calculate_results(self) -> LoadTestResult:
|
||||||
|
"""Calculate load test results from metrics."""
|
||||||
|
if not self.metrics:
|
||||||
|
return LoadTestResult(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, {}, {})
|
||||||
|
|
||||||
|
successful = [m for m in self.metrics if m[2]]
|
||||||
|
failed = [m for m in self.metrics if not m[2]]
|
||||||
|
|
||||||
|
latencies = sorted([m[1] for m in self.metrics])
|
||||||
|
duration = max([m[0] + m[1] for m in self.metrics]) - min([m[0] for m in self.metrics])
|
||||||
|
|
||||||
|
# Calculate percentiles
|
||||||
|
def percentile(data: List[float], p: float) -> float:
|
||||||
|
if not data:
|
||||||
|
return 0.0
|
||||||
|
idx = int(len(data) * p / 100)
|
||||||
|
return data[min(idx, len(data) - 1)]
|
||||||
|
|
||||||
|
# Get resource usage
|
||||||
|
process = psutil.Process()
|
||||||
|
resource_usage = {
|
||||||
|
'cpu_percent': process.cpu_percent(),
|
||||||
|
'memory_mb': process.memory_info().rss / 1024 / 1024,
|
||||||
|
'num_threads': process.num_threads(),
|
||||||
|
}
|
||||||
|
|
||||||
|
return LoadTestResult(
|
||||||
|
total_operations=len(self.metrics),
|
||||||
|
successful_operations=len(successful),
|
||||||
|
failed_operations=len(failed),
|
||||||
|
duration=duration,
|
||||||
|
throughput=len(self.metrics) / duration if duration > 0 else 0,
|
||||||
|
average_latency=sum(latencies) / len(latencies) if latencies else 0,
|
||||||
|
p50_latency=percentile(latencies, 50),
|
||||||
|
p95_latency=percentile(latencies, 95),
|
||||||
|
p99_latency=percentile(latencies, 99),
|
||||||
|
max_latency=max(latencies) if latencies else 0,
|
||||||
|
errors=self.errors,
|
||||||
|
resource_usage=resource_usage,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestLoadScenarios:
|
||||||
|
"""Various load testing scenarios."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.slow
|
||||||
|
async def test_sustained_load(self):
|
||||||
|
"""Test system under sustained moderate load."""
|
||||||
|
config = LoadTestConfig(
|
||||||
|
num_clients=5,
|
||||||
|
operations_per_client=20,
|
||||||
|
ramp_up_time=2.0,
|
||||||
|
test_duration=30.0,
|
||||||
|
think_time=0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
async with graphiti_test_client() as (session, group_id):
|
||||||
|
tester = LoadTester(config)
|
||||||
|
tester.start_time = time.time()
|
||||||
|
|
||||||
|
# Run client workloads
|
||||||
|
client_tasks = []
|
||||||
|
for client_id in range(config.num_clients):
|
||||||
|
task = tester.run_client_workload(client_id, session, group_id)
|
||||||
|
client_tasks.append(task)
|
||||||
|
|
||||||
|
# Execute all clients
|
||||||
|
await asyncio.gather(*client_tasks)
|
||||||
|
|
||||||
|
# Calculate results
|
||||||
|
results = tester.calculate_results()
|
||||||
|
|
||||||
|
# Assertions
|
||||||
|
assert results.successful_operations > results.failed_operations
|
||||||
|
assert results.average_latency < 5.0, f"Average latency too high: {results.average_latency:.2f}s"
|
||||||
|
assert results.p95_latency < 10.0, f"P95 latency too high: {results.p95_latency:.2f}s"
|
||||||
|
|
||||||
|
# Report results
|
||||||
|
print(f"\nSustained Load Test Results:")
|
||||||
|
print(f" Total operations: {results.total_operations}")
|
||||||
|
print(f" Success rate: {results.successful_operations / results.total_operations * 100:.1f}%")
|
||||||
|
print(f" Throughput: {results.throughput:.2f} ops/s")
|
||||||
|
print(f" Avg latency: {results.average_latency:.2f}s")
|
||||||
|
print(f" P95 latency: {results.p95_latency:.2f}s")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.slow
|
||||||
|
async def test_spike_load(self):
|
||||||
|
"""Test system response to sudden load spikes."""
|
||||||
|
async with graphiti_test_client() as (session, group_id):
|
||||||
|
# Normal load phase
|
||||||
|
normal_tasks = []
|
||||||
|
for i in range(3):
|
||||||
|
task = session.call_tool(
|
||||||
|
'add_memory',
|
||||||
|
{
|
||||||
|
'name': f'Normal Load {i}',
|
||||||
|
'episode_body': 'Normal operation',
|
||||||
|
'source': 'text',
|
||||||
|
'source_description': 'normal',
|
||||||
|
'group_id': group_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
normal_tasks.append(task)
|
||||||
|
await asyncio.sleep(0.5)
|
||||||
|
|
||||||
|
await asyncio.gather(*normal_tasks)
|
||||||
|
|
||||||
|
# Spike phase - sudden burst of requests
|
||||||
|
spike_start = time.time()
|
||||||
|
spike_tasks = []
|
||||||
|
for i in range(50):
|
||||||
|
task = session.call_tool(
|
||||||
|
'add_memory',
|
||||||
|
{
|
||||||
|
'name': f'Spike Load {i}',
|
||||||
|
'episode_body': TestDataGenerator.generate_technical_document(),
|
||||||
|
'source': 'text',
|
||||||
|
'source_description': 'spike',
|
||||||
|
'group_id': group_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
spike_tasks.append(task)
|
||||||
|
|
||||||
|
# Execute spike
|
||||||
|
spike_results = await asyncio.gather(*spike_tasks, return_exceptions=True)
|
||||||
|
spike_duration = time.time() - spike_start
|
||||||
|
|
||||||
|
# Analyze spike handling
|
||||||
|
spike_failures = sum(1 for r in spike_results if isinstance(r, Exception))
|
||||||
|
spike_success_rate = (len(spike_results) - spike_failures) / len(spike_results)
|
||||||
|
|
||||||
|
print(f"\nSpike Load Test Results:")
|
||||||
|
print(f" Spike size: {len(spike_tasks)} operations")
|
||||||
|
print(f" Duration: {spike_duration:.2f}s")
|
||||||
|
print(f" Success rate: {spike_success_rate * 100:.1f}%")
|
||||||
|
print(f" Throughput: {len(spike_tasks) / spike_duration:.2f} ops/s")
|
||||||
|
|
||||||
|
# System should handle at least 80% of spike
|
||||||
|
assert spike_success_rate > 0.8, f"Too many failures during spike: {spike_failures}"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.slow
|
||||||
|
async def test_memory_leak_detection(self):
|
||||||
|
"""Test for memory leaks during extended operation."""
|
||||||
|
async with graphiti_test_client() as (session, group_id):
|
||||||
|
process = psutil.Process()
|
||||||
|
gc.collect() # Force garbage collection
|
||||||
|
initial_memory = process.memory_info().rss / 1024 / 1024 # MB
|
||||||
|
|
||||||
|
# Perform many operations
|
||||||
|
for batch in range(10):
|
||||||
|
batch_tasks = []
|
||||||
|
for i in range(10):
|
||||||
|
task = session.call_tool(
|
||||||
|
'add_memory',
|
||||||
|
{
|
||||||
|
'name': f'Memory Test {batch}-{i}',
|
||||||
|
'episode_body': TestDataGenerator.generate_technical_document(),
|
||||||
|
'source': 'text',
|
||||||
|
'source_description': 'memory test',
|
||||||
|
'group_id': group_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
batch_tasks.append(task)
|
||||||
|
|
||||||
|
await asyncio.gather(*batch_tasks)
|
||||||
|
|
||||||
|
# Force garbage collection between batches
|
||||||
|
gc.collect()
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
# Check memory after operations
|
||||||
|
gc.collect()
|
||||||
|
final_memory = process.memory_info().rss / 1024 / 1024 # MB
|
||||||
|
memory_growth = final_memory - initial_memory
|
||||||
|
|
||||||
|
print(f"\nMemory Leak Test:")
|
||||||
|
print(f" Initial memory: {initial_memory:.1f} MB")
|
||||||
|
print(f" Final memory: {final_memory:.1f} MB")
|
||||||
|
print(f" Growth: {memory_growth:.1f} MB")
|
||||||
|
|
||||||
|
# Allow for some memory growth but flag potential leaks
|
||||||
|
# This is a soft check - actual threshold depends on system
|
||||||
|
if memory_growth > 100: # More than 100MB growth
|
||||||
|
print(f" ⚠️ Potential memory leak detected: {memory_growth:.1f} MB growth")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.slow
|
||||||
|
async def test_connection_pool_exhaustion(self):
|
||||||
|
"""Test behavior when connection pools are exhausted."""
|
||||||
|
async with graphiti_test_client() as (session, group_id):
|
||||||
|
# Create many concurrent long-running operations
|
||||||
|
long_tasks = []
|
||||||
|
for i in range(100): # Many more than typical pool size
|
||||||
|
task = session.call_tool(
|
||||||
|
'search_memory_nodes',
|
||||||
|
{
|
||||||
|
'query': f'complex query {i} ' + ' '.join([TestDataGenerator.fake.word() for _ in range(10)]),
|
||||||
|
'group_id': group_id,
|
||||||
|
'limit': 100,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
long_tasks.append(task)
|
||||||
|
|
||||||
|
# Execute with timeout
|
||||||
|
try:
|
||||||
|
results = await asyncio.wait_for(
|
||||||
|
asyncio.gather(*long_tasks, return_exceptions=True),
|
||||||
|
timeout=60.0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Count connection-related errors
|
||||||
|
connection_errors = sum(
|
||||||
|
1 for r in results
|
||||||
|
if isinstance(r, Exception) and 'connection' in str(r).lower()
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"\nConnection Pool Test:")
|
||||||
|
print(f" Total requests: {len(long_tasks)}")
|
||||||
|
print(f" Connection errors: {connection_errors}")
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
print(" Test timed out - possible deadlock or exhaustion")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.slow
|
||||||
|
async def test_gradual_degradation(self):
|
||||||
|
"""Test system degradation under increasing load."""
|
||||||
|
async with graphiti_test_client() as (session, group_id):
|
||||||
|
load_levels = [5, 10, 20, 40, 80] # Increasing concurrent operations
|
||||||
|
results_by_level = {}
|
||||||
|
|
||||||
|
for level in load_levels:
|
||||||
|
level_start = time.time()
|
||||||
|
tasks = []
|
||||||
|
|
||||||
|
for i in range(level):
|
||||||
|
task = session.call_tool(
|
||||||
|
'add_memory',
|
||||||
|
{
|
||||||
|
'name': f'Load Level {level} Op {i}',
|
||||||
|
'episode_body': f'Testing at load level {level}',
|
||||||
|
'source': 'text',
|
||||||
|
'source_description': 'degradation test',
|
||||||
|
'group_id': group_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
tasks.append(task)
|
||||||
|
|
||||||
|
# Execute level
|
||||||
|
level_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
level_duration = time.time() - level_start
|
||||||
|
|
||||||
|
# Calculate metrics
|
||||||
|
failures = sum(1 for r in level_results if isinstance(r, Exception))
|
||||||
|
success_rate = (level - failures) / level * 100
|
||||||
|
throughput = level / level_duration
|
||||||
|
|
||||||
|
results_by_level[level] = {
|
||||||
|
'success_rate': success_rate,
|
||||||
|
'throughput': throughput,
|
||||||
|
'duration': level_duration,
|
||||||
|
}
|
||||||
|
|
||||||
|
print(f"\nLoad Level {level}:")
|
||||||
|
print(f" Success rate: {success_rate:.1f}%")
|
||||||
|
print(f" Throughput: {throughput:.2f} ops/s")
|
||||||
|
print(f" Duration: {level_duration:.2f}s")
|
||||||
|
|
||||||
|
# Brief pause between levels
|
||||||
|
await asyncio.sleep(2)
|
||||||
|
|
||||||
|
# Verify graceful degradation
|
||||||
|
# Success rate should not drop below 50% even at high load
|
||||||
|
for level, metrics in results_by_level.items():
|
||||||
|
assert metrics['success_rate'] > 50, f"Poor performance at load level {level}"
|
||||||
|
|
||||||
|
|
||||||
|
class TestResourceLimits:
|
||||||
|
"""Test behavior at resource limits."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_large_payload_handling(self):
|
||||||
|
"""Test handling of very large payloads."""
|
||||||
|
async with graphiti_test_client() as (session, group_id):
|
||||||
|
payload_sizes = [
|
||||||
|
(1_000, "1KB"),
|
||||||
|
(10_000, "10KB"),
|
||||||
|
(100_000, "100KB"),
|
||||||
|
(1_000_000, "1MB"),
|
||||||
|
]
|
||||||
|
|
||||||
|
for size, label in payload_sizes:
|
||||||
|
content = "x" * size
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
try:
|
||||||
|
result = await asyncio.wait_for(
|
||||||
|
session.call_tool(
|
||||||
|
'add_memory',
|
||||||
|
{
|
||||||
|
'name': f'Large Payload {label}',
|
||||||
|
'episode_body': content,
|
||||||
|
'source': 'text',
|
||||||
|
'source_description': 'payload test',
|
||||||
|
'group_id': group_id,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
timeout=30.0
|
||||||
|
)
|
||||||
|
duration = time.time() - start_time
|
||||||
|
status = "✅ Success"
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
duration = 30.0
|
||||||
|
status = "⏱️ Timeout"
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
duration = time.time() - start_time
|
||||||
|
status = f"❌ Error: {type(e).__name__}"
|
||||||
|
|
||||||
|
print(f"Payload {label}: {status} ({duration:.2f}s)")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rate_limit_handling(self):
|
||||||
|
"""Test handling of rate limits."""
|
||||||
|
async with graphiti_test_client() as (session, group_id):
|
||||||
|
# Rapid fire requests to trigger rate limits
|
||||||
|
rapid_tasks = []
|
||||||
|
for i in range(100):
|
||||||
|
task = session.call_tool(
|
||||||
|
'add_memory',
|
||||||
|
{
|
||||||
|
'name': f'Rate Limit Test {i}',
|
||||||
|
'episode_body': f'Testing rate limit {i}',
|
||||||
|
'source': 'text',
|
||||||
|
'source_description': 'rate test',
|
||||||
|
'group_id': group_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
rapid_tasks.append(task)
|
||||||
|
|
||||||
|
# Execute without delays
|
||||||
|
results = await asyncio.gather(*rapid_tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
# Count rate limit errors
|
||||||
|
rate_limit_errors = sum(
|
||||||
|
1 for r in results
|
||||||
|
if isinstance(r, Exception) and ('rate' in str(r).lower() or '429' in str(r))
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"\nRate Limit Test:")
|
||||||
|
print(f" Total requests: {len(rapid_tasks)}")
|
||||||
|
print(f" Rate limit errors: {rate_limit_errors}")
|
||||||
|
print(f" Success rate: {(len(rapid_tasks) - rate_limit_errors) / len(rapid_tasks) * 100:.1f}%")
|
||||||
|
|
||||||
|
|
||||||
|
def generate_load_test_report(results: List[LoadTestResult]) -> str:
|
||||||
|
"""Generate comprehensive load test report."""
|
||||||
|
report = []
|
||||||
|
report.append("\n" + "=" * 60)
|
||||||
|
report.append("LOAD TEST REPORT")
|
||||||
|
report.append("=" * 60)
|
||||||
|
|
||||||
|
for i, result in enumerate(results):
|
||||||
|
report.append(f"\nTest Run {i + 1}:")
|
||||||
|
report.append(f" Total Operations: {result.total_operations}")
|
||||||
|
report.append(f" Success Rate: {result.successful_operations / result.total_operations * 100:.1f}%")
|
||||||
|
report.append(f" Throughput: {result.throughput:.2f} ops/s")
|
||||||
|
report.append(f" Latency (avg/p50/p95/p99/max): {result.average_latency:.2f}/{result.p50_latency:.2f}/{result.p95_latency:.2f}/{result.p99_latency:.2f}/{result.max_latency:.2f}s")
|
||||||
|
|
||||||
|
if result.errors:
|
||||||
|
report.append(" Errors:")
|
||||||
|
for error_type, count in result.errors.items():
|
||||||
|
report.append(f" {error_type}: {count}")
|
||||||
|
|
||||||
|
report.append(" Resource Usage:")
|
||||||
|
for metric, value in result.resource_usage.items():
|
||||||
|
report.append(f" {metric}: {value:.2f}")
|
||||||
|
|
||||||
|
report.append("=" * 60)
|
||||||
|
return "\n".join(report)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__, "-v", "--asyncio-mode=auto", "-m", "slow"])
|
||||||
Loading…
Add table
Reference in a new issue