feat: Add factory pattern for FastMCP Cloud deployment

- Add create_server() async factory function as entrypoint for FastMCP Cloud
- Add _register_tools() function that registers tools using closures
  to capture service instances (avoids global variable dependencies)
- Maintain backward compatibility with lifespan-based mcp instance for local dev
- Factory pattern initializes services before creating server instance
- Tools registered include: add_memory, search_nodes, search_memory_facts,
  delete_entity_edge, delete_episode, get_entity_edge, get_episodes,
  clear_graph, get_status

FastMCP Cloud entrypoint: src/graphiti_mcp_server.py:create_server

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
donbr 2025-12-07 00:53:08 -08:00
parent b24b835411
commit 883ea86924

View file

@ -153,44 +153,396 @@ graphiti_client: Graphiti | None = None
semaphore: asyncio.Semaphore
# ------------------------------------------------------------------------------
# Tool Registration for Factory Pattern
# ------------------------------------------------------------------------------
def _register_tools(
server: FastMCP,
cfg: GraphitiConfig,
graphiti_svc: 'GraphitiService',
queue_svc: QueueService,
) -> None:
"""Register all MCP tools using closures to capture service instances.
This function is used by the factory pattern (create_server) for FastMCP Cloud.
Tools use closures to access services instead of global variables.
"""
from graphiti_core.edges import EntityEdge
from graphiti_core.nodes import EpisodeType, EpisodicNode
from graphiti_core.search.search_config_recipes import NODE_HYBRID_SEARCH_RRF
from graphiti_core.search.search_filters import SearchFilters
from graphiti_core.utils.maintenance.graph_data_operations import clear_data
@server.tool()
async def add_memory(
name: str,
episode_body: str,
group_id: str | None = None,
source: str = 'text',
source_description: str = '',
uuid: str | None = None,
) -> SuccessResponse | ErrorResponse:
"""Add an episode to memory."""
try:
effective_group_id = group_id or cfg.graphiti.group_id
episode_type = EpisodeType.text
if source:
try:
episode_type = EpisodeType[source.lower()]
except (KeyError, AttributeError):
logger.warning(f"Unknown source type '{source}', using 'text' as default")
episode_type = EpisodeType.text
await queue_svc.add_episode(
group_id=effective_group_id,
name=name,
content=episode_body,
source_description=source_description,
episode_type=episode_type,
entity_types=graphiti_svc.entity_types,
uuid=uuid or None,
)
return SuccessResponse(
message=f"Episode '{name}' queued for processing in group '{effective_group_id}'"
)
except Exception as e:
error_msg = str(e)
logger.error(f'Error queuing episode: {error_msg}')
return ErrorResponse(error=f'Error queuing episode: {error_msg}')
@server.tool()
async def search_nodes(
query: str,
group_ids: list[str] | None = None,
max_nodes: int = 10,
entity_types: list[str] | None = None,
) -> NodeSearchResponse | ErrorResponse:
"""Search for nodes in the graph memory."""
try:
client = await graphiti_svc.get_client()
effective_group_ids = (
group_ids
if group_ids is not None
else [cfg.graphiti.group_id]
if cfg.graphiti.group_id
else []
)
search_filters = SearchFilters(node_labels=entity_types)
results = await client.search_(
query=query,
config=NODE_HYBRID_SEARCH_RRF,
group_ids=effective_group_ids,
search_filter=search_filters,
)
nodes = results.nodes[:max_nodes] if results.nodes else []
if not nodes:
return NodeSearchResponse(message='No relevant nodes found', nodes=[])
node_results = []
for node in nodes:
attrs = node.attributes if hasattr(node, 'attributes') else {}
attrs = {k: v for k, v in attrs.items() if 'embedding' not in k.lower()}
node_results.append(
NodeResult(
uuid=node.uuid,
name=node.name,
labels=node.labels if node.labels else [],
created_at=node.created_at.isoformat() if node.created_at else None,
summary=node.summary,
group_id=node.group_id,
attributes=attrs,
)
)
return NodeSearchResponse(message='Nodes retrieved successfully', nodes=node_results)
except Exception as e:
error_msg = str(e)
logger.error(f'Error searching nodes: {error_msg}')
return ErrorResponse(error=f'Error searching nodes: {error_msg}')
@server.tool()
async def search_memory_facts(
query: str,
group_ids: list[str] | None = None,
max_facts: int = 10,
center_node_uuid: str | None = None,
) -> FactSearchResponse | ErrorResponse:
"""Search the graph memory for relevant facts."""
try:
if max_facts <= 0:
return ErrorResponse(error='max_facts must be a positive integer')
client = await graphiti_svc.get_client()
effective_group_ids = (
group_ids
if group_ids is not None
else [cfg.graphiti.group_id]
if cfg.graphiti.group_id
else []
)
relevant_edges = await client.search(
group_ids=effective_group_ids,
query=query,
num_results=max_facts,
center_node_uuid=center_node_uuid,
)
if not relevant_edges:
return FactSearchResponse(message='No relevant facts found', facts=[])
facts = [format_fact_result(edge) for edge in relevant_edges]
return FactSearchResponse(message='Facts retrieved successfully', facts=facts)
except Exception as e:
error_msg = str(e)
logger.error(f'Error searching facts: {error_msg}')
return ErrorResponse(error=f'Error searching facts: {error_msg}')
@server.tool()
async def delete_entity_edge(uuid: str) -> SuccessResponse | ErrorResponse:
"""Delete an entity edge from the graph memory."""
try:
client = await graphiti_svc.get_client()
entity_edge = await EntityEdge.get_by_uuid(client.driver, uuid)
await entity_edge.delete(client.driver)
return SuccessResponse(message=f'Entity edge with UUID {uuid} deleted successfully')
except Exception as e:
error_msg = str(e)
logger.error(f'Error deleting entity edge: {error_msg}')
return ErrorResponse(error=f'Error deleting entity edge: {error_msg}')
@server.tool()
async def delete_episode(uuid: str) -> SuccessResponse | ErrorResponse:
"""Delete an episode from the graph memory."""
try:
client = await graphiti_svc.get_client()
episodic_node = await EpisodicNode.get_by_uuid(client.driver, uuid)
await episodic_node.delete(client.driver)
return SuccessResponse(message=f'Episode with UUID {uuid} deleted successfully')
except Exception as e:
error_msg = str(e)
logger.error(f'Error deleting episode: {error_msg}')
return ErrorResponse(error=f'Error deleting episode: {error_msg}')
@server.tool()
async def get_entity_edge(uuid: str) -> dict[str, Any] | ErrorResponse:
"""Get an entity edge from the graph memory by its UUID."""
try:
client = await graphiti_svc.get_client()
entity_edge = await EntityEdge.get_by_uuid(client.driver, uuid)
return format_fact_result(entity_edge)
except Exception as e:
error_msg = str(e)
logger.error(f'Error getting entity edge: {error_msg}')
return ErrorResponse(error=f'Error getting entity edge: {error_msg}')
@server.tool()
async def get_episodes(
group_ids: list[str] | None = None,
max_episodes: int = 10,
) -> EpisodeSearchResponse | ErrorResponse:
"""Get episodes from the graph memory."""
try:
client = await graphiti_svc.get_client()
effective_group_ids = (
group_ids
if group_ids is not None
else [cfg.graphiti.group_id]
if cfg.graphiti.group_id
else []
)
if effective_group_ids:
episodes = await EpisodicNode.get_by_group_ids(
client.driver, effective_group_ids, limit=max_episodes
)
else:
episodes = []
if not episodes:
return EpisodeSearchResponse(message='No episodes found', episodes=[])
episode_results = []
for episode in episodes:
episode_dict = {
'uuid': episode.uuid,
'name': episode.name,
'content': episode.content,
'created_at': episode.created_at.isoformat() if episode.created_at else None,
'source': episode.source.value
if hasattr(episode.source, 'value')
else str(episode.source),
'source_description': episode.source_description,
'group_id': episode.group_id,
}
episode_results.append(episode_dict)
return EpisodeSearchResponse(
message='Episodes retrieved successfully', episodes=episode_results
)
except Exception as e:
error_msg = str(e)
logger.error(f'Error getting episodes: {error_msg}')
return ErrorResponse(error=f'Error getting episodes: {error_msg}')
@server.tool()
async def clear_graph(group_ids: list[str] | None = None) -> SuccessResponse | ErrorResponse:
"""Clear all data from the graph for specified group IDs."""
try:
client = await graphiti_svc.get_client()
effective_group_ids = (
group_ids or [cfg.graphiti.group_id] if cfg.graphiti.group_id else []
)
if not effective_group_ids:
return ErrorResponse(error='No group IDs specified for clearing')
await clear_data(client.driver, group_ids=effective_group_ids)
return SuccessResponse(
message=f'Graph data cleared successfully for group IDs: {", ".join(effective_group_ids)}'
)
except Exception as e:
error_msg = str(e)
logger.error(f'Error clearing graph: {error_msg}')
return ErrorResponse(error=f'Error clearing graph: {error_msg}')
@server.tool()
async def get_status() -> StatusResponse:
"""Get the status of the Graphiti MCP server and database connection."""
try:
client = await graphiti_svc.get_client()
async with client.driver.session() as session:
result = await session.run('MATCH (n) RETURN count(n) as count')
if result:
_ = [record async for record in result]
provider_name = graphiti_svc.config.database.provider
return StatusResponse(
status='ok',
message=f'Graphiti MCP server is running and connected to {provider_name} database',
)
except Exception as e:
error_msg = str(e)
logger.error(f'Error checking database connection: {error_msg}')
return StatusResponse(
status='error',
message=f'Graphiti MCP server is running but database connection failed: {error_msg}',
)
# ------------------------------------------------------------------------------
# Factory Entrypoint for FastMCP Cloud
# ------------------------------------------------------------------------------
async def create_server() -> FastMCP:
"""Factory function that creates and initializes the MCP server.
This is the entrypoint for FastMCP Cloud and `fastmcp dev`.
Configuration comes from environment variables set in the FastMCP Cloud UI.
FastMCP Cloud uses this pattern:
- Entrypoint: src/graphiti_mcp_server.py:create_server
- This function is called to create the server instance
- Tools are registered using closures that capture service instances
"""
global config, graphiti_service, queue_service, graphiti_client, semaphore
logger.info('Initializing Graphiti MCP server via factory pattern...')
# 1. Load configuration from environment variables
config = GraphitiConfig()
# Log configuration details
logger.info('Using configuration:')
logger.info(f' - LLM: {config.llm.provider} / {config.llm.model}')
logger.info(f' - Embedder: {config.embedder.provider} / {config.embedder.model}')
logger.info(f' - Database: {config.database.provider}')
logger.info(f' - Group ID: {config.graphiti.group_id}')
# 2. Initialize Services
graphiti_service = GraphitiService(config, SEMAPHORE_LIMIT)
queue_service = QueueService()
await graphiti_service.initialize()
# Set global client for backward compatibility
graphiti_client = await graphiti_service.get_client()
semaphore = graphiti_service.semaphore
# Initialize queue service with the client
await queue_service.initialize(graphiti_client)
logger.info('Graphiti services initialized successfully via factory')
# 3. Create Server Instance
server = FastMCP(
'Graphiti Agent Memory',
instructions=GRAPHITI_MCP_INSTRUCTIONS,
)
# 4. Register Tools using closures that capture service instances
_register_tools(server, config, graphiti_service, queue_service)
# 5. Register Custom Routes
@server.custom_route('/health', methods=['GET'])
async def health_check(request):
return JSONResponse({'status': 'healthy', 'service': 'graphiti-mcp'})
@server.custom_route('/status', methods=['GET'])
async def status_check(request):
return JSONResponse({'status': 'ok', 'service': 'graphiti-mcp'})
logger.info('FastMCP server created with factory pattern')
return server
# Also create the module-level mcp instance for backward compatibility with local dev
# This uses a lifespan for the traditional pattern
@asynccontextmanager
async def graphiti_lifespan(app):
"""Lifespan context manager for FastMCP Cloud deployment.
"""Lifespan context manager for local development.
This function initializes the Graphiti service when the server starts.
FastMCP Cloud calls this automatically - it does NOT run the if __name__ == '__main__' block.
Note: FastMCP Cloud should use the create_server() factory function instead.
"""
global config, graphiti_service, queue_service, graphiti_client, semaphore
logger.info('Initializing Graphiti service via lifespan...')
try:
# Load configuration from environment variables (FastMCP Cloud sets these)
config = GraphitiConfig()
# Log configuration details
logger.info('Using configuration:')
logger.info(f' - LLM: {config.llm.provider} / {config.llm.model}')
logger.info(f' - Embedder: {config.embedder.provider} / {config.embedder.model}')
logger.info(f' - Database: {config.database.provider}')
logger.info(f' - Group ID: {config.graphiti.group_id}')
# Initialize services (GraphitiService is defined below, but that's OK
# since this function is only called at runtime, not definition time)
graphiti_service = GraphitiService(config, SEMAPHORE_LIMIT)
queue_service = QueueService()
await graphiti_service.initialize()
# Set global client for backward compatibility
graphiti_client = await graphiti_service.get_client()
semaphore = graphiti_service.semaphore
# Initialize queue service with the client
await queue_service.initialize(graphiti_client)
logger.info('Graphiti service initialized successfully via lifespan')
yield # Server runs here
yield
except Exception as e:
logger.error(f'Failed to initialize Graphiti service: {e}')
@ -199,7 +551,7 @@ async def graphiti_lifespan(app):
logger.info('Shutting down Graphiti service...')
# MCP server instance - pass lifespan to constructor for FastMCP Cloud
# Module-level MCP server instance for backward compatibility
mcp = FastMCP(
'Graphiti Agent Memory',
instructions=GRAPHITI_MCP_INSTRUCTIONS,