Refactor: Replace dictionary responses with structured response classes in graphiti_mcp_server.py (#668)
Updated the error and success responses in various functions to utilize structured response classes (ErrorResponse, SuccessResponse, FactSearchResponse, EpisodeSearchResponse, StatusResponse) for improved consistency and clarity. Incremented version in pyproject.toml to 0.2.1.
This commit is contained in:
parent
7ce07942b1
commit
743d5e8612
2 changed files with 42 additions and 33 deletions
|
|
@ -755,7 +755,7 @@ async def add_memory(
|
|||
global graphiti_client, episode_queues, queue_workers
|
||||
|
||||
if graphiti_client is None:
|
||||
return {'error': 'Graphiti client not initialized'}
|
||||
return ErrorResponse(error='Graphiti client not initialized')
|
||||
|
||||
try:
|
||||
# Map string source to EpisodeType enum
|
||||
|
|
@ -817,13 +817,13 @@ async def add_memory(
|
|||
asyncio.create_task(process_episode_queue(group_id_str))
|
||||
|
||||
# Return immediately with a success message
|
||||
return {
|
||||
'message': f"Episode '{name}' queued for processing (position: {episode_queues[group_id_str].qsize()})"
|
||||
}
|
||||
return SuccessResponse(
|
||||
message=f"Episode '{name}' queued for processing (position: {episode_queues[group_id_str].qsize()})"
|
||||
)
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
logger.error(f'Error queuing episode task: {error_msg}')
|
||||
return {'error': f'Error queuing episode task: {error_msg}'}
|
||||
return ErrorResponse(error=f'Error queuing episode task: {error_msg}')
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
|
|
@ -925,9 +925,13 @@ async def search_memory_facts(
|
|||
global graphiti_client
|
||||
|
||||
if graphiti_client is None:
|
||||
return {'error': 'Graphiti client not initialized'}
|
||||
return ErrorResponse(error='Graphiti client not initialized')
|
||||
|
||||
try:
|
||||
# Validate max_facts parameter
|
||||
if max_facts <= 0:
|
||||
return ErrorResponse(error='max_facts must be a positive integer')
|
||||
|
||||
# Use the provided group_ids or fall back to the default from config if none provided
|
||||
effective_group_ids = (
|
||||
group_ids if group_ids is not None else [config.group_id] if config.group_id else []
|
||||
|
|
@ -947,14 +951,14 @@ async def search_memory_facts(
|
|||
)
|
||||
|
||||
if not relevant_edges:
|
||||
return {'message': 'No relevant facts found', 'facts': []}
|
||||
return FactSearchResponse(message='No relevant facts found', facts=[])
|
||||
|
||||
facts = [format_fact_result(edge) for edge in relevant_edges]
|
||||
return {'message': 'Facts retrieved successfully', 'facts': facts}
|
||||
return FactSearchResponse(message='Facts retrieved successfully', facts=facts)
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
logger.error(f'Error searching facts: {error_msg}')
|
||||
return {'error': f'Error searching facts: {error_msg}'}
|
||||
return ErrorResponse(error=f'Error searching facts: {error_msg}')
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
|
|
@ -967,7 +971,7 @@ async def delete_entity_edge(uuid: str) -> SuccessResponse | ErrorResponse:
|
|||
global graphiti_client
|
||||
|
||||
if graphiti_client is None:
|
||||
return {'error': 'Graphiti client not initialized'}
|
||||
return ErrorResponse(error='Graphiti client not initialized')
|
||||
|
||||
try:
|
||||
# We've already checked that graphiti_client is not None above
|
||||
|
|
@ -980,11 +984,11 @@ async def delete_entity_edge(uuid: str) -> SuccessResponse | ErrorResponse:
|
|||
entity_edge = await EntityEdge.get_by_uuid(client.driver, uuid)
|
||||
# Delete the edge using its delete method
|
||||
await entity_edge.delete(client.driver)
|
||||
return {'message': f'Entity edge with UUID {uuid} deleted successfully'}
|
||||
return SuccessResponse(message=f'Entity edge with UUID {uuid} deleted successfully')
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
logger.error(f'Error deleting entity edge: {error_msg}')
|
||||
return {'error': f'Error deleting entity edge: {error_msg}'}
|
||||
return ErrorResponse(error=f'Error deleting entity edge: {error_msg}')
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
|
|
@ -997,7 +1001,7 @@ async def delete_episode(uuid: str) -> SuccessResponse | ErrorResponse:
|
|||
global graphiti_client
|
||||
|
||||
if graphiti_client is None:
|
||||
return {'error': 'Graphiti client not initialized'}
|
||||
return ErrorResponse(error='Graphiti client not initialized')
|
||||
|
||||
try:
|
||||
# We've already checked that graphiti_client is not None above
|
||||
|
|
@ -1010,11 +1014,11 @@ async def delete_episode(uuid: str) -> SuccessResponse | ErrorResponse:
|
|||
episodic_node = await EpisodicNode.get_by_uuid(client.driver, uuid)
|
||||
# Delete the node using its delete method
|
||||
await episodic_node.delete(client.driver)
|
||||
return {'message': f'Episode with UUID {uuid} deleted successfully'}
|
||||
return SuccessResponse(message=f'Episode with UUID {uuid} deleted successfully')
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
logger.error(f'Error deleting episode: {error_msg}')
|
||||
return {'error': f'Error deleting episode: {error_msg}'}
|
||||
return ErrorResponse(error=f'Error deleting episode: {error_msg}')
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
|
|
@ -1027,7 +1031,7 @@ async def get_entity_edge(uuid: str) -> dict[str, Any] | ErrorResponse:
|
|||
global graphiti_client
|
||||
|
||||
if graphiti_client is None:
|
||||
return {'error': 'Graphiti client not initialized'}
|
||||
return ErrorResponse(error='Graphiti client not initialized')
|
||||
|
||||
try:
|
||||
# We've already checked that graphiti_client is not None above
|
||||
|
|
@ -1045,7 +1049,7 @@ async def get_entity_edge(uuid: str) -> dict[str, Any] | ErrorResponse:
|
|||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
logger.error(f'Error getting entity edge: {error_msg}')
|
||||
return {'error': f'Error getting entity edge: {error_msg}'}
|
||||
return ErrorResponse(error=f'Error getting entity edge: {error_msg}')
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
|
|
@ -1061,14 +1065,14 @@ async def get_episodes(
|
|||
global graphiti_client
|
||||
|
||||
if graphiti_client is None:
|
||||
return {'error': 'Graphiti client not initialized'}
|
||||
return ErrorResponse(error='Graphiti client not initialized')
|
||||
|
||||
try:
|
||||
# Use the provided group_id or fall back to the default from config
|
||||
effective_group_id = group_id if group_id is not None else config.group_id
|
||||
|
||||
if not isinstance(effective_group_id, str):
|
||||
return {'error': 'Group ID must be a string'}
|
||||
return ErrorResponse(error='Group ID must be a string')
|
||||
|
||||
# We've already checked that graphiti_client is not None above
|
||||
assert graphiti_client is not None
|
||||
|
|
@ -1081,7 +1085,9 @@ async def get_episodes(
|
|||
)
|
||||
|
||||
if not episodes:
|
||||
return {'message': f'No episodes found for group {effective_group_id}', 'episodes': []}
|
||||
return EpisodeSearchResponse(
|
||||
message=f'No episodes found for group {effective_group_id}', episodes=[]
|
||||
)
|
||||
|
||||
# Use Pydantic's model_dump method for EpisodicNode serialization
|
||||
formatted_episodes = [
|
||||
|
|
@ -1095,7 +1101,7 @@ async def get_episodes(
|
|||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
logger.error(f'Error getting episodes: {error_msg}')
|
||||
return {'error': f'Error getting episodes: {error_msg}'}
|
||||
return ErrorResponse(error=f'Error getting episodes: {error_msg}')
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
|
|
@ -1104,7 +1110,7 @@ async def clear_graph() -> SuccessResponse | ErrorResponse:
|
|||
global graphiti_client
|
||||
|
||||
if graphiti_client is None:
|
||||
return {'error': 'Graphiti client not initialized'}
|
||||
return ErrorResponse(error='Graphiti client not initialized')
|
||||
|
||||
try:
|
||||
# We've already checked that graphiti_client is not None above
|
||||
|
|
@ -1116,11 +1122,11 @@ async def clear_graph() -> SuccessResponse | ErrorResponse:
|
|||
# clear_data is already imported at the top
|
||||
await clear_data(client.driver)
|
||||
await client.build_indices_and_constraints()
|
||||
return {'message': 'Graph cleared successfully and indices rebuilt'}
|
||||
return SuccessResponse(message='Graph cleared successfully and indices rebuilt')
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
logger.error(f'Error clearing graph: {error_msg}')
|
||||
return {'error': f'Error clearing graph: {error_msg}'}
|
||||
return ErrorResponse(error=f'Error clearing graph: {error_msg}')
|
||||
|
||||
|
||||
@mcp.resource('http://graphiti/status')
|
||||
|
|
@ -1129,7 +1135,7 @@ async def get_status() -> StatusResponse:
|
|||
global graphiti_client
|
||||
|
||||
if graphiti_client is None:
|
||||
return {'status': 'error', 'message': 'Graphiti client not initialized'}
|
||||
return StatusResponse(status='error', message='Graphiti client not initialized')
|
||||
|
||||
try:
|
||||
# We've already checked that graphiti_client is not None above
|
||||
|
|
@ -1138,16 +1144,19 @@ async def get_status() -> StatusResponse:
|
|||
# Use cast to help the type checker understand that graphiti_client is not None
|
||||
client = cast(Graphiti, graphiti_client)
|
||||
|
||||
# Test Neo4j connection
|
||||
await client.driver.verify_connectivity()
|
||||
return {'status': 'ok', 'message': 'Graphiti MCP server is running and connected to Neo4j'}
|
||||
# Test database connection
|
||||
await client.driver.client.verify_connectivity() # type: ignore
|
||||
|
||||
return StatusResponse(
|
||||
status='ok', message='Graphiti MCP server is running and connected to Neo4j'
|
||||
)
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
logger.error(f'Error checking Neo4j connection: {error_msg}')
|
||||
return {
|
||||
'status': 'error',
|
||||
'message': f'Graphiti MCP server is running but Neo4j connection failed: {error_msg}',
|
||||
}
|
||||
return StatusResponse(
|
||||
status='error',
|
||||
message=f'Graphiti MCP server is running but Neo4j connection failed: {error_msg}',
|
||||
)
|
||||
|
||||
|
||||
async def initialize_server() -> MCPConfig:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
[project]
|
||||
name = "mcp-server"
|
||||
version = "0.2.0"
|
||||
version = "0.2.1"
|
||||
description = "Graphiti MCP Server"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10,<4"
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue