* experimental * experimental * experimental * wip * wip * wip * wip * code cleanup * refactor and cleanup * fix lint * remove unneeded mcp dep * polish
763 lines
28 KiB
Python
763 lines
28 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Graphiti MCP Server - Exposes Graphiti functionality through the Model Context Protocol (MCP)
|
|
"""
|
|
|
|
import argparse
|
|
import asyncio
|
|
import logging
|
|
import os
|
|
import sys
|
|
import uuid
|
|
from collections.abc import Awaitable, Callable
|
|
from datetime import datetime, timezone
|
|
from typing import Any, Optional, TypedDict, TypeVar, Union, cast
|
|
|
|
from dotenv import load_dotenv
|
|
from mcp.server.fastmcp import FastMCP
|
|
from pydantic import BaseModel
|
|
|
|
# graphiti_core imports
|
|
from graphiti_core import Graphiti
|
|
from graphiti_core.edges import EntityEdge
|
|
from graphiti_core.llm_client import LLMClient
|
|
from graphiti_core.llm_client.config import LLMConfig
|
|
from graphiti_core.llm_client.openai_client import OpenAIClient
|
|
from graphiti_core.nodes import EpisodeType, EpisodicNode
|
|
from graphiti_core.search.search_config_recipes import NODE_HYBRID_SEARCH_RRF
|
|
from graphiti_core.search.search_filters import SearchFilters
|
|
from graphiti_core.utils.maintenance.graph_data_operations import clear_data
|
|
|
|
load_dotenv()
|
|
|
|
|
|
# Type definitions for API responses
|
|
class ErrorResponse(TypedDict):
|
|
error: str
|
|
|
|
|
|
class SuccessResponse(TypedDict):
|
|
message: str
|
|
|
|
|
|
class NodeResult(TypedDict):
|
|
uuid: str
|
|
name: str
|
|
summary: str
|
|
labels: list[str]
|
|
group_id: str
|
|
created_at: str
|
|
attributes: dict[str, Any]
|
|
|
|
|
|
class NodeSearchResponse(TypedDict):
|
|
message: str
|
|
nodes: list[NodeResult]
|
|
|
|
|
|
class FactSearchResponse(TypedDict):
|
|
message: str
|
|
facts: list[dict[str, Any]]
|
|
|
|
|
|
class EpisodeSearchResponse(TypedDict):
|
|
message: str
|
|
episodes: list[dict[str, Any]]
|
|
|
|
|
|
class StatusResponse(TypedDict):
|
|
status: str
|
|
message: str
|
|
|
|
|
|
# Server configuration classes
|
|
class GraphitiConfig(BaseModel):
|
|
"""Configuration for Graphiti client.
|
|
|
|
Centralizes all configuration parameters for the Graphiti client,
|
|
including database connection details and LLM settings.
|
|
"""
|
|
|
|
neo4j_uri: str = 'bolt://localhost:7687'
|
|
neo4j_user: str = 'neo4j'
|
|
neo4j_password: str = 'password'
|
|
openai_api_key: Optional[str] = None
|
|
openai_base_url: Optional[str] = None
|
|
model_name: Optional[str] = None
|
|
group_id: Optional[str] = None
|
|
|
|
@classmethod
|
|
def from_env(cls) -> 'GraphitiConfig':
|
|
"""Create a configuration instance from environment variables."""
|
|
return cls(
|
|
neo4j_uri=os.environ.get('NEO4J_URI', 'bolt://localhost:7687'),
|
|
neo4j_user=os.environ.get('NEO4J_USER', 'neo4j'),
|
|
neo4j_password=os.environ.get('NEO4J_PASSWORD', 'password'),
|
|
openai_api_key=os.environ.get('OPENAI_API_KEY'),
|
|
openai_base_url=os.environ.get('OPENAI_BASE_URL'),
|
|
model_name=os.environ.get('MODEL_NAME'),
|
|
)
|
|
|
|
|
|
class MCPConfig(BaseModel):
|
|
"""Configuration for MCP server."""
|
|
|
|
transport: str
|
|
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
|
stream=sys.stderr,
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Create global config instance
|
|
config = GraphitiConfig.from_env()
|
|
|
|
# MCP server instructions
|
|
GRAPHITI_MCP_INSTRUCTIONS = """
|
|
Welcome to Graphiti MCP - a memory service for AI agents built on a knowledge graph. Graphiti performs well
|
|
with dynamic data such as user interactions, changing enterprise data, and external information.
|
|
|
|
Graphiti transforms information into a richly connected knowledge network, allowing you to
|
|
capture relationships between concepts, entities, and information. The system organizes data as episodes
|
|
(content snippets), nodes (entities), and facts (relationships between entities), creating a dynamic,
|
|
queryable memory store that evolves with new information. Graphiti supports multiple data formats, including
|
|
structured JSON data, enabling seamless integration with existing data pipelines and systems.
|
|
|
|
Facts contain temporal metadata, allowing you to track the time of creation and whether a fact is invalid
|
|
(superseded by new information).
|
|
|
|
Key capabilities:
|
|
1. Add episodes (text, messages, or JSON) to the knowledge graph with the add_episode tool
|
|
2. Search for nodes (entities) in the graph using natural language queries with search_nodes
|
|
3. Find relevant facts (relationships between entities) with search_facts
|
|
4. Retrieve specific entity edges or episodes by UUID
|
|
5. Manage the knowledge graph with tools like delete_episode, delete_entity_edge, and clear_graph
|
|
|
|
The server connects to a database for persistent storage and uses language models for certain operations.
|
|
Each piece of information is organized by group_id, allowing you to maintain separate knowledge domains.
|
|
|
|
When adding information, provide descriptive names and detailed content to improve search quality.
|
|
When searching, use specific queries and consider filtering by group_id for more relevant results.
|
|
|
|
For optimal performance, ensure the database is properly configured and accessible, and valid
|
|
API keys are provided for any language model operations.
|
|
"""
|
|
|
|
|
|
# MCP server instance
|
|
mcp = FastMCP(
|
|
'graphiti',
|
|
instructions=GRAPHITI_MCP_INSTRUCTIONS,
|
|
)
|
|
|
|
|
|
# Initialize Graphiti client
|
|
graphiti_client: Optional[Graphiti] = None
|
|
|
|
# Type for functions that can be wrapped with graphiti_error_handler
|
|
T = TypeVar('T')
|
|
GraphitiFunc = Callable[..., Awaitable[T]]
|
|
|
|
|
|
# Note: We've removed the error handler decorator in favor of inline error handling
|
|
# This is to avoid type checking issues with the global graphiti_client variable
|
|
|
|
|
|
async def initialize_graphiti(llm_client: Optional[LLMClient] = None, destroy_graph: bool = False):
|
|
"""Initialize the Graphiti client with the provided settings.
|
|
|
|
Args:
|
|
llm_client: Optional LLMClient instance to use for LLM operations
|
|
destroy_graph: Optional boolean to destroy all Graphiti graphs
|
|
"""
|
|
global graphiti_client
|
|
|
|
# If no client is provided, create a default OpenAI client
|
|
if not llm_client:
|
|
if config.openai_api_key:
|
|
llm_config = LLMConfig(api_key=config.openai_api_key)
|
|
if config.openai_base_url:
|
|
llm_config.base_url = config.openai_base_url
|
|
if config.model_name:
|
|
llm_config.model = config.model_name
|
|
llm_client = OpenAIClient(config=llm_config)
|
|
logger.info('Using OpenAI as LLM client')
|
|
else:
|
|
raise ValueError('OPENAI_API_KEY must be set when not using a custom LLM client')
|
|
|
|
if not config.neo4j_uri or not config.neo4j_user or not config.neo4j_password:
|
|
raise ValueError('NEO4J_URI, NEO4J_USER, and NEO4J_PASSWORD must be set')
|
|
|
|
graphiti_client = Graphiti(
|
|
uri=config.neo4j_uri,
|
|
user=config.neo4j_user,
|
|
password=config.neo4j_password,
|
|
llm_client=llm_client,
|
|
)
|
|
|
|
if destroy_graph:
|
|
logger.info('Destroying graph...')
|
|
await clear_data(graphiti_client.driver)
|
|
|
|
# Initialize the graph database with Graphiti's indices
|
|
await graphiti_client.build_indices_and_constraints()
|
|
logger.info('Graphiti client initialized successfully')
|
|
|
|
|
|
def format_fact_result(edge: EntityEdge) -> dict[str, Any]:
|
|
"""Format an entity edge into a readable result.
|
|
|
|
Since EntityEdge is a Pydantic BaseModel, we can use its built-in serialization capabilities.
|
|
|
|
Args:
|
|
edge: The EntityEdge to format
|
|
|
|
Returns:
|
|
A dictionary representation of the edge with serialized dates and excluded embeddings
|
|
"""
|
|
# Convert to dict using Pydantic's model_dump method with mode='json'
|
|
# This automatically handles datetime serialization and other complex types
|
|
return edge.model_dump(
|
|
mode='json', # Properly handle datetime serialization for JSON
|
|
exclude={
|
|
'fact_embedding', # Exclude embedding data
|
|
},
|
|
)
|
|
|
|
|
|
@mcp.tool()
|
|
async def add_episode(
|
|
name: str,
|
|
episode_body: str,
|
|
group_id: Optional[str] = None,
|
|
source: str = 'text',
|
|
source_description: str = '',
|
|
uuid: Optional[str] = None,
|
|
) -> Union[SuccessResponse, ErrorResponse]:
|
|
"""Add an episode to the Graphiti knowledge graph. This is the primary way to add information to the graph.
|
|
|
|
Args:
|
|
name (str): Name of the episode
|
|
episode_body (str): The content of the episode. When source='json', this must be a properly escaped JSON string,
|
|
not a raw Python dictionary. The JSON data will be automatically processed
|
|
to extract entities and relationships.
|
|
group_id (str, optional): A unique ID for this graph. If not provided, uses the default group_id from CLI
|
|
or a generated one.
|
|
source (str, optional): Source type, must be one of:
|
|
- 'text': For plain text content (default)
|
|
- 'json': For structured data
|
|
- 'message': For conversation-style content
|
|
source_description (str, optional): Description of the source
|
|
uuid (str, optional): Optional UUID for the episode
|
|
|
|
Examples:
|
|
# Adding plain text content
|
|
add_episode(
|
|
name="Company News",
|
|
episode_body="Acme Corp announced a new product line today.",
|
|
source="text",
|
|
source_description="news article",
|
|
group_id="some_arbitrary_string"
|
|
)
|
|
|
|
# Adding structured JSON data
|
|
# NOTE: episode_body must be a properly escaped JSON string
|
|
add_episode(
|
|
name="Customer Profile",
|
|
episode_body="{\\\"company\\\": {\\\"name\\\": \\\"Acme Technologies\\\"}, \\\"products\\\": [{\\\"id\\\": \\\"P001\\\", \\\"name\\\": \\\"CloudSync\\\"}, {\\\"id\\\": \\\"P002\\\", \\\"name\\\": \\\"DataMiner\\\"}]}",
|
|
source="json",
|
|
source_description="CRM data"
|
|
)
|
|
|
|
# Adding more complex JSON with arrays and nested objects
|
|
add_episode(
|
|
name="Product Catalog",
|
|
episode_body='\{"catalog": \{"company": "Tech Solutions Inc.", "products": [\{"id": "P001", "name": "Product X", "features": ["Feature A", "Feature B"]\}]\}\}',
|
|
source="json",
|
|
source_description="Product catalog"
|
|
)
|
|
|
|
# Adding message-style content
|
|
add_episode(
|
|
name="Customer Conversation",
|
|
episode_body="user: What's your return policy?\nassistant: You can return items within 30 days.",
|
|
source="message",
|
|
source_description="chat transcript",
|
|
group_id="some_arbitrary_string"
|
|
)
|
|
|
|
Notes:
|
|
When using source='json':
|
|
- The JSON must be a properly escaped string, not a raw Python dictionary
|
|
- The JSON will be automatically processed to extract entities and relationships
|
|
- Complex nested structures are supported (arrays, nested objects, mixed data types), but keep nesting to a minimum
|
|
- Entities will be created from appropriate JSON properties
|
|
- Relationships between entities will be established based on the JSON structure
|
|
"""
|
|
global graphiti_client
|
|
|
|
if graphiti_client is None:
|
|
return {'error': 'Graphiti client not initialized'}
|
|
|
|
try:
|
|
# Map string source to EpisodeType enum
|
|
source_type = EpisodeType.text
|
|
if source.lower() == 'message':
|
|
source_type = EpisodeType.message
|
|
elif source.lower() == 'json':
|
|
source_type = EpisodeType.json
|
|
|
|
# Use the provided group_id or fall back to the default from config
|
|
effective_group_id = group_id if group_id is not None else config.group_id
|
|
|
|
# Cast group_id to str to satisfy type checker
|
|
# The Graphiti client expects a str for group_id, not Optional[str]
|
|
group_id_str = str(effective_group_id) if effective_group_id is not None else ''
|
|
|
|
# We've already checked that graphiti_client is not None above
|
|
# This assert statement helps type checkers understand that graphiti_client is defined
|
|
# from this point forward in the function
|
|
assert graphiti_client is not None, 'graphiti_client should not be None here'
|
|
|
|
# Use cast to help the type checker understand that graphiti_client is not None
|
|
# This doesn't change the runtime behavior, only helps with static type checking
|
|
client = cast(Graphiti, graphiti_client)
|
|
|
|
# Type checking will now know that client is a Graphiti instance (not None)
|
|
# and group_id is a str, not Optional[str]
|
|
await client.add_episode(
|
|
name=name,
|
|
episode_body=episode_body,
|
|
source=source_type,
|
|
source_description=source_description,
|
|
group_id=group_id_str, # Using the string version of group_id
|
|
uuid=uuid,
|
|
reference_time=datetime.now(timezone.utc),
|
|
)
|
|
return {'message': f"Episode '{name}' added successfully"}
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
logger.error(f'Error adding episode: {error_msg}')
|
|
return {'error': f'Error adding episode: {error_msg}'}
|
|
|
|
|
|
@mcp.tool()
|
|
async def search_nodes(
|
|
query: str,
|
|
group_ids: Optional[list[str]] = None,
|
|
max_nodes: int = 10,
|
|
center_node_uuid: Optional[str] = None,
|
|
) -> Union[NodeSearchResponse, ErrorResponse]:
|
|
"""Search the Graphiti knowledge graph for relevant node summaries.
|
|
These contain a summary of all of a node's relationships with other nodes.
|
|
|
|
Args:
|
|
query: The search query
|
|
group_ids: Optional list of group IDs to filter results
|
|
max_nodes: Maximum number of nodes to return (default: 10)
|
|
center_node_uuid: Optional UUID of a node to center the search around
|
|
"""
|
|
global graphiti_client
|
|
|
|
if graphiti_client is None:
|
|
return ErrorResponse(error='Graphiti client not initialized')
|
|
|
|
try:
|
|
# Use the provided group_ids or fall back to the default from config if none provided
|
|
effective_group_ids = (
|
|
group_ids if group_ids is not None else [config.group_id] if config.group_id else []
|
|
)
|
|
|
|
# Configure the search
|
|
search_config = NODE_HYBRID_SEARCH_RRF.model_copy(deep=True)
|
|
search_config.limit = max_nodes
|
|
|
|
# We've already checked that graphiti_client is not None above
|
|
assert graphiti_client is not None
|
|
|
|
# Use cast to help the type checker understand that graphiti_client is not None
|
|
client = cast(Graphiti, graphiti_client)
|
|
|
|
# Perform the search using the _search method
|
|
search_results = await client._search(
|
|
query=query,
|
|
config=search_config,
|
|
group_ids=effective_group_ids,
|
|
center_node_uuid=center_node_uuid,
|
|
search_filter=SearchFilters(),
|
|
)
|
|
|
|
if not search_results.nodes:
|
|
return NodeSearchResponse(message='No relevant nodes found', nodes=[])
|
|
|
|
# Format the node results
|
|
formatted_nodes: list[NodeResult] = [
|
|
{
|
|
'uuid': node.uuid,
|
|
'name': node.name,
|
|
'summary': node.summary if hasattr(node, 'summary') else '',
|
|
'labels': node.labels if hasattr(node, 'labels') else [],
|
|
'group_id': node.group_id,
|
|
'created_at': node.created_at.isoformat(),
|
|
'attributes': node.attributes if hasattr(node, 'attributes') else {},
|
|
}
|
|
for node in search_results.nodes
|
|
]
|
|
|
|
return NodeSearchResponse(message='Nodes retrieved successfully', nodes=formatted_nodes)
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
logger.error(f'Error searching nodes: {error_msg}')
|
|
return ErrorResponse(error=f'Error searching nodes: {error_msg}')
|
|
|
|
|
|
@mcp.tool()
|
|
async def search_facts(
|
|
query: str,
|
|
group_ids: Optional[list[str]] = None,
|
|
max_facts: int = 10,
|
|
center_node_uuid: Optional[str] = None,
|
|
) -> Union[FactSearchResponse, ErrorResponse]:
|
|
"""Search the Graphiti knowledge graph for relevant facts.
|
|
|
|
Args:
|
|
query: The search query
|
|
group_ids: Optional list of group IDs to filter results
|
|
max_facts: Maximum number of facts to return (default: 10)
|
|
center_node_uuid: Optional UUID of a node to center the search around
|
|
"""
|
|
global graphiti_client
|
|
|
|
if graphiti_client is None:
|
|
return {'error': 'Graphiti client not initialized'}
|
|
|
|
try:
|
|
# Use the provided group_ids or fall back to the default from config if none provided
|
|
effective_group_ids = (
|
|
group_ids if group_ids is not None else [config.group_id] if config.group_id else []
|
|
)
|
|
|
|
# We've already checked that graphiti_client is not None above
|
|
assert graphiti_client is not None
|
|
|
|
# Use cast to help the type checker understand that graphiti_client is not None
|
|
client = cast(Graphiti, graphiti_client)
|
|
|
|
relevant_edges = await client.search(
|
|
group_ids=effective_group_ids,
|
|
query=query,
|
|
num_results=max_facts,
|
|
center_node_uuid=center_node_uuid,
|
|
)
|
|
|
|
if not relevant_edges:
|
|
return {'message': 'No relevant facts found', 'facts': []}
|
|
|
|
facts = [format_fact_result(edge) for edge in relevant_edges]
|
|
return {'message': 'Facts retrieved successfully', 'facts': facts}
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
logger.error(f'Error searching facts: {error_msg}')
|
|
return {'error': f'Error searching facts: {error_msg}'}
|
|
|
|
|
|
@mcp.tool()
|
|
async def delete_entity_edge(uuid: str) -> Union[SuccessResponse, ErrorResponse]:
|
|
"""Delete an entity edge from the Graphiti knowledge graph.
|
|
|
|
Args:
|
|
uuid: UUID of the entity edge to delete
|
|
"""
|
|
global graphiti_client
|
|
|
|
if graphiti_client is None:
|
|
return {'error': 'Graphiti client not initialized'}
|
|
|
|
try:
|
|
# We've already checked that graphiti_client is not None above
|
|
assert graphiti_client is not None
|
|
|
|
# Use cast to help the type checker understand that graphiti_client is not None
|
|
client = cast(Graphiti, graphiti_client)
|
|
|
|
# Get the entity edge by UUID
|
|
entity_edge = await EntityEdge.get_by_uuid(client.driver, uuid)
|
|
# Delete the edge using its delete method
|
|
await entity_edge.delete(client.driver)
|
|
return {'message': f'Entity edge with UUID {uuid} deleted successfully'}
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
logger.error(f'Error deleting entity edge: {error_msg}')
|
|
return {'error': f'Error deleting entity edge: {error_msg}'}
|
|
|
|
|
|
@mcp.tool()
|
|
async def delete_episode(uuid: str) -> Union[SuccessResponse, ErrorResponse]:
|
|
"""Delete an episode from the Graphiti knowledge graph.
|
|
|
|
Args:
|
|
uuid: UUID of the episode to delete
|
|
"""
|
|
global graphiti_client
|
|
|
|
if graphiti_client is None:
|
|
return {'error': 'Graphiti client not initialized'}
|
|
|
|
try:
|
|
# We've already checked that graphiti_client is not None above
|
|
assert graphiti_client is not None
|
|
|
|
# Use cast to help the type checker understand that graphiti_client is not None
|
|
client = cast(Graphiti, graphiti_client)
|
|
|
|
# Get the episodic node by UUID - EpisodicNode is already imported at the top
|
|
episodic_node = await EpisodicNode.get_by_uuid(client.driver, uuid)
|
|
# Delete the node using its delete method
|
|
await episodic_node.delete(client.driver)
|
|
return {'message': f'Episode with UUID {uuid} deleted successfully'}
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
logger.error(f'Error deleting episode: {error_msg}')
|
|
return {'error': f'Error deleting episode: {error_msg}'}
|
|
|
|
|
|
@mcp.tool()
|
|
async def get_entity_edge(uuid: str) -> Union[dict[str, Any], ErrorResponse]:
|
|
"""Get an entity edge from the Graphiti knowledge graph by its UUID.
|
|
|
|
Args:
|
|
uuid: UUID of the entity edge to retrieve
|
|
"""
|
|
global graphiti_client
|
|
|
|
if graphiti_client is None:
|
|
return {'error': 'Graphiti client not initialized'}
|
|
|
|
try:
|
|
# We've already checked that graphiti_client is not None above
|
|
assert graphiti_client is not None
|
|
|
|
# Use cast to help the type checker understand that graphiti_client is not None
|
|
client = cast(Graphiti, graphiti_client)
|
|
|
|
# Get the entity edge directly using the EntityEdge class method
|
|
entity_edge = await EntityEdge.get_by_uuid(client.driver, uuid)
|
|
|
|
# Use the format_fact_result function to serialize the edge
|
|
# Return the Python dict directly - MCP will handle serialization
|
|
return format_fact_result(entity_edge)
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
logger.error(f'Error getting entity edge: {error_msg}')
|
|
return {'error': f'Error getting entity edge: {error_msg}'}
|
|
|
|
|
|
@mcp.tool()
|
|
async def get_episodes(
|
|
group_id: Optional[str] = None, last_n: int = 10
|
|
) -> Union[list[dict[str, Any]], EpisodeSearchResponse, ErrorResponse]:
|
|
"""Get the most recent episodes for a specific group.
|
|
|
|
Args:
|
|
group_id: ID of the group to retrieve episodes from. If not provided, uses the default group_id.
|
|
last_n: Number of most recent episodes to retrieve (default: 10)
|
|
"""
|
|
global graphiti_client
|
|
|
|
if graphiti_client is None:
|
|
return {'error': 'Graphiti client not initialized'}
|
|
|
|
try:
|
|
# Use the provided group_id or fall back to the default from config
|
|
effective_group_id = group_id if group_id is not None else config.group_id
|
|
|
|
if not isinstance(effective_group_id, str):
|
|
return {'error': 'Group ID must be a string'}
|
|
|
|
# We've already checked that graphiti_client is not None above
|
|
assert graphiti_client is not None
|
|
|
|
# Use cast to help the type checker understand that graphiti_client is not None
|
|
client = cast(Graphiti, graphiti_client)
|
|
|
|
episodes = await client.retrieve_episodes(
|
|
group_ids=[effective_group_id], last_n=last_n, reference_time=datetime.now(timezone.utc)
|
|
)
|
|
|
|
if not episodes:
|
|
return {'message': f'No episodes found for group {effective_group_id}', 'episodes': []}
|
|
|
|
# Use Pydantic's model_dump method for EpisodicNode serialization
|
|
formatted_episodes = [
|
|
# Use mode='json' to handle datetime serialization
|
|
episode.model_dump(mode='json')
|
|
for episode in episodes
|
|
]
|
|
|
|
# Return the Python list directly - MCP will handle serialization
|
|
return formatted_episodes
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
logger.error(f'Error getting episodes: {error_msg}')
|
|
return {'error': f'Error getting episodes: {error_msg}'}
|
|
|
|
|
|
@mcp.tool()
|
|
async def clear_graph() -> Union[SuccessResponse, ErrorResponse]:
|
|
"""Clear all data from the Graphiti knowledge graph and rebuild indices."""
|
|
global graphiti_client
|
|
|
|
if graphiti_client is None:
|
|
return {'error': 'Graphiti client not initialized'}
|
|
|
|
try:
|
|
# We've already checked that graphiti_client is not None above
|
|
assert graphiti_client is not None
|
|
|
|
# Use cast to help the type checker understand that graphiti_client is not None
|
|
client = cast(Graphiti, graphiti_client)
|
|
|
|
# clear_data is already imported at the top
|
|
await clear_data(client.driver)
|
|
await client.build_indices_and_constraints()
|
|
return {'message': 'Graph cleared successfully and indices rebuilt'}
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
logger.error(f'Error clearing graph: {error_msg}')
|
|
return {'error': f'Error clearing graph: {error_msg}'}
|
|
|
|
|
|
@mcp.resource('http://graphiti/status')
|
|
async def get_status() -> StatusResponse:
|
|
"""Get the status of the Graphiti MCP server and Neo4j connection."""
|
|
global graphiti_client
|
|
|
|
if graphiti_client is None:
|
|
return {'status': 'error', 'message': 'Graphiti client not initialized'}
|
|
|
|
try:
|
|
# We've already checked that graphiti_client is not None above
|
|
assert graphiti_client is not None
|
|
|
|
# Use cast to help the type checker understand that graphiti_client is not None
|
|
client = cast(Graphiti, graphiti_client)
|
|
|
|
# Test Neo4j connection
|
|
await client.driver.verify_connectivity()
|
|
return {'status': 'ok', 'message': 'Graphiti MCP server is running and connected to Neo4j'}
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
logger.error(f'Error checking Neo4j connection: {error_msg}')
|
|
return {
|
|
'status': 'error',
|
|
'message': f'Graphiti MCP server is running but Neo4j connection failed: {error_msg}',
|
|
}
|
|
|
|
|
|
def create_llm_client(
|
|
client_type: str = 'openai', api_key: Optional[str] = None, model: Optional[str] = None
|
|
) -> LLMClient:
|
|
"""Create an OpenAI LLM client.
|
|
|
|
Args:
|
|
client_type: Type of LLM client to create (only 'openai' or 'openai_generic' supported)
|
|
api_key: API key for the OpenAI service
|
|
model: Model name to use
|
|
|
|
Returns:
|
|
An instance of the OpenAI LLM client
|
|
"""
|
|
# Create config with provided API key and model
|
|
llm_config = LLMConfig(api_key=api_key)
|
|
|
|
# Set model if provided
|
|
if model:
|
|
llm_config.model = model
|
|
|
|
# Create and return the client
|
|
return OpenAIClient(config=llm_config)
|
|
|
|
|
|
async def initialize_server() -> MCPConfig:
|
|
"""Initialize the Graphiti server with the specified LLM client."""
|
|
global config
|
|
|
|
parser = argparse.ArgumentParser(
|
|
description='Run the Graphiti MCP server with optional LLM client'
|
|
)
|
|
parser.add_argument(
|
|
'--group-id',
|
|
help='Namespace for the graph. This is an arbitrary string used to organize related data. '
|
|
'If not provided, a random UUID will be generated.',
|
|
)
|
|
parser.add_argument(
|
|
'--transport',
|
|
choices=['sse', 'stdio'],
|
|
default='sse',
|
|
help='Transport to use for communication with the client. (default: sse)',
|
|
)
|
|
# OpenAI is the only supported LLM client
|
|
parser.add_argument('--model', help='Model name to use with the LLM client')
|
|
parser.add_argument('--destroy-graph', action='store_true', help='Destroy all Graphiti graphs')
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Set the group_id from CLI argument or generate a random one
|
|
if args.group_id:
|
|
config.group_id = args.group_id
|
|
logger.info(f'Using provided group_id: {config.group_id}')
|
|
else:
|
|
config.group_id = f'graph_{uuid.uuid4().hex[:8]}'
|
|
logger.info(f'Generated random group_id: {config.group_id}')
|
|
|
|
llm_client = None
|
|
|
|
# Create OpenAI client if model is specified or if OPENAI_API_KEY is available
|
|
if args.model or config.openai_api_key:
|
|
# Override model from command line if specified
|
|
if args.model:
|
|
config.model_name = args.model
|
|
|
|
# Create the OpenAI client
|
|
llm_client = create_llm_client(
|
|
client_type='openai', api_key=config.openai_api_key, model=config.model_name
|
|
)
|
|
|
|
# Initialize Graphiti with the specified LLM client
|
|
await initialize_graphiti(llm_client, destroy_graph=args.destroy_graph)
|
|
|
|
return MCPConfig(transport=args.transport)
|
|
|
|
|
|
async def run_mcp_server():
|
|
"""Run the MCP server in the current event loop."""
|
|
# Initialize the server
|
|
mcp_config = await initialize_server()
|
|
|
|
# Run the server with stdio transport for MCP in the same event loop
|
|
logger.info(f'Starting MCP server with transport: {mcp_config.transport}')
|
|
if mcp_config.transport == 'stdio':
|
|
await mcp.run_stdio_async()
|
|
elif mcp_config.transport == 'sse':
|
|
logger.info(
|
|
f'Running MCP server with SSE transport on {mcp.settings.host}:{mcp.settings.port}'
|
|
)
|
|
await mcp.run_sse_async()
|
|
|
|
|
|
def main():
|
|
"""Main function to run the Graphiti MCP server."""
|
|
try:
|
|
# Run everything in a single event loop
|
|
asyncio.run(run_mcp_server())
|
|
except Exception as e:
|
|
logger.error(f'Error initializing Graphiti MCP server: {str(e)}')
|
|
raise
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|