feat: Add factory pattern for FastMCP Cloud deployment
- Add create_server() async factory function as entrypoint for FastMCP Cloud - Add _register_tools() function that registers tools using closures to capture service instances (avoids global variable dependencies) - Maintain backward compatibility with lifespan-based mcp instance for local dev - Factory pattern initializes services before creating server instance - Tools registered include: add_memory, search_nodes, search_memory_facts, delete_entity_edge, delete_episode, get_entity_edge, get_episodes, clear_graph, get_status FastMCP Cloud entrypoint: src/graphiti_mcp_server.py:create_server 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
b24b835411
commit
883ea86924
1 changed files with 363 additions and 11 deletions
|
|
@ -153,44 +153,396 @@ graphiti_client: Graphiti | None = None
|
|||
semaphore: asyncio.Semaphore
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Tool Registration for Factory Pattern
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _register_tools(
|
||||
server: FastMCP,
|
||||
cfg: GraphitiConfig,
|
||||
graphiti_svc: 'GraphitiService',
|
||||
queue_svc: QueueService,
|
||||
) -> None:
|
||||
"""Register all MCP tools using closures to capture service instances.
|
||||
|
||||
This function is used by the factory pattern (create_server) for FastMCP Cloud.
|
||||
Tools use closures to access services instead of global variables.
|
||||
"""
|
||||
from graphiti_core.edges import EntityEdge
|
||||
from graphiti_core.nodes import EpisodeType, EpisodicNode
|
||||
from graphiti_core.search.search_config_recipes import NODE_HYBRID_SEARCH_RRF
|
||||
from graphiti_core.search.search_filters import SearchFilters
|
||||
from graphiti_core.utils.maintenance.graph_data_operations import clear_data
|
||||
|
||||
@server.tool()
|
||||
async def add_memory(
|
||||
name: str,
|
||||
episode_body: str,
|
||||
group_id: str | None = None,
|
||||
source: str = 'text',
|
||||
source_description: str = '',
|
||||
uuid: str | None = None,
|
||||
) -> SuccessResponse | ErrorResponse:
|
||||
"""Add an episode to memory."""
|
||||
try:
|
||||
effective_group_id = group_id or cfg.graphiti.group_id
|
||||
|
||||
episode_type = EpisodeType.text
|
||||
if source:
|
||||
try:
|
||||
episode_type = EpisodeType[source.lower()]
|
||||
except (KeyError, AttributeError):
|
||||
logger.warning(f"Unknown source type '{source}', using 'text' as default")
|
||||
episode_type = EpisodeType.text
|
||||
|
||||
await queue_svc.add_episode(
|
||||
group_id=effective_group_id,
|
||||
name=name,
|
||||
content=episode_body,
|
||||
source_description=source_description,
|
||||
episode_type=episode_type,
|
||||
entity_types=graphiti_svc.entity_types,
|
||||
uuid=uuid or None,
|
||||
)
|
||||
|
||||
return SuccessResponse(
|
||||
message=f"Episode '{name}' queued for processing in group '{effective_group_id}'"
|
||||
)
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
logger.error(f'Error queuing episode: {error_msg}')
|
||||
return ErrorResponse(error=f'Error queuing episode: {error_msg}')
|
||||
|
||||
@server.tool()
|
||||
async def search_nodes(
|
||||
query: str,
|
||||
group_ids: list[str] | None = None,
|
||||
max_nodes: int = 10,
|
||||
entity_types: list[str] | None = None,
|
||||
) -> NodeSearchResponse | ErrorResponse:
|
||||
"""Search for nodes in the graph memory."""
|
||||
try:
|
||||
client = await graphiti_svc.get_client()
|
||||
|
||||
effective_group_ids = (
|
||||
group_ids
|
||||
if group_ids is not None
|
||||
else [cfg.graphiti.group_id]
|
||||
if cfg.graphiti.group_id
|
||||
else []
|
||||
)
|
||||
|
||||
search_filters = SearchFilters(node_labels=entity_types)
|
||||
|
||||
results = await client.search_(
|
||||
query=query,
|
||||
config=NODE_HYBRID_SEARCH_RRF,
|
||||
group_ids=effective_group_ids,
|
||||
search_filter=search_filters,
|
||||
)
|
||||
|
||||
nodes = results.nodes[:max_nodes] if results.nodes else []
|
||||
|
||||
if not nodes:
|
||||
return NodeSearchResponse(message='No relevant nodes found', nodes=[])
|
||||
|
||||
node_results = []
|
||||
for node in nodes:
|
||||
attrs = node.attributes if hasattr(node, 'attributes') else {}
|
||||
attrs = {k: v for k, v in attrs.items() if 'embedding' not in k.lower()}
|
||||
|
||||
node_results.append(
|
||||
NodeResult(
|
||||
uuid=node.uuid,
|
||||
name=node.name,
|
||||
labels=node.labels if node.labels else [],
|
||||
created_at=node.created_at.isoformat() if node.created_at else None,
|
||||
summary=node.summary,
|
||||
group_id=node.group_id,
|
||||
attributes=attrs,
|
||||
)
|
||||
)
|
||||
|
||||
return NodeSearchResponse(message='Nodes retrieved successfully', nodes=node_results)
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
logger.error(f'Error searching nodes: {error_msg}')
|
||||
return ErrorResponse(error=f'Error searching nodes: {error_msg}')
|
||||
|
||||
@server.tool()
|
||||
async def search_memory_facts(
|
||||
query: str,
|
||||
group_ids: list[str] | None = None,
|
||||
max_facts: int = 10,
|
||||
center_node_uuid: str | None = None,
|
||||
) -> FactSearchResponse | ErrorResponse:
|
||||
"""Search the graph memory for relevant facts."""
|
||||
try:
|
||||
if max_facts <= 0:
|
||||
return ErrorResponse(error='max_facts must be a positive integer')
|
||||
|
||||
client = await graphiti_svc.get_client()
|
||||
|
||||
effective_group_ids = (
|
||||
group_ids
|
||||
if group_ids is not None
|
||||
else [cfg.graphiti.group_id]
|
||||
if cfg.graphiti.group_id
|
||||
else []
|
||||
)
|
||||
|
||||
relevant_edges = await client.search(
|
||||
group_ids=effective_group_ids,
|
||||
query=query,
|
||||
num_results=max_facts,
|
||||
center_node_uuid=center_node_uuid,
|
||||
)
|
||||
|
||||
if not relevant_edges:
|
||||
return FactSearchResponse(message='No relevant facts found', facts=[])
|
||||
|
||||
facts = [format_fact_result(edge) for edge in relevant_edges]
|
||||
return FactSearchResponse(message='Facts retrieved successfully', facts=facts)
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
logger.error(f'Error searching facts: {error_msg}')
|
||||
return ErrorResponse(error=f'Error searching facts: {error_msg}')
|
||||
|
||||
@server.tool()
|
||||
async def delete_entity_edge(uuid: str) -> SuccessResponse | ErrorResponse:
|
||||
"""Delete an entity edge from the graph memory."""
|
||||
try:
|
||||
client = await graphiti_svc.get_client()
|
||||
entity_edge = await EntityEdge.get_by_uuid(client.driver, uuid)
|
||||
await entity_edge.delete(client.driver)
|
||||
return SuccessResponse(message=f'Entity edge with UUID {uuid} deleted successfully')
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
logger.error(f'Error deleting entity edge: {error_msg}')
|
||||
return ErrorResponse(error=f'Error deleting entity edge: {error_msg}')
|
||||
|
||||
@server.tool()
|
||||
async def delete_episode(uuid: str) -> SuccessResponse | ErrorResponse:
|
||||
"""Delete an episode from the graph memory."""
|
||||
try:
|
||||
client = await graphiti_svc.get_client()
|
||||
episodic_node = await EpisodicNode.get_by_uuid(client.driver, uuid)
|
||||
await episodic_node.delete(client.driver)
|
||||
return SuccessResponse(message=f'Episode with UUID {uuid} deleted successfully')
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
logger.error(f'Error deleting episode: {error_msg}')
|
||||
return ErrorResponse(error=f'Error deleting episode: {error_msg}')
|
||||
|
||||
@server.tool()
|
||||
async def get_entity_edge(uuid: str) -> dict[str, Any] | ErrorResponse:
|
||||
"""Get an entity edge from the graph memory by its UUID."""
|
||||
try:
|
||||
client = await graphiti_svc.get_client()
|
||||
entity_edge = await EntityEdge.get_by_uuid(client.driver, uuid)
|
||||
return format_fact_result(entity_edge)
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
logger.error(f'Error getting entity edge: {error_msg}')
|
||||
return ErrorResponse(error=f'Error getting entity edge: {error_msg}')
|
||||
|
||||
@server.tool()
|
||||
async def get_episodes(
|
||||
group_ids: list[str] | None = None,
|
||||
max_episodes: int = 10,
|
||||
) -> EpisodeSearchResponse | ErrorResponse:
|
||||
"""Get episodes from the graph memory."""
|
||||
try:
|
||||
client = await graphiti_svc.get_client()
|
||||
|
||||
effective_group_ids = (
|
||||
group_ids
|
||||
if group_ids is not None
|
||||
else [cfg.graphiti.group_id]
|
||||
if cfg.graphiti.group_id
|
||||
else []
|
||||
)
|
||||
|
||||
if effective_group_ids:
|
||||
episodes = await EpisodicNode.get_by_group_ids(
|
||||
client.driver, effective_group_ids, limit=max_episodes
|
||||
)
|
||||
else:
|
||||
episodes = []
|
||||
|
||||
if not episodes:
|
||||
return EpisodeSearchResponse(message='No episodes found', episodes=[])
|
||||
|
||||
episode_results = []
|
||||
for episode in episodes:
|
||||
episode_dict = {
|
||||
'uuid': episode.uuid,
|
||||
'name': episode.name,
|
||||
'content': episode.content,
|
||||
'created_at': episode.created_at.isoformat() if episode.created_at else None,
|
||||
'source': episode.source.value
|
||||
if hasattr(episode.source, 'value')
|
||||
else str(episode.source),
|
||||
'source_description': episode.source_description,
|
||||
'group_id': episode.group_id,
|
||||
}
|
||||
episode_results.append(episode_dict)
|
||||
|
||||
return EpisodeSearchResponse(
|
||||
message='Episodes retrieved successfully', episodes=episode_results
|
||||
)
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
logger.error(f'Error getting episodes: {error_msg}')
|
||||
return ErrorResponse(error=f'Error getting episodes: {error_msg}')
|
||||
|
||||
@server.tool()
|
||||
async def clear_graph(group_ids: list[str] | None = None) -> SuccessResponse | ErrorResponse:
|
||||
"""Clear all data from the graph for specified group IDs."""
|
||||
try:
|
||||
client = await graphiti_svc.get_client()
|
||||
|
||||
effective_group_ids = (
|
||||
group_ids or [cfg.graphiti.group_id] if cfg.graphiti.group_id else []
|
||||
)
|
||||
|
||||
if not effective_group_ids:
|
||||
return ErrorResponse(error='No group IDs specified for clearing')
|
||||
|
||||
await clear_data(client.driver, group_ids=effective_group_ids)
|
||||
|
||||
return SuccessResponse(
|
||||
message=f'Graph data cleared successfully for group IDs: {", ".join(effective_group_ids)}'
|
||||
)
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
logger.error(f'Error clearing graph: {error_msg}')
|
||||
return ErrorResponse(error=f'Error clearing graph: {error_msg}')
|
||||
|
||||
@server.tool()
|
||||
async def get_status() -> StatusResponse:
|
||||
"""Get the status of the Graphiti MCP server and database connection."""
|
||||
try:
|
||||
client = await graphiti_svc.get_client()
|
||||
|
||||
async with client.driver.session() as session:
|
||||
result = await session.run('MATCH (n) RETURN count(n) as count')
|
||||
if result:
|
||||
_ = [record async for record in result]
|
||||
|
||||
provider_name = graphiti_svc.config.database.provider
|
||||
return StatusResponse(
|
||||
status='ok',
|
||||
message=f'Graphiti MCP server is running and connected to {provider_name} database',
|
||||
)
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
logger.error(f'Error checking database connection: {error_msg}')
|
||||
return StatusResponse(
|
||||
status='error',
|
||||
message=f'Graphiti MCP server is running but database connection failed: {error_msg}',
|
||||
)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Factory Entrypoint for FastMCP Cloud
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
async def create_server() -> FastMCP:
|
||||
"""Factory function that creates and initializes the MCP server.
|
||||
|
||||
This is the entrypoint for FastMCP Cloud and `fastmcp dev`.
|
||||
Configuration comes from environment variables set in the FastMCP Cloud UI.
|
||||
|
||||
FastMCP Cloud uses this pattern:
|
||||
- Entrypoint: src/graphiti_mcp_server.py:create_server
|
||||
- This function is called to create the server instance
|
||||
- Tools are registered using closures that capture service instances
|
||||
"""
|
||||
global config, graphiti_service, queue_service, graphiti_client, semaphore
|
||||
|
||||
logger.info('Initializing Graphiti MCP server via factory pattern...')
|
||||
|
||||
# 1. Load configuration from environment variables
|
||||
config = GraphitiConfig()
|
||||
|
||||
# Log configuration details
|
||||
logger.info('Using configuration:')
|
||||
logger.info(f' - LLM: {config.llm.provider} / {config.llm.model}')
|
||||
logger.info(f' - Embedder: {config.embedder.provider} / {config.embedder.model}')
|
||||
logger.info(f' - Database: {config.database.provider}')
|
||||
logger.info(f' - Group ID: {config.graphiti.group_id}')
|
||||
|
||||
# 2. Initialize Services
|
||||
graphiti_service = GraphitiService(config, SEMAPHORE_LIMIT)
|
||||
queue_service = QueueService()
|
||||
await graphiti_service.initialize()
|
||||
|
||||
# Set global client for backward compatibility
|
||||
graphiti_client = await graphiti_service.get_client()
|
||||
semaphore = graphiti_service.semaphore
|
||||
|
||||
# Initialize queue service with the client
|
||||
await queue_service.initialize(graphiti_client)
|
||||
|
||||
logger.info('Graphiti services initialized successfully via factory')
|
||||
|
||||
# 3. Create Server Instance
|
||||
server = FastMCP(
|
||||
'Graphiti Agent Memory',
|
||||
instructions=GRAPHITI_MCP_INSTRUCTIONS,
|
||||
)
|
||||
|
||||
# 4. Register Tools using closures that capture service instances
|
||||
_register_tools(server, config, graphiti_service, queue_service)
|
||||
|
||||
# 5. Register Custom Routes
|
||||
@server.custom_route('/health', methods=['GET'])
|
||||
async def health_check(request):
|
||||
return JSONResponse({'status': 'healthy', 'service': 'graphiti-mcp'})
|
||||
|
||||
@server.custom_route('/status', methods=['GET'])
|
||||
async def status_check(request):
|
||||
return JSONResponse({'status': 'ok', 'service': 'graphiti-mcp'})
|
||||
|
||||
logger.info('FastMCP server created with factory pattern')
|
||||
return server
|
||||
|
||||
|
||||
# Also create the module-level mcp instance for backward compatibility with local dev
|
||||
# This uses a lifespan for the traditional pattern
|
||||
@asynccontextmanager
|
||||
async def graphiti_lifespan(app):
|
||||
"""Lifespan context manager for FastMCP Cloud deployment.
|
||||
"""Lifespan context manager for local development.
|
||||
|
||||
This function initializes the Graphiti service when the server starts.
|
||||
FastMCP Cloud calls this automatically - it does NOT run the if __name__ == '__main__' block.
|
||||
Note: FastMCP Cloud should use the create_server() factory function instead.
|
||||
"""
|
||||
global config, graphiti_service, queue_service, graphiti_client, semaphore
|
||||
|
||||
logger.info('Initializing Graphiti service via lifespan...')
|
||||
|
||||
try:
|
||||
# Load configuration from environment variables (FastMCP Cloud sets these)
|
||||
config = GraphitiConfig()
|
||||
|
||||
# Log configuration details
|
||||
logger.info('Using configuration:')
|
||||
logger.info(f' - LLM: {config.llm.provider} / {config.llm.model}')
|
||||
logger.info(f' - Embedder: {config.embedder.provider} / {config.embedder.model}')
|
||||
logger.info(f' - Database: {config.database.provider}')
|
||||
logger.info(f' - Group ID: {config.graphiti.group_id}')
|
||||
|
||||
# Initialize services (GraphitiService is defined below, but that's OK
|
||||
# since this function is only called at runtime, not definition time)
|
||||
graphiti_service = GraphitiService(config, SEMAPHORE_LIMIT)
|
||||
queue_service = QueueService()
|
||||
await graphiti_service.initialize()
|
||||
|
||||
# Set global client for backward compatibility
|
||||
graphiti_client = await graphiti_service.get_client()
|
||||
semaphore = graphiti_service.semaphore
|
||||
|
||||
# Initialize queue service with the client
|
||||
await queue_service.initialize(graphiti_client)
|
||||
|
||||
logger.info('Graphiti service initialized successfully via lifespan')
|
||||
|
||||
yield # Server runs here
|
||||
yield
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f'Failed to initialize Graphiti service: {e}')
|
||||
|
|
@ -199,7 +551,7 @@ async def graphiti_lifespan(app):
|
|||
logger.info('Shutting down Graphiti service...')
|
||||
|
||||
|
||||
# MCP server instance - pass lifespan to constructor for FastMCP Cloud
|
||||
# Module-level MCP server instance for backward compatibility
|
||||
mcp = FastMCP(
|
||||
'Graphiti Agent Memory',
|
||||
instructions=GRAPHITI_MCP_INSTRUCTIONS,
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue