diff --git a/mcp_server/graphiti_mcp_server.py b/mcp_server/graphiti_mcp_server.py index 35d5182b..9b382074 100644 --- a/mcp_server/graphiti_mcp_server.py +++ b/mcp_server/graphiti_mcp_server.py @@ -755,7 +755,7 @@ async def add_memory( global graphiti_client, episode_queues, queue_workers if graphiti_client is None: - return {'error': 'Graphiti client not initialized'} + return ErrorResponse(error='Graphiti client not initialized') try: # Map string source to EpisodeType enum @@ -817,13 +817,13 @@ async def add_memory( asyncio.create_task(process_episode_queue(group_id_str)) # Return immediately with a success message - return { - 'message': f"Episode '{name}' queued for processing (position: {episode_queues[group_id_str].qsize()})" - } + return SuccessResponse( + message=f"Episode '{name}' queued for processing (position: {episode_queues[group_id_str].qsize()})" + ) except Exception as e: error_msg = str(e) logger.error(f'Error queuing episode task: {error_msg}') - return {'error': f'Error queuing episode task: {error_msg}'} + return ErrorResponse(error=f'Error queuing episode task: {error_msg}') @mcp.tool() @@ -925,9 +925,13 @@ async def search_memory_facts( global graphiti_client if graphiti_client is None: - return {'error': 'Graphiti client not initialized'} + return ErrorResponse(error='Graphiti client not initialized') try: + # Validate max_facts parameter + if max_facts <= 0: + return ErrorResponse(error='max_facts must be a positive integer') + # Use the provided group_ids or fall back to the default from config if none provided effective_group_ids = ( group_ids if group_ids is not None else [config.group_id] if config.group_id else [] @@ -947,14 +951,14 @@ async def search_memory_facts( ) if not relevant_edges: - return {'message': 'No relevant facts found', 'facts': []} + return FactSearchResponse(message='No relevant facts found', facts=[]) facts = [format_fact_result(edge) for edge in relevant_edges] - return {'message': 'Facts retrieved successfully', 'facts': facts} + return FactSearchResponse(message='Facts retrieved successfully', facts=facts) except Exception as e: error_msg = str(e) logger.error(f'Error searching facts: {error_msg}') - return {'error': f'Error searching facts: {error_msg}'} + return ErrorResponse(error=f'Error searching facts: {error_msg}') @mcp.tool() @@ -967,7 +971,7 @@ async def delete_entity_edge(uuid: str) -> SuccessResponse | ErrorResponse: global graphiti_client if graphiti_client is None: - return {'error': 'Graphiti client not initialized'} + return ErrorResponse(error='Graphiti client not initialized') try: # We've already checked that graphiti_client is not None above @@ -980,11 +984,11 @@ async def delete_entity_edge(uuid: str) -> SuccessResponse | ErrorResponse: entity_edge = await EntityEdge.get_by_uuid(client.driver, uuid) # Delete the edge using its delete method await entity_edge.delete(client.driver) - return {'message': f'Entity edge with UUID {uuid} deleted successfully'} + return SuccessResponse(message=f'Entity edge with UUID {uuid} deleted successfully') except Exception as e: error_msg = str(e) logger.error(f'Error deleting entity edge: {error_msg}') - return {'error': f'Error deleting entity edge: {error_msg}'} + return ErrorResponse(error=f'Error deleting entity edge: {error_msg}') @mcp.tool() @@ -997,7 +1001,7 @@ async def delete_episode(uuid: str) -> SuccessResponse | ErrorResponse: global graphiti_client if graphiti_client is None: - return {'error': 'Graphiti client not initialized'} + return ErrorResponse(error='Graphiti client not initialized') try: # We've already checked that graphiti_client is not None above @@ -1010,11 +1014,11 @@ async def delete_episode(uuid: str) -> SuccessResponse | ErrorResponse: episodic_node = await EpisodicNode.get_by_uuid(client.driver, uuid) # Delete the node using its delete method await episodic_node.delete(client.driver) - return {'message': f'Episode with UUID {uuid} deleted successfully'} + return SuccessResponse(message=f'Episode with UUID {uuid} deleted successfully') except Exception as e: error_msg = str(e) logger.error(f'Error deleting episode: {error_msg}') - return {'error': f'Error deleting episode: {error_msg}'} + return ErrorResponse(error=f'Error deleting episode: {error_msg}') @mcp.tool() @@ -1027,7 +1031,7 @@ async def get_entity_edge(uuid: str) -> dict[str, Any] | ErrorResponse: global graphiti_client if graphiti_client is None: - return {'error': 'Graphiti client not initialized'} + return ErrorResponse(error='Graphiti client not initialized') try: # We've already checked that graphiti_client is not None above @@ -1045,7 +1049,7 @@ async def get_entity_edge(uuid: str) -> dict[str, Any] | ErrorResponse: except Exception as e: error_msg = str(e) logger.error(f'Error getting entity edge: {error_msg}') - return {'error': f'Error getting entity edge: {error_msg}'} + return ErrorResponse(error=f'Error getting entity edge: {error_msg}') @mcp.tool() @@ -1061,14 +1065,14 @@ async def get_episodes( global graphiti_client if graphiti_client is None: - return {'error': 'Graphiti client not initialized'} + return ErrorResponse(error='Graphiti client not initialized') try: # Use the provided group_id or fall back to the default from config effective_group_id = group_id if group_id is not None else config.group_id if not isinstance(effective_group_id, str): - return {'error': 'Group ID must be a string'} + return ErrorResponse(error='Group ID must be a string') # We've already checked that graphiti_client is not None above assert graphiti_client is not None @@ -1081,7 +1085,9 @@ async def get_episodes( ) if not episodes: - return {'message': f'No episodes found for group {effective_group_id}', 'episodes': []} + return EpisodeSearchResponse( + message=f'No episodes found for group {effective_group_id}', episodes=[] + ) # Use Pydantic's model_dump method for EpisodicNode serialization formatted_episodes = [ @@ -1095,7 +1101,7 @@ async def get_episodes( except Exception as e: error_msg = str(e) logger.error(f'Error getting episodes: {error_msg}') - return {'error': f'Error getting episodes: {error_msg}'} + return ErrorResponse(error=f'Error getting episodes: {error_msg}') @mcp.tool() @@ -1104,7 +1110,7 @@ async def clear_graph() -> SuccessResponse | ErrorResponse: global graphiti_client if graphiti_client is None: - return {'error': 'Graphiti client not initialized'} + return ErrorResponse(error='Graphiti client not initialized') try: # We've already checked that graphiti_client is not None above @@ -1116,11 +1122,11 @@ async def clear_graph() -> SuccessResponse | ErrorResponse: # clear_data is already imported at the top await clear_data(client.driver) await client.build_indices_and_constraints() - return {'message': 'Graph cleared successfully and indices rebuilt'} + return SuccessResponse(message='Graph cleared successfully and indices rebuilt') except Exception as e: error_msg = str(e) logger.error(f'Error clearing graph: {error_msg}') - return {'error': f'Error clearing graph: {error_msg}'} + return ErrorResponse(error=f'Error clearing graph: {error_msg}') @mcp.resource('http://graphiti/status') @@ -1129,7 +1135,7 @@ async def get_status() -> StatusResponse: global graphiti_client if graphiti_client is None: - return {'status': 'error', 'message': 'Graphiti client not initialized'} + return StatusResponse(status='error', message='Graphiti client not initialized') try: # We've already checked that graphiti_client is not None above @@ -1138,16 +1144,19 @@ async def get_status() -> StatusResponse: # Use cast to help the type checker understand that graphiti_client is not None client = cast(Graphiti, graphiti_client) - # Test Neo4j connection - await client.driver.verify_connectivity() - return {'status': 'ok', 'message': 'Graphiti MCP server is running and connected to Neo4j'} + # Test database connection + await client.driver.client.verify_connectivity() # type: ignore + + return StatusResponse( + status='ok', message='Graphiti MCP server is running and connected to Neo4j' + ) except Exception as e: error_msg = str(e) logger.error(f'Error checking Neo4j connection: {error_msg}') - return { - 'status': 'error', - 'message': f'Graphiti MCP server is running but Neo4j connection failed: {error_msg}', - } + return StatusResponse( + status='error', + message=f'Graphiti MCP server is running but Neo4j connection failed: {error_msg}', + ) async def initialize_server() -> MCPConfig: diff --git a/mcp_server/pyproject.toml b/mcp_server/pyproject.toml index d6b9a3fc..8c0b144e 100644 --- a/mcp_server/pyproject.toml +++ b/mcp_server/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "mcp-server" -version = "0.2.0" +version = "0.2.1" description = "Graphiti MCP Server" readme = "README.md" requires-python = ">=3.10,<4"