#!/usr/bin/env python3 """ Graphiti MCP Server - Exposes Graphiti functionality through the Model Context Protocol (MCP) """ import argparse import asyncio import logging import os import sys from datetime import datetime from pathlib import Path from typing import Any, Optional from config_schema import GraphitiConfig from dotenv import load_dotenv from entity_types import ENTITY_TYPES from factories import DatabaseDriverFactory, EmbedderFactory, LLMClientFactory from formatting import format_fact_result from mcp.server.fastmcp import FastMCP from pydantic import BaseModel from queue_service import QueueService from response_types import ( EpisodeSearchResponse, ErrorResponse, FactSearchResponse, NodeResult, NodeSearchResponse, StatusResponse, SuccessResponse, ) from server_config import MCPConfig from graphiti_core import Graphiti from graphiti_core.edges import EntityEdge from graphiti_core.nodes import EpisodeType, EpisodicNode from graphiti_core.search.search_config_recipes import ( NODE_HYBRID_SEARCH_RRF, ) from graphiti_core.search.search_filters import SearchFilters from graphiti_core.utils.maintenance.graph_data_operations import clear_data load_dotenv() # Semaphore limit for concurrent Graphiti operations. # Decrease this if you're experiencing 429 rate limit errors from your LLM provider. # Increase if you have high rate limits. SEMAPHORE_LIMIT = int(os.getenv('SEMAPHORE_LIMIT', 10)) # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', stream=sys.stderr, ) logger = logging.getLogger(__name__) # Create global config instance - will be properly initialized later config: GraphitiConfig # MCP server instructions GRAPHITI_MCP_INSTRUCTIONS = """ Graphiti is a memory service for AI agents built on a knowledge graph. Graphiti performs well with dynamic data such as user interactions, changing enterprise data, and external information. Graphiti transforms information into a richly connected knowledge network, allowing you to capture relationships between concepts, entities, and information. The system organizes data as episodes (content snippets), nodes (entities), and facts (relationships between entities), creating a dynamic, queryable memory store that evolves with new information. Graphiti supports multiple data formats, including structured JSON data, enabling seamless integration with existing data pipelines and systems. Facts contain temporal metadata, allowing you to track the time of creation and whether a fact is invalid (superseded by new information). Key capabilities: 1. Add episodes (text, messages, or JSON) to the knowledge graph with the add_memory tool 2. Search for nodes (entities) in the graph using natural language queries with search_nodes 3. Find relevant facts (relationships between entities) with search_facts 4. Retrieve specific entity edges or episodes by UUID 5. Manage the knowledge graph with tools like delete_episode, delete_entity_edge, and clear_graph The server connects to a database for persistent storage and uses language models for certain operations. Each piece of information is organized by group_id, allowing you to maintain separate knowledge domains. When adding information, provide descriptive names and detailed content to improve search quality. When searching, use specific queries and consider filtering by group_id for more relevant results. For optimal performance, ensure the database is properly configured and accessible, and valid API keys are provided for any language model operations. """ # MCP server instance mcp = FastMCP( 'Graphiti Agent Memory', instructions=GRAPHITI_MCP_INSTRUCTIONS, ) # Global services graphiti_service: Optional['GraphitiService'] = None queue_service: QueueService | None = None # Global client for backward compatibility graphiti_client: Graphiti | None = None semaphore: asyncio.Semaphore class GraphitiService: """Graphiti service using the unified configuration system.""" def __init__(self, config: GraphitiConfig, semaphore_limit: int = 10): self.config = config self.semaphore = asyncio.Semaphore(semaphore_limit) self.client: Graphiti | None = None self.entity_types = None async def initialize(self) -> None: """Initialize the Graphiti client with factory-created components.""" try: # Create clients using factories llm_client = None embedder_client = None # Only create LLM client if API key is available if self.config.llm.providers.openai and self.config.llm.providers.openai.api_key: llm_client = LLMClientFactory.create(self.config.llm) # Only create embedder client if API key is available if ( self.config.embedder.providers.openai and self.config.embedder.providers.openai.api_key ): embedder_client = EmbedderFactory.create(self.config.embedder) # Get database configuration db_config = DatabaseDriverFactory.create_config(self.config.database) # Build custom entity types if configured custom_types = None if self.config.graphiti.entity_types: custom_types = [] for entity_type in self.config.graphiti.entity_types: # Create a dynamic Pydantic model for each entity type entity_model = type( entity_type.name, (BaseModel,), { '__annotations__': {'name': str}, '__doc__': entity_type.description, }, ) custom_types.append(entity_model) # Also support the existing ENTITY_TYPES if use_custom_entities is set elif hasattr(self.config, 'use_custom_entities') and self.config.use_custom_entities: custom_types = ENTITY_TYPES # Store entity types for later use self.entity_types = custom_types # Initialize Graphiti client with database connection params self.client = Graphiti( uri=db_config['uri'], user=db_config['user'], password=db_config['password'], llm_client=llm_client, embedder=embedder_client, custom_node_types=custom_types, max_coroutines=self.semaphore_limit, ) # Test connection await self.client.driver.client.verify_connectivity() # type: ignore # Build indices await self.client.build_indices_and_constraints() logger.info('Successfully initialized Graphiti client') # Log configuration details if llm_client: logger.info( f'Using LLM provider: {self.config.llm.provider} / {self.config.llm.model}' ) else: logger.info('No LLM client configured - entity extraction will be limited') if embedder_client: logger.info(f'Using Embedder provider: {self.config.embedder.provider}') else: logger.info('No Embedder client configured - search will be limited') logger.info(f'Using database: {self.config.database.provider}') logger.info(f'Using group_id: {self.config.graphiti.group_id}') except Exception as e: logger.error(f'Failed to initialize Graphiti client: {e}') raise async def get_client(self) -> Graphiti: """Get the Graphiti client, initializing if necessary.""" if self.client is None: await self.initialize() if self.client is None: raise RuntimeError('Failed to initialize Graphiti client') return self.client @mcp.tool() async def add_memory( name: str, episode_body: str, group_id: str | None = None, source: str = 'text', source_description: str = '', uuid: str | None = None, ) -> SuccessResponse | ErrorResponse: """Add an episode to memory. This is the primary way to add information to the graph. This function returns immediately and processes the episode addition in the background. Episodes for the same group_id are processed sequentially to avoid race conditions. Args: name (str): Name of the episode episode_body (str): The content of the episode to persist to memory. When source='json', this must be a properly escaped JSON string, not a raw Python dictionary. The JSON data will be automatically processed to extract entities and relationships. group_id (str, optional): A unique ID for this graph. If not provided, uses the default group_id from CLI or a generated one. source (str, optional): Source type, must be one of: - 'text': For plain text content (default) - 'json': For structured data - 'message': For conversation-style content source_description (str, optional): Description of the source uuid (str, optional): Optional UUID for the episode Examples: # Adding plain text content add_memory( name="Company News", episode_body="Acme Corp announced a new product line today.", source="text", source_description="news article", group_id="some_arbitrary_string" ) # Adding structured JSON data # NOTE: episode_body must be a properly escaped JSON string. Note the triple backslashes add_memory( name="Customer Profile", episode_body="{\\\"company\\\": {\\\"name\\\": \\\"Acme Technologies\\\"}, \\\"products\\\": [{\\\"id\\\": \\\"P001\\\", \\\"name\\\": \\\"CloudSync\\\"}, {\\\"id\\\": \\\"P002\\\", \\\"name\\\": \\\"DataMiner\\\"}]}", source="json", source_description="CRM data" ) """ global graphiti_service, queue_service if graphiti_service is None or queue_service is None: return ErrorResponse(error='Services not initialized') try: # Use the provided group_id or fall back to the default from config effective_group_id = group_id or config.graphiti.group_id # Try to parse the source as an EpisodeType enum, with fallback to text episode_type = EpisodeType.text # Default if source: try: episode_type = EpisodeType[source.lower()] except (KeyError, AttributeError): # If the source doesn't match any enum value, use text as default logger.warning(f"Unknown source type '{source}', using 'text' as default") episode_type = EpisodeType.text # Submit to queue service for async processing await queue_service.add_episode( group_id=effective_group_id, name=name, content=episode_body, source_description=source_description, episode_type=episode_type, custom_types=graphiti_service.entity_types, uuid=uuid, ) return SuccessResponse( message=f"Episode '{name}' queued for processing in group '{effective_group_id}'" ) except Exception as e: error_msg = str(e) logger.error(f'Error queuing episode: {error_msg}') return ErrorResponse(error=f'Error queuing episode: {error_msg}') @mcp.tool() async def search_nodes( query: str, group_ids: list[str] | None = None, max_nodes: int = 10, entity_types: list[str] | None = None, ) -> NodeSearchResponse | ErrorResponse: """Search for nodes in the graph memory. Args: query: The search query group_ids: Optional list of group IDs to filter results max_nodes: Maximum number of nodes to return (default: 10) entity_types: Optional list of entity type names to filter by """ global graphiti_service if graphiti_service is None: return ErrorResponse(error='Graphiti service not initialized') try: client = await graphiti_service.get_client() # Use the provided group_ids or fall back to the default from config if none provided effective_group_ids = ( group_ids if group_ids is not None else [config.graphiti.group_id] if config.graphiti.group_id else [] ) # Create search filters search_filters = SearchFilters( group_ids=effective_group_ids, node_labels=entity_types, ) # Perform the search nodes = await client.search_nodes( query=query, limit=max_nodes, search_config=NODE_HYBRID_SEARCH_RRF, search_filters=search_filters, ) if not nodes: return NodeSearchResponse(message='No relevant nodes found', nodes=[]) # Format the results node_results = [ NodeResult( uuid=node.uuid, name=node.name, type=node.type or 'Unknown', created_at=node.created_at.isoformat() if node.created_at else None, summary=node.summary, ) for node in nodes ] return NodeSearchResponse(message='Nodes retrieved successfully', nodes=node_results) except Exception as e: error_msg = str(e) logger.error(f'Error searching nodes: {error_msg}') return ErrorResponse(error=f'Error searching nodes: {error_msg}') @mcp.tool() async def search_memory_facts( query: str, group_ids: list[str] | None = None, max_facts: int = 10, center_node_uuid: str | None = None, ) -> FactSearchResponse | ErrorResponse: """Search the graph memory for relevant facts. Args: query: The search query group_ids: Optional list of group IDs to filter results max_facts: Maximum number of facts to return (default: 10) center_node_uuid: Optional UUID of a node to center the search around """ global graphiti_service if graphiti_service is None: return ErrorResponse(error='Graphiti service not initialized') try: # Validate max_facts parameter if max_facts <= 0: return ErrorResponse(error='max_facts must be a positive integer') client = await graphiti_service.get_client() # Use the provided group_ids or fall back to the default from config if none provided effective_group_ids = ( group_ids if group_ids is not None else [config.graphiti.group_id] if config.graphiti.group_id else [] ) relevant_edges = await client.search( group_ids=effective_group_ids, query=query, num_results=max_facts, center_node_uuid=center_node_uuid, ) if not relevant_edges: return FactSearchResponse(message='No relevant facts found', facts=[]) facts = [format_fact_result(edge) for edge in relevant_edges] return FactSearchResponse(message='Facts retrieved successfully', facts=facts) except Exception as e: error_msg = str(e) logger.error(f'Error searching facts: {error_msg}') return ErrorResponse(error=f'Error searching facts: {error_msg}') @mcp.tool() async def delete_entity_edge(uuid: str) -> SuccessResponse | ErrorResponse: """Delete an entity edge from the graph memory. Args: uuid: UUID of the entity edge to delete """ global graphiti_service if graphiti_service is None: return ErrorResponse(error='Graphiti service not initialized') try: client = await graphiti_service.get_client() # Get the entity edge by UUID entity_edge = await EntityEdge.get_by_uuid(client.driver, uuid) # Delete the edge using its delete method await entity_edge.delete(client.driver) return SuccessResponse(message=f'Entity edge with UUID {uuid} deleted successfully') except Exception as e: error_msg = str(e) logger.error(f'Error deleting entity edge: {error_msg}') return ErrorResponse(error=f'Error deleting entity edge: {error_msg}') @mcp.tool() async def delete_episode(uuid: str) -> SuccessResponse | ErrorResponse: """Delete an episode from the graph memory. Args: uuid: UUID of the episode to delete """ global graphiti_service if graphiti_service is None: return ErrorResponse(error='Graphiti service not initialized') try: client = await graphiti_service.get_client() # Get the episodic node by UUID episodic_node = await EpisodicNode.get_by_uuid(client.driver, uuid) # Delete the node using its delete method await episodic_node.delete(client.driver) return SuccessResponse(message=f'Episode with UUID {uuid} deleted successfully') except Exception as e: error_msg = str(e) logger.error(f'Error deleting episode: {error_msg}') return ErrorResponse(error=f'Error deleting episode: {error_msg}') @mcp.tool() async def get_entity_edge(uuid: str) -> dict[str, Any] | ErrorResponse: """Get an entity edge from the graph memory by its UUID. Args: uuid: UUID of the entity edge to retrieve """ global graphiti_service if graphiti_service is None: return ErrorResponse(error='Graphiti service not initialized') try: client = await graphiti_service.get_client() # Get the entity edge directly using the EntityEdge class method entity_edge = await EntityEdge.get_by_uuid(client.driver, uuid) # Use the format_fact_result function to serialize the edge # Return the Python dict directly - MCP will handle serialization return format_fact_result(entity_edge) except Exception as e: error_msg = str(e) logger.error(f'Error getting entity edge: {error_msg}') return ErrorResponse(error=f'Error getting entity edge: {error_msg}') @mcp.tool() async def search_episodes( query: str | None = None, group_ids: list[str] | None = None, max_episodes: int = 10, start_date: str | None = None, end_date: str | None = None, ) -> EpisodeSearchResponse | ErrorResponse: """Search for episodes in the graph memory. Args: query: Optional search query for semantic search group_ids: Optional list of group IDs to filter results max_episodes: Maximum number of episodes to return (default: 10) start_date: Optional start date (ISO format) to filter episodes end_date: Optional end date (ISO format) to filter episodes """ global graphiti_service if graphiti_service is None: return ErrorResponse(error='Graphiti service not initialized') try: client = await graphiti_service.get_client() # Use the provided group_ids or fall back to the default from config if none provided effective_group_ids = ( group_ids if group_ids is not None else [config.graphiti.group_id] if config.graphiti.group_id else [] ) # Convert date strings to datetime objects if provided start_dt = datetime.fromisoformat(start_date) if start_date else None end_dt = datetime.fromisoformat(end_date) if end_date else None # Search for episodes episodes = await client.search_episodes( query=query, group_ids=effective_group_ids, limit=max_episodes, start_date=start_dt, end_date=end_dt, ) if not episodes: return EpisodeSearchResponse(message='No episodes found', episodes=[]) # Format the results episode_results = [] for episode in episodes: episode_dict = { 'uuid': episode.uuid, 'name': episode.name, 'content': episode.content, 'created_at': episode.created_at.isoformat() if episode.created_at else None, 'source': episode.source, 'source_description': episode.source_description, 'group_id': episode.group_id, } episode_results.append(episode_dict) return EpisodeSearchResponse( message='Episodes retrieved successfully', episodes=episode_results ) except Exception as e: error_msg = str(e) logger.error(f'Error searching episodes: {error_msg}') return ErrorResponse(error=f'Error searching episodes: {error_msg}') @mcp.tool() async def clear_graph(group_ids: list[str] | None = None) -> SuccessResponse | ErrorResponse: """Clear all data from the graph for specified group IDs. Args: group_ids: Optional list of group IDs to clear. If not provided, clears the default group. """ global graphiti_service if graphiti_service is None: return ErrorResponse(error='Graphiti service not initialized') try: client = await graphiti_service.get_client() # Use the provided group_ids or fall back to the default from config if none provided effective_group_ids = ( group_ids or [config.graphiti.group_id] if config.graphiti.group_id else [] ) if not effective_group_ids: return ErrorResponse(error='No group IDs specified for clearing') # Clear data for the specified group IDs await clear_data(client.driver, group_ids=effective_group_ids) return SuccessResponse( message=f"Graph data cleared successfully for group IDs: {', '.join(effective_group_ids)}" ) except Exception as e: error_msg = str(e) logger.error(f'Error clearing graph: {error_msg}') return ErrorResponse(error=f'Error clearing graph: {error_msg}') @mcp.tool() async def get_status() -> StatusResponse: """Get the status of the Graphiti MCP server and database connection.""" global graphiti_service if graphiti_service is None: return StatusResponse(status='error', message='Graphiti service not initialized') try: client = await graphiti_service.get_client() # Test database connection await client.driver.client.verify_connectivity() # type: ignore provider_info = f'{config.database.provider} database' return StatusResponse( status='ok', message=f'Graphiti MCP server is running and connected to {provider_info}' ) except Exception as e: error_msg = str(e) logger.error(f'Error checking database connection: {error_msg}') return StatusResponse( status='error', message=f'Graphiti MCP server is running but database connection failed: {error_msg}', ) async def initialize_server() -> MCPConfig: """Parse CLI arguments and initialize the Graphiti server configuration.""" global config, graphiti_service, queue_service, graphiti_client, semaphore parser = argparse.ArgumentParser( description='Run the Graphiti MCP server with YAML configuration support' ) # Configuration file argument parser.add_argument( '--config', type=Path, default=Path('config.yaml'), help='Path to YAML configuration file (default: config.yaml)', ) # Transport arguments parser.add_argument( '--transport', choices=['sse', 'stdio'], help='Transport to use for communication with the client', ) parser.add_argument( '--host', help='Host to bind the MCP server to', ) parser.add_argument( '--port', type=int, help='Port to bind the MCP server to', ) # Provider selection arguments parser.add_argument( '--llm-provider', choices=['openai', 'azure_openai', 'anthropic', 'gemini', 'groq'], help='LLM provider to use', ) parser.add_argument( '--embedder-provider', choices=['openai', 'azure_openai', 'gemini', 'voyage'], help='Embedder provider to use', ) parser.add_argument( '--database-provider', choices=['neo4j', 'falkordb'], help='Database provider to use', ) # LLM configuration arguments parser.add_argument('--model', help='Model name to use with the LLM client') parser.add_argument('--small-model', help='Small model name to use with the LLM client') parser.add_argument( '--temperature', type=float, help='Temperature setting for the LLM (0.0-2.0)' ) # Embedder configuration arguments parser.add_argument('--embedder-model', help='Model name to use with the embedder') # Graphiti-specific arguments parser.add_argument( '--group-id', help='Namespace for the graph. If not provided, uses config file or generates random UUID.', ) parser.add_argument( '--user-id', help='User ID for tracking operations', ) parser.add_argument( '--destroy-graph', action='store_true', help='Destroy all Graphiti graphs on startup', ) parser.add_argument( '--use-custom-entities', action='store_true', help='Enable entity extraction using the predefined ENTITY_TYPES', ) args = parser.parse_args() # Set config path in environment for the settings to pick up if args.config: os.environ['CONFIG_PATH'] = str(args.config) # Load configuration with environment variables and YAML config = GraphitiConfig() # Apply CLI overrides config.apply_cli_overrides(args) # Also apply legacy CLI args for backward compatibility if hasattr(args, 'use_custom_entities'): config.use_custom_entities = args.use_custom_entities if hasattr(args, 'destroy_graph'): config.destroy_graph = args.destroy_graph # Log configuration details logger.info('Using configuration:') logger.info(f' - LLM: {config.llm.provider} / {config.llm.model}') logger.info(f' - Embedder: {config.embedder.provider} / {config.embedder.model}') logger.info(f' - Database: {config.database.provider}') logger.info(f' - Group ID: {config.graphiti.group_id}') logger.info(f' - Transport: {config.server.transport}') # Handle graph destruction if requested if hasattr(config, 'destroy_graph') and config.destroy_graph: logger.warning('Destroying all Graphiti graphs as requested...') temp_service = GraphitiService(config, SEMAPHORE_LIMIT) await temp_service.initialize() client = await temp_service.get_client() await clear_data(client.driver) logger.info('All graphs destroyed') # Initialize services graphiti_service = GraphitiService(config, SEMAPHORE_LIMIT) queue_service = QueueService() await graphiti_service.initialize() # Set global client for backward compatibility graphiti_client = await graphiti_service.get_client() semaphore = graphiti_service.semaphore # Initialize queue service with the client await queue_service.initialize(graphiti_client) # Set MCP server settings if config.server.host: mcp.settings.host = config.server.host if config.server.port: mcp.settings.port = config.server.port # Return MCP configuration for transport return MCPConfig(transport=config.server.transport) async def run_mcp_server(): """Run the MCP server in the current event loop.""" # Initialize the server mcp_config = await initialize_server() # Run the server with configured transport logger.info(f'Starting MCP server with transport: {mcp_config.transport}') if mcp_config.transport == 'stdio': await mcp.run_stdio_async() elif mcp_config.transport == 'sse': logger.info( f'Running MCP server with SSE transport on {mcp.settings.host}:{mcp.settings.port}' ) await mcp.run_sse_async() def main(): """Main function to run the Graphiti MCP server.""" try: # Run everything in a single event loop asyncio.run(run_mcp_server()) except KeyboardInterrupt: logger.info('Server shutting down...') except Exception as e: logger.error(f'Error initializing Graphiti MCP server: {str(e)}') raise if __name__ == '__main__': main()