- Replace bare except with except Exception - Remove unused imports and variables - Fix type hints to use modern syntax - Apply ruff formatting for line length - Ensure all tests pass linting checks
866 lines
32 KiB
Python
866 lines
32 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Graphiti MCP Server - Exposes Graphiti functionality through the Model Context Protocol (MCP)
|
|
"""
|
|
|
|
import argparse
|
|
import asyncio
|
|
import logging
|
|
import os
|
|
import sys
|
|
from pathlib import Path
|
|
from typing import Any, Optional
|
|
|
|
from dotenv import load_dotenv
|
|
from graphiti_core import Graphiti
|
|
from graphiti_core.edges import EntityEdge
|
|
from graphiti_core.nodes import EpisodeType, EpisodicNode
|
|
from graphiti_core.search.search_filters import SearchFilters
|
|
from graphiti_core.utils.maintenance.graph_data_operations import clear_data
|
|
from mcp.server.fastmcp import FastMCP
|
|
from pydantic import BaseModel
|
|
|
|
from config.schema import GraphitiConfig, ServerConfig
|
|
from models.entity_types import ENTITY_TYPES
|
|
from models.response_types import (
|
|
EpisodeSearchResponse,
|
|
ErrorResponse,
|
|
FactSearchResponse,
|
|
NodeResult,
|
|
NodeSearchResponse,
|
|
StatusResponse,
|
|
SuccessResponse,
|
|
)
|
|
from services.factories import DatabaseDriverFactory, EmbedderFactory, LLMClientFactory
|
|
from services.queue_service import QueueService
|
|
from utils.formatting import format_fact_result
|
|
|
|
# Load .env file from mcp_server directory
|
|
mcp_server_dir = Path(__file__).parent.parent
|
|
env_file = mcp_server_dir / '.env'
|
|
if env_file.exists():
|
|
load_dotenv(env_file)
|
|
else:
|
|
# Try current working directory as fallback
|
|
load_dotenv()
|
|
|
|
|
|
# Semaphore limit for concurrent Graphiti operations.
|
|
# Decrease this if you're experiencing 429 rate limit errors from your LLM provider.
|
|
# Increase if you have high rate limits.
|
|
SEMAPHORE_LIMIT = int(os.getenv('SEMAPHORE_LIMIT', 10))
|
|
|
|
|
|
# Configure logging to match uvicorn format
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(levelname)-8s %(message)s', # Match uvicorn's default format
|
|
stream=sys.stderr,
|
|
)
|
|
|
|
# Configure specific loggers
|
|
logging.getLogger('uvicorn').setLevel(logging.INFO)
|
|
logging.getLogger('uvicorn.access').setLevel(logging.WARNING) # Reduce access log noise
|
|
logging.getLogger('mcp.server.streamable_http_manager').setLevel(
|
|
logging.WARNING
|
|
) # Reduce MCP noise
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Create global config instance - will be properly initialized later
|
|
config: GraphitiConfig
|
|
|
|
# MCP server instructions
|
|
GRAPHITI_MCP_INSTRUCTIONS = """
|
|
Graphiti is a memory service for AI agents built on a knowledge graph. Graphiti performs well
|
|
with dynamic data such as user interactions, changing enterprise data, and external information.
|
|
|
|
Graphiti transforms information into a richly connected knowledge network, allowing you to
|
|
capture relationships between concepts, entities, and information. The system organizes data as episodes
|
|
(content snippets), nodes (entities), and facts (relationships between entities), creating a dynamic,
|
|
queryable memory store that evolves with new information. Graphiti supports multiple data formats, including
|
|
structured JSON data, enabling seamless integration with existing data pipelines and systems.
|
|
|
|
Facts contain temporal metadata, allowing you to track the time of creation and whether a fact is invalid
|
|
(superseded by new information).
|
|
|
|
Key capabilities:
|
|
1. Add episodes (text, messages, or JSON) to the knowledge graph with the add_memory tool
|
|
2. Search for nodes (entities) in the graph using natural language queries with search_nodes
|
|
3. Find relevant facts (relationships between entities) with search_facts
|
|
4. Retrieve specific entity edges or episodes by UUID
|
|
5. Manage the knowledge graph with tools like delete_episode, delete_entity_edge, and clear_graph
|
|
|
|
The server connects to a database for persistent storage and uses language models for certain operations.
|
|
Each piece of information is organized by group_id, allowing you to maintain separate knowledge domains.
|
|
|
|
When adding information, provide descriptive names and detailed content to improve search quality.
|
|
When searching, use specific queries and consider filtering by group_id for more relevant results.
|
|
|
|
For optimal performance, ensure the database is properly configured and accessible, and valid
|
|
API keys are provided for any language model operations.
|
|
"""
|
|
|
|
# MCP server instance
|
|
mcp = FastMCP(
|
|
'Graphiti Agent Memory',
|
|
instructions=GRAPHITI_MCP_INSTRUCTIONS,
|
|
)
|
|
|
|
# Global services
|
|
graphiti_service: Optional['GraphitiService'] = None
|
|
queue_service: QueueService | None = None
|
|
|
|
# Global client for backward compatibility
|
|
graphiti_client: Graphiti | None = None
|
|
semaphore: asyncio.Semaphore
|
|
|
|
|
|
class GraphitiService:
|
|
"""Graphiti service using the unified configuration system."""
|
|
|
|
def __init__(self, config: GraphitiConfig, semaphore_limit: int = 10):
|
|
self.config = config
|
|
self.semaphore_limit = semaphore_limit
|
|
self.semaphore = asyncio.Semaphore(semaphore_limit)
|
|
self.client: Graphiti | None = None
|
|
self.entity_types = None
|
|
|
|
async def initialize(self) -> None:
|
|
"""Initialize the Graphiti client with factory-created components."""
|
|
try:
|
|
# Create clients using factories
|
|
llm_client = None
|
|
embedder_client = None
|
|
|
|
# Create LLM client based on configured provider
|
|
try:
|
|
llm_client = LLMClientFactory.create(self.config.llm)
|
|
except Exception as e:
|
|
logger.warning(f'Failed to create LLM client: {e}')
|
|
|
|
# Create embedder client based on configured provider
|
|
try:
|
|
embedder_client = EmbedderFactory.create(self.config.embedder)
|
|
except Exception as e:
|
|
logger.warning(f'Failed to create embedder client: {e}')
|
|
|
|
# Get database configuration
|
|
db_config = DatabaseDriverFactory.create_config(self.config.database)
|
|
|
|
# Build custom entity types if configured
|
|
custom_types = None
|
|
if self.config.graphiti.entity_types:
|
|
custom_types = []
|
|
for entity_type in self.config.graphiti.entity_types:
|
|
# Create a dynamic Pydantic model for each entity type
|
|
entity_model = type(
|
|
entity_type.name,
|
|
(BaseModel,),
|
|
{
|
|
'__annotations__': {'name': str},
|
|
'__doc__': entity_type.description,
|
|
},
|
|
)
|
|
custom_types.append(entity_model)
|
|
# Also support the existing ENTITY_TYPES if use_custom_entities is set
|
|
elif hasattr(self.config, 'use_custom_entities') and self.config.use_custom_entities:
|
|
custom_types = ENTITY_TYPES
|
|
|
|
# Store entity types for later use
|
|
self.entity_types = custom_types
|
|
|
|
# Initialize Graphiti client with appropriate driver
|
|
if self.config.database.provider.lower() == 'kuzu':
|
|
# For KuzuDB, create a KuzuDriver instance directly
|
|
from graphiti_core.driver.kuzu_driver import KuzuDriver
|
|
|
|
kuzu_driver = KuzuDriver(
|
|
db=db_config['db'],
|
|
max_concurrent_queries=db_config['max_concurrent_queries'],
|
|
)
|
|
|
|
self.client = Graphiti(
|
|
graph_driver=kuzu_driver,
|
|
llm_client=llm_client,
|
|
embedder=embedder_client,
|
|
max_coroutines=self.semaphore_limit,
|
|
)
|
|
elif self.config.database.provider.lower() == 'falkordb':
|
|
# For FalkorDB, create a FalkorDriver instance directly
|
|
from graphiti_core.driver.falkordb_driver import FalkorDriver
|
|
|
|
falkor_driver = FalkorDriver(
|
|
host=db_config['host'],
|
|
port=db_config['port'],
|
|
password=db_config['password'],
|
|
database=db_config['database'],
|
|
)
|
|
|
|
self.client = Graphiti(
|
|
graph_driver=falkor_driver,
|
|
llm_client=llm_client,
|
|
embedder=embedder_client,
|
|
max_coroutines=self.semaphore_limit,
|
|
)
|
|
else:
|
|
# For Neo4j (default), use the original approach
|
|
self.client = Graphiti(
|
|
uri=db_config['uri'],
|
|
user=db_config['user'],
|
|
password=db_config['password'],
|
|
llm_client=llm_client,
|
|
embedder=embedder_client,
|
|
max_coroutines=self.semaphore_limit,
|
|
)
|
|
|
|
# Test connection (Neo4j and FalkorDB have verify_connectivity, KuzuDB doesn't need it)
|
|
if self.config.database.provider.lower() != 'kuzu':
|
|
await self.client.driver.client.verify_connectivity() # type: ignore
|
|
|
|
# Build indices
|
|
await self.client.build_indices_and_constraints()
|
|
|
|
logger.info('Successfully initialized Graphiti client')
|
|
|
|
# Log configuration details
|
|
if llm_client:
|
|
logger.info(
|
|
f'Using LLM provider: {self.config.llm.provider} / {self.config.llm.model}'
|
|
)
|
|
else:
|
|
logger.info('No LLM client configured - entity extraction will be limited')
|
|
|
|
if embedder_client:
|
|
logger.info(f'Using Embedder provider: {self.config.embedder.provider}')
|
|
else:
|
|
logger.info('No Embedder client configured - search will be limited')
|
|
|
|
logger.info(f'Using database: {self.config.database.provider}')
|
|
logger.info(f'Using group_id: {self.config.graphiti.group_id}')
|
|
|
|
except Exception as e:
|
|
logger.error(f'Failed to initialize Graphiti client: {e}')
|
|
raise
|
|
|
|
async def get_client(self) -> Graphiti:
|
|
"""Get the Graphiti client, initializing if necessary."""
|
|
if self.client is None:
|
|
await self.initialize()
|
|
if self.client is None:
|
|
raise RuntimeError('Failed to initialize Graphiti client')
|
|
return self.client
|
|
|
|
|
|
@mcp.tool()
|
|
async def add_memory(
|
|
name: str,
|
|
episode_body: str,
|
|
group_id: str | None = None,
|
|
source: str = 'text',
|
|
source_description: str = '',
|
|
uuid: str | None = None,
|
|
) -> SuccessResponse | ErrorResponse:
|
|
"""Add an episode to memory. This is the primary way to add information to the graph.
|
|
|
|
This function returns immediately and processes the episode addition in the background.
|
|
Episodes for the same group_id are processed sequentially to avoid race conditions.
|
|
|
|
Args:
|
|
name (str): Name of the episode
|
|
episode_body (str): The content of the episode to persist to memory. When source='json', this must be a
|
|
properly escaped JSON string, not a raw Python dictionary. The JSON data will be
|
|
automatically processed to extract entities and relationships.
|
|
group_id (str, optional): A unique ID for this graph. If not provided, uses the default group_id from CLI
|
|
or a generated one.
|
|
source (str, optional): Source type, must be one of:
|
|
- 'text': For plain text content (default)
|
|
- 'json': For structured data
|
|
- 'message': For conversation-style content
|
|
source_description (str, optional): Description of the source
|
|
uuid (str, optional): Optional UUID for the episode
|
|
|
|
Examples:
|
|
# Adding plain text content
|
|
add_memory(
|
|
name="Company News",
|
|
episode_body="Acme Corp announced a new product line today.",
|
|
source="text",
|
|
source_description="news article",
|
|
group_id="some_arbitrary_string"
|
|
)
|
|
|
|
# Adding structured JSON data
|
|
# NOTE: episode_body must be a properly escaped JSON string. Note the triple backslashes
|
|
add_memory(
|
|
name="Customer Profile",
|
|
episode_body="{\\\"company\\\": {\\\"name\\\": \\\"Acme Technologies\\\"}, \\\"products\\\": [{\\\"id\\\": \\\"P001\\\", \\\"name\\\": \\\"CloudSync\\\"}, {\\\"id\\\": \\\"P002\\\", \\\"name\\\": \\\"DataMiner\\\"}]}",
|
|
source="json",
|
|
source_description="CRM data"
|
|
)
|
|
"""
|
|
global graphiti_service, queue_service
|
|
|
|
if graphiti_service is None or queue_service is None:
|
|
return ErrorResponse(error='Services not initialized')
|
|
|
|
try:
|
|
# Use the provided group_id or fall back to the default from config
|
|
effective_group_id = group_id or config.graphiti.group_id
|
|
|
|
# Try to parse the source as an EpisodeType enum, with fallback to text
|
|
episode_type = EpisodeType.text # Default
|
|
if source:
|
|
try:
|
|
episode_type = EpisodeType[source.lower()]
|
|
except (KeyError, AttributeError):
|
|
# If the source doesn't match any enum value, use text as default
|
|
logger.warning(f"Unknown source type '{source}', using 'text' as default")
|
|
episode_type = EpisodeType.text
|
|
|
|
# Submit to queue service for async processing
|
|
await queue_service.add_episode(
|
|
group_id=effective_group_id,
|
|
name=name,
|
|
content=episode_body,
|
|
source_description=source_description,
|
|
episode_type=episode_type,
|
|
entity_types=graphiti_service.entity_types,
|
|
uuid=uuid or None, # Ensure None is passed if uuid is None
|
|
)
|
|
|
|
return SuccessResponse(
|
|
message=f"Episode '{name}' queued for processing in group '{effective_group_id}'"
|
|
)
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
logger.error(f'Error queuing episode: {error_msg}')
|
|
return ErrorResponse(error=f'Error queuing episode: {error_msg}')
|
|
|
|
|
|
@mcp.tool()
|
|
async def search_nodes(
|
|
query: str,
|
|
group_ids: list[str] | None = None,
|
|
max_nodes: int = 10,
|
|
entity_types: list[str] | None = None,
|
|
) -> NodeSearchResponse | ErrorResponse:
|
|
"""Search for nodes in the graph memory.
|
|
|
|
Args:
|
|
query: The search query
|
|
group_ids: Optional list of group IDs to filter results
|
|
max_nodes: Maximum number of nodes to return (default: 10)
|
|
entity_types: Optional list of entity type names to filter by
|
|
"""
|
|
global graphiti_service
|
|
|
|
if graphiti_service is None:
|
|
return ErrorResponse(error='Graphiti service not initialized')
|
|
|
|
try:
|
|
client = await graphiti_service.get_client()
|
|
|
|
# Use the provided group_ids or fall back to the default from config if none provided
|
|
effective_group_ids = (
|
|
group_ids
|
|
if group_ids is not None
|
|
else [config.graphiti.group_id]
|
|
if config.graphiti.group_id
|
|
else []
|
|
)
|
|
|
|
# Create search filters
|
|
search_filters = SearchFilters(
|
|
node_labels=entity_types,
|
|
)
|
|
|
|
# Use the search_ method with node search config
|
|
from graphiti_core.search.search_config_recipes import NODE_HYBRID_SEARCH_RRF
|
|
|
|
results = await client.search_(
|
|
query=query,
|
|
config=NODE_HYBRID_SEARCH_RRF,
|
|
group_ids=effective_group_ids,
|
|
search_filter=search_filters,
|
|
)
|
|
|
|
# Extract nodes from results
|
|
nodes = results.nodes[:max_nodes] if results.nodes else []
|
|
|
|
if not nodes:
|
|
return NodeSearchResponse(message='No relevant nodes found', nodes=[])
|
|
|
|
# Format the results
|
|
node_results = [
|
|
NodeResult(
|
|
uuid=node.uuid,
|
|
name=node.name,
|
|
type=node.labels[0] if node.labels else 'Unknown',
|
|
created_at=node.created_at.isoformat() if node.created_at else None,
|
|
summary=node.summary,
|
|
)
|
|
for node in nodes
|
|
]
|
|
|
|
return NodeSearchResponse(message='Nodes retrieved successfully', nodes=node_results)
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
logger.error(f'Error searching nodes: {error_msg}')
|
|
return ErrorResponse(error=f'Error searching nodes: {error_msg}')
|
|
|
|
|
|
@mcp.tool()
|
|
async def search_memory_facts(
|
|
query: str,
|
|
group_ids: list[str] | None = None,
|
|
max_facts: int = 10,
|
|
center_node_uuid: str | None = None,
|
|
) -> FactSearchResponse | ErrorResponse:
|
|
"""Search the graph memory for relevant facts.
|
|
|
|
Args:
|
|
query: The search query
|
|
group_ids: Optional list of group IDs to filter results
|
|
max_facts: Maximum number of facts to return (default: 10)
|
|
center_node_uuid: Optional UUID of a node to center the search around
|
|
"""
|
|
global graphiti_service
|
|
|
|
if graphiti_service is None:
|
|
return ErrorResponse(error='Graphiti service not initialized')
|
|
|
|
try:
|
|
# Validate max_facts parameter
|
|
if max_facts <= 0:
|
|
return ErrorResponse(error='max_facts must be a positive integer')
|
|
|
|
client = await graphiti_service.get_client()
|
|
|
|
# Use the provided group_ids or fall back to the default from config if none provided
|
|
effective_group_ids = (
|
|
group_ids
|
|
if group_ids is not None
|
|
else [config.graphiti.group_id]
|
|
if config.graphiti.group_id
|
|
else []
|
|
)
|
|
|
|
relevant_edges = await client.search(
|
|
group_ids=effective_group_ids,
|
|
query=query,
|
|
num_results=max_facts,
|
|
center_node_uuid=center_node_uuid,
|
|
)
|
|
|
|
if not relevant_edges:
|
|
return FactSearchResponse(message='No relevant facts found', facts=[])
|
|
|
|
facts = [format_fact_result(edge) for edge in relevant_edges]
|
|
return FactSearchResponse(message='Facts retrieved successfully', facts=facts)
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
logger.error(f'Error searching facts: {error_msg}')
|
|
return ErrorResponse(error=f'Error searching facts: {error_msg}')
|
|
|
|
|
|
@mcp.tool()
|
|
async def delete_entity_edge(uuid: str) -> SuccessResponse | ErrorResponse:
|
|
"""Delete an entity edge from the graph memory.
|
|
|
|
Args:
|
|
uuid: UUID of the entity edge to delete
|
|
"""
|
|
global graphiti_service
|
|
|
|
if graphiti_service is None:
|
|
return ErrorResponse(error='Graphiti service not initialized')
|
|
|
|
try:
|
|
client = await graphiti_service.get_client()
|
|
|
|
# Get the entity edge by UUID
|
|
entity_edge = await EntityEdge.get_by_uuid(client.driver, uuid)
|
|
# Delete the edge using its delete method
|
|
await entity_edge.delete(client.driver)
|
|
return SuccessResponse(message=f'Entity edge with UUID {uuid} deleted successfully')
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
logger.error(f'Error deleting entity edge: {error_msg}')
|
|
return ErrorResponse(error=f'Error deleting entity edge: {error_msg}')
|
|
|
|
|
|
@mcp.tool()
|
|
async def delete_episode(uuid: str) -> SuccessResponse | ErrorResponse:
|
|
"""Delete an episode from the graph memory.
|
|
|
|
Args:
|
|
uuid: UUID of the episode to delete
|
|
"""
|
|
global graphiti_service
|
|
|
|
if graphiti_service is None:
|
|
return ErrorResponse(error='Graphiti service not initialized')
|
|
|
|
try:
|
|
client = await graphiti_service.get_client()
|
|
|
|
# Get the episodic node by UUID
|
|
episodic_node = await EpisodicNode.get_by_uuid(client.driver, uuid)
|
|
# Delete the node using its delete method
|
|
await episodic_node.delete(client.driver)
|
|
return SuccessResponse(message=f'Episode with UUID {uuid} deleted successfully')
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
logger.error(f'Error deleting episode: {error_msg}')
|
|
return ErrorResponse(error=f'Error deleting episode: {error_msg}')
|
|
|
|
|
|
@mcp.tool()
|
|
async def get_entity_edge(uuid: str) -> dict[str, Any] | ErrorResponse:
|
|
"""Get an entity edge from the graph memory by its UUID.
|
|
|
|
Args:
|
|
uuid: UUID of the entity edge to retrieve
|
|
"""
|
|
global graphiti_service
|
|
|
|
if graphiti_service is None:
|
|
return ErrorResponse(error='Graphiti service not initialized')
|
|
|
|
try:
|
|
client = await graphiti_service.get_client()
|
|
|
|
# Get the entity edge directly using the EntityEdge class method
|
|
entity_edge = await EntityEdge.get_by_uuid(client.driver, uuid)
|
|
|
|
# Use the format_fact_result function to serialize the edge
|
|
# Return the Python dict directly - MCP will handle serialization
|
|
return format_fact_result(entity_edge)
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
logger.error(f'Error getting entity edge: {error_msg}')
|
|
return ErrorResponse(error=f'Error getting entity edge: {error_msg}')
|
|
|
|
|
|
@mcp.tool()
|
|
async def get_episodes(
|
|
group_ids: list[str] | None = None,
|
|
max_episodes: int = 10,
|
|
) -> EpisodeSearchResponse | ErrorResponse:
|
|
"""Get episodes from the graph memory.
|
|
|
|
Args:
|
|
group_ids: Optional list of group IDs to filter results
|
|
max_episodes: Maximum number of episodes to return (default: 10)
|
|
"""
|
|
global graphiti_service
|
|
|
|
if graphiti_service is None:
|
|
return ErrorResponse(error='Graphiti service not initialized')
|
|
|
|
try:
|
|
client = await graphiti_service.get_client()
|
|
|
|
# Use the provided group_ids or fall back to the default from config if none provided
|
|
effective_group_ids = (
|
|
group_ids
|
|
if group_ids is not None
|
|
else [config.graphiti.group_id]
|
|
if config.graphiti.group_id
|
|
else []
|
|
)
|
|
|
|
# Get episodes from the driver directly
|
|
from graphiti_core.nodes import EpisodicNode
|
|
|
|
if effective_group_ids:
|
|
episodes = await EpisodicNode.get_by_group_ids(
|
|
client.driver, effective_group_ids, limit=max_episodes
|
|
)
|
|
else:
|
|
# If no group IDs, we need to use a different approach
|
|
# For now, return empty list when no group IDs specified
|
|
episodes = []
|
|
|
|
if not episodes:
|
|
return EpisodeSearchResponse(message='No episodes found', episodes=[])
|
|
|
|
# Format the results
|
|
episode_results = []
|
|
for episode in episodes:
|
|
episode_dict = {
|
|
'uuid': episode.uuid,
|
|
'name': episode.name,
|
|
'content': episode.content,
|
|
'created_at': episode.created_at.isoformat() if episode.created_at else None,
|
|
'source': episode.source.value
|
|
if hasattr(episode.source, 'value')
|
|
else str(episode.source),
|
|
'source_description': episode.source_description,
|
|
'group_id': episode.group_id,
|
|
}
|
|
episode_results.append(episode_dict)
|
|
|
|
return EpisodeSearchResponse(
|
|
message='Episodes retrieved successfully', episodes=episode_results
|
|
)
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
logger.error(f'Error getting episodes: {error_msg}')
|
|
return ErrorResponse(error=f'Error getting episodes: {error_msg}')
|
|
|
|
|
|
@mcp.tool()
|
|
async def clear_graph(group_ids: list[str] | None = None) -> SuccessResponse | ErrorResponse:
|
|
"""Clear all data from the graph for specified group IDs.
|
|
|
|
Args:
|
|
group_ids: Optional list of group IDs to clear. If not provided, clears the default group.
|
|
"""
|
|
global graphiti_service
|
|
|
|
if graphiti_service is None:
|
|
return ErrorResponse(error='Graphiti service not initialized')
|
|
|
|
try:
|
|
client = await graphiti_service.get_client()
|
|
|
|
# Use the provided group_ids or fall back to the default from config if none provided
|
|
effective_group_ids = (
|
|
group_ids or [config.graphiti.group_id] if config.graphiti.group_id else []
|
|
)
|
|
|
|
if not effective_group_ids:
|
|
return ErrorResponse(error='No group IDs specified for clearing')
|
|
|
|
# Clear data for the specified group IDs
|
|
await clear_data(client.driver, group_ids=effective_group_ids)
|
|
|
|
return SuccessResponse(
|
|
message=f'Graph data cleared successfully for group IDs: {", ".join(effective_group_ids)}'
|
|
)
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
logger.error(f'Error clearing graph: {error_msg}')
|
|
return ErrorResponse(error=f'Error clearing graph: {error_msg}')
|
|
|
|
|
|
@mcp.tool()
|
|
async def get_status() -> StatusResponse:
|
|
"""Get the status of the Graphiti MCP server and database connection."""
|
|
global graphiti_service
|
|
|
|
if graphiti_service is None:
|
|
return StatusResponse(status='error', message='Graphiti service not initialized')
|
|
|
|
try:
|
|
client = await graphiti_service.get_client()
|
|
|
|
# Test database connection with a simple query
|
|
# This works for all supported databases (Neo4j, FalkorDB, KuzuDB)
|
|
async with client.driver.session() as session:
|
|
result = await session.run('MATCH (n) RETURN count(n) as count')
|
|
# Consume the result to verify query execution
|
|
_ = [record async for record in result]
|
|
|
|
provider_info = f'{config.database.provider} database'
|
|
return StatusResponse(
|
|
status='ok', message=f'Graphiti MCP server is running and connected to {provider_info}'
|
|
)
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
logger.error(f'Error checking database connection: {error_msg}')
|
|
return StatusResponse(
|
|
status='error',
|
|
message=f'Graphiti MCP server is running but database connection failed: {error_msg}',
|
|
)
|
|
|
|
|
|
async def initialize_server() -> ServerConfig:
|
|
"""Parse CLI arguments and initialize the Graphiti server configuration."""
|
|
global config, graphiti_service, queue_service, graphiti_client, semaphore
|
|
|
|
parser = argparse.ArgumentParser(
|
|
description='Run the Graphiti MCP server with YAML configuration support'
|
|
)
|
|
|
|
# Configuration file argument
|
|
# Default to config/config.yaml relative to the mcp_server directory
|
|
default_config = Path(__file__).parent.parent / 'config' / 'config.yaml'
|
|
parser.add_argument(
|
|
'--config',
|
|
type=Path,
|
|
default=default_config,
|
|
help='Path to YAML configuration file (default: config/config.yaml)',
|
|
)
|
|
|
|
# Transport arguments
|
|
parser.add_argument(
|
|
'--transport',
|
|
choices=['sse', 'stdio', 'http'],
|
|
help='Transport to use: sse (Server-Sent Events), stdio (standard I/O), or http (streamable HTTP)',
|
|
)
|
|
parser.add_argument(
|
|
'--host',
|
|
help='Host to bind the MCP server to',
|
|
)
|
|
parser.add_argument(
|
|
'--port',
|
|
type=int,
|
|
help='Port to bind the MCP server to',
|
|
)
|
|
|
|
# Provider selection arguments
|
|
parser.add_argument(
|
|
'--llm-provider',
|
|
choices=['openai', 'azure_openai', 'anthropic', 'gemini', 'groq'],
|
|
help='LLM provider to use',
|
|
)
|
|
parser.add_argument(
|
|
'--embedder-provider',
|
|
choices=['openai', 'azure_openai', 'gemini', 'voyage'],
|
|
help='Embedder provider to use',
|
|
)
|
|
parser.add_argument(
|
|
'--database-provider',
|
|
choices=['neo4j', 'falkordb', 'kuzu'],
|
|
help='Database provider to use',
|
|
)
|
|
|
|
# LLM configuration arguments
|
|
parser.add_argument('--model', help='Model name to use with the LLM client')
|
|
parser.add_argument('--small-model', help='Small model name to use with the LLM client')
|
|
parser.add_argument(
|
|
'--temperature', type=float, help='Temperature setting for the LLM (0.0-2.0)'
|
|
)
|
|
|
|
# Embedder configuration arguments
|
|
parser.add_argument('--embedder-model', help='Model name to use with the embedder')
|
|
|
|
# Graphiti-specific arguments
|
|
parser.add_argument(
|
|
'--group-id',
|
|
help='Namespace for the graph. If not provided, uses config file or generates random UUID.',
|
|
)
|
|
parser.add_argument(
|
|
'--user-id',
|
|
help='User ID for tracking operations',
|
|
)
|
|
parser.add_argument(
|
|
'--destroy-graph',
|
|
action='store_true',
|
|
help='Destroy all Graphiti graphs on startup',
|
|
)
|
|
parser.add_argument(
|
|
'--use-custom-entities',
|
|
action='store_true',
|
|
help='Enable entity extraction using the predefined ENTITY_TYPES',
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Set config path in environment for the settings to pick up
|
|
if args.config:
|
|
os.environ['CONFIG_PATH'] = str(args.config)
|
|
|
|
# Load configuration with environment variables and YAML
|
|
config = GraphitiConfig()
|
|
|
|
# Apply CLI overrides
|
|
config.apply_cli_overrides(args)
|
|
|
|
# Also apply legacy CLI args for backward compatibility
|
|
if hasattr(args, 'use_custom_entities'):
|
|
config.use_custom_entities = args.use_custom_entities
|
|
if hasattr(args, 'destroy_graph'):
|
|
config.destroy_graph = args.destroy_graph
|
|
|
|
# Log configuration details
|
|
logger.info('Using configuration:')
|
|
logger.info(f' - LLM: {config.llm.provider} / {config.llm.model}')
|
|
logger.info(f' - Embedder: {config.embedder.provider} / {config.embedder.model}')
|
|
logger.info(f' - Database: {config.database.provider}')
|
|
logger.info(f' - Group ID: {config.graphiti.group_id}')
|
|
logger.info(f' - Transport: {config.server.transport}')
|
|
|
|
# Handle graph destruction if requested
|
|
if hasattr(config, 'destroy_graph') and config.destroy_graph:
|
|
logger.warning('Destroying all Graphiti graphs as requested...')
|
|
temp_service = GraphitiService(config, SEMAPHORE_LIMIT)
|
|
await temp_service.initialize()
|
|
client = await temp_service.get_client()
|
|
await clear_data(client.driver)
|
|
logger.info('All graphs destroyed')
|
|
|
|
# Initialize services
|
|
graphiti_service = GraphitiService(config, SEMAPHORE_LIMIT)
|
|
queue_service = QueueService()
|
|
await graphiti_service.initialize()
|
|
|
|
# Set global client for backward compatibility
|
|
graphiti_client = await graphiti_service.get_client()
|
|
semaphore = graphiti_service.semaphore
|
|
|
|
# Initialize queue service with the client
|
|
await queue_service.initialize(graphiti_client)
|
|
|
|
# Set MCP server settings
|
|
if config.server.host:
|
|
mcp.settings.host = config.server.host
|
|
if config.server.port:
|
|
mcp.settings.port = config.server.port
|
|
|
|
# Return MCP configuration for transport
|
|
return config.server
|
|
|
|
|
|
async def run_mcp_server():
|
|
"""Run the MCP server in the current event loop."""
|
|
# Initialize the server
|
|
mcp_config = await initialize_server()
|
|
|
|
# Run the server with configured transport
|
|
logger.info(f'Starting MCP server with transport: {mcp_config.transport}')
|
|
if mcp_config.transport == 'stdio':
|
|
await mcp.run_stdio_async()
|
|
elif mcp_config.transport == 'sse':
|
|
logger.info(
|
|
f'Running MCP server with SSE transport on {mcp.settings.host}:{mcp.settings.port}'
|
|
)
|
|
logger.info(f'Access the server at: http://{mcp.settings.host}:{mcp.settings.port}/sse')
|
|
await mcp.run_sse_async()
|
|
elif mcp_config.transport == 'http':
|
|
# Use localhost for display if binding to 0.0.0.0
|
|
display_host = 'localhost' if mcp.settings.host == '0.0.0.0' else mcp.settings.host
|
|
logger.info(
|
|
f'Running MCP server with streamable HTTP transport on {mcp.settings.host}:{mcp.settings.port}'
|
|
)
|
|
logger.info('=' * 60)
|
|
logger.info('MCP Server Access Information:')
|
|
logger.info(f' Base URL: http://{display_host}:{mcp.settings.port}/')
|
|
logger.info(f' MCP Endpoint: http://{display_host}:{mcp.settings.port}/mcp/')
|
|
logger.info(' Transport: HTTP (streamable)')
|
|
logger.info('=' * 60)
|
|
logger.info('For MCP clients, connect to the /mcp/ endpoint above')
|
|
await mcp.run_streamable_http_async()
|
|
else:
|
|
raise ValueError(
|
|
f'Unsupported transport: {mcp_config.transport}. Use "sse", "stdio", or "http"'
|
|
)
|
|
|
|
|
|
def main():
|
|
"""Main function to run the Graphiti MCP server."""
|
|
try:
|
|
# Run everything in a single event loop
|
|
asyncio.run(run_mcp_server())
|
|
except KeyboardInterrupt:
|
|
logger.info('Server shutting down...')
|
|
except Exception as e:
|
|
logger.error(f'Error initializing Graphiti MCP server: {str(e)}')
|
|
raise
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|