graphiti/mcp_server/graphiti_mcp_server.py
Daniel Chalef 452a45cb4e wip
2025-08-30 08:50:48 -07:00

659 lines
25 KiB
Python

#!/usr/bin/env python3
"""
Graphiti MCP Server - Exposes Graphiti functionality through the Model Context Protocol (MCP)
"""
import argparse
import asyncio
import logging
import os
import sys
from datetime import datetime, timezone
from typing import Any, cast
from config_manager import GraphitiConfig
from dotenv import load_dotenv
from entity_types import ENTITY_TYPES
from formatting import format_fact_result
from graphiti_service import GraphitiService
from llm_config import DEFAULT_LLM_MODEL, SMALL_LLM_MODEL
from mcp.server.fastmcp import FastMCP
from queue_service import QueueService
from response_types import (
EpisodeSearchResponse,
ErrorResponse,
FactSearchResponse,
NodeResult,
NodeSearchResponse,
StatusResponse,
SuccessResponse,
)
from server_config import MCPConfig
from graphiti_core import Graphiti
from graphiti_core.edges import EntityEdge
from graphiti_core.nodes import EpisodeType, EpisodicNode
from graphiti_core.search.search_config_recipes import (
NODE_HYBRID_SEARCH_NODE_DISTANCE,
NODE_HYBRID_SEARCH_RRF,
)
from graphiti_core.search.search_filters import SearchFilters
from graphiti_core.utils.maintenance.graph_data_operations import clear_data
load_dotenv()
# Semaphore limit for concurrent Graphiti operations.
# Decrease this if you're experiencing 429 rate limit errors from your LLM provider.
# Increase if you have high rate limits.
SEMAPHORE_LIMIT = int(os.getenv('SEMAPHORE_LIMIT', 10))
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
stream=sys.stderr,
)
logger = logging.getLogger(__name__)
# Create global config instance - will be properly initialized later
config = GraphitiConfig()
# MCP server instructions
GRAPHITI_MCP_INSTRUCTIONS = """
Graphiti is a memory service for AI agents built on a knowledge graph. Graphiti performs well
with dynamic data such as user interactions, changing enterprise data, and external information.
Graphiti transforms information into a richly connected knowledge network, allowing you to
capture relationships between concepts, entities, and information. The system organizes data as episodes
(content snippets), nodes (entities), and facts (relationships between entities), creating a dynamic,
queryable memory store that evolves with new information. Graphiti supports multiple data formats, including
structured JSON data, enabling seamless integration with existing data pipelines and systems.
Facts contain temporal metadata, allowing you to track the time of creation and whether a fact is invalid
(superseded by new information).
Key capabilities:
1. Add episodes (text, messages, or JSON) to the knowledge graph with the add_memory tool
2. Search for nodes (entities) in the graph using natural language queries with search_nodes
3. Find relevant facts (relationships between entities) with search_facts
4. Retrieve specific entity edges or episodes by UUID
5. Manage the knowledge graph with tools like delete_episode, delete_entity_edge, and clear_graph
The server connects to a database for persistent storage and uses language models for certain operations.
Each piece of information is organized by group_id, allowing you to maintain separate knowledge domains.
When adding information, provide descriptive names and detailed content to improve search quality.
When searching, use specific queries and consider filtering by group_id for more relevant results.
For optimal performance, ensure the database is properly configured and accessible, and valid
API keys are provided for any language model operations.
"""
# MCP server instance
mcp = FastMCP(
'Graphiti Agent Memory',
instructions=GRAPHITI_MCP_INSTRUCTIONS,
)
# Global services
graphiti_service: GraphitiService | None = None
queue_service: QueueService | None = None
@mcp.tool()
async def add_memory(
name: str,
episode_body: str,
group_id: str | None = None,
source: str = 'text',
source_description: str = '',
uuid: str | None = None,
) -> SuccessResponse | ErrorResponse:
"""Add an episode to memory. This is the primary way to add information to the graph.
This function returns immediately and processes the episode addition in the background.
Episodes for the same group_id are processed sequentially to avoid race conditions.
Args:
name (str): Name of the episode
episode_body (str): The content of the episode to persist to memory. When source='json', this must be a
properly escaped JSON string, not a raw Python dictionary. The JSON data will be
automatically processed to extract entities and relationships.
group_id (str, optional): A unique ID for this graph. If not provided, uses the default group_id from CLI
or a generated one.
source (str, optional): Source type, must be one of:
- 'text': For plain text content (default)
- 'json': For structured data
- 'message': For conversation-style content
source_description (str, optional): Description of the source
uuid (str, optional): Optional UUID for the episode
Examples:
# Adding plain text content
add_memory(
name="Company News",
episode_body="Acme Corp announced a new product line today.",
source="text",
source_description="news article",
group_id="some_arbitrary_string"
)
# Adding structured JSON data
# NOTE: episode_body must be a properly escaped JSON string. Note the triple backslashes
add_memory(
name="Customer Profile",
episode_body="{\\\"company\\\": {\\\"name\\\": \\\"Acme Technologies\\\"}, \\\"products\\\": [{\\\"id\\\": \\\"P001\\\", \\\"name\\\": \\\"CloudSync\\\"}, {\\\"id\\\": \\\"P002\\\", \\\"name\\\": \\\"DataMiner\\\"}]}",
source="json",
source_description="CRM data"
)
# Adding message-style content
add_memory(
name="Customer Conversation",
episode_body="user: What's your return policy?\nassistant: You can return items within 30 days.",
source="message",
source_description="chat transcript",
group_id="some_arbitrary_string"
)
Notes:
When using source='json':
- The JSON must be a properly escaped string, not a raw Python dictionary
- The JSON will be automatically processed to extract entities and relationships
- Complex nested structures are supported (arrays, nested objects, mixed data types), but keep nesting to a minimum
- Entities will be created from appropriate JSON properties
- Relationships between entities will be established based on the JSON structure
"""
global graphiti_service, queue_service, config
if not graphiti_service or not graphiti_service.is_initialized():
return ErrorResponse(error='Graphiti service not initialized')
if not queue_service:
return ErrorResponse(error='Queue service not initialized')
try:
# Map string source to EpisodeType enum
source_type = EpisodeType.text
if source.lower() == 'message':
source_type = EpisodeType.message
elif source.lower() == 'json':
source_type = EpisodeType.json
# Use the provided group_id or fall back to the default from config
effective_group_id = group_id if group_id is not None else config.group_id
# Cast group_id to str to satisfy type checker
# The Graphiti client expects a str for group_id, not Optional[str]
group_id_str = str(effective_group_id) if effective_group_id is not None else ''
# Define the episode processing function
async def process_episode():
try:
logger.info(f"Processing queued episode '{name}' for group_id: {group_id_str}")
# Use all entity types if use_custom_entities is enabled, otherwise use empty dict
entity_types = ENTITY_TYPES if config.use_custom_entities else {}
await graphiti_service.client.add_episode(
name=name,
episode_body=episode_body,
source=source_type,
source_description=source_description,
group_id=group_id_str, # Using the string version of group_id
uuid=uuid,
reference_time=datetime.now(timezone.utc),
entity_types=entity_types,
)
logger.info(f"Episode '{name}' processed successfully")
except Exception as e:
error_msg = str(e)
logger.error(
f"Error processing episode '{name}' for group_id {group_id_str}: {error_msg}"
)
# Add the episode processing function to the queue
queue_position = await queue_service.add_episode_task(group_id_str, process_episode)
# Return immediately with a success message
return SuccessResponse(
message=f"Episode '{name}' queued for processing (position: {queue_position})"
)
except Exception as e:
error_msg = str(e)
logger.error(f'Error queuing episode task: {error_msg}')
return ErrorResponse(error=f'Error queuing episode task: {error_msg}')
@mcp.tool()
async def search_memory_nodes(
query: str,
group_ids: list[str] | None = None,
max_nodes: int = 10,
center_node_uuid: str | None = None,
entity: str = '', # cursor seems to break with None
) -> NodeSearchResponse | ErrorResponse:
"""Search the graph memory for relevant node summaries.
These contain a summary of all of a node's relationships with other nodes.
Note: entity is a single entity type to filter results (permitted: "Preference", "Procedure").
Args:
query: The search query
group_ids: Optional list of group IDs to filter results
max_nodes: Maximum number of nodes to return (default: 10)
center_node_uuid: Optional UUID of a node to center the search around
entity: Optional single entity type to filter results (permitted: "Preference", "Procedure")
"""
global graphiti_service, config
if not graphiti_service or not graphiti_service.is_initialized():
return ErrorResponse(error='Graphiti service not initialized')
try:
# Use the provided group_ids or fall back to the default from config if none provided
effective_group_ids = (
group_ids if group_ids is not None else [config.group_id] if config.group_id else []
)
# Configure the search
if center_node_uuid is not None:
search_config = NODE_HYBRID_SEARCH_NODE_DISTANCE.model_copy(deep=True)
else:
search_config = NODE_HYBRID_SEARCH_RRF.model_copy(deep=True)
search_config.limit = max_nodes
filters = SearchFilters()
if entity != '':
filters.node_labels = [entity]
client = graphiti_service.client
# Perform the search using the _search method
search_results = await client._search(
query=query,
config=search_config,
group_ids=effective_group_ids,
center_node_uuid=center_node_uuid,
search_filter=filters,
)
if not search_results.nodes:
return NodeSearchResponse(message='No relevant nodes found', nodes=[])
# Format the node results
formatted_nodes: list[NodeResult] = [
{
'uuid': node.uuid,
'name': node.name,
'summary': node.summary if hasattr(node, 'summary') else '',
'labels': node.labels if hasattr(node, 'labels') else [],
'group_id': node.group_id,
'created_at': node.created_at.isoformat(),
'attributes': node.attributes if hasattr(node, 'attributes') else {},
}
for node in search_results.nodes
]
return NodeSearchResponse(message='Nodes retrieved successfully', nodes=formatted_nodes)
except Exception as e:
error_msg = str(e)
logger.error(f'Error searching nodes: {error_msg}')
return ErrorResponse(error=f'Error searching nodes: {error_msg}')
@mcp.tool()
async def search_memory_facts(
query: str,
group_ids: list[str] | None = None,
max_facts: int = 10,
center_node_uuid: str | None = None,
) -> FactSearchResponse | ErrorResponse:
"""Search the graph memory for relevant facts.
Args:
query: The search query
group_ids: Optional list of group IDs to filter results
max_facts: Maximum number of facts to return (default: 10)
center_node_uuid: Optional UUID of a node to center the search around
"""
global graphiti_client
if graphiti_client is None:
return ErrorResponse(error='Graphiti client not initialized')
try:
# Validate max_facts parameter
if max_facts <= 0:
return ErrorResponse(error='max_facts must be a positive integer')
# Use the provided group_ids or fall back to the default from config if none provided
effective_group_ids = (
group_ids if group_ids is not None else [config.group_id] if config.group_id else []
)
# We've already checked that graphiti_client is not None above
assert graphiti_client is not None
# Use cast to help the type checker understand that graphiti_client is not None
client = cast(Graphiti, graphiti_client)
relevant_edges = await client.search(
group_ids=effective_group_ids,
query=query,
num_results=max_facts,
center_node_uuid=center_node_uuid,
)
if not relevant_edges:
return FactSearchResponse(message='No relevant facts found', facts=[])
facts = [format_fact_result(edge) for edge in relevant_edges]
return FactSearchResponse(message='Facts retrieved successfully', facts=facts)
except Exception as e:
error_msg = str(e)
logger.error(f'Error searching facts: {error_msg}')
return ErrorResponse(error=f'Error searching facts: {error_msg}')
@mcp.tool()
async def delete_entity_edge(uuid: str) -> SuccessResponse | ErrorResponse:
"""Delete an entity edge from the graph memory.
Args:
uuid: UUID of the entity edge to delete
"""
global graphiti_client
if graphiti_client is None:
return ErrorResponse(error='Graphiti client not initialized')
try:
# We've already checked that graphiti_client is not None above
assert graphiti_client is not None
# Use cast to help the type checker understand that graphiti_client is not None
client = cast(Graphiti, graphiti_client)
# Get the entity edge by UUID
entity_edge = await EntityEdge.get_by_uuid(client.driver, uuid)
# Delete the edge using its delete method
await entity_edge.delete(client.driver)
return SuccessResponse(message=f'Entity edge with UUID {uuid} deleted successfully')
except Exception as e:
error_msg = str(e)
logger.error(f'Error deleting entity edge: {error_msg}')
return ErrorResponse(error=f'Error deleting entity edge: {error_msg}')
@mcp.tool()
async def delete_episode(uuid: str) -> SuccessResponse | ErrorResponse:
"""Delete an episode from the graph memory.
Args:
uuid: UUID of the episode to delete
"""
global graphiti_client
if graphiti_client is None:
return ErrorResponse(error='Graphiti client not initialized')
try:
# We've already checked that graphiti_client is not None above
assert graphiti_client is not None
# Use cast to help the type checker understand that graphiti_client is not None
client = cast(Graphiti, graphiti_client)
# Get the episodic node by UUID - EpisodicNode is already imported at the top
episodic_node = await EpisodicNode.get_by_uuid(client.driver, uuid)
# Delete the node using its delete method
await episodic_node.delete(client.driver)
return SuccessResponse(message=f'Episode with UUID {uuid} deleted successfully')
except Exception as e:
error_msg = str(e)
logger.error(f'Error deleting episode: {error_msg}')
return ErrorResponse(error=f'Error deleting episode: {error_msg}')
@mcp.tool()
async def get_entity_edge(uuid: str) -> dict[str, Any] | ErrorResponse:
"""Get an entity edge from the graph memory by its UUID.
Args:
uuid: UUID of the entity edge to retrieve
"""
global graphiti_client
if graphiti_client is None:
return ErrorResponse(error='Graphiti client not initialized')
try:
# We've already checked that graphiti_client is not None above
assert graphiti_client is not None
# Use cast to help the type checker understand that graphiti_client is not None
client = cast(Graphiti, graphiti_client)
# Get the entity edge directly using the EntityEdge class method
entity_edge = await EntityEdge.get_by_uuid(client.driver, uuid)
# Use the format_fact_result function to serialize the edge
# Return the Python dict directly - MCP will handle serialization
return format_fact_result(entity_edge)
except Exception as e:
error_msg = str(e)
logger.error(f'Error getting entity edge: {error_msg}')
return ErrorResponse(error=f'Error getting entity edge: {error_msg}')
@mcp.tool()
async def get_episodes(
group_id: str | None = None, last_n: int = 10
) -> list[dict[str, Any]] | EpisodeSearchResponse | ErrorResponse:
"""Get the most recent memory episodes for a specific group.
Args:
group_id: ID of the group to retrieve episodes from. If not provided, uses the default group_id.
last_n: Number of most recent episodes to retrieve (default: 10)
"""
global graphiti_client
if graphiti_client is None:
return ErrorResponse(error='Graphiti client not initialized')
try:
# Use the provided group_id or fall back to the default from config
effective_group_id = group_id if group_id is not None else config.group_id
if not isinstance(effective_group_id, str):
return ErrorResponse(error='Group ID must be a string')
# We've already checked that graphiti_client is not None above
assert graphiti_client is not None
# Use cast to help the type checker understand that graphiti_client is not None
client = cast(Graphiti, graphiti_client)
episodes = await client.retrieve_episodes(
group_ids=[effective_group_id], last_n=last_n, reference_time=datetime.now(timezone.utc)
)
if not episodes:
return EpisodeSearchResponse(
message=f'No episodes found for group {effective_group_id}', episodes=[]
)
# Use Pydantic's model_dump method for EpisodicNode serialization
formatted_episodes = [
# Use mode='json' to handle datetime serialization
episode.model_dump(mode='json')
for episode in episodes
]
# Return the Python list directly - MCP will handle serialization
return formatted_episodes
except Exception as e:
error_msg = str(e)
logger.error(f'Error getting episodes: {error_msg}')
return ErrorResponse(error=f'Error getting episodes: {error_msg}')
@mcp.tool()
async def clear_graph() -> SuccessResponse | ErrorResponse:
"""Clear all data from the graph memory and rebuild indices."""
global graphiti_client
if graphiti_client is None:
return ErrorResponse(error='Graphiti client not initialized')
try:
# We've already checked that graphiti_client is not None above
assert graphiti_client is not None
# Use cast to help the type checker understand that graphiti_client is not None
client = cast(Graphiti, graphiti_client)
# clear_data is already imported at the top
await clear_data(client.driver)
await client.build_indices_and_constraints()
return SuccessResponse(message='Graph cleared successfully and indices rebuilt')
except Exception as e:
error_msg = str(e)
logger.error(f'Error clearing graph: {error_msg}')
return ErrorResponse(error=f'Error clearing graph: {error_msg}')
@mcp.resource('http://graphiti/status')
async def get_status() -> StatusResponse:
"""Get the status of the Graphiti MCP server and Neo4j connection."""
global graphiti_client
if graphiti_client is None:
return StatusResponse(status='error', message='Graphiti client not initialized')
try:
# We've already checked that graphiti_client is not None above
assert graphiti_client is not None
# Use cast to help the type checker understand that graphiti_client is not None
client = cast(Graphiti, graphiti_client)
# Test database connection
await client.driver.client.verify_connectivity() # type: ignore
return StatusResponse(
status='ok', message='Graphiti MCP server is running and connected to Neo4j'
)
except Exception as e:
error_msg = str(e)
logger.error(f'Error checking Neo4j connection: {error_msg}')
return StatusResponse(
status='error',
message=f'Graphiti MCP server is running but Neo4j connection failed: {error_msg}',
)
async def initialize_server() -> MCPConfig:
"""Parse CLI arguments and initialize the Graphiti server configuration."""
global config
parser = argparse.ArgumentParser(
description='Run the Graphiti MCP server with optional LLM client'
)
parser.add_argument(
'--group-id',
help='Namespace for the graph. This is an arbitrary string used to organize related data. '
'If not provided, a random UUID will be generated.',
)
parser.add_argument(
'--transport',
choices=['sse', 'stdio'],
default='sse',
help='Transport to use for communication with the client. (default: sse)',
)
parser.add_argument(
'--model', help=f'Model name to use with the LLM client. (default: {DEFAULT_LLM_MODEL})'
)
parser.add_argument(
'--small-model',
help=f'Small model name to use with the LLM client. (default: {SMALL_LLM_MODEL})',
)
parser.add_argument(
'--temperature',
type=float,
help='Temperature setting for the LLM (0.0-2.0). Lower values make output more deterministic. (default: 0.7)',
)
parser.add_argument('--destroy-graph', action='store_true', help='Destroy all Graphiti graphs')
parser.add_argument(
'--use-custom-entities',
action='store_true',
help='Enable entity extraction using the predefined ENTITY_TYPES',
)
parser.add_argument(
'--host',
default=os.environ.get('MCP_SERVER_HOST'),
help='Host to bind the MCP server to (default: MCP_SERVER_HOST environment variable)',
)
args = parser.parse_args()
# Build configuration from CLI arguments and environment variables
config = GraphitiConfig.from_cli_and_env(args)
# Log the group ID configuration
if args.group_id:
logger.info(f'Using provided group_id: {config.group_id}')
else:
logger.info(f'Generated random group_id: {config.group_id}')
# Log entity extraction configuration
if config.use_custom_entities:
logger.info('Entity extraction enabled using predefined ENTITY_TYPES')
else:
logger.info('Entity extraction disabled (no custom entities will be used)')
# Initialize services
global graphiti_service, queue_service
graphiti_service = GraphitiService(config, SEMAPHORE_LIMIT)
queue_service = QueueService()
await graphiti_service.initialize()
if args.host:
logger.info(f'Setting MCP server host to: {args.host}')
# Set MCP server host from CLI or env
mcp.settings.host = args.host
# Return MCP configuration
return MCPConfig.from_cli(args)
async def run_mcp_server():
"""Run the MCP server in the current event loop."""
# Initialize the server
mcp_config = await initialize_server()
# Run the server with stdio transport for MCP in the same event loop
logger.info(f'Starting MCP server with transport: {mcp_config.transport}')
if mcp_config.transport == 'stdio':
await mcp.run_stdio_async()
elif mcp_config.transport == 'sse':
logger.info(
f'Running MCP server with SSE transport on {mcp.settings.host}:{mcp.settings.port}'
)
await mcp.run_sse_async()
def main():
"""Main function to run the Graphiti MCP server."""
try:
# Run everything in a single event loop
asyncio.run(run_mcp_server())
except Exception as e:
logger.error(f'Error initializing Graphiti MCP server: {str(e)}')
raise
if __name__ == '__main__':
main()