From 883ea8692473fbead823944ce630389049c3c3f6 Mon Sep 17 00:00:00 2001 From: donbr Date: Sun, 7 Dec 2025 00:53:08 -0800 Subject: [PATCH] feat: Add factory pattern for FastMCP Cloud deployment MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add create_server() async factory function as entrypoint for FastMCP Cloud - Add _register_tools() function that registers tools using closures to capture service instances (avoids global variable dependencies) - Maintain backward compatibility with lifespan-based mcp instance for local dev - Factory pattern initializes services before creating server instance - Tools registered include: add_memory, search_nodes, search_memory_facts, delete_entity_edge, delete_episode, get_entity_edge, get_episodes, clear_graph, get_status FastMCP Cloud entrypoint: src/graphiti_mcp_server.py:create_server 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- mcp_server/src/graphiti_mcp_server.py | 374 +++++++++++++++++++++++++- 1 file changed, 363 insertions(+), 11 deletions(-) diff --git a/mcp_server/src/graphiti_mcp_server.py b/mcp_server/src/graphiti_mcp_server.py index 3405e0c0..809269f2 100644 --- a/mcp_server/src/graphiti_mcp_server.py +++ b/mcp_server/src/graphiti_mcp_server.py @@ -153,44 +153,396 @@ graphiti_client: Graphiti | None = None semaphore: asyncio.Semaphore +# ------------------------------------------------------------------------------ +# Tool Registration for Factory Pattern +# ------------------------------------------------------------------------------ + + +def _register_tools( + server: FastMCP, + cfg: GraphitiConfig, + graphiti_svc: 'GraphitiService', + queue_svc: QueueService, +) -> None: + """Register all MCP tools using closures to capture service instances. + + This function is used by the factory pattern (create_server) for FastMCP Cloud. + Tools use closures to access services instead of global variables. + """ + from graphiti_core.edges import EntityEdge + from graphiti_core.nodes import EpisodeType, EpisodicNode + from graphiti_core.search.search_config_recipes import NODE_HYBRID_SEARCH_RRF + from graphiti_core.search.search_filters import SearchFilters + from graphiti_core.utils.maintenance.graph_data_operations import clear_data + + @server.tool() + async def add_memory( + name: str, + episode_body: str, + group_id: str | None = None, + source: str = 'text', + source_description: str = '', + uuid: str | None = None, + ) -> SuccessResponse | ErrorResponse: + """Add an episode to memory.""" + try: + effective_group_id = group_id or cfg.graphiti.group_id + + episode_type = EpisodeType.text + if source: + try: + episode_type = EpisodeType[source.lower()] + except (KeyError, AttributeError): + logger.warning(f"Unknown source type '{source}', using 'text' as default") + episode_type = EpisodeType.text + + await queue_svc.add_episode( + group_id=effective_group_id, + name=name, + content=episode_body, + source_description=source_description, + episode_type=episode_type, + entity_types=graphiti_svc.entity_types, + uuid=uuid or None, + ) + + return SuccessResponse( + message=f"Episode '{name}' queued for processing in group '{effective_group_id}'" + ) + except Exception as e: + error_msg = str(e) + logger.error(f'Error queuing episode: {error_msg}') + return ErrorResponse(error=f'Error queuing episode: {error_msg}') + + @server.tool() + async def search_nodes( + query: str, + group_ids: list[str] | None = None, + max_nodes: int = 10, + entity_types: list[str] | None = None, + ) -> NodeSearchResponse | ErrorResponse: + """Search for nodes in the graph memory.""" + try: + client = await graphiti_svc.get_client() + + effective_group_ids = ( + group_ids + if group_ids is not None + else [cfg.graphiti.group_id] + if cfg.graphiti.group_id + else [] + ) + + search_filters = SearchFilters(node_labels=entity_types) + + results = await client.search_( + query=query, + config=NODE_HYBRID_SEARCH_RRF, + group_ids=effective_group_ids, + search_filter=search_filters, + ) + + nodes = results.nodes[:max_nodes] if results.nodes else [] + + if not nodes: + return NodeSearchResponse(message='No relevant nodes found', nodes=[]) + + node_results = [] + for node in nodes: + attrs = node.attributes if hasattr(node, 'attributes') else {} + attrs = {k: v for k, v in attrs.items() if 'embedding' not in k.lower()} + + node_results.append( + NodeResult( + uuid=node.uuid, + name=node.name, + labels=node.labels if node.labels else [], + created_at=node.created_at.isoformat() if node.created_at else None, + summary=node.summary, + group_id=node.group_id, + attributes=attrs, + ) + ) + + return NodeSearchResponse(message='Nodes retrieved successfully', nodes=node_results) + except Exception as e: + error_msg = str(e) + logger.error(f'Error searching nodes: {error_msg}') + return ErrorResponse(error=f'Error searching nodes: {error_msg}') + + @server.tool() + async def search_memory_facts( + query: str, + group_ids: list[str] | None = None, + max_facts: int = 10, + center_node_uuid: str | None = None, + ) -> FactSearchResponse | ErrorResponse: + """Search the graph memory for relevant facts.""" + try: + if max_facts <= 0: + return ErrorResponse(error='max_facts must be a positive integer') + + client = await graphiti_svc.get_client() + + effective_group_ids = ( + group_ids + if group_ids is not None + else [cfg.graphiti.group_id] + if cfg.graphiti.group_id + else [] + ) + + relevant_edges = await client.search( + group_ids=effective_group_ids, + query=query, + num_results=max_facts, + center_node_uuid=center_node_uuid, + ) + + if not relevant_edges: + return FactSearchResponse(message='No relevant facts found', facts=[]) + + facts = [format_fact_result(edge) for edge in relevant_edges] + return FactSearchResponse(message='Facts retrieved successfully', facts=facts) + except Exception as e: + error_msg = str(e) + logger.error(f'Error searching facts: {error_msg}') + return ErrorResponse(error=f'Error searching facts: {error_msg}') + + @server.tool() + async def delete_entity_edge(uuid: str) -> SuccessResponse | ErrorResponse: + """Delete an entity edge from the graph memory.""" + try: + client = await graphiti_svc.get_client() + entity_edge = await EntityEdge.get_by_uuid(client.driver, uuid) + await entity_edge.delete(client.driver) + return SuccessResponse(message=f'Entity edge with UUID {uuid} deleted successfully') + except Exception as e: + error_msg = str(e) + logger.error(f'Error deleting entity edge: {error_msg}') + return ErrorResponse(error=f'Error deleting entity edge: {error_msg}') + + @server.tool() + async def delete_episode(uuid: str) -> SuccessResponse | ErrorResponse: + """Delete an episode from the graph memory.""" + try: + client = await graphiti_svc.get_client() + episodic_node = await EpisodicNode.get_by_uuid(client.driver, uuid) + await episodic_node.delete(client.driver) + return SuccessResponse(message=f'Episode with UUID {uuid} deleted successfully') + except Exception as e: + error_msg = str(e) + logger.error(f'Error deleting episode: {error_msg}') + return ErrorResponse(error=f'Error deleting episode: {error_msg}') + + @server.tool() + async def get_entity_edge(uuid: str) -> dict[str, Any] | ErrorResponse: + """Get an entity edge from the graph memory by its UUID.""" + try: + client = await graphiti_svc.get_client() + entity_edge = await EntityEdge.get_by_uuid(client.driver, uuid) + return format_fact_result(entity_edge) + except Exception as e: + error_msg = str(e) + logger.error(f'Error getting entity edge: {error_msg}') + return ErrorResponse(error=f'Error getting entity edge: {error_msg}') + + @server.tool() + async def get_episodes( + group_ids: list[str] | None = None, + max_episodes: int = 10, + ) -> EpisodeSearchResponse | ErrorResponse: + """Get episodes from the graph memory.""" + try: + client = await graphiti_svc.get_client() + + effective_group_ids = ( + group_ids + if group_ids is not None + else [cfg.graphiti.group_id] + if cfg.graphiti.group_id + else [] + ) + + if effective_group_ids: + episodes = await EpisodicNode.get_by_group_ids( + client.driver, effective_group_ids, limit=max_episodes + ) + else: + episodes = [] + + if not episodes: + return EpisodeSearchResponse(message='No episodes found', episodes=[]) + + episode_results = [] + for episode in episodes: + episode_dict = { + 'uuid': episode.uuid, + 'name': episode.name, + 'content': episode.content, + 'created_at': episode.created_at.isoformat() if episode.created_at else None, + 'source': episode.source.value + if hasattr(episode.source, 'value') + else str(episode.source), + 'source_description': episode.source_description, + 'group_id': episode.group_id, + } + episode_results.append(episode_dict) + + return EpisodeSearchResponse( + message='Episodes retrieved successfully', episodes=episode_results + ) + except Exception as e: + error_msg = str(e) + logger.error(f'Error getting episodes: {error_msg}') + return ErrorResponse(error=f'Error getting episodes: {error_msg}') + + @server.tool() + async def clear_graph(group_ids: list[str] | None = None) -> SuccessResponse | ErrorResponse: + """Clear all data from the graph for specified group IDs.""" + try: + client = await graphiti_svc.get_client() + + effective_group_ids = ( + group_ids or [cfg.graphiti.group_id] if cfg.graphiti.group_id else [] + ) + + if not effective_group_ids: + return ErrorResponse(error='No group IDs specified for clearing') + + await clear_data(client.driver, group_ids=effective_group_ids) + + return SuccessResponse( + message=f'Graph data cleared successfully for group IDs: {", ".join(effective_group_ids)}' + ) + except Exception as e: + error_msg = str(e) + logger.error(f'Error clearing graph: {error_msg}') + return ErrorResponse(error=f'Error clearing graph: {error_msg}') + + @server.tool() + async def get_status() -> StatusResponse: + """Get the status of the Graphiti MCP server and database connection.""" + try: + client = await graphiti_svc.get_client() + + async with client.driver.session() as session: + result = await session.run('MATCH (n) RETURN count(n) as count') + if result: + _ = [record async for record in result] + + provider_name = graphiti_svc.config.database.provider + return StatusResponse( + status='ok', + message=f'Graphiti MCP server is running and connected to {provider_name} database', + ) + except Exception as e: + error_msg = str(e) + logger.error(f'Error checking database connection: {error_msg}') + return StatusResponse( + status='error', + message=f'Graphiti MCP server is running but database connection failed: {error_msg}', + ) + + +# ------------------------------------------------------------------------------ +# Factory Entrypoint for FastMCP Cloud +# ------------------------------------------------------------------------------ + +async def create_server() -> FastMCP: + """Factory function that creates and initializes the MCP server. + + This is the entrypoint for FastMCP Cloud and `fastmcp dev`. + Configuration comes from environment variables set in the FastMCP Cloud UI. + + FastMCP Cloud uses this pattern: + - Entrypoint: src/graphiti_mcp_server.py:create_server + - This function is called to create the server instance + - Tools are registered using closures that capture service instances + """ + global config, graphiti_service, queue_service, graphiti_client, semaphore + + logger.info('Initializing Graphiti MCP server via factory pattern...') + + # 1. Load configuration from environment variables + config = GraphitiConfig() + + # Log configuration details + logger.info('Using configuration:') + logger.info(f' - LLM: {config.llm.provider} / {config.llm.model}') + logger.info(f' - Embedder: {config.embedder.provider} / {config.embedder.model}') + logger.info(f' - Database: {config.database.provider}') + logger.info(f' - Group ID: {config.graphiti.group_id}') + + # 2. Initialize Services + graphiti_service = GraphitiService(config, SEMAPHORE_LIMIT) + queue_service = QueueService() + await graphiti_service.initialize() + + # Set global client for backward compatibility + graphiti_client = await graphiti_service.get_client() + semaphore = graphiti_service.semaphore + + # Initialize queue service with the client + await queue_service.initialize(graphiti_client) + + logger.info('Graphiti services initialized successfully via factory') + + # 3. Create Server Instance + server = FastMCP( + 'Graphiti Agent Memory', + instructions=GRAPHITI_MCP_INSTRUCTIONS, + ) + + # 4. Register Tools using closures that capture service instances + _register_tools(server, config, graphiti_service, queue_service) + + # 5. Register Custom Routes + @server.custom_route('/health', methods=['GET']) + async def health_check(request): + return JSONResponse({'status': 'healthy', 'service': 'graphiti-mcp'}) + + @server.custom_route('/status', methods=['GET']) + async def status_check(request): + return JSONResponse({'status': 'ok', 'service': 'graphiti-mcp'}) + + logger.info('FastMCP server created with factory pattern') + return server + + +# Also create the module-level mcp instance for backward compatibility with local dev +# This uses a lifespan for the traditional pattern @asynccontextmanager async def graphiti_lifespan(app): - """Lifespan context manager for FastMCP Cloud deployment. + """Lifespan context manager for local development. - This function initializes the Graphiti service when the server starts. - FastMCP Cloud calls this automatically - it does NOT run the if __name__ == '__main__' block. + Note: FastMCP Cloud should use the create_server() factory function instead. """ global config, graphiti_service, queue_service, graphiti_client, semaphore logger.info('Initializing Graphiti service via lifespan...') try: - # Load configuration from environment variables (FastMCP Cloud sets these) config = GraphitiConfig() - # Log configuration details logger.info('Using configuration:') logger.info(f' - LLM: {config.llm.provider} / {config.llm.model}') logger.info(f' - Embedder: {config.embedder.provider} / {config.embedder.model}') logger.info(f' - Database: {config.database.provider}') logger.info(f' - Group ID: {config.graphiti.group_id}') - # Initialize services (GraphitiService is defined below, but that's OK - # since this function is only called at runtime, not definition time) graphiti_service = GraphitiService(config, SEMAPHORE_LIMIT) queue_service = QueueService() await graphiti_service.initialize() - # Set global client for backward compatibility graphiti_client = await graphiti_service.get_client() semaphore = graphiti_service.semaphore - # Initialize queue service with the client await queue_service.initialize(graphiti_client) logger.info('Graphiti service initialized successfully via lifespan') - yield # Server runs here + yield except Exception as e: logger.error(f'Failed to initialize Graphiti service: {e}') @@ -199,7 +551,7 @@ async def graphiti_lifespan(app): logger.info('Shutting down Graphiti service...') -# MCP server instance - pass lifespan to constructor for FastMCP Cloud +# Module-level MCP server instance for backward compatibility mcp = FastMCP( 'Graphiti Agent Memory', instructions=GRAPHITI_MCP_INSTRUCTIONS,