Fix: Priority 1 & 2 code review fixes for graph exploration tools
Priority 1 (Critical): - Fix NULL handling in get_entity_timeline sort using datetime.min fallback - Add UUID validation to both get_entity_connections and get_entity_timeline - Add comprehensive automated integration tests (test_graph_exploration_int.py) Priority 2 (Important): - Improve exception handling with specific Neo4jError, ValueError, AttributeError - Add performance limitation documentation to tool docstrings - Generalize coaching references to technical/research examples Technical Details: - Import datetime and UUID for validation/sorting - Import Neo4jError for specific exception handling - Validate entity_uuid format before database calls - Handle None values in chronological sort with datetime.min - Add exc_info=True to error logging for better debugging - Sanitize error messages to avoid exposing internal details Test Coverage: - Tool availability verification - Invalid UUID validation - Nonexistent entity handling - Real data flow (add → search → explore) - Max limits parameter testing - Chronological ordering verification Performance Notes Added: - Document fetching behavior (database → application filtering) - Warn about high-degree nodes (1000+ connections/episodes) - Recommend smaller limits for large result sets All changes maintain backward compatibility and follow existing code patterns. Refs: mcp_server/src/graphiti_mcp_server.py:1560,1656
This commit is contained in:
parent
21c0eae78f
commit
f33aaa82bb
2 changed files with 374 additions and 17 deletions
|
|
@ -8,8 +8,10 @@ import asyncio
|
|||
import logging
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from graphiti_core import Graphiti
|
||||
|
|
@ -19,6 +21,7 @@ from graphiti_core.nodes import EpisodeType, EpisodicNode
|
|||
from graphiti_core.search.search_filters import SearchFilters
|
||||
from graphiti_core.utils.maintenance.graph_data_operations import clear_data
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
from neo4j.exceptions import Neo4jError
|
||||
from pydantic import BaseModel
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
|
|
@ -223,24 +226,24 @@ for node in nodes:
|
|||
# Synthesize comprehensive answer from COMPLETE data
|
||||
```
|
||||
|
||||
### Example 3: Pattern Recognition
|
||||
### Example 3: Pattern Recognition (Research/Notes Context)
|
||||
|
||||
```python
|
||||
# User: "I'm feeling stressed today"
|
||||
# User: "Found another article about performance optimization"
|
||||
|
||||
nodes = search_nodes(query="stress")
|
||||
nodes = search_nodes(query="performance optimization")
|
||||
connections = get_entity_connections(entity_uuid=nodes[0]['uuid'])
|
||||
# Discovers: stress ↔ work, sleep, project-deadline, coffee-intake
|
||||
# Discovers: performance ↔ caching, database-indexing, lazy-loading, CDN
|
||||
|
||||
timeline = get_entity_timeline(entity_uuid=nodes[0]['uuid'])
|
||||
# Shows: First mentioned 3 months ago, frequency increasing
|
||||
# Shows: First researched 3 months ago, techniques evolving over time
|
||||
|
||||
# Can now make informed observations based on complete data
|
||||
|
||||
add_memory(
|
||||
name="Stress discussion",
|
||||
episode_body="Discussed stress today. User recognizes connection to
|
||||
project deadlines and sleep quality from our past conversations."
|
||||
name="Performance research",
|
||||
episode_body="New article discusses lazy loading techniques. Relates to
|
||||
our previous research on caching and database optimization."
|
||||
)
|
||||
```
|
||||
|
||||
|
|
@ -1593,6 +1596,11 @@ async def get_entity_connections(
|
|||
Returns:
|
||||
FactSearchResponse with all connected relationships and temporal metadata
|
||||
|
||||
Performance Notes:
|
||||
- Fetches all connections from database, then filters/limits in application code
|
||||
- For entities with 100+ connections, consider using smaller max_connections
|
||||
- High-degree nodes (1000+ connections) may have slower response times
|
||||
|
||||
Examples:
|
||||
# After finding entity
|
||||
nodes = search_nodes(query="Django project")
|
||||
|
|
@ -1613,6 +1621,12 @@ async def get_entity_connections(
|
|||
if graphiti_service is None:
|
||||
return ErrorResponse(error='Graphiti service not initialized')
|
||||
|
||||
# Validate UUID format
|
||||
try:
|
||||
UUID(entity_uuid)
|
||||
except (ValueError, AttributeError):
|
||||
return ErrorResponse(error='Invalid UUID format provided for entity_uuid')
|
||||
|
||||
try:
|
||||
client = await graphiti_service.get_client()
|
||||
|
||||
|
|
@ -1638,10 +1652,12 @@ async def get_entity_connections(
|
|||
message=f'Found {len(facts)} connection(s) for entity', facts=facts
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
except (Neo4jError, ValueError, AttributeError) as e:
|
||||
error_msg = str(e)
|
||||
logger.error(f'Error getting entity connections: {error_msg}')
|
||||
return ErrorResponse(error=f'Error getting entity connections: {error_msg}')
|
||||
logger.error(f'Error getting entity connections: {error_msg}', exc_info=True)
|
||||
return ErrorResponse(
|
||||
error='Failed to retrieve entity connections. Please check the entity UUID and try again.'
|
||||
)
|
||||
|
||||
|
||||
@mcp.tool(
|
||||
|
|
@ -1679,7 +1695,7 @@ async def get_entity_timeline(
|
|||
Use Cases:
|
||||
- "When did we first discuss microservices architecture?"
|
||||
- "Show all mentions of the deployment pipeline"
|
||||
- "Timeline of stress mentions"
|
||||
- "Timeline of performance optimization research"
|
||||
- "How did our understanding of GraphQL evolve?"
|
||||
|
||||
Args:
|
||||
|
|
@ -1690,6 +1706,11 @@ async def get_entity_timeline(
|
|||
Returns:
|
||||
EpisodeSearchResponse with episodes ordered chronologically
|
||||
|
||||
Performance Notes:
|
||||
- Fetches all episodes from database, then sorts/limits in application code
|
||||
- For entities mentioned in 100+ episodes, consider using smaller max_episodes
|
||||
- Very frequently mentioned entities (1000+ episodes) may have slower response times
|
||||
|
||||
Examples:
|
||||
# After finding entity
|
||||
nodes = search_nodes(query="microservices")
|
||||
|
|
@ -1710,6 +1731,12 @@ async def get_entity_timeline(
|
|||
if graphiti_service is None:
|
||||
return ErrorResponse(error='Graphiti service not initialized')
|
||||
|
||||
# Validate UUID format
|
||||
try:
|
||||
UUID(entity_uuid)
|
||||
except (ValueError, AttributeError):
|
||||
return ErrorResponse(error='Invalid UUID format provided for entity_uuid')
|
||||
|
||||
try:
|
||||
client = await graphiti_service.get_client()
|
||||
|
||||
|
|
@ -1720,8 +1747,8 @@ async def get_entity_timeline(
|
|||
if group_ids:
|
||||
episodes = [e for e in episodes if e.group_id in group_ids]
|
||||
|
||||
# Sort by valid_at (chronological order)
|
||||
episodes.sort(key=lambda e: e.valid_at)
|
||||
# Sort by valid_at (chronological order), handle None values
|
||||
episodes.sort(key=lambda e: e.valid_at or datetime.min)
|
||||
|
||||
# Limit results
|
||||
episodes = episodes[:max_episodes]
|
||||
|
|
@ -1751,10 +1778,12 @@ async def get_entity_timeline(
|
|||
episodes=episode_results,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
except (Neo4jError, ValueError, AttributeError) as e:
|
||||
error_msg = str(e)
|
||||
logger.error(f'Error getting entity timeline: {error_msg}')
|
||||
return ErrorResponse(error=f'Error getting entity timeline: {error_msg}')
|
||||
logger.error(f'Error getting entity timeline: {error_msg}', exc_info=True)
|
||||
return ErrorResponse(
|
||||
error='Failed to retrieve entity timeline. Please check the entity UUID and try again.'
|
||||
)
|
||||
|
||||
|
||||
@mcp.tool(
|
||||
|
|
|
|||
328
mcp_server/tests/test_graph_exploration_int.py
Normal file
328
mcp_server/tests/test_graph_exploration_int.py
Normal file
|
|
@ -0,0 +1,328 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Integration tests for graph exploration tools: get_entity_connections and get_entity_timeline.
|
||||
Tests UUID validation, error handling, result formatting, and chronological ordering.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.stdio import stdio_client
|
||||
|
||||
|
||||
class GraphExplorationToolsTest:
|
||||
"""Integration test for get_entity_connections and get_entity_timeline tools."""
|
||||
|
||||
def __init__(self):
|
||||
self.test_group_id = f'test_graph_exp_{int(time.time())}'
|
||||
self.session = None
|
||||
|
||||
async def __aenter__(self):
|
||||
"""Start the MCP client session."""
|
||||
server_params = StdioServerParameters(
|
||||
command='uv',
|
||||
args=['run', 'main.py', '--transport', 'stdio'],
|
||||
env={
|
||||
'NEO4J_URI': os.environ.get('NEO4J_URI', 'bolt://localhost:7687'),
|
||||
'NEO4J_USER': os.environ.get('NEO4J_USER', 'neo4j'),
|
||||
'NEO4J_PASSWORD': os.environ.get('NEO4J_PASSWORD', 'graphiti'),
|
||||
'OPENAI_API_KEY': os.environ.get('OPENAI_API_KEY', 'dummy_key_for_testing'),
|
||||
},
|
||||
)
|
||||
|
||||
print(f'🚀 Starting test session with group: {self.test_group_id}')
|
||||
|
||||
self.client_context = stdio_client(server_params)
|
||||
read, write = await self.client_context.__aenter__()
|
||||
self.session = ClientSession(read, write)
|
||||
await self.session.initialize()
|
||||
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Close the MCP client session."""
|
||||
if self.session:
|
||||
await self.session.close()
|
||||
if hasattr(self, 'client_context'):
|
||||
await self.client_context.__aexit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any:
|
||||
"""Call an MCP tool and return the result."""
|
||||
try:
|
||||
result = await self.session.call_tool(tool_name, arguments)
|
||||
# Parse JSON response from text content
|
||||
import json
|
||||
|
||||
if result.content:
|
||||
text_content = result.content[0].text
|
||||
return json.loads(text_content)
|
||||
return {'error': 'No content returned'}
|
||||
except Exception as e:
|
||||
return {'error': str(e)}
|
||||
|
||||
async def test_tools_available(self) -> bool:
|
||||
"""Test that the new tools are available."""
|
||||
print('🔍 Test 1: Verifying new tools are available...')
|
||||
|
||||
try:
|
||||
tools_result = await self.session.list_tools()
|
||||
tools = [tool.name for tool in tools_result.tools]
|
||||
|
||||
required_tools = ['get_entity_connections', 'get_entity_timeline']
|
||||
|
||||
for tool_name in required_tools:
|
||||
if tool_name in tools:
|
||||
print(f' ✅ {tool_name} is available')
|
||||
else:
|
||||
print(f' ❌ {tool_name} is NOT available')
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f' ❌ Test failed: {e}')
|
||||
return False
|
||||
|
||||
async def test_invalid_uuid_validation(self) -> bool:
|
||||
"""Test UUID validation for invalid UUIDs."""
|
||||
print('🔍 Test 2: Testing UUID validation with invalid UUID...')
|
||||
|
||||
try:
|
||||
# Test get_entity_connections with invalid UUID
|
||||
result1 = await self.call_tool(
|
||||
'get_entity_connections', {'entity_uuid': 'not-a-valid-uuid'}
|
||||
)
|
||||
|
||||
if 'error' in result1 and 'Invalid UUID' in result1['error']:
|
||||
print(' ✅ get_entity_connections correctly rejects invalid UUID')
|
||||
else:
|
||||
print(f' ❌ get_entity_connections did not validate UUID: {result1}')
|
||||
return False
|
||||
|
||||
# Test get_entity_timeline with invalid UUID
|
||||
result2 = await self.call_tool('get_entity_timeline', {'entity_uuid': 'also-not-valid'})
|
||||
|
||||
if 'error' in result2 and 'Invalid UUID' in result2['error']:
|
||||
print(' ✅ get_entity_timeline correctly rejects invalid UUID')
|
||||
else:
|
||||
print(f' ❌ get_entity_timeline did not validate UUID: {result2}')
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f' ❌ Test failed: {e}')
|
||||
return False
|
||||
|
||||
async def test_nonexistent_entity(self) -> bool:
|
||||
"""Test behavior with valid UUID but nonexistent entity."""
|
||||
print('🔍 Test 3: Testing with valid but nonexistent UUID...')
|
||||
|
||||
try:
|
||||
# Use a valid UUID format but one that doesn't exist
|
||||
fake_uuid = '00000000-0000-0000-0000-000000000000'
|
||||
|
||||
# Test get_entity_connections
|
||||
result1 = await self.call_tool(
|
||||
'get_entity_connections', {'entity_uuid': fake_uuid, 'max_connections': 10}
|
||||
)
|
||||
|
||||
if 'facts' in result1 and len(result1['facts']) == 0:
|
||||
print(' ✅ get_entity_connections returns empty list for nonexistent entity')
|
||||
else:
|
||||
print(f' ❌ Unexpected result for nonexistent entity: {result1}')
|
||||
return False
|
||||
|
||||
# Test get_entity_timeline
|
||||
result2 = await self.call_tool(
|
||||
'get_entity_timeline', {'entity_uuid': fake_uuid, 'max_episodes': 10}
|
||||
)
|
||||
|
||||
if 'episodes' in result2 and len(result2['episodes']) == 0:
|
||||
print(' ✅ get_entity_timeline returns empty list for nonexistent entity')
|
||||
else:
|
||||
print(f' ❌ Unexpected result for nonexistent entity: {result2}')
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f' ❌ Test failed: {e}')
|
||||
return False
|
||||
|
||||
async def test_with_real_data(self) -> bool:
|
||||
"""Test with actual data - add memory, search, and explore connections."""
|
||||
print('🔍 Test 4: Testing with real data (add → search → explore)...')
|
||||
|
||||
try:
|
||||
# Step 1: Add some test data
|
||||
print(' Step 1: Adding test episodes...')
|
||||
await self.call_tool(
|
||||
'add_memory',
|
||||
{
|
||||
'name': 'Database Decision',
|
||||
'episode_body': 'We chose PostgreSQL for the new service because it has better JSON support than MySQL.',
|
||||
'source': 'text',
|
||||
'group_id': self.test_group_id,
|
||||
},
|
||||
)
|
||||
|
||||
await self.call_tool(
|
||||
'add_memory',
|
||||
{
|
||||
'name': 'Second Discussion',
|
||||
'episode_body': 'Discussed PostgreSQL performance tuning. Need to optimize connection pooling.',
|
||||
'source': 'text',
|
||||
'group_id': self.test_group_id,
|
||||
},
|
||||
)
|
||||
|
||||
# Wait for processing
|
||||
print(' Waiting 5 seconds for episode processing...')
|
||||
await asyncio.sleep(5)
|
||||
|
||||
# Step 2: Search for the entity
|
||||
print(' Step 2: Searching for PostgreSQL entity...')
|
||||
search_result = await self.call_tool(
|
||||
'search_nodes', {'query': 'PostgreSQL', 'group_ids': [self.test_group_id]}
|
||||
)
|
||||
|
||||
if 'nodes' not in search_result or len(search_result['nodes']) == 0:
|
||||
print(' ⚠️ No PostgreSQL entity found yet (may need more processing time)')
|
||||
return True # Don't fail the test, processing might not be complete
|
||||
|
||||
entity_uuid = search_result['nodes'][0]['uuid']
|
||||
print(f' ✅ Found PostgreSQL entity: {entity_uuid[:8]}...')
|
||||
|
||||
# Step 3: Test get_entity_connections
|
||||
print(' Step 3: Testing get_entity_connections...')
|
||||
connections_result = await self.call_tool(
|
||||
'get_entity_connections',
|
||||
{
|
||||
'entity_uuid': entity_uuid,
|
||||
'group_ids': [self.test_group_id],
|
||||
'max_connections': 20,
|
||||
},
|
||||
)
|
||||
|
||||
if 'facts' in connections_result:
|
||||
print(
|
||||
f' ✅ get_entity_connections returned {len(connections_result["facts"])} connection(s)'
|
||||
)
|
||||
else:
|
||||
print(f' ⚠️ get_entity_connections result: {connections_result}')
|
||||
|
||||
# Step 4: Test get_entity_timeline
|
||||
print(' Step 4: Testing get_entity_timeline...')
|
||||
timeline_result = await self.call_tool(
|
||||
'get_entity_timeline',
|
||||
{
|
||||
'entity_uuid': entity_uuid,
|
||||
'group_ids': [self.test_group_id],
|
||||
'max_episodes': 20,
|
||||
},
|
||||
)
|
||||
|
||||
if 'episodes' in timeline_result:
|
||||
episodes = timeline_result['episodes']
|
||||
print(f' ✅ get_entity_timeline returned {len(episodes)} episode(s)')
|
||||
|
||||
# Verify chronological order
|
||||
if len(episodes) > 1:
|
||||
valid_at_values = [ep['valid_at'] for ep in episodes if ep.get('valid_at')]
|
||||
is_sorted = valid_at_values == sorted(valid_at_values)
|
||||
if is_sorted:
|
||||
print(' ✅ Episodes are in chronological order')
|
||||
else:
|
||||
print(' ❌ Episodes are NOT in chronological order')
|
||||
return False
|
||||
else:
|
||||
print(f' ⚠️ get_entity_timeline result: {timeline_result}')
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f' ❌ Test failed: {e}')
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
async def test_max_limits(self) -> bool:
|
||||
"""Test that max_connections and max_episodes limits work."""
|
||||
print('🔍 Test 5: Testing max limits parameters...')
|
||||
|
||||
try:
|
||||
# This test just verifies the parameters are accepted
|
||||
# Actual limit testing would require creating many connections
|
||||
|
||||
fake_uuid = '00000000-0000-0000-0000-000000000000'
|
||||
|
||||
# Test with different max values
|
||||
result1 = await self.call_tool(
|
||||
'get_entity_connections', {'entity_uuid': fake_uuid, 'max_connections': 5}
|
||||
)
|
||||
|
||||
result2 = await self.call_tool(
|
||||
'get_entity_timeline', {'entity_uuid': fake_uuid, 'max_episodes': 5}
|
||||
)
|
||||
|
||||
if 'facts' in result1 and 'episodes' in result2:
|
||||
print(' ✅ Both tools accept max limit parameters')
|
||||
return True
|
||||
else:
|
||||
print(' ❌ Tools did not accept max limit parameters')
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f' ❌ Test failed: {e}')
|
||||
return False
|
||||
|
||||
async def run_all_tests(self):
|
||||
"""Run all test scenarios."""
|
||||
print('\n' + '=' * 70)
|
||||
print('Graph Exploration Tools Integration Tests')
|
||||
print('=' * 70 + '\n')
|
||||
|
||||
results = {
|
||||
'Tools Available': await self.test_tools_available(),
|
||||
'Invalid UUID Validation': await self.test_invalid_uuid_validation(),
|
||||
'Nonexistent Entity Handling': await self.test_nonexistent_entity(),
|
||||
'Real Data Flow': await self.test_with_real_data(),
|
||||
'Max Limits Parameters': await self.test_max_limits(),
|
||||
}
|
||||
|
||||
print('\n' + '=' * 70)
|
||||
print('Test Results Summary')
|
||||
print('=' * 70)
|
||||
|
||||
all_passed = True
|
||||
for test_name, passed in results.items():
|
||||
status = '✅ PASSED' if passed else '❌ FAILED'
|
||||
print(f'{test_name}: {status}')
|
||||
if not passed:
|
||||
all_passed = False
|
||||
|
||||
print('=' * 70)
|
||||
|
||||
if all_passed:
|
||||
print('🎉 All tests PASSED!')
|
||||
else:
|
||||
print('⚠️ Some tests FAILED')
|
||||
|
||||
return all_passed
|
||||
|
||||
|
||||
async def main():
|
||||
"""Run the integration tests."""
|
||||
async with GraphExplorationToolsTest() as test:
|
||||
success = await test.run_all_tests()
|
||||
return 0 if success else 1
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
exit_code = asyncio.run(main())
|
||||
exit(exit_code)
|
||||
Loading…
Add table
Reference in a new issue