graphiti/mcp_server/tests/test_async_operations.py
Daniel Chalef 669aaae785 fix: Fix all linting errors in test suite
- Replace bare except with except Exception
- Remove unused imports and variables
- Fix type hints to use modern syntax
- Apply ruff formatting for line length
- Ensure all tests pass linting checks
2025-10-29 18:41:18 -07:00

491 lines
18 KiB
Python

#!/usr/bin/env python3
"""
Asynchronous operation tests for Graphiti MCP Server.
Tests concurrent operations, queue management, and async patterns.
"""
import asyncio
import json
import time
import pytest
from test_fixtures import (
TestDataGenerator,
graphiti_test_client,
)
class TestAsyncQueueManagement:
"""Test asynchronous queue operations and episode processing."""
@pytest.mark.asyncio
async def test_sequential_queue_processing(self):
"""Verify episodes are processed sequentially within a group."""
async with graphiti_test_client() as (session, group_id):
# Add multiple episodes quickly
episodes = []
for i in range(5):
result = await session.call_tool(
'add_memory',
{
'name': f'Sequential Test {i}',
'episode_body': f'Episode {i} with timestamp {time.time()}',
'source': 'text',
'source_description': 'sequential test',
'group_id': group_id,
'reference_id': f'seq_{i}', # Add reference for tracking
},
)
episodes.append(result)
# Wait for processing
await asyncio.sleep(10) # Allow time for sequential processing
# Retrieve episodes and verify order
result = await session.call_tool('get_episodes', {'group_id': group_id, 'last_n': 10})
processed_episodes = json.loads(result.content[0].text)['episodes']
# Verify all episodes were processed
assert len(processed_episodes) >= 5, (
f'Expected at least 5 episodes, got {len(processed_episodes)}'
)
# Verify sequential processing (timestamps should be ordered)
timestamps = [ep.get('created_at') for ep in processed_episodes]
assert timestamps == sorted(timestamps), 'Episodes not processed in order'
@pytest.mark.asyncio
async def test_concurrent_group_processing(self):
"""Test that different groups can process concurrently."""
async with graphiti_test_client() as (session, _):
groups = [f'group_{i}_{time.time()}' for i in range(3)]
tasks = []
# Create tasks for different groups
for group_id in groups:
for j in range(2):
task = session.call_tool(
'add_memory',
{
'name': f'Group {group_id} Episode {j}',
'episode_body': f'Content for {group_id}',
'source': 'text',
'source_description': 'concurrent test',
'group_id': group_id,
},
)
tasks.append(task)
# Execute all tasks concurrently
start_time = time.time()
results = await asyncio.gather(*tasks, return_exceptions=True)
execution_time = time.time() - start_time
# Verify all succeeded
failures = [r for r in results if isinstance(r, Exception)]
assert not failures, f'Concurrent operations failed: {failures}'
# Check that execution was actually concurrent (should be faster than sequential)
# Sequential would take at least 6 * processing_time
assert execution_time < 30, f'Concurrent execution too slow: {execution_time}s'
@pytest.mark.asyncio
async def test_queue_overflow_handling(self):
"""Test behavior when queue reaches capacity."""
async with graphiti_test_client() as (session, group_id):
# Attempt to add many episodes rapidly
tasks = []
for i in range(100): # Large number to potentially overflow
task = session.call_tool(
'add_memory',
{
'name': f'Overflow Test {i}',
'episode_body': f'Episode {i}',
'source': 'text',
'source_description': 'overflow test',
'group_id': group_id,
},
)
tasks.append(task)
# Execute with gathering to catch any failures
results = await asyncio.gather(*tasks, return_exceptions=True)
# Count successful queuing
successful = sum(1 for r in results if not isinstance(r, Exception))
# Should handle overflow gracefully
assert successful > 0, 'No episodes were queued successfully'
# Log overflow behavior
if successful < 100:
print(f'Queue overflow: {successful}/100 episodes queued')
class TestConcurrentOperations:
"""Test concurrent tool calls and operations."""
@pytest.mark.asyncio
async def test_concurrent_search_operations(self):
"""Test multiple concurrent search operations."""
async with graphiti_test_client() as (session, group_id):
# First, add some test data
data_gen = TestDataGenerator()
add_tasks = []
for _ in range(5):
task = session.call_tool(
'add_memory',
{
'name': 'Search Test Data',
'episode_body': data_gen.generate_technical_document(),
'source': 'text',
'source_description': 'search test',
'group_id': group_id,
},
)
add_tasks.append(task)
await asyncio.gather(*add_tasks)
await asyncio.sleep(15) # Wait for processing
# Now perform concurrent searches
search_queries = [
'architecture',
'performance',
'implementation',
'dependencies',
'latency',
]
search_tasks = []
for query in search_queries:
task = session.call_tool(
'search_memory_nodes',
{
'query': query,
'group_id': group_id,
'limit': 10,
},
)
search_tasks.append(task)
start_time = time.time()
results = await asyncio.gather(*search_tasks, return_exceptions=True)
search_time = time.time() - start_time
# Verify all searches completed
failures = [r for r in results if isinstance(r, Exception)]
assert not failures, f'Search operations failed: {failures}'
# Verify concurrent execution efficiency
assert search_time < len(search_queries) * 2, 'Searches not executing concurrently'
@pytest.mark.asyncio
async def test_mixed_operation_concurrency(self):
"""Test different types of operations running concurrently."""
async with graphiti_test_client() as (session, group_id):
operations = []
# Add memory operation
operations.append(
session.call_tool(
'add_memory',
{
'name': 'Mixed Op Test',
'episode_body': 'Testing mixed operations',
'source': 'text',
'source_description': 'test',
'group_id': group_id,
},
)
)
# Search operation
operations.append(
session.call_tool(
'search_memory_nodes',
{
'query': 'test',
'group_id': group_id,
'limit': 5,
},
)
)
# Get episodes operation
operations.append(
session.call_tool(
'get_episodes',
{
'group_id': group_id,
'last_n': 10,
},
)
)
# Get status operation
operations.append(session.call_tool('get_status', {}))
# Execute all concurrently
results = await asyncio.gather(*operations, return_exceptions=True)
# Check results
for i, result in enumerate(results):
assert not isinstance(result, Exception), f'Operation {i} failed: {result}'
class TestAsyncErrorHandling:
"""Test async error handling and recovery."""
@pytest.mark.asyncio
async def test_timeout_recovery(self):
"""Test recovery from operation timeouts."""
async with graphiti_test_client() as (session, group_id):
# Create a very large episode that might timeout
large_content = 'x' * 1000000 # 1MB of data
try:
await asyncio.wait_for(
session.call_tool(
'add_memory',
{
'name': 'Timeout Test',
'episode_body': large_content,
'source': 'text',
'source_description': 'timeout test',
'group_id': group_id,
},
),
timeout=2.0, # Short timeout
)
except asyncio.TimeoutError:
# Expected timeout
pass
# Verify server is still responsive after timeout
status_result = await session.call_tool('get_status', {})
assert status_result is not None, 'Server unresponsive after timeout'
@pytest.mark.asyncio
async def test_cancellation_handling(self):
"""Test proper handling of cancelled operations."""
async with graphiti_test_client() as (session, group_id):
# Start a long-running operation
task = asyncio.create_task(
session.call_tool(
'add_memory',
{
'name': 'Cancellation Test',
'episode_body': TestDataGenerator.generate_technical_document(),
'source': 'text',
'source_description': 'cancel test',
'group_id': group_id,
},
)
)
# Cancel after a short delay
await asyncio.sleep(0.1)
task.cancel()
# Verify cancellation was handled
with pytest.raises(asyncio.CancelledError):
await task
# Server should still be operational
result = await session.call_tool('get_status', {})
assert result is not None
@pytest.mark.asyncio
async def test_exception_propagation(self):
"""Test that exceptions are properly propagated in async context."""
async with graphiti_test_client() as (session, group_id):
# Call with invalid arguments
with pytest.raises(ValueError):
await session.call_tool(
'add_memory',
{
# Missing required fields
'group_id': group_id,
},
)
# Server should remain operational
status = await session.call_tool('get_status', {})
assert status is not None
class TestAsyncPerformance:
"""Performance tests for async operations."""
@pytest.mark.asyncio
async def test_async_throughput(self, performance_benchmark):
"""Measure throughput of async operations."""
async with graphiti_test_client() as (session, group_id):
num_operations = 50
start_time = time.time()
# Create many concurrent operations
tasks = []
for i in range(num_operations):
task = session.call_tool(
'add_memory',
{
'name': f'Throughput Test {i}',
'episode_body': f'Content {i}',
'source': 'text',
'source_description': 'throughput test',
'group_id': group_id,
},
)
tasks.append(task)
# Execute all
results = await asyncio.gather(*tasks, return_exceptions=True)
total_time = time.time() - start_time
# Calculate metrics
successful = sum(1 for r in results if not isinstance(r, Exception))
throughput = successful / total_time
performance_benchmark.record('async_throughput', throughput)
# Log results
print('\nAsync Throughput Test:')
print(f' Operations: {num_operations}')
print(f' Successful: {successful}')
print(f' Total time: {total_time:.2f}s')
print(f' Throughput: {throughput:.2f} ops/s')
# Assert minimum throughput
assert throughput > 1.0, f'Throughput too low: {throughput:.2f} ops/s'
@pytest.mark.asyncio
async def test_latency_under_load(self, performance_benchmark):
"""Test operation latency under concurrent load."""
async with graphiti_test_client() as (session, group_id):
# Create background load
background_tasks = []
for i in range(10):
task = asyncio.create_task(
session.call_tool(
'add_memory',
{
'name': f'Background {i}',
'episode_body': TestDataGenerator.generate_technical_document(),
'source': 'text',
'source_description': 'background',
'group_id': f'background_{group_id}',
},
)
)
background_tasks.append(task)
# Measure latency of operations under load
latencies = []
for _ in range(5):
start = time.time()
await session.call_tool('get_status', {})
latency = time.time() - start
latencies.append(latency)
performance_benchmark.record('latency_under_load', latency)
# Clean up background tasks
for task in background_tasks:
task.cancel()
# Analyze latencies
avg_latency = sum(latencies) / len(latencies)
max_latency = max(latencies)
print('\nLatency Under Load:')
print(f' Average: {avg_latency:.3f}s')
print(f' Max: {max_latency:.3f}s')
# Assert acceptable latency
assert avg_latency < 2.0, f'Average latency too high: {avg_latency:.3f}s'
assert max_latency < 5.0, f'Max latency too high: {max_latency:.3f}s'
class TestAsyncStreamHandling:
"""Test handling of streaming responses and data."""
@pytest.mark.asyncio
async def test_large_response_streaming(self):
"""Test handling of large streamed responses."""
async with graphiti_test_client() as (session, group_id):
# Add many episodes
for i in range(20):
await session.call_tool(
'add_memory',
{
'name': f'Stream Test {i}',
'episode_body': f'Episode content {i}',
'source': 'text',
'source_description': 'stream test',
'group_id': group_id,
},
)
# Wait for processing
await asyncio.sleep(30)
# Request large result set
result = await session.call_tool(
'get_episodes',
{
'group_id': group_id,
'last_n': 100, # Request all
},
)
# Verify response handling
episodes = json.loads(result.content[0].text)['episodes']
assert len(episodes) >= 20, f'Expected at least 20 episodes, got {len(episodes)}'
@pytest.mark.asyncio
async def test_incremental_processing(self):
"""Test incremental processing of results."""
async with graphiti_test_client() as (session, group_id):
# Add episodes incrementally
for batch in range(3):
batch_tasks = []
for i in range(5):
task = session.call_tool(
'add_memory',
{
'name': f'Batch {batch} Item {i}',
'episode_body': f'Content for batch {batch}',
'source': 'text',
'source_description': 'incremental test',
'group_id': group_id,
},
)
batch_tasks.append(task)
# Process batch
await asyncio.gather(*batch_tasks)
# Wait for this batch to process
await asyncio.sleep(10)
# Verify incremental results
result = await session.call_tool(
'get_episodes',
{
'group_id': group_id,
'last_n': 100,
},
)
episodes = json.loads(result.content[0].text)['episodes']
expected_min = (batch + 1) * 5
assert len(episodes) >= expected_min, (
f'Batch {batch}: Expected at least {expected_min} episodes'
)
if __name__ == '__main__':
pytest.main([__file__, '-v', '--asyncio-mode=auto'])