fix: Fix all linting errors in test suite
- Replace bare except with except Exception - Remove unused imports and variables - Fix type hints to use modern syntax - Apply ruff formatting for line length - Ensure all tests pass linting checks
This commit is contained in:
parent
6d1c8ad17a
commit
669aaae785
7 changed files with 403 additions and 453 deletions
|
|
@ -61,7 +61,9 @@ logging.basicConfig(
|
||||||
# Configure specific loggers
|
# Configure specific loggers
|
||||||
logging.getLogger('uvicorn').setLevel(logging.INFO)
|
logging.getLogger('uvicorn').setLevel(logging.INFO)
|
||||||
logging.getLogger('uvicorn.access').setLevel(logging.WARNING) # Reduce access log noise
|
logging.getLogger('uvicorn.access').setLevel(logging.WARNING) # Reduce access log noise
|
||||||
logging.getLogger('mcp.server.streamable_http_manager').setLevel(logging.WARNING) # Reduce MCP noise
|
logging.getLogger('mcp.server.streamable_http_manager').setLevel(
|
||||||
|
logging.WARNING
|
||||||
|
) # Reduce MCP noise
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,14 +5,10 @@ Provides various test execution modes and reporting options.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import os
|
import os
|
||||||
import subprocess
|
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
@ -34,7 +30,7 @@ class TestRunner:
|
||||||
self.test_dir = Path(__file__).parent
|
self.test_dir = Path(__file__).parent
|
||||||
self.results = {}
|
self.results = {}
|
||||||
|
|
||||||
def check_prerequisites(self) -> Dict[str, bool]:
|
def check_prerequisites(self) -> dict[str, bool]:
|
||||||
"""Check if required services and dependencies are available."""
|
"""Check if required services and dependencies are available."""
|
||||||
checks = {}
|
checks = {}
|
||||||
|
|
||||||
|
|
@ -46,7 +42,9 @@ class TestRunner:
|
||||||
# Check if .env file exists for helpful message
|
# Check if .env file exists for helpful message
|
||||||
env_path = Path(__file__).parent.parent / '.env'
|
env_path = Path(__file__).parent.parent / '.env'
|
||||||
if not env_path.exists():
|
if not env_path.exists():
|
||||||
checks['openai_api_key_hint'] = 'Set OPENAI_API_KEY in environment or create mcp_server/.env file'
|
checks['openai_api_key_hint'] = (
|
||||||
|
'Set OPENAI_API_KEY in environment or create mcp_server/.env file'
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
checks['openai_api_key'] = True
|
checks['openai_api_key'] = True
|
||||||
|
|
||||||
|
|
@ -69,6 +67,7 @@ class TestRunner:
|
||||||
"""Check if Neo4j is available."""
|
"""Check if Neo4j is available."""
|
||||||
try:
|
try:
|
||||||
import neo4j
|
import neo4j
|
||||||
|
|
||||||
# Try to connect
|
# Try to connect
|
||||||
uri = os.environ.get('NEO4J_URI', 'bolt://localhost:7687')
|
uri = os.environ.get('NEO4J_URI', 'bolt://localhost:7687')
|
||||||
user = os.environ.get('NEO4J_USER', 'neo4j')
|
user = os.environ.get('NEO4J_USER', 'neo4j')
|
||||||
|
|
@ -76,21 +75,22 @@ class TestRunner:
|
||||||
|
|
||||||
driver = neo4j.GraphDatabase.driver(uri, auth=(user, password))
|
driver = neo4j.GraphDatabase.driver(uri, auth=(user, password))
|
||||||
with driver.session() as session:
|
with driver.session() as session:
|
||||||
session.run("RETURN 1")
|
session.run('RETURN 1')
|
||||||
driver.close()
|
driver.close()
|
||||||
return True
|
return True
|
||||||
except:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _check_falkordb(self) -> bool:
|
def _check_falkordb(self) -> bool:
|
||||||
"""Check if FalkorDB is available."""
|
"""Check if FalkorDB is available."""
|
||||||
try:
|
try:
|
||||||
import redis
|
import redis
|
||||||
|
|
||||||
uri = os.environ.get('FALKORDB_URI', 'redis://localhost:6379')
|
uri = os.environ.get('FALKORDB_URI', 'redis://localhost:6379')
|
||||||
r = redis.from_url(uri)
|
r = redis.from_url(uri)
|
||||||
r.ping()
|
r.ping()
|
||||||
return True
|
return True
|
||||||
except:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _check_python_package(self, package: str) -> bool:
|
def _check_python_package(self, package: str) -> bool:
|
||||||
|
|
@ -124,10 +124,12 @@ class TestRunner:
|
||||||
pytest_args.extend(['-m', 'slow', 'test_stress_load.py'])
|
pytest_args.extend(['-m', 'slow', 'test_stress_load.py'])
|
||||||
elif suite == 'smoke':
|
elif suite == 'smoke':
|
||||||
# Quick smoke test - just basic operations
|
# Quick smoke test - just basic operations
|
||||||
pytest_args.extend([
|
pytest_args.extend(
|
||||||
'test_comprehensive_integration.py::TestCoreOperations::test_server_initialization',
|
[
|
||||||
'test_comprehensive_integration.py::TestCoreOperations::test_add_text_memory'
|
'test_comprehensive_integration.py::TestCoreOperations::test_server_initialization',
|
||||||
])
|
'test_comprehensive_integration.py::TestCoreOperations::test_add_text_memory',
|
||||||
|
]
|
||||||
|
)
|
||||||
elif suite == 'all':
|
elif suite == 'all':
|
||||||
pytest_args.append('.')
|
pytest_args.append('.')
|
||||||
else:
|
else:
|
||||||
|
|
@ -161,7 +163,7 @@ class TestRunner:
|
||||||
env['DATABASE_PROVIDER'] = self.args.database
|
env['DATABASE_PROVIDER'] = self.args.database
|
||||||
|
|
||||||
# Run tests from the test directory
|
# Run tests from the test directory
|
||||||
print(f"Running {suite} tests with pytest args: {' '.join(pytest_args)}")
|
print(f'Running {suite} tests with pytest args: {" ".join(pytest_args)}')
|
||||||
|
|
||||||
# Change to test directory to run tests
|
# Change to test directory to run tests
|
||||||
original_dir = os.getcwd()
|
original_dir = os.getcwd()
|
||||||
|
|
@ -176,52 +178,52 @@ class TestRunner:
|
||||||
|
|
||||||
def run_performance_benchmark(self):
|
def run_performance_benchmark(self):
|
||||||
"""Run performance benchmarking suite."""
|
"""Run performance benchmarking suite."""
|
||||||
print("Running performance benchmarks...")
|
print('Running performance benchmarks...')
|
||||||
|
|
||||||
# Import test modules
|
# Import test modules
|
||||||
from test_comprehensive_integration import TestPerformance
|
|
||||||
from test_async_operations import TestAsyncPerformance
|
|
||||||
|
|
||||||
# Run performance tests
|
# Run performance tests
|
||||||
result = pytest.main([
|
result = pytest.main(
|
||||||
'-v',
|
[
|
||||||
'test_comprehensive_integration.py::TestPerformance',
|
'-v',
|
||||||
'test_async_operations.py::TestAsyncPerformance',
|
'test_comprehensive_integration.py::TestPerformance',
|
||||||
'--benchmark-only' if self.args.benchmark_only else '',
|
'test_async_operations.py::TestAsyncPerformance',
|
||||||
])
|
'--benchmark-only' if self.args.benchmark_only else '',
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def generate_report(self):
|
def generate_report(self):
|
||||||
"""Generate test execution report."""
|
"""Generate test execution report."""
|
||||||
report = []
|
report = []
|
||||||
report.append("\n" + "=" * 60)
|
report.append('\n' + '=' * 60)
|
||||||
report.append("GRAPHITI MCP TEST EXECUTION REPORT")
|
report.append('GRAPHITI MCP TEST EXECUTION REPORT')
|
||||||
report.append("=" * 60)
|
report.append('=' * 60)
|
||||||
|
|
||||||
# Prerequisites check
|
# Prerequisites check
|
||||||
checks = self.check_prerequisites()
|
checks = self.check_prerequisites()
|
||||||
report.append("\nPrerequisites:")
|
report.append('\nPrerequisites:')
|
||||||
for check, passed in checks.items():
|
for check, passed in checks.items():
|
||||||
status = "✅" if passed else "❌"
|
status = '✅' if passed else '❌'
|
||||||
report.append(f" {status} {check}")
|
report.append(f' {status} {check}')
|
||||||
|
|
||||||
# Test configuration
|
# Test configuration
|
||||||
report.append(f"\nConfiguration:")
|
report.append('\nConfiguration:')
|
||||||
report.append(f" Database: {self.args.database}")
|
report.append(f' Database: {self.args.database}')
|
||||||
report.append(f" Mock LLM: {self.args.mock_llm}")
|
report.append(f' Mock LLM: {self.args.mock_llm}')
|
||||||
report.append(f" Parallel: {self.args.parallel or 'No'}")
|
report.append(f' Parallel: {self.args.parallel or "No"}')
|
||||||
report.append(f" Timeout: {self.args.timeout}s")
|
report.append(f' Timeout: {self.args.timeout}s')
|
||||||
|
|
||||||
# Results summary (if available)
|
# Results summary (if available)
|
||||||
if self.results:
|
if self.results:
|
||||||
report.append(f"\nResults:")
|
report.append('\nResults:')
|
||||||
for suite, result in self.results.items():
|
for suite, result in self.results.items():
|
||||||
status = "✅ Passed" if result == 0 else f"❌ Failed ({result})"
|
status = '✅ Passed' if result == 0 else f'❌ Failed ({result})'
|
||||||
report.append(f" {suite}: {status}")
|
report.append(f' {suite}: {status}')
|
||||||
|
|
||||||
report.append("=" * 60)
|
report.append('=' * 60)
|
||||||
return "\n".join(report)
|
return '\n'.join(report)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|
@ -244,70 +246,42 @@ Examples:
|
||||||
python run_tests.py integration --parallel 4 # Run integration tests in parallel
|
python run_tests.py integration --parallel 4 # Run integration tests in parallel
|
||||||
python run_tests.py stress --database neo4j # Run stress tests with Neo4j
|
python run_tests.py stress --database neo4j # Run stress tests with Neo4j
|
||||||
python run_tests.py all --coverage # Run all tests with coverage
|
python run_tests.py all --coverage # Run all tests with coverage
|
||||||
"""
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'suite',
|
'suite',
|
||||||
choices=['unit', 'integration', 'comprehensive', 'async', 'stress', 'smoke', 'all'],
|
choices=['unit', 'integration', 'comprehensive', 'async', 'stress', 'smoke', 'all'],
|
||||||
help='Test suite to run'
|
help='Test suite to run',
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--database',
|
'--database',
|
||||||
choices=['neo4j', 'falkordb', 'kuzu'],
|
choices=['neo4j', 'falkordb', 'kuzu'],
|
||||||
default='kuzu',
|
default='kuzu',
|
||||||
help='Database backend to test (default: kuzu)'
|
help='Database backend to test (default: kuzu)',
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument('--mock-llm', action='store_true', help='Use mock LLM for faster testing')
|
||||||
'--mock-llm',
|
|
||||||
action='store_true',
|
|
||||||
help='Use mock LLM for faster testing'
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--parallel',
|
'--parallel', type=int, metavar='N', help='Run tests in parallel with N workers'
|
||||||
type=int,
|
|
||||||
metavar='N',
|
|
||||||
help='Run tests in parallel with N workers'
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument('--coverage', action='store_true', help='Generate coverage report')
|
||||||
'--coverage',
|
|
||||||
action='store_true',
|
parser.add_argument('--verbose', action='store_true', help='Verbose output')
|
||||||
help='Generate coverage report'
|
|
||||||
)
|
parser.add_argument('--skip-slow', action='store_true', help='Skip slow tests')
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--verbose',
|
'--timeout', type=int, default=300, help='Test timeout in seconds (default: 300)'
|
||||||
action='store_true',
|
|
||||||
help='Verbose output'
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument('--benchmark-only', action='store_true', help='Run only benchmark tests')
|
||||||
'--skip-slow',
|
|
||||||
action='store_true',
|
|
||||||
help='Skip slow tests'
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--timeout',
|
'--check-only', action='store_true', help='Only check prerequisites without running tests'
|
||||||
type=int,
|
|
||||||
default=300,
|
|
||||||
help='Test timeout in seconds (default: 300)'
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
'--benchmark-only',
|
|
||||||
action='store_true',
|
|
||||||
help='Run only benchmark tests'
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
'--check-only',
|
|
||||||
action='store_true',
|
|
||||||
help='Only check prerequisites without running tests'
|
|
||||||
)
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
@ -326,26 +300,26 @@ Examples:
|
||||||
validation_checks = {k: v for k, v in checks.items() if not k.endswith('_hint')}
|
validation_checks = {k: v for k, v in checks.items() if not k.endswith('_hint')}
|
||||||
|
|
||||||
if not all(validation_checks.values()):
|
if not all(validation_checks.values()):
|
||||||
print("⚠️ Some prerequisites are not met:")
|
print('⚠️ Some prerequisites are not met:')
|
||||||
for check, passed in checks.items():
|
for check, passed in checks.items():
|
||||||
if check.endswith('_hint'):
|
if check.endswith('_hint'):
|
||||||
continue # Skip hint entries
|
continue # Skip hint entries
|
||||||
if not passed:
|
if not passed:
|
||||||
print(f" ❌ {check}")
|
print(f' ❌ {check}')
|
||||||
# Show hint if available
|
# Show hint if available
|
||||||
hint_key = f"{check}_hint"
|
hint_key = f'{check}_hint'
|
||||||
if hint_key in checks:
|
if hint_key in checks:
|
||||||
print(f" 💡 {checks[hint_key]}")
|
print(f' 💡 {checks[hint_key]}')
|
||||||
|
|
||||||
if not args.mock_llm and not checks.get('openai_api_key'):
|
if not args.mock_llm and not checks.get('openai_api_key'):
|
||||||
print("\n💡 Tip: Use --mock-llm to run tests without OpenAI API key")
|
print('\n💡 Tip: Use --mock-llm to run tests without OpenAI API key')
|
||||||
|
|
||||||
response = input("\nContinue anyway? (y/N): ")
|
response = input('\nContinue anyway? (y/N): ')
|
||||||
if response.lower() != 'y':
|
if response.lower() != 'y':
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
# Run tests
|
# Run tests
|
||||||
print(f"\n🚀 Starting test execution: {args.suite}")
|
print(f'\n🚀 Starting test execution: {args.suite}')
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
if args.benchmark_only:
|
if args.benchmark_only:
|
||||||
|
|
@ -360,11 +334,11 @@ Examples:
|
||||||
|
|
||||||
# Generate and print report
|
# Generate and print report
|
||||||
print(runner.generate_report())
|
print(runner.generate_report())
|
||||||
print(f"\n⏱️ Test execution completed in {duration:.2f} seconds")
|
print(f'\n⏱️ Test execution completed in {duration:.2f} seconds')
|
||||||
|
|
||||||
# Exit with test result code
|
# Exit with test result code
|
||||||
sys.exit(result)
|
sys.exit(result)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
||||||
|
|
@ -7,15 +7,11 @@ Tests concurrent operations, queue management, and async patterns.
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
from typing import Any, Dict, List
|
|
||||||
from unittest.mock import AsyncMock, patch
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from test_fixtures import (
|
from test_fixtures import (
|
||||||
TestDataGenerator,
|
TestDataGenerator,
|
||||||
graphiti_test_client,
|
graphiti_test_client,
|
||||||
performance_benchmark,
|
|
||||||
test_data_generator,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -38,7 +34,7 @@ class TestAsyncQueueManagement:
|
||||||
'source_description': 'sequential test',
|
'source_description': 'sequential test',
|
||||||
'group_id': group_id,
|
'group_id': group_id,
|
||||||
'reference_id': f'seq_{i}', # Add reference for tracking
|
'reference_id': f'seq_{i}', # Add reference for tracking
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
episodes.append(result)
|
episodes.append(result)
|
||||||
|
|
||||||
|
|
@ -46,19 +42,18 @@ class TestAsyncQueueManagement:
|
||||||
await asyncio.sleep(10) # Allow time for sequential processing
|
await asyncio.sleep(10) # Allow time for sequential processing
|
||||||
|
|
||||||
# Retrieve episodes and verify order
|
# Retrieve episodes and verify order
|
||||||
result = await session.call_tool(
|
result = await session.call_tool('get_episodes', {'group_id': group_id, 'last_n': 10})
|
||||||
'get_episodes',
|
|
||||||
{'group_id': group_id, 'last_n': 10}
|
|
||||||
)
|
|
||||||
|
|
||||||
processed_episodes = json.loads(result.content[0].text)['episodes']
|
processed_episodes = json.loads(result.content[0].text)['episodes']
|
||||||
|
|
||||||
# Verify all episodes were processed
|
# Verify all episodes were processed
|
||||||
assert len(processed_episodes) >= 5, f"Expected at least 5 episodes, got {len(processed_episodes)}"
|
assert len(processed_episodes) >= 5, (
|
||||||
|
f'Expected at least 5 episodes, got {len(processed_episodes)}'
|
||||||
|
)
|
||||||
|
|
||||||
# Verify sequential processing (timestamps should be ordered)
|
# Verify sequential processing (timestamps should be ordered)
|
||||||
timestamps = [ep.get('created_at') for ep in processed_episodes]
|
timestamps = [ep.get('created_at') for ep in processed_episodes]
|
||||||
assert timestamps == sorted(timestamps), "Episodes not processed in order"
|
assert timestamps == sorted(timestamps), 'Episodes not processed in order'
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_concurrent_group_processing(self):
|
async def test_concurrent_group_processing(self):
|
||||||
|
|
@ -78,7 +73,7 @@ class TestAsyncQueueManagement:
|
||||||
'source': 'text',
|
'source': 'text',
|
||||||
'source_description': 'concurrent test',
|
'source_description': 'concurrent test',
|
||||||
'group_id': group_id,
|
'group_id': group_id,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
tasks.append(task)
|
tasks.append(task)
|
||||||
|
|
||||||
|
|
@ -89,11 +84,11 @@ class TestAsyncQueueManagement:
|
||||||
|
|
||||||
# Verify all succeeded
|
# Verify all succeeded
|
||||||
failures = [r for r in results if isinstance(r, Exception)]
|
failures = [r for r in results if isinstance(r, Exception)]
|
||||||
assert not failures, f"Concurrent operations failed: {failures}"
|
assert not failures, f'Concurrent operations failed: {failures}'
|
||||||
|
|
||||||
# Check that execution was actually concurrent (should be faster than sequential)
|
# Check that execution was actually concurrent (should be faster than sequential)
|
||||||
# Sequential would take at least 6 * processing_time
|
# Sequential would take at least 6 * processing_time
|
||||||
assert execution_time < 30, f"Concurrent execution too slow: {execution_time}s"
|
assert execution_time < 30, f'Concurrent execution too slow: {execution_time}s'
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_queue_overflow_handling(self):
|
async def test_queue_overflow_handling(self):
|
||||||
|
|
@ -110,7 +105,7 @@ class TestAsyncQueueManagement:
|
||||||
'source': 'text',
|
'source': 'text',
|
||||||
'source_description': 'overflow test',
|
'source_description': 'overflow test',
|
||||||
'group_id': group_id,
|
'group_id': group_id,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
tasks.append(task)
|
tasks.append(task)
|
||||||
|
|
||||||
|
|
@ -121,11 +116,11 @@ class TestAsyncQueueManagement:
|
||||||
successful = sum(1 for r in results if not isinstance(r, Exception))
|
successful = sum(1 for r in results if not isinstance(r, Exception))
|
||||||
|
|
||||||
# Should handle overflow gracefully
|
# Should handle overflow gracefully
|
||||||
assert successful > 0, "No episodes were queued successfully"
|
assert successful > 0, 'No episodes were queued successfully'
|
||||||
|
|
||||||
# Log overflow behavior
|
# Log overflow behavior
|
||||||
if successful < 100:
|
if successful < 100:
|
||||||
print(f"Queue overflow: {successful}/100 episodes queued")
|
print(f'Queue overflow: {successful}/100 episodes queued')
|
||||||
|
|
||||||
|
|
||||||
class TestConcurrentOperations:
|
class TestConcurrentOperations:
|
||||||
|
|
@ -148,7 +143,7 @@ class TestConcurrentOperations:
|
||||||
'source': 'text',
|
'source': 'text',
|
||||||
'source_description': 'search test',
|
'source_description': 'search test',
|
||||||
'group_id': group_id,
|
'group_id': group_id,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
add_tasks.append(task)
|
add_tasks.append(task)
|
||||||
|
|
||||||
|
|
@ -172,7 +167,7 @@ class TestConcurrentOperations:
|
||||||
'query': query,
|
'query': query,
|
||||||
'group_id': group_id,
|
'group_id': group_id,
|
||||||
'limit': 10,
|
'limit': 10,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
search_tasks.append(task)
|
search_tasks.append(task)
|
||||||
|
|
||||||
|
|
@ -182,10 +177,10 @@ class TestConcurrentOperations:
|
||||||
|
|
||||||
# Verify all searches completed
|
# Verify all searches completed
|
||||||
failures = [r for r in results if isinstance(r, Exception)]
|
failures = [r for r in results if isinstance(r, Exception)]
|
||||||
assert not failures, f"Search operations failed: {failures}"
|
assert not failures, f'Search operations failed: {failures}'
|
||||||
|
|
||||||
# Verify concurrent execution efficiency
|
# Verify concurrent execution efficiency
|
||||||
assert search_time < len(search_queries) * 2, "Searches not executing concurrently"
|
assert search_time < len(search_queries) * 2, 'Searches not executing concurrently'
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_mixed_operation_concurrency(self):
|
async def test_mixed_operation_concurrency(self):
|
||||||
|
|
@ -194,48 +189,51 @@ class TestConcurrentOperations:
|
||||||
operations = []
|
operations = []
|
||||||
|
|
||||||
# Add memory operation
|
# Add memory operation
|
||||||
operations.append(session.call_tool(
|
operations.append(
|
||||||
'add_memory',
|
session.call_tool(
|
||||||
{
|
'add_memory',
|
||||||
'name': 'Mixed Op Test',
|
{
|
||||||
'episode_body': 'Testing mixed operations',
|
'name': 'Mixed Op Test',
|
||||||
'source': 'text',
|
'episode_body': 'Testing mixed operations',
|
||||||
'source_description': 'test',
|
'source': 'text',
|
||||||
'group_id': group_id,
|
'source_description': 'test',
|
||||||
}
|
'group_id': group_id,
|
||||||
))
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Search operation
|
# Search operation
|
||||||
operations.append(session.call_tool(
|
operations.append(
|
||||||
'search_memory_nodes',
|
session.call_tool(
|
||||||
{
|
'search_memory_nodes',
|
||||||
'query': 'test',
|
{
|
||||||
'group_id': group_id,
|
'query': 'test',
|
||||||
'limit': 5,
|
'group_id': group_id,
|
||||||
}
|
'limit': 5,
|
||||||
))
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Get episodes operation
|
# Get episodes operation
|
||||||
operations.append(session.call_tool(
|
operations.append(
|
||||||
'get_episodes',
|
session.call_tool(
|
||||||
{
|
'get_episodes',
|
||||||
'group_id': group_id,
|
{
|
||||||
'last_n': 10,
|
'group_id': group_id,
|
||||||
}
|
'last_n': 10,
|
||||||
))
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Get status operation
|
# Get status operation
|
||||||
operations.append(session.call_tool(
|
operations.append(session.call_tool('get_status', {}))
|
||||||
'get_status',
|
|
||||||
{}
|
|
||||||
))
|
|
||||||
|
|
||||||
# Execute all concurrently
|
# Execute all concurrently
|
||||||
results = await asyncio.gather(*operations, return_exceptions=True)
|
results = await asyncio.gather(*operations, return_exceptions=True)
|
||||||
|
|
||||||
# Check results
|
# Check results
|
||||||
for i, result in enumerate(results):
|
for i, result in enumerate(results):
|
||||||
assert not isinstance(result, Exception), f"Operation {i} failed: {result}"
|
assert not isinstance(result, Exception), f'Operation {i} failed: {result}'
|
||||||
|
|
||||||
|
|
||||||
class TestAsyncErrorHandling:
|
class TestAsyncErrorHandling:
|
||||||
|
|
@ -246,10 +244,10 @@ class TestAsyncErrorHandling:
|
||||||
"""Test recovery from operation timeouts."""
|
"""Test recovery from operation timeouts."""
|
||||||
async with graphiti_test_client() as (session, group_id):
|
async with graphiti_test_client() as (session, group_id):
|
||||||
# Create a very large episode that might timeout
|
# Create a very large episode that might timeout
|
||||||
large_content = "x" * 1000000 # 1MB of data
|
large_content = 'x' * 1000000 # 1MB of data
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = await asyncio.wait_for(
|
await asyncio.wait_for(
|
||||||
session.call_tool(
|
session.call_tool(
|
||||||
'add_memory',
|
'add_memory',
|
||||||
{
|
{
|
||||||
|
|
@ -258,9 +256,9 @@ class TestAsyncErrorHandling:
|
||||||
'source': 'text',
|
'source': 'text',
|
||||||
'source_description': 'timeout test',
|
'source_description': 'timeout test',
|
||||||
'group_id': group_id,
|
'group_id': group_id,
|
||||||
}
|
},
|
||||||
),
|
),
|
||||||
timeout=2.0 # Short timeout
|
timeout=2.0, # Short timeout
|
||||||
)
|
)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
# Expected timeout
|
# Expected timeout
|
||||||
|
|
@ -268,7 +266,7 @@ class TestAsyncErrorHandling:
|
||||||
|
|
||||||
# Verify server is still responsive after timeout
|
# Verify server is still responsive after timeout
|
||||||
status_result = await session.call_tool('get_status', {})
|
status_result = await session.call_tool('get_status', {})
|
||||||
assert status_result is not None, "Server unresponsive after timeout"
|
assert status_result is not None, 'Server unresponsive after timeout'
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_cancellation_handling(self):
|
async def test_cancellation_handling(self):
|
||||||
|
|
@ -284,7 +282,7 @@ class TestAsyncErrorHandling:
|
||||||
'source': 'text',
|
'source': 'text',
|
||||||
'source_description': 'cancel test',
|
'source_description': 'cancel test',
|
||||||
'group_id': group_id,
|
'group_id': group_id,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -305,13 +303,13 @@ class TestAsyncErrorHandling:
|
||||||
"""Test that exceptions are properly propagated in async context."""
|
"""Test that exceptions are properly propagated in async context."""
|
||||||
async with graphiti_test_client() as (session, group_id):
|
async with graphiti_test_client() as (session, group_id):
|
||||||
# Call with invalid arguments
|
# Call with invalid arguments
|
||||||
with pytest.raises(Exception):
|
with pytest.raises(ValueError):
|
||||||
await session.call_tool(
|
await session.call_tool(
|
||||||
'add_memory',
|
'add_memory',
|
||||||
{
|
{
|
||||||
# Missing required fields
|
# Missing required fields
|
||||||
'group_id': group_id,
|
'group_id': group_id,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Server should remain operational
|
# Server should remain operational
|
||||||
|
|
@ -340,7 +338,7 @@ class TestAsyncPerformance:
|
||||||
'source': 'text',
|
'source': 'text',
|
||||||
'source_description': 'throughput test',
|
'source_description': 'throughput test',
|
||||||
'group_id': group_id,
|
'group_id': group_id,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
tasks.append(task)
|
tasks.append(task)
|
||||||
|
|
||||||
|
|
@ -355,14 +353,14 @@ class TestAsyncPerformance:
|
||||||
performance_benchmark.record('async_throughput', throughput)
|
performance_benchmark.record('async_throughput', throughput)
|
||||||
|
|
||||||
# Log results
|
# Log results
|
||||||
print(f"\nAsync Throughput Test:")
|
print('\nAsync Throughput Test:')
|
||||||
print(f" Operations: {num_operations}")
|
print(f' Operations: {num_operations}')
|
||||||
print(f" Successful: {successful}")
|
print(f' Successful: {successful}')
|
||||||
print(f" Total time: {total_time:.2f}s")
|
print(f' Total time: {total_time:.2f}s')
|
||||||
print(f" Throughput: {throughput:.2f} ops/s")
|
print(f' Throughput: {throughput:.2f} ops/s')
|
||||||
|
|
||||||
# Assert minimum throughput
|
# Assert minimum throughput
|
||||||
assert throughput > 1.0, f"Throughput too low: {throughput:.2f} ops/s"
|
assert throughput > 1.0, f'Throughput too low: {throughput:.2f} ops/s'
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_latency_under_load(self, performance_benchmark):
|
async def test_latency_under_load(self, performance_benchmark):
|
||||||
|
|
@ -380,7 +378,7 @@ class TestAsyncPerformance:
|
||||||
'source': 'text',
|
'source': 'text',
|
||||||
'source_description': 'background',
|
'source_description': 'background',
|
||||||
'group_id': f'background_{group_id}',
|
'group_id': f'background_{group_id}',
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
background_tasks.append(task)
|
background_tasks.append(task)
|
||||||
|
|
@ -389,10 +387,7 @@ class TestAsyncPerformance:
|
||||||
latencies = []
|
latencies = []
|
||||||
for _ in range(5):
|
for _ in range(5):
|
||||||
start = time.time()
|
start = time.time()
|
||||||
await session.call_tool(
|
await session.call_tool('get_status', {})
|
||||||
'get_status',
|
|
||||||
{}
|
|
||||||
)
|
|
||||||
latency = time.time() - start
|
latency = time.time() - start
|
||||||
latencies.append(latency)
|
latencies.append(latency)
|
||||||
performance_benchmark.record('latency_under_load', latency)
|
performance_benchmark.record('latency_under_load', latency)
|
||||||
|
|
@ -405,13 +400,13 @@ class TestAsyncPerformance:
|
||||||
avg_latency = sum(latencies) / len(latencies)
|
avg_latency = sum(latencies) / len(latencies)
|
||||||
max_latency = max(latencies)
|
max_latency = max(latencies)
|
||||||
|
|
||||||
print(f"\nLatency Under Load:")
|
print('\nLatency Under Load:')
|
||||||
print(f" Average: {avg_latency:.3f}s")
|
print(f' Average: {avg_latency:.3f}s')
|
||||||
print(f" Max: {max_latency:.3f}s")
|
print(f' Max: {max_latency:.3f}s')
|
||||||
|
|
||||||
# Assert acceptable latency
|
# Assert acceptable latency
|
||||||
assert avg_latency < 2.0, f"Average latency too high: {avg_latency:.3f}s"
|
assert avg_latency < 2.0, f'Average latency too high: {avg_latency:.3f}s'
|
||||||
assert max_latency < 5.0, f"Max latency too high: {max_latency:.3f}s"
|
assert max_latency < 5.0, f'Max latency too high: {max_latency:.3f}s'
|
||||||
|
|
||||||
|
|
||||||
class TestAsyncStreamHandling:
|
class TestAsyncStreamHandling:
|
||||||
|
|
@ -431,7 +426,7 @@ class TestAsyncStreamHandling:
|
||||||
'source': 'text',
|
'source': 'text',
|
||||||
'source_description': 'stream test',
|
'source_description': 'stream test',
|
||||||
'group_id': group_id,
|
'group_id': group_id,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Wait for processing
|
# Wait for processing
|
||||||
|
|
@ -443,12 +438,12 @@ class TestAsyncStreamHandling:
|
||||||
{
|
{
|
||||||
'group_id': group_id,
|
'group_id': group_id,
|
||||||
'last_n': 100, # Request all
|
'last_n': 100, # Request all
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify response handling
|
# Verify response handling
|
||||||
episodes = json.loads(result.content[0].text)['episodes']
|
episodes = json.loads(result.content[0].text)['episodes']
|
||||||
assert len(episodes) >= 20, f"Expected at least 20 episodes, got {len(episodes)}"
|
assert len(episodes) >= 20, f'Expected at least 20 episodes, got {len(episodes)}'
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_incremental_processing(self):
|
async def test_incremental_processing(self):
|
||||||
|
|
@ -466,7 +461,7 @@ class TestAsyncStreamHandling:
|
||||||
'source': 'text',
|
'source': 'text',
|
||||||
'source_description': 'incremental test',
|
'source_description': 'incremental test',
|
||||||
'group_id': group_id,
|
'group_id': group_id,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
batch_tasks.append(task)
|
batch_tasks.append(task)
|
||||||
|
|
||||||
|
|
@ -482,13 +477,15 @@ class TestAsyncStreamHandling:
|
||||||
{
|
{
|
||||||
'group_id': group_id,
|
'group_id': group_id,
|
||||||
'last_n': 100,
|
'last_n': 100,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
episodes = json.loads(result.content[0].text)['episodes']
|
episodes = json.loads(result.content[0].text)['episodes']
|
||||||
expected_min = (batch + 1) * 5
|
expected_min = (batch + 1) * 5
|
||||||
assert len(episodes) >= expected_min, f"Batch {batch}: Expected at least {expected_min} episodes"
|
assert len(episodes) >= expected_min, (
|
||||||
|
f'Batch {batch}: Expected at least {expected_min} episodes'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == '__main__':
|
||||||
pytest.main([__file__, "-v", "--asyncio-mode=auto"])
|
pytest.main([__file__, '-v', '--asyncio-mode=auto'])
|
||||||
|
|
|
||||||
|
|
@ -9,9 +9,7 @@ import json
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import datetime, timedelta
|
from typing import Any
|
||||||
from typing import Any, Dict, List, Optional
|
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from mcp import ClientSession, StdioServerParameters
|
from mcp import ClientSession, StdioServerParameters
|
||||||
|
|
@ -26,7 +24,7 @@ class TestMetrics:
|
||||||
start_time: float
|
start_time: float
|
||||||
end_time: float
|
end_time: float
|
||||||
success: bool
|
success: bool
|
||||||
details: Dict[str, Any]
|
details: dict[str, Any]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def duration(self) -> float:
|
def duration(self) -> float:
|
||||||
|
|
@ -37,10 +35,10 @@ class TestMetrics:
|
||||||
class GraphitiTestClient:
|
class GraphitiTestClient:
|
||||||
"""Enhanced test client for comprehensive Graphiti MCP testing."""
|
"""Enhanced test client for comprehensive Graphiti MCP testing."""
|
||||||
|
|
||||||
def __init__(self, test_group_id: Optional[str] = None):
|
def __init__(self, test_group_id: str | None = None):
|
||||||
self.test_group_id = test_group_id or f'test_{int(time.time())}'
|
self.test_group_id = test_group_id or f'test_{int(time.time())}'
|
||||||
self.session = None
|
self.session = None
|
||||||
self.metrics: List[TestMetrics] = []
|
self.metrics: list[TestMetrics] = []
|
||||||
self.default_timeout = 30 # seconds
|
self.default_timeout = 30 # seconds
|
||||||
|
|
||||||
async def __aenter__(self):
|
async def __aenter__(self):
|
||||||
|
|
@ -76,7 +74,7 @@ class GraphitiTestClient:
|
||||||
await self.client_context.__aexit__(exc_type, exc_val, exc_tb)
|
await self.client_context.__aexit__(exc_type, exc_val, exc_tb)
|
||||||
|
|
||||||
async def call_tool_with_metrics(
|
async def call_tool_with_metrics(
|
||||||
self, tool_name: str, arguments: Dict[str, Any], timeout: Optional[float] = None
|
self, tool_name: str, arguments: dict[str, Any], timeout: float | None = None
|
||||||
) -> tuple[Any, TestMetrics]:
|
) -> tuple[Any, TestMetrics]:
|
||||||
"""Call a tool and capture performance metrics."""
|
"""Call a tool and capture performance metrics."""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
@ -84,8 +82,7 @@ class GraphitiTestClient:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = await asyncio.wait_for(
|
result = await asyncio.wait_for(
|
||||||
self.session.call_tool(tool_name, arguments),
|
self.session.call_tool(tool_name, arguments), timeout=timeout
|
||||||
timeout=timeout
|
|
||||||
)
|
)
|
||||||
|
|
||||||
content = result.content[0].text if result.content else None
|
content = result.content[0].text if result.content else None
|
||||||
|
|
@ -108,7 +105,7 @@ class GraphitiTestClient:
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
end_time=end_time,
|
end_time=end_time,
|
||||||
success=success,
|
success=success,
|
||||||
details=details
|
details=details,
|
||||||
)
|
)
|
||||||
self.metrics.append(metric)
|
self.metrics.append(metric)
|
||||||
|
|
||||||
|
|
@ -132,8 +129,7 @@ class GraphitiTestClient:
|
||||||
|
|
||||||
while (time.time() - start_time) < max_wait:
|
while (time.time() - start_time) < max_wait:
|
||||||
result, _ = await self.call_tool_with_metrics(
|
result, _ = await self.call_tool_with_metrics(
|
||||||
'get_episodes',
|
'get_episodes', {'group_id': self.test_group_id, 'last_n': 100}
|
||||||
{'group_id': self.test_group_id, 'last_n': 100}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if result:
|
if result:
|
||||||
|
|
@ -168,11 +164,11 @@ class TestCoreOperations:
|
||||||
'delete_entity_edge',
|
'delete_entity_edge',
|
||||||
'get_entity_edge',
|
'get_entity_edge',
|
||||||
'clear_graph',
|
'clear_graph',
|
||||||
'get_status'
|
'get_status',
|
||||||
}
|
}
|
||||||
|
|
||||||
missing_tools = required_tools - tools
|
missing_tools = required_tools - tools
|
||||||
assert not missing_tools, f"Missing required tools: {missing_tools}"
|
assert not missing_tools, f'Missing required tools: {missing_tools}'
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_add_text_memory(self):
|
async def test_add_text_memory(self):
|
||||||
|
|
@ -187,15 +183,15 @@ class TestCoreOperations:
|
||||||
'source': 'text',
|
'source': 'text',
|
||||||
'source_description': 'conference notes',
|
'source_description': 'conference notes',
|
||||||
'group_id': client.test_group_id,
|
'group_id': client.test_group_id,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
assert metric.success, f"Failed to add memory: {metric.details}"
|
assert metric.success, f'Failed to add memory: {metric.details}'
|
||||||
assert 'queued' in str(result).lower()
|
assert 'queued' in str(result).lower()
|
||||||
|
|
||||||
# Wait for processing
|
# Wait for processing
|
||||||
processed = await client.wait_for_episode_processing(expected_count=1)
|
processed = await client.wait_for_episode_processing(expected_count=1)
|
||||||
assert processed, "Episode was not processed within timeout"
|
assert processed, 'Episode was not processed within timeout'
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_add_json_memory(self):
|
async def test_add_json_memory(self):
|
||||||
|
|
@ -205,12 +201,9 @@ class TestCoreOperations:
|
||||||
'project': {
|
'project': {
|
||||||
'name': 'GraphitiDB',
|
'name': 'GraphitiDB',
|
||||||
'version': '2.0.0',
|
'version': '2.0.0',
|
||||||
'features': ['temporal-awareness', 'hybrid-search', 'custom-entities']
|
'features': ['temporal-awareness', 'hybrid-search', 'custom-entities'],
|
||||||
},
|
},
|
||||||
'team': {
|
'team': {'size': 5, 'roles': ['engineering', 'product', 'research']},
|
||||||
'size': 5,
|
|
||||||
'roles': ['engineering', 'product', 'research']
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
result, metric = await client.call_tool_with_metrics(
|
result, metric = await client.call_tool_with_metrics(
|
||||||
|
|
@ -221,7 +214,7 @@ class TestCoreOperations:
|
||||||
'source': 'json',
|
'source': 'json',
|
||||||
'source_description': 'project database',
|
'source_description': 'project database',
|
||||||
'group_id': client.test_group_id,
|
'group_id': client.test_group_id,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
assert metric.success
|
assert metric.success
|
||||||
|
|
@ -246,11 +239,11 @@ class TestCoreOperations:
|
||||||
'source': 'message',
|
'source': 'message',
|
||||||
'source_description': 'support chat',
|
'source_description': 'support chat',
|
||||||
'group_id': client.test_group_id,
|
'group_id': client.test_group_id,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
assert metric.success
|
assert metric.success
|
||||||
assert metric.duration < 5, f"Add memory took too long: {metric.duration}s"
|
assert metric.duration < 5, f'Add memory took too long: {metric.duration}s'
|
||||||
|
|
||||||
|
|
||||||
class TestSearchOperations:
|
class TestSearchOperations:
|
||||||
|
|
@ -269,7 +262,7 @@ class TestSearchOperations:
|
||||||
'source': 'text',
|
'source': 'text',
|
||||||
'source_description': 'product roadmap',
|
'source_description': 'product roadmap',
|
||||||
'group_id': client.test_group_id,
|
'group_id': client.test_group_id,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Wait for processing
|
# Wait for processing
|
||||||
|
|
@ -278,11 +271,7 @@ class TestSearchOperations:
|
||||||
# Search for nodes
|
# Search for nodes
|
||||||
result, metric = await client.call_tool_with_metrics(
|
result, metric = await client.call_tool_with_metrics(
|
||||||
'search_memory_nodes',
|
'search_memory_nodes',
|
||||||
{
|
{'query': 'AI product features', 'group_id': client.test_group_id, 'limit': 10},
|
||||||
'query': 'AI product features',
|
|
||||||
'group_id': client.test_group_id,
|
|
||||||
'limit': 10
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
assert metric.success
|
assert metric.success
|
||||||
|
|
@ -301,7 +290,7 @@ class TestSearchOperations:
|
||||||
'source': 'text',
|
'source': 'text',
|
||||||
'source_description': 'company profile',
|
'source_description': 'company profile',
|
||||||
'group_id': client.test_group_id,
|
'group_id': client.test_group_id,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
await client.wait_for_episode_processing()
|
await client.wait_for_episode_processing()
|
||||||
|
|
@ -313,8 +302,8 @@ class TestSearchOperations:
|
||||||
'query': 'company information',
|
'query': 'company information',
|
||||||
'group_id': client.test_group_id,
|
'group_id': client.test_group_id,
|
||||||
'created_after': '2020-01-01T00:00:00Z',
|
'created_after': '2020-01-01T00:00:00Z',
|
||||||
'limit': 20
|
'limit': 20,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
assert metric.success
|
assert metric.success
|
||||||
|
|
@ -328,13 +317,13 @@ class TestSearchOperations:
|
||||||
{
|
{
|
||||||
'name': 'Technical Doc',
|
'name': 'Technical Doc',
|
||||||
'episode_body': 'GraphQL API endpoints support pagination, filtering, and real-time subscriptions.',
|
'episode_body': 'GraphQL API endpoints support pagination, filtering, and real-time subscriptions.',
|
||||||
'source': 'text'
|
'source': 'text',
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
'name': 'Architecture',
|
'name': 'Architecture',
|
||||||
'episode_body': 'The system uses Neo4j for graph storage and OpenAI embeddings for semantic search.',
|
'episode_body': 'The system uses Neo4j for graph storage and OpenAI embeddings for semantic search.',
|
||||||
'source': 'text'
|
'source': 'text',
|
||||||
}
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
for memory in test_memories:
|
for memory in test_memories:
|
||||||
|
|
@ -347,11 +336,7 @@ class TestSearchOperations:
|
||||||
# Test semantic + keyword search
|
# Test semantic + keyword search
|
||||||
result, metric = await client.call_tool_with_metrics(
|
result, metric = await client.call_tool_with_metrics(
|
||||||
'search_memory_nodes',
|
'search_memory_nodes',
|
||||||
{
|
{'query': 'Neo4j graph database', 'group_id': client.test_group_id, 'limit': 10},
|
||||||
'query': 'Neo4j graph database',
|
|
||||||
'group_id': client.test_group_id,
|
|
||||||
'limit': 10
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
assert metric.success
|
assert metric.success
|
||||||
|
|
@ -374,18 +359,14 @@ class TestEpisodeManagement:
|
||||||
'source': 'text',
|
'source': 'text',
|
||||||
'source_description': 'test',
|
'source_description': 'test',
|
||||||
'group_id': client.test_group_id,
|
'group_id': client.test_group_id,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
await client.wait_for_episode_processing(expected_count=5)
|
await client.wait_for_episode_processing(expected_count=5)
|
||||||
|
|
||||||
# Test pagination
|
# Test pagination
|
||||||
result, metric = await client.call_tool_with_metrics(
|
result, metric = await client.call_tool_with_metrics(
|
||||||
'get_episodes',
|
'get_episodes', {'group_id': client.test_group_id, 'last_n': 3}
|
||||||
{
|
|
||||||
'group_id': client.test_group_id,
|
|
||||||
'last_n': 3
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
assert metric.success
|
assert metric.success
|
||||||
|
|
@ -405,15 +386,14 @@ class TestEpisodeManagement:
|
||||||
'source': 'text',
|
'source': 'text',
|
||||||
'source_description': 'test',
|
'source_description': 'test',
|
||||||
'group_id': client.test_group_id,
|
'group_id': client.test_group_id,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
await client.wait_for_episode_processing()
|
await client.wait_for_episode_processing()
|
||||||
|
|
||||||
# Get episode UUID
|
# Get episode UUID
|
||||||
result, _ = await client.call_tool_with_metrics(
|
result, _ = await client.call_tool_with_metrics(
|
||||||
'get_episodes',
|
'get_episodes', {'group_id': client.test_group_id, 'last_n': 1}
|
||||||
{'group_id': client.test_group_id, 'last_n': 1}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
episodes = json.loads(result) if isinstance(result, str) else result
|
episodes = json.loads(result) if isinstance(result, str) else result
|
||||||
|
|
@ -421,8 +401,7 @@ class TestEpisodeManagement:
|
||||||
|
|
||||||
# Delete the episode
|
# Delete the episode
|
||||||
result, metric = await client.call_tool_with_metrics(
|
result, metric = await client.call_tool_with_metrics(
|
||||||
'delete_episode',
|
'delete_episode', {'episode_uuid': episode_uuid}
|
||||||
{'episode_uuid': episode_uuid}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
assert metric.success
|
assert metric.success
|
||||||
|
|
@ -445,7 +424,7 @@ class TestEntityAndEdgeOperations:
|
||||||
'source': 'text',
|
'source': 'text',
|
||||||
'source_description': 'org chart',
|
'source_description': 'org chart',
|
||||||
'group_id': client.test_group_id,
|
'group_id': client.test_group_id,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
await client.wait_for_episode_processing()
|
await client.wait_for_episode_processing()
|
||||||
|
|
@ -453,11 +432,7 @@ class TestEntityAndEdgeOperations:
|
||||||
# Search for nodes to get UUIDs
|
# Search for nodes to get UUIDs
|
||||||
result, _ = await client.call_tool_with_metrics(
|
result, _ = await client.call_tool_with_metrics(
|
||||||
'search_memory_nodes',
|
'search_memory_nodes',
|
||||||
{
|
{'query': 'TechCorp', 'group_id': client.test_group_id, 'limit': 5},
|
||||||
'query': 'TechCorp',
|
|
||||||
'group_id': client.test_group_id,
|
|
||||||
'limit': 5
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Note: This test assumes edges are created between entities
|
# Note: This test assumes edges are created between entities
|
||||||
|
|
@ -480,7 +455,7 @@ class TestErrorHandling:
|
||||||
# Missing required arguments
|
# Missing required arguments
|
||||||
result, metric = await client.call_tool_with_metrics(
|
result, metric = await client.call_tool_with_metrics(
|
||||||
'add_memory',
|
'add_memory',
|
||||||
{'name': 'Incomplete'} # Missing required fields
|
{'name': 'Incomplete'}, # Missing required fields
|
||||||
)
|
)
|
||||||
|
|
||||||
assert not metric.success
|
assert not metric.success
|
||||||
|
|
@ -491,7 +466,7 @@ class TestErrorHandling:
|
||||||
"""Test timeout handling for long operations."""
|
"""Test timeout handling for long operations."""
|
||||||
async with GraphitiTestClient() as client:
|
async with GraphitiTestClient() as client:
|
||||||
# Simulate a very large episode that might timeout
|
# Simulate a very large episode that might timeout
|
||||||
large_text = "Large document content. " * 10000
|
large_text = 'Large document content. ' * 10000
|
||||||
|
|
||||||
result, metric = await client.call_tool_with_metrics(
|
result, metric = await client.call_tool_with_metrics(
|
||||||
'add_memory',
|
'add_memory',
|
||||||
|
|
@ -502,7 +477,7 @@ class TestErrorHandling:
|
||||||
'source_description': 'large file',
|
'source_description': 'large file',
|
||||||
'group_id': client.test_group_id,
|
'group_id': client.test_group_id,
|
||||||
},
|
},
|
||||||
timeout=5 # Short timeout
|
timeout=5, # Short timeout
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if timeout was handled gracefully
|
# Check if timeout was handled gracefully
|
||||||
|
|
@ -524,7 +499,7 @@ class TestErrorHandling:
|
||||||
'source': 'text',
|
'source': 'text',
|
||||||
'source_description': 'concurrent test',
|
'source_description': 'concurrent test',
|
||||||
'group_id': client.test_group_id,
|
'group_id': client.test_group_id,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
tasks.append(task)
|
tasks.append(task)
|
||||||
|
|
||||||
|
|
@ -543,35 +518,34 @@ class TestPerformance:
|
||||||
"""Measure and validate operation latencies."""
|
"""Measure and validate operation latencies."""
|
||||||
async with GraphitiTestClient() as client:
|
async with GraphitiTestClient() as client:
|
||||||
operations = [
|
operations = [
|
||||||
('add_memory', {
|
(
|
||||||
'name': 'Perf Test',
|
'add_memory',
|
||||||
'episode_body': 'Simple text',
|
{
|
||||||
'source': 'text',
|
'name': 'Perf Test',
|
||||||
'source_description': 'test',
|
'episode_body': 'Simple text',
|
||||||
'group_id': client.test_group_id,
|
'source': 'text',
|
||||||
}),
|
'source_description': 'test',
|
||||||
('search_memory_nodes', {
|
'group_id': client.test_group_id,
|
||||||
'query': 'test',
|
},
|
||||||
'group_id': client.test_group_id,
|
),
|
||||||
'limit': 10
|
(
|
||||||
}),
|
'search_memory_nodes',
|
||||||
('get_episodes', {
|
{'query': 'test', 'group_id': client.test_group_id, 'limit': 10},
|
||||||
'group_id': client.test_group_id,
|
),
|
||||||
'last_n': 10
|
('get_episodes', {'group_id': client.test_group_id, 'last_n': 10}),
|
||||||
})
|
|
||||||
]
|
]
|
||||||
|
|
||||||
for tool_name, args in operations:
|
for tool_name, args in operations:
|
||||||
_, metric = await client.call_tool_with_metrics(tool_name, args)
|
_, metric = await client.call_tool_with_metrics(tool_name, args)
|
||||||
|
|
||||||
# Log performance metrics
|
# Log performance metrics
|
||||||
print(f"{tool_name}: {metric.duration:.2f}s")
|
print(f'{tool_name}: {metric.duration:.2f}s')
|
||||||
|
|
||||||
# Basic latency assertions
|
# Basic latency assertions
|
||||||
if tool_name == 'get_episodes':
|
if tool_name == 'get_episodes':
|
||||||
assert metric.duration < 2, f"{tool_name} too slow"
|
assert metric.duration < 2, f'{tool_name} too slow'
|
||||||
elif tool_name == 'search_memory_nodes':
|
elif tool_name == 'search_memory_nodes':
|
||||||
assert metric.duration < 10, f"{tool_name} too slow"
|
assert metric.duration < 10, f'{tool_name} too slow'
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_batch_processing_efficiency(self):
|
async def test_batch_processing_efficiency(self):
|
||||||
|
|
@ -590,33 +564,35 @@ class TestPerformance:
|
||||||
'source': 'text',
|
'source': 'text',
|
||||||
'source_description': 'batch test',
|
'source_description': 'batch test',
|
||||||
'group_id': client.test_group_id,
|
'group_id': client.test_group_id,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Wait for all to process
|
# Wait for all to process
|
||||||
processed = await client.wait_for_episode_processing(
|
processed = await client.wait_for_episode_processing(
|
||||||
expected_count=batch_size,
|
expected_count=batch_size,
|
||||||
max_wait=120 # Allow more time for batch
|
max_wait=120, # Allow more time for batch
|
||||||
)
|
)
|
||||||
|
|
||||||
total_time = time.time() - start_time
|
total_time = time.time() - start_time
|
||||||
avg_time_per_item = total_time / batch_size
|
avg_time_per_item = total_time / batch_size
|
||||||
|
|
||||||
assert processed, f"Failed to process {batch_size} items"
|
assert processed, f'Failed to process {batch_size} items'
|
||||||
assert avg_time_per_item < 15, f"Batch processing too slow: {avg_time_per_item:.2f}s per item"
|
assert avg_time_per_item < 15, (
|
||||||
|
f'Batch processing too slow: {avg_time_per_item:.2f}s per item'
|
||||||
|
)
|
||||||
|
|
||||||
# Generate performance report
|
# Generate performance report
|
||||||
print(f"\nBatch Performance Report:")
|
print('\nBatch Performance Report:')
|
||||||
print(f" Total items: {batch_size}")
|
print(f' Total items: {batch_size}')
|
||||||
print(f" Total time: {total_time:.2f}s")
|
print(f' Total time: {total_time:.2f}s')
|
||||||
print(f" Avg per item: {avg_time_per_item:.2f}s")
|
print(f' Avg per item: {avg_time_per_item:.2f}s')
|
||||||
|
|
||||||
|
|
||||||
class TestDatabaseBackends:
|
class TestDatabaseBackends:
|
||||||
"""Test different database backend configurations."""
|
"""Test different database backend configurations."""
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize("database", ["neo4j", "falkordb", "kuzu"])
|
@pytest.mark.parametrize('database', ['neo4j', 'falkordb', 'kuzu'])
|
||||||
async def test_database_operations(self, database):
|
async def test_database_operations(self, database):
|
||||||
"""Test operations with different database backends."""
|
"""Test operations with different database backends."""
|
||||||
env_vars = {
|
env_vars = {
|
||||||
|
|
@ -625,53 +601,50 @@ class TestDatabaseBackends:
|
||||||
}
|
}
|
||||||
|
|
||||||
if database == 'neo4j':
|
if database == 'neo4j':
|
||||||
env_vars.update({
|
env_vars.update(
|
||||||
'NEO4J_URI': os.environ.get('NEO4J_URI', 'bolt://localhost:7687'),
|
{
|
||||||
'NEO4J_USER': os.environ.get('NEO4J_USER', 'neo4j'),
|
'NEO4J_URI': os.environ.get('NEO4J_URI', 'bolt://localhost:7687'),
|
||||||
'NEO4J_PASSWORD': os.environ.get('NEO4J_PASSWORD', 'graphiti'),
|
'NEO4J_USER': os.environ.get('NEO4J_USER', 'neo4j'),
|
||||||
})
|
'NEO4J_PASSWORD': os.environ.get('NEO4J_PASSWORD', 'graphiti'),
|
||||||
|
}
|
||||||
|
)
|
||||||
elif database == 'falkordb':
|
elif database == 'falkordb':
|
||||||
env_vars['FALKORDB_URI'] = os.environ.get('FALKORDB_URI', 'redis://localhost:6379')
|
env_vars['FALKORDB_URI'] = os.environ.get('FALKORDB_URI', 'redis://localhost:6379')
|
||||||
elif database == 'kuzu':
|
elif database == 'kuzu':
|
||||||
env_vars['KUZU_PATH'] = os.environ.get('KUZU_PATH', f'./test_kuzu_{int(time.time())}.db')
|
env_vars['KUZU_PATH'] = os.environ.get(
|
||||||
|
'KUZU_PATH', f'./test_kuzu_{int(time.time())}.db'
|
||||||
|
)
|
||||||
|
|
||||||
server_params = StdioServerParameters(
|
# This test would require setting up server with specific database
|
||||||
command='uv',
|
|
||||||
args=['run', 'main.py', '--transport', 'stdio', '--database', database],
|
|
||||||
env=env_vars
|
|
||||||
)
|
|
||||||
|
|
||||||
# Test basic operations with each backend
|
|
||||||
# Implementation depends on database availability
|
# Implementation depends on database availability
|
||||||
|
pass # Placeholder for database-specific tests
|
||||||
|
|
||||||
|
|
||||||
def generate_test_report(client: GraphitiTestClient) -> str:
|
def generate_test_report(client: GraphitiTestClient) -> str:
|
||||||
"""Generate a comprehensive test report from metrics."""
|
"""Generate a comprehensive test report from metrics."""
|
||||||
if not client.metrics:
|
if not client.metrics:
|
||||||
return "No metrics collected"
|
return 'No metrics collected'
|
||||||
|
|
||||||
report = []
|
report = []
|
||||||
report.append("\n" + "="*60)
|
report.append('\n' + '=' * 60)
|
||||||
report.append("GRAPHITI MCP TEST REPORT")
|
report.append('GRAPHITI MCP TEST REPORT')
|
||||||
report.append("="*60)
|
report.append('=' * 60)
|
||||||
|
|
||||||
# Summary statistics
|
# Summary statistics
|
||||||
total_ops = len(client.metrics)
|
total_ops = len(client.metrics)
|
||||||
successful_ops = sum(1 for m in client.metrics if m.success)
|
successful_ops = sum(1 for m in client.metrics if m.success)
|
||||||
avg_duration = sum(m.duration for m in client.metrics) / total_ops
|
avg_duration = sum(m.duration for m in client.metrics) / total_ops
|
||||||
|
|
||||||
report.append(f"\nTotal Operations: {total_ops}")
|
report.append(f'\nTotal Operations: {total_ops}')
|
||||||
report.append(f"Successful: {successful_ops} ({successful_ops/total_ops*100:.1f}%)")
|
report.append(f'Successful: {successful_ops} ({successful_ops / total_ops * 100:.1f}%)')
|
||||||
report.append(f"Average Duration: {avg_duration:.2f}s")
|
report.append(f'Average Duration: {avg_duration:.2f}s')
|
||||||
|
|
||||||
# Operation breakdown
|
# Operation breakdown
|
||||||
report.append("\nOperation Breakdown:")
|
report.append('\nOperation Breakdown:')
|
||||||
operation_stats = {}
|
operation_stats = {}
|
||||||
for metric in client.metrics:
|
for metric in client.metrics:
|
||||||
if metric.operation not in operation_stats:
|
if metric.operation not in operation_stats:
|
||||||
operation_stats[metric.operation] = {
|
operation_stats[metric.operation] = {'count': 0, 'success': 0, 'total_duration': 0}
|
||||||
'count': 0, 'success': 0, 'total_duration': 0
|
|
||||||
}
|
|
||||||
stats = operation_stats[metric.operation]
|
stats = operation_stats[metric.operation]
|
||||||
stats['count'] += 1
|
stats['count'] += 1
|
||||||
stats['success'] += 1 if metric.success else 0
|
stats['success'] += 1 if metric.success else 0
|
||||||
|
|
@ -681,20 +654,19 @@ def generate_test_report(client: GraphitiTestClient) -> str:
|
||||||
avg_dur = stats['total_duration'] / stats['count']
|
avg_dur = stats['total_duration'] / stats['count']
|
||||||
success_rate = stats['success'] / stats['count'] * 100
|
success_rate = stats['success'] / stats['count'] * 100
|
||||||
report.append(
|
report.append(
|
||||||
f" {op}: {stats['count']} calls, "
|
f' {op}: {stats["count"]} calls, {success_rate:.0f}% success, {avg_dur:.2f}s avg'
|
||||||
f"{success_rate:.0f}% success, {avg_dur:.2f}s avg"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Slowest operations
|
# Slowest operations
|
||||||
slowest = sorted(client.metrics, key=lambda m: m.duration, reverse=True)[:5]
|
slowest = sorted(client.metrics, key=lambda m: m.duration, reverse=True)[:5]
|
||||||
report.append("\nSlowest Operations:")
|
report.append('\nSlowest Operations:')
|
||||||
for metric in slowest:
|
for metric in slowest:
|
||||||
report.append(f" {metric.operation}: {metric.duration:.2f}s")
|
report.append(f' {metric.operation}: {metric.duration:.2f}s')
|
||||||
|
|
||||||
report.append("="*60)
|
report.append('=' * 60)
|
||||||
return "\n".join(report)
|
return '\n'.join(report)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == '__main__':
|
||||||
# Run tests with pytest
|
# Run tests with pytest
|
||||||
pytest.main([__file__, "-v", "--asyncio-mode=auto"])
|
pytest.main([__file__, '-v', '--asyncio-mode=auto'])
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ import os
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from faker import Faker
|
from faker import Faker
|
||||||
|
|
@ -22,7 +22,7 @@ class TestDataGenerator:
|
||||||
"""Generate realistic test data for various scenarios."""
|
"""Generate realistic test data for various scenarios."""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def generate_company_profile() -> Dict[str, Any]:
|
def generate_company_profile() -> dict[str, Any]:
|
||||||
"""Generate a realistic company profile."""
|
"""Generate a realistic company profile."""
|
||||||
return {
|
return {
|
||||||
'company': {
|
'company': {
|
||||||
|
|
@ -30,7 +30,7 @@ class TestDataGenerator:
|
||||||
'founded': random.randint(1990, 2023),
|
'founded': random.randint(1990, 2023),
|
||||||
'industry': random.choice(['Tech', 'Finance', 'Healthcare', 'Retail']),
|
'industry': random.choice(['Tech', 'Finance', 'Healthcare', 'Retail']),
|
||||||
'employees': random.randint(10, 10000),
|
'employees': random.randint(10, 10000),
|
||||||
'revenue': f"${random.randint(1, 1000)}M",
|
'revenue': f'${random.randint(1, 1000)}M',
|
||||||
'headquarters': fake.city(),
|
'headquarters': fake.city(),
|
||||||
},
|
},
|
||||||
'products': [
|
'products': [
|
||||||
|
|
@ -46,41 +46,41 @@ class TestDataGenerator:
|
||||||
'ceo': fake.name(),
|
'ceo': fake.name(),
|
||||||
'cto': fake.name(),
|
'cto': fake.name(),
|
||||||
'cfo': fake.name(),
|
'cfo': fake.name(),
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def generate_conversation(turns: int = 3) -> str:
|
def generate_conversation(turns: int = 3) -> str:
|
||||||
"""Generate a realistic conversation."""
|
"""Generate a realistic conversation."""
|
||||||
topics = [
|
topics = [
|
||||||
"product features",
|
'product features',
|
||||||
"pricing",
|
'pricing',
|
||||||
"technical support",
|
'technical support',
|
||||||
"integration",
|
'integration',
|
||||||
"documentation",
|
'documentation',
|
||||||
"performance",
|
'performance',
|
||||||
]
|
]
|
||||||
|
|
||||||
conversation = []
|
conversation = []
|
||||||
for _ in range(turns):
|
for _ in range(turns):
|
||||||
topic = random.choice(topics)
|
topic = random.choice(topics)
|
||||||
user_msg = f"user: {fake.sentence()} about {topic}?"
|
user_msg = f'user: {fake.sentence()} about {topic}?'
|
||||||
assistant_msg = f"assistant: {fake.paragraph(nb_sentences=2)}"
|
assistant_msg = f'assistant: {fake.paragraph(nb_sentences=2)}'
|
||||||
conversation.extend([user_msg, assistant_msg])
|
conversation.extend([user_msg, assistant_msg])
|
||||||
|
|
||||||
return "\n".join(conversation)
|
return '\n'.join(conversation)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def generate_technical_document() -> str:
|
def generate_technical_document() -> str:
|
||||||
"""Generate technical documentation content."""
|
"""Generate technical documentation content."""
|
||||||
sections = [
|
sections = [
|
||||||
f"# {fake.catch_phrase()}\n\n{fake.paragraph()}",
|
f'# {fake.catch_phrase()}\n\n{fake.paragraph()}',
|
||||||
f"## Architecture\n{fake.paragraph()}",
|
f'## Architecture\n{fake.paragraph()}',
|
||||||
f"## Implementation\n{fake.paragraph()}",
|
f'## Implementation\n{fake.paragraph()}',
|
||||||
f"## Performance\n- Latency: {random.randint(1, 100)}ms\n- Throughput: {random.randint(100, 10000)} req/s",
|
f'## Performance\n- Latency: {random.randint(1, 100)}ms\n- Throughput: {random.randint(100, 10000)} req/s',
|
||||||
f"## Dependencies\n- {fake.word()}\n- {fake.word()}\n- {fake.word()}",
|
f'## Dependencies\n- {fake.word()}\n- {fake.word()}\n- {fake.word()}',
|
||||||
]
|
]
|
||||||
return "\n\n".join(sections)
|
return '\n\n'.join(sections)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def generate_news_article() -> str:
|
def generate_news_article() -> str:
|
||||||
|
|
@ -100,7 +100,7 @@ class TestDataGenerator:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def generate_user_profile() -> Dict[str, Any]:
|
def generate_user_profile() -> dict[str, Any]:
|
||||||
"""Generate a user profile."""
|
"""Generate a user profile."""
|
||||||
return {
|
return {
|
||||||
'user_id': fake.uuid4(),
|
'user_id': fake.uuid4(),
|
||||||
|
|
@ -115,8 +115,8 @@ class TestDataGenerator:
|
||||||
'activity': {
|
'activity': {
|
||||||
'last_login': fake.date_time_this_month().isoformat(),
|
'last_login': fake.date_time_this_month().isoformat(),
|
||||||
'total_sessions': random.randint(1, 1000),
|
'total_sessions': random.randint(1, 1000),
|
||||||
'average_duration': f"{random.randint(1, 60)} minutes",
|
'average_duration': f'{random.randint(1, 60)} minutes',
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -131,25 +131,27 @@ class MockLLMProvider:
|
||||||
await asyncio.sleep(self.delay)
|
await asyncio.sleep(self.delay)
|
||||||
|
|
||||||
# Return deterministic responses based on prompt patterns
|
# Return deterministic responses based on prompt patterns
|
||||||
if "extract entities" in prompt.lower():
|
if 'extract entities' in prompt.lower():
|
||||||
return json.dumps({
|
return json.dumps(
|
||||||
'entities': [
|
{
|
||||||
{'name': 'TestEntity1', 'type': 'PERSON'},
|
'entities': [
|
||||||
{'name': 'TestEntity2', 'type': 'ORGANIZATION'},
|
{'name': 'TestEntity1', 'type': 'PERSON'},
|
||||||
]
|
{'name': 'TestEntity2', 'type': 'ORGANIZATION'},
|
||||||
})
|
]
|
||||||
elif "summarize" in prompt.lower():
|
}
|
||||||
return "This is a test summary of the provided content."
|
)
|
||||||
|
elif 'summarize' in prompt.lower():
|
||||||
|
return 'This is a test summary of the provided content.'
|
||||||
else:
|
else:
|
||||||
return "Mock LLM response"
|
return 'Mock LLM response'
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def graphiti_test_client(
|
async def graphiti_test_client(
|
||||||
group_id: Optional[str] = None,
|
group_id: str | None = None,
|
||||||
database: str = "kuzu",
|
database: str = 'kuzu',
|
||||||
use_mock_llm: bool = False,
|
use_mock_llm: bool = False,
|
||||||
config_overrides: Optional[Dict[str, Any]] = None
|
config_overrides: dict[str, Any] | None = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Context manager for creating test clients with various configurations.
|
Context manager for creating test clients with various configurations.
|
||||||
|
|
@ -169,11 +171,13 @@ async def graphiti_test_client(
|
||||||
|
|
||||||
# Database-specific configuration
|
# Database-specific configuration
|
||||||
if database == 'neo4j':
|
if database == 'neo4j':
|
||||||
env.update({
|
env.update(
|
||||||
'NEO4J_URI': os.environ.get('NEO4J_URI', 'bolt://localhost:7687'),
|
{
|
||||||
'NEO4J_USER': os.environ.get('NEO4J_USER', 'neo4j'),
|
'NEO4J_URI': os.environ.get('NEO4J_URI', 'bolt://localhost:7687'),
|
||||||
'NEO4J_PASSWORD': os.environ.get('NEO4J_PASSWORD', 'graphiti'),
|
'NEO4J_USER': os.environ.get('NEO4J_USER', 'neo4j'),
|
||||||
})
|
'NEO4J_PASSWORD': os.environ.get('NEO4J_PASSWORD', 'graphiti'),
|
||||||
|
}
|
||||||
|
)
|
||||||
elif database == 'falkordb':
|
elif database == 'falkordb':
|
||||||
env['FALKORDB_URI'] = os.environ.get('FALKORDB_URI', 'redis://localhost:6379')
|
env['FALKORDB_URI'] = os.environ.get('FALKORDB_URI', 'redis://localhost:6379')
|
||||||
elif database == 'kuzu':
|
elif database == 'kuzu':
|
||||||
|
|
@ -188,9 +192,7 @@ async def graphiti_test_client(
|
||||||
env['USE_MOCK_LLM'] = 'true'
|
env['USE_MOCK_LLM'] = 'true'
|
||||||
|
|
||||||
server_params = StdioServerParameters(
|
server_params = StdioServerParameters(
|
||||||
command='uv',
|
command='uv', args=['run', 'main.py', '--transport', 'stdio'], env=env
|
||||||
args=['run', 'main.py', '--transport', 'stdio'],
|
|
||||||
env=env
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async with stdio_client(server_params) as (read, write):
|
async with stdio_client(server_params) as (read, write):
|
||||||
|
|
@ -203,7 +205,7 @@ async def graphiti_test_client(
|
||||||
# Cleanup: Clear test data
|
# Cleanup: Clear test data
|
||||||
try:
|
try:
|
||||||
await session.call_tool('clear_graph', {'group_id': test_group_id})
|
await session.call_tool('clear_graph', {'group_id': test_group_id})
|
||||||
except:
|
except Exception:
|
||||||
pass # Ignore cleanup errors
|
pass # Ignore cleanup errors
|
||||||
|
|
||||||
await session.close()
|
await session.close()
|
||||||
|
|
@ -213,7 +215,7 @@ class PerformanceBenchmark:
|
||||||
"""Track and analyze performance benchmarks."""
|
"""Track and analyze performance benchmarks."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.measurements: Dict[str, List[float]] = {}
|
self.measurements: dict[str, list[float]] = {}
|
||||||
|
|
||||||
def record(self, operation: str, duration: float):
|
def record(self, operation: str, duration: float):
|
||||||
"""Record a performance measurement."""
|
"""Record a performance measurement."""
|
||||||
|
|
@ -221,7 +223,7 @@ class PerformanceBenchmark:
|
||||||
self.measurements[operation] = []
|
self.measurements[operation] = []
|
||||||
self.measurements[operation].append(duration)
|
self.measurements[operation].append(duration)
|
||||||
|
|
||||||
def get_stats(self, operation: str) -> Dict[str, float]:
|
def get_stats(self, operation: str) -> dict[str, float]:
|
||||||
"""Get statistics for an operation."""
|
"""Get statistics for an operation."""
|
||||||
if operation not in self.measurements or not self.measurements[operation]:
|
if operation not in self.measurements or not self.measurements[operation]:
|
||||||
return {}
|
return {}
|
||||||
|
|
@ -237,18 +239,18 @@ class PerformanceBenchmark:
|
||||||
|
|
||||||
def report(self) -> str:
|
def report(self) -> str:
|
||||||
"""Generate a performance report."""
|
"""Generate a performance report."""
|
||||||
lines = ["Performance Benchmark Report", "=" * 40]
|
lines = ['Performance Benchmark Report', '=' * 40]
|
||||||
|
|
||||||
for operation in sorted(self.measurements.keys()):
|
for operation in sorted(self.measurements.keys()):
|
||||||
stats = self.get_stats(operation)
|
stats = self.get_stats(operation)
|
||||||
lines.append(f"\n{operation}:")
|
lines.append(f'\n{operation}:')
|
||||||
lines.append(f" Samples: {stats['count']}")
|
lines.append(f' Samples: {stats["count"]}')
|
||||||
lines.append(f" Mean: {stats['mean']:.3f}s")
|
lines.append(f' Mean: {stats["mean"]:.3f}s')
|
||||||
lines.append(f" Median: {stats['median']:.3f}s")
|
lines.append(f' Median: {stats["median"]:.3f}s')
|
||||||
lines.append(f" Min: {stats['min']:.3f}s")
|
lines.append(f' Min: {stats["min"]:.3f}s')
|
||||||
lines.append(f" Max: {stats['max']:.3f}s")
|
lines.append(f' Max: {stats["max"]:.3f}s')
|
||||||
|
|
||||||
return "\n".join(lines)
|
return '\n'.join(lines)
|
||||||
|
|
||||||
|
|
||||||
# Pytest fixtures
|
# Pytest fixtures
|
||||||
|
|
|
||||||
|
|
@ -6,16 +6,13 @@ Tests system behavior under high load, resource constraints, and edge conditions
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import gc
|
import gc
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import psutil
|
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Optional, Tuple
|
|
||||||
|
|
||||||
|
import psutil
|
||||||
import pytest
|
import pytest
|
||||||
from test_fixtures import TestDataGenerator, graphiti_test_client, PerformanceBenchmark
|
from test_fixtures import TestDataGenerator, graphiti_test_client
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -26,7 +23,7 @@ class LoadTestConfig:
|
||||||
operations_per_client: int = 100
|
operations_per_client: int = 100
|
||||||
ramp_up_time: float = 5.0 # seconds
|
ramp_up_time: float = 5.0 # seconds
|
||||||
test_duration: float = 60.0 # seconds
|
test_duration: float = 60.0 # seconds
|
||||||
target_throughput: Optional[float] = None # ops/sec
|
target_throughput: float | None = None # ops/sec
|
||||||
think_time: float = 0.1 # seconds between ops
|
think_time: float = 0.1 # seconds between ops
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -44,8 +41,8 @@ class LoadTestResult:
|
||||||
p95_latency: float
|
p95_latency: float
|
||||||
p99_latency: float
|
p99_latency: float
|
||||||
max_latency: float
|
max_latency: float
|
||||||
errors: Dict[str, int]
|
errors: dict[str, int]
|
||||||
resource_usage: Dict[str, float]
|
resource_usage: dict[str, float]
|
||||||
|
|
||||||
|
|
||||||
class LoadTester:
|
class LoadTester:
|
||||||
|
|
@ -53,16 +50,11 @@ class LoadTester:
|
||||||
|
|
||||||
def __init__(self, config: LoadTestConfig):
|
def __init__(self, config: LoadTestConfig):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.metrics: List[Tuple[float, float, bool]] = [] # (start, duration, success)
|
self.metrics: list[tuple[float, float, bool]] = [] # (start, duration, success)
|
||||||
self.errors: Dict[str, int] = {}
|
self.errors: dict[str, int] = {}
|
||||||
self.start_time: Optional[float] = None
|
self.start_time: float | None = None
|
||||||
|
|
||||||
async def run_client_workload(
|
async def run_client_workload(self, client_id: int, session, group_id: str) -> dict[str, int]:
|
||||||
self,
|
|
||||||
client_id: int,
|
|
||||||
session,
|
|
||||||
group_id: str
|
|
||||||
) -> Dict[str, int]:
|
|
||||||
"""Run workload for a single simulated client."""
|
"""Run workload for a single simulated client."""
|
||||||
stats = {'success': 0, 'failure': 0}
|
stats = {'success': 0, 'failure': 0}
|
||||||
data_gen = TestDataGenerator()
|
data_gen = TestDataGenerator()
|
||||||
|
|
@ -76,11 +68,13 @@ class LoadTester:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Randomly select operation type
|
# Randomly select operation type
|
||||||
operation = random.choice([
|
operation = random.choice(
|
||||||
'add_memory',
|
[
|
||||||
'search_memory_nodes',
|
'add_memory',
|
||||||
'get_episodes',
|
'search_memory_nodes',
|
||||||
])
|
'get_episodes',
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
if operation == 'add_memory':
|
if operation == 'add_memory':
|
||||||
args = {
|
args = {
|
||||||
|
|
@ -103,10 +97,7 @@ class LoadTester:
|
||||||
}
|
}
|
||||||
|
|
||||||
# Execute operation with timeout
|
# Execute operation with timeout
|
||||||
result = await asyncio.wait_for(
|
await asyncio.wait_for(session.call_tool(operation, args), timeout=30.0)
|
||||||
session.call_tool(operation, args),
|
|
||||||
timeout=30.0
|
|
||||||
)
|
|
||||||
|
|
||||||
duration = time.time() - operation_start
|
duration = time.time() - operation_start
|
||||||
self.metrics.append((operation_start, duration, True))
|
self.metrics.append((operation_start, duration, True))
|
||||||
|
|
@ -146,7 +137,7 @@ class LoadTester:
|
||||||
duration = max([m[0] + m[1] for m in self.metrics]) - min([m[0] for m in self.metrics])
|
duration = max([m[0] + m[1] for m in self.metrics]) - min([m[0] for m in self.metrics])
|
||||||
|
|
||||||
# Calculate percentiles
|
# Calculate percentiles
|
||||||
def percentile(data: List[float], p: float) -> float:
|
def percentile(data: list[float], p: float) -> float:
|
||||||
if not data:
|
if not data:
|
||||||
return 0.0
|
return 0.0
|
||||||
idx = int(len(data) * p / 100)
|
idx = int(len(data) * p / 100)
|
||||||
|
|
@ -209,16 +200,20 @@ class TestLoadScenarios:
|
||||||
|
|
||||||
# Assertions
|
# Assertions
|
||||||
assert results.successful_operations > results.failed_operations
|
assert results.successful_operations > results.failed_operations
|
||||||
assert results.average_latency < 5.0, f"Average latency too high: {results.average_latency:.2f}s"
|
assert results.average_latency < 5.0, (
|
||||||
assert results.p95_latency < 10.0, f"P95 latency too high: {results.p95_latency:.2f}s"
|
f'Average latency too high: {results.average_latency:.2f}s'
|
||||||
|
)
|
||||||
|
assert results.p95_latency < 10.0, f'P95 latency too high: {results.p95_latency:.2f}s'
|
||||||
|
|
||||||
# Report results
|
# Report results
|
||||||
print(f"\nSustained Load Test Results:")
|
print('\nSustained Load Test Results:')
|
||||||
print(f" Total operations: {results.total_operations}")
|
print(f' Total operations: {results.total_operations}')
|
||||||
print(f" Success rate: {results.successful_operations / results.total_operations * 100:.1f}%")
|
print(
|
||||||
print(f" Throughput: {results.throughput:.2f} ops/s")
|
f' Success rate: {results.successful_operations / results.total_operations * 100:.1f}%'
|
||||||
print(f" Avg latency: {results.average_latency:.2f}s")
|
)
|
||||||
print(f" P95 latency: {results.p95_latency:.2f}s")
|
print(f' Throughput: {results.throughput:.2f} ops/s')
|
||||||
|
print(f' Avg latency: {results.average_latency:.2f}s')
|
||||||
|
print(f' P95 latency: {results.p95_latency:.2f}s')
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
|
|
@ -236,7 +231,7 @@ class TestLoadScenarios:
|
||||||
'source': 'text',
|
'source': 'text',
|
||||||
'source_description': 'normal',
|
'source_description': 'normal',
|
||||||
'group_id': group_id,
|
'group_id': group_id,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
normal_tasks.append(task)
|
normal_tasks.append(task)
|
||||||
await asyncio.sleep(0.5)
|
await asyncio.sleep(0.5)
|
||||||
|
|
@ -255,7 +250,7 @@ class TestLoadScenarios:
|
||||||
'source': 'text',
|
'source': 'text',
|
||||||
'source_description': 'spike',
|
'source_description': 'spike',
|
||||||
'group_id': group_id,
|
'group_id': group_id,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
spike_tasks.append(task)
|
spike_tasks.append(task)
|
||||||
|
|
||||||
|
|
@ -267,14 +262,14 @@ class TestLoadScenarios:
|
||||||
spike_failures = sum(1 for r in spike_results if isinstance(r, Exception))
|
spike_failures = sum(1 for r in spike_results if isinstance(r, Exception))
|
||||||
spike_success_rate = (len(spike_results) - spike_failures) / len(spike_results)
|
spike_success_rate = (len(spike_results) - spike_failures) / len(spike_results)
|
||||||
|
|
||||||
print(f"\nSpike Load Test Results:")
|
print('\nSpike Load Test Results:')
|
||||||
print(f" Spike size: {len(spike_tasks)} operations")
|
print(f' Spike size: {len(spike_tasks)} operations')
|
||||||
print(f" Duration: {spike_duration:.2f}s")
|
print(f' Duration: {spike_duration:.2f}s')
|
||||||
print(f" Success rate: {spike_success_rate * 100:.1f}%")
|
print(f' Success rate: {spike_success_rate * 100:.1f}%')
|
||||||
print(f" Throughput: {len(spike_tasks) / spike_duration:.2f} ops/s")
|
print(f' Throughput: {len(spike_tasks) / spike_duration:.2f} ops/s')
|
||||||
|
|
||||||
# System should handle at least 80% of spike
|
# System should handle at least 80% of spike
|
||||||
assert spike_success_rate > 0.8, f"Too many failures during spike: {spike_failures}"
|
assert spike_success_rate > 0.8, f'Too many failures during spike: {spike_failures}'
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
|
|
@ -297,7 +292,7 @@ class TestLoadScenarios:
|
||||||
'source': 'text',
|
'source': 'text',
|
||||||
'source_description': 'memory test',
|
'source_description': 'memory test',
|
||||||
'group_id': group_id,
|
'group_id': group_id,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
batch_tasks.append(task)
|
batch_tasks.append(task)
|
||||||
|
|
||||||
|
|
@ -312,15 +307,15 @@ class TestLoadScenarios:
|
||||||
final_memory = process.memory_info().rss / 1024 / 1024 # MB
|
final_memory = process.memory_info().rss / 1024 / 1024 # MB
|
||||||
memory_growth = final_memory - initial_memory
|
memory_growth = final_memory - initial_memory
|
||||||
|
|
||||||
print(f"\nMemory Leak Test:")
|
print('\nMemory Leak Test:')
|
||||||
print(f" Initial memory: {initial_memory:.1f} MB")
|
print(f' Initial memory: {initial_memory:.1f} MB')
|
||||||
print(f" Final memory: {final_memory:.1f} MB")
|
print(f' Final memory: {final_memory:.1f} MB')
|
||||||
print(f" Growth: {memory_growth:.1f} MB")
|
print(f' Growth: {memory_growth:.1f} MB')
|
||||||
|
|
||||||
# Allow for some memory growth but flag potential leaks
|
# Allow for some memory growth but flag potential leaks
|
||||||
# This is a soft check - actual threshold depends on system
|
# This is a soft check - actual threshold depends on system
|
||||||
if memory_growth > 100: # More than 100MB growth
|
if memory_growth > 100: # More than 100MB growth
|
||||||
print(f" ⚠️ Potential memory leak detected: {memory_growth:.1f} MB growth")
|
print(f' ⚠️ Potential memory leak detected: {memory_growth:.1f} MB growth')
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
|
|
@ -333,32 +328,33 @@ class TestLoadScenarios:
|
||||||
task = session.call_tool(
|
task = session.call_tool(
|
||||||
'search_memory_nodes',
|
'search_memory_nodes',
|
||||||
{
|
{
|
||||||
'query': f'complex query {i} ' + ' '.join([TestDataGenerator.fake.word() for _ in range(10)]),
|
'query': f'complex query {i} '
|
||||||
|
+ ' '.join([TestDataGenerator.fake.word() for _ in range(10)]),
|
||||||
'group_id': group_id,
|
'group_id': group_id,
|
||||||
'limit': 100,
|
'limit': 100,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
long_tasks.append(task)
|
long_tasks.append(task)
|
||||||
|
|
||||||
# Execute with timeout
|
# Execute with timeout
|
||||||
try:
|
try:
|
||||||
results = await asyncio.wait_for(
|
results = await asyncio.wait_for(
|
||||||
asyncio.gather(*long_tasks, return_exceptions=True),
|
asyncio.gather(*long_tasks, return_exceptions=True), timeout=60.0
|
||||||
timeout=60.0
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Count connection-related errors
|
# Count connection-related errors
|
||||||
connection_errors = sum(
|
connection_errors = sum(
|
||||||
1 for r in results
|
1
|
||||||
|
for r in results
|
||||||
if isinstance(r, Exception) and 'connection' in str(r).lower()
|
if isinstance(r, Exception) and 'connection' in str(r).lower()
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"\nConnection Pool Test:")
|
print('\nConnection Pool Test:')
|
||||||
print(f" Total requests: {len(long_tasks)}")
|
print(f' Total requests: {len(long_tasks)}')
|
||||||
print(f" Connection errors: {connection_errors}")
|
print(f' Connection errors: {connection_errors}')
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
print(" Test timed out - possible deadlock or exhaustion")
|
print(' Test timed out - possible deadlock or exhaustion')
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
|
|
@ -381,7 +377,7 @@ class TestLoadScenarios:
|
||||||
'source': 'text',
|
'source': 'text',
|
||||||
'source_description': 'degradation test',
|
'source_description': 'degradation test',
|
||||||
'group_id': group_id,
|
'group_id': group_id,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
tasks.append(task)
|
tasks.append(task)
|
||||||
|
|
||||||
|
|
@ -400,10 +396,10 @@ class TestLoadScenarios:
|
||||||
'duration': level_duration,
|
'duration': level_duration,
|
||||||
}
|
}
|
||||||
|
|
||||||
print(f"\nLoad Level {level}:")
|
print(f'\nLoad Level {level}:')
|
||||||
print(f" Success rate: {success_rate:.1f}%")
|
print(f' Success rate: {success_rate:.1f}%')
|
||||||
print(f" Throughput: {throughput:.2f} ops/s")
|
print(f' Throughput: {throughput:.2f} ops/s')
|
||||||
print(f" Duration: {level_duration:.2f}s")
|
print(f' Duration: {level_duration:.2f}s')
|
||||||
|
|
||||||
# Brief pause between levels
|
# Brief pause between levels
|
||||||
await asyncio.sleep(2)
|
await asyncio.sleep(2)
|
||||||
|
|
@ -411,7 +407,7 @@ class TestLoadScenarios:
|
||||||
# Verify graceful degradation
|
# Verify graceful degradation
|
||||||
# Success rate should not drop below 50% even at high load
|
# Success rate should not drop below 50% even at high load
|
||||||
for level, metrics in results_by_level.items():
|
for level, metrics in results_by_level.items():
|
||||||
assert metrics['success_rate'] > 50, f"Poor performance at load level {level}"
|
assert metrics['success_rate'] > 50, f'Poor performance at load level {level}'
|
||||||
|
|
||||||
|
|
||||||
class TestResourceLimits:
|
class TestResourceLimits:
|
||||||
|
|
@ -422,18 +418,18 @@ class TestResourceLimits:
|
||||||
"""Test handling of very large payloads."""
|
"""Test handling of very large payloads."""
|
||||||
async with graphiti_test_client() as (session, group_id):
|
async with graphiti_test_client() as (session, group_id):
|
||||||
payload_sizes = [
|
payload_sizes = [
|
||||||
(1_000, "1KB"),
|
(1_000, '1KB'),
|
||||||
(10_000, "10KB"),
|
(10_000, '10KB'),
|
||||||
(100_000, "100KB"),
|
(100_000, '100KB'),
|
||||||
(1_000_000, "1MB"),
|
(1_000_000, '1MB'),
|
||||||
]
|
]
|
||||||
|
|
||||||
for size, label in payload_sizes:
|
for size, label in payload_sizes:
|
||||||
content = "x" * size
|
content = 'x' * size
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
try:
|
try:
|
||||||
result = await asyncio.wait_for(
|
await asyncio.wait_for(
|
||||||
session.call_tool(
|
session.call_tool(
|
||||||
'add_memory',
|
'add_memory',
|
||||||
{
|
{
|
||||||
|
|
@ -442,22 +438,22 @@ class TestResourceLimits:
|
||||||
'source': 'text',
|
'source': 'text',
|
||||||
'source_description': 'payload test',
|
'source_description': 'payload test',
|
||||||
'group_id': group_id,
|
'group_id': group_id,
|
||||||
}
|
},
|
||||||
),
|
),
|
||||||
timeout=30.0
|
timeout=30.0,
|
||||||
)
|
)
|
||||||
duration = time.time() - start_time
|
duration = time.time() - start_time
|
||||||
status = "✅ Success"
|
status = '✅ Success'
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
duration = 30.0
|
duration = 30.0
|
||||||
status = "⏱️ Timeout"
|
status = '⏱️ Timeout'
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
duration = time.time() - start_time
|
duration = time.time() - start_time
|
||||||
status = f"❌ Error: {type(e).__name__}"
|
status = f'❌ Error: {type(e).__name__}'
|
||||||
|
|
||||||
print(f"Payload {label}: {status} ({duration:.2f}s)")
|
print(f'Payload {label}: {status} ({duration:.2f}s)')
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_rate_limit_handling(self):
|
async def test_rate_limit_handling(self):
|
||||||
|
|
@ -474,7 +470,7 @@ class TestResourceLimits:
|
||||||
'source': 'text',
|
'source': 'text',
|
||||||
'source_description': 'rate test',
|
'source_description': 'rate test',
|
||||||
'group_id': group_id,
|
'group_id': group_id,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
rapid_tasks.append(task)
|
rapid_tasks.append(task)
|
||||||
|
|
||||||
|
|
@ -483,42 +479,49 @@ class TestResourceLimits:
|
||||||
|
|
||||||
# Count rate limit errors
|
# Count rate limit errors
|
||||||
rate_limit_errors = sum(
|
rate_limit_errors = sum(
|
||||||
1 for r in results
|
1
|
||||||
|
for r in results
|
||||||
if isinstance(r, Exception) and ('rate' in str(r).lower() or '429' in str(r))
|
if isinstance(r, Exception) and ('rate' in str(r).lower() or '429' in str(r))
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"\nRate Limit Test:")
|
print('\nRate Limit Test:')
|
||||||
print(f" Total requests: {len(rapid_tasks)}")
|
print(f' Total requests: {len(rapid_tasks)}')
|
||||||
print(f" Rate limit errors: {rate_limit_errors}")
|
print(f' Rate limit errors: {rate_limit_errors}')
|
||||||
print(f" Success rate: {(len(rapid_tasks) - rate_limit_errors) / len(rapid_tasks) * 100:.1f}%")
|
print(
|
||||||
|
f' Success rate: {(len(rapid_tasks) - rate_limit_errors) / len(rapid_tasks) * 100:.1f}%'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def generate_load_test_report(results: List[LoadTestResult]) -> str:
|
def generate_load_test_report(results: list[LoadTestResult]) -> str:
|
||||||
"""Generate comprehensive load test report."""
|
"""Generate comprehensive load test report."""
|
||||||
report = []
|
report = []
|
||||||
report.append("\n" + "=" * 60)
|
report.append('\n' + '=' * 60)
|
||||||
report.append("LOAD TEST REPORT")
|
report.append('LOAD TEST REPORT')
|
||||||
report.append("=" * 60)
|
report.append('=' * 60)
|
||||||
|
|
||||||
for i, result in enumerate(results):
|
for i, result in enumerate(results):
|
||||||
report.append(f"\nTest Run {i + 1}:")
|
report.append(f'\nTest Run {i + 1}:')
|
||||||
report.append(f" Total Operations: {result.total_operations}")
|
report.append(f' Total Operations: {result.total_operations}')
|
||||||
report.append(f" Success Rate: {result.successful_operations / result.total_operations * 100:.1f}%")
|
report.append(
|
||||||
report.append(f" Throughput: {result.throughput:.2f} ops/s")
|
f' Success Rate: {result.successful_operations / result.total_operations * 100:.1f}%'
|
||||||
report.append(f" Latency (avg/p50/p95/p99/max): {result.average_latency:.2f}/{result.p50_latency:.2f}/{result.p95_latency:.2f}/{result.p99_latency:.2f}/{result.max_latency:.2f}s")
|
)
|
||||||
|
report.append(f' Throughput: {result.throughput:.2f} ops/s')
|
||||||
|
report.append(
|
||||||
|
f' Latency (avg/p50/p95/p99/max): {result.average_latency:.2f}/{result.p50_latency:.2f}/{result.p95_latency:.2f}/{result.p99_latency:.2f}/{result.max_latency:.2f}s'
|
||||||
|
)
|
||||||
|
|
||||||
if result.errors:
|
if result.errors:
|
||||||
report.append(" Errors:")
|
report.append(' Errors:')
|
||||||
for error_type, count in result.errors.items():
|
for error_type, count in result.errors.items():
|
||||||
report.append(f" {error_type}: {count}")
|
report.append(f' {error_type}: {count}')
|
||||||
|
|
||||||
report.append(" Resource Usage:")
|
report.append(' Resource Usage:')
|
||||||
for metric, value in result.resource_usage.items():
|
for metric, value in result.resource_usage.items():
|
||||||
report.append(f" {metric}: {value:.2f}")
|
report.append(f' {metric}: {value:.2f}')
|
||||||
|
|
||||||
report.append("=" * 60)
|
report.append('=' * 60)
|
||||||
return "\n".join(report)
|
return '\n'.join(report)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == '__main__':
|
||||||
pytest.main([__file__, "-v", "--asyncio-mode=auto", "-m", "slow"])
|
pytest.main([__file__, '-v', '--asyncio-mode=auto', '-m', 'slow'])
|
||||||
|
|
|
||||||
2
mcp_server/uv.lock
generated
2
mcp_server/uv.lock
generated
|
|
@ -635,7 +635,7 @@ wheels = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "graphiti-core"
|
name = "graphiti-core"
|
||||||
version = "0.22.0"
|
version = "0.22.1rc2"
|
||||||
source = { editable = "../" }
|
source = { editable = "../" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "diskcache" },
|
{ name = "diskcache" },
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue