* Fix MCP server telemetry and update graphiti-core to v0.23.0 - Regenerate uv.lock in Docker builds to ensure posthog dependency is included - Update graphiti-core version to 0.23.0 across all configurations - Fix Python 3.11 compatibility for Pydantic TypedDict import - Add FalkorDB Browser UI access information to startup logs 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * Add typing-extensions as explicit dependency Required for Python 3.11 compatibility with Pydantic 2.11.7, which mandates typing_extensions.TypedDict on Python < 3.12. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * Use built-in TypedDict from typing module Python 3.10+ includes TypedDict in the standard library. Removed typing-extensions dependency as it's unnecessary. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * Revert to typing_extensions.TypedDict for Pydantic compatibility Pydantic requires typing_extensions.TypedDict on Python < 3.12. Docker container uses Python 3.11, so this is necessary. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * Add Ruff linter rule to prevent typing.TypedDict usage Adds banned-api configuration to both main and MCP server pyproject.toml files to enforce typing_extensions.TypedDict usage, required for Pydantic compatibility on Python < 3.12. Also includes Ruff auto-formatting changes to quote styles. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * Bump MCP server version to 1.0.1 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> --------- Co-authored-by: Claude <noreply@anthropic.com>
965 lines
36 KiB
Python
965 lines
36 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Graphiti MCP Server - Exposes Graphiti functionality through the Model Context Protocol (MCP)
|
|
"""
|
|
|
|
import argparse
|
|
import asyncio
|
|
import logging
|
|
import os
|
|
import sys
|
|
from pathlib import Path
|
|
from typing import Any, Optional
|
|
|
|
from dotenv import load_dotenv
|
|
from graphiti_core import Graphiti
|
|
from graphiti_core.edges import EntityEdge
|
|
from graphiti_core.nodes import EpisodeType, EpisodicNode
|
|
from graphiti_core.search.search_filters import SearchFilters
|
|
from graphiti_core.utils.maintenance.graph_data_operations import clear_data
|
|
from mcp.server.fastmcp import FastMCP
|
|
from pydantic import BaseModel
|
|
from starlette.responses import JSONResponse
|
|
|
|
from config.schema import GraphitiConfig, ServerConfig
|
|
from models.response_types import (
|
|
EpisodeSearchResponse,
|
|
ErrorResponse,
|
|
FactSearchResponse,
|
|
NodeResult,
|
|
NodeSearchResponse,
|
|
StatusResponse,
|
|
SuccessResponse,
|
|
)
|
|
from services.factories import DatabaseDriverFactory, EmbedderFactory, LLMClientFactory
|
|
from services.queue_service import QueueService
|
|
from utils.formatting import format_fact_result
|
|
|
|
# Load .env file from mcp_server directory
|
|
mcp_server_dir = Path(__file__).parent.parent
|
|
env_file = mcp_server_dir / '.env'
|
|
if env_file.exists():
|
|
load_dotenv(env_file)
|
|
else:
|
|
# Try current working directory as fallback
|
|
load_dotenv()
|
|
|
|
|
|
# Semaphore limit for concurrent Graphiti operations.
|
|
#
|
|
# This controls how many episodes can be processed simultaneously. Each episode
|
|
# processing involves multiple LLM calls (entity extraction, deduplication, etc.),
|
|
# so the actual number of concurrent LLM requests will be higher.
|
|
#
|
|
# TUNING GUIDELINES:
|
|
#
|
|
# LLM Provider Rate Limits (requests per minute):
|
|
# - OpenAI Tier 1 (free): 3 RPM -> SEMAPHORE_LIMIT=1-2
|
|
# - OpenAI Tier 2: 60 RPM -> SEMAPHORE_LIMIT=5-8
|
|
# - OpenAI Tier 3: 500 RPM -> SEMAPHORE_LIMIT=10-15
|
|
# - OpenAI Tier 4: 5,000 RPM -> SEMAPHORE_LIMIT=20-50
|
|
# - Anthropic (default): 50 RPM -> SEMAPHORE_LIMIT=5-8
|
|
# - Anthropic (high tier): 1,000 RPM -> SEMAPHORE_LIMIT=15-30
|
|
# - Azure OpenAI (varies): Consult your quota -> adjust accordingly
|
|
#
|
|
# SYMPTOMS:
|
|
# - Too high: 429 rate limit errors, increased costs from parallel processing
|
|
# - Too low: Slow throughput, underutilized API quota
|
|
#
|
|
# MONITORING:
|
|
# - Watch logs for rate limit errors (429)
|
|
# - Monitor episode processing times
|
|
# - Check LLM provider dashboard for actual request rates
|
|
#
|
|
# DEFAULT: 10 (suitable for OpenAI Tier 3, mid-tier Anthropic)
|
|
SEMAPHORE_LIMIT = int(os.getenv('SEMAPHORE_LIMIT', 10))
|
|
|
|
|
|
# Configure structured logging with timestamps
|
|
LOG_FORMAT = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
DATE_FORMAT = '%Y-%m-%d %H:%M:%S'
|
|
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format=LOG_FORMAT,
|
|
datefmt=DATE_FORMAT,
|
|
stream=sys.stderr,
|
|
)
|
|
|
|
# Configure specific loggers
|
|
logging.getLogger('uvicorn').setLevel(logging.INFO)
|
|
logging.getLogger('uvicorn.access').setLevel(logging.WARNING) # Reduce access log noise
|
|
logging.getLogger('mcp.server.streamable_http_manager').setLevel(
|
|
logging.WARNING
|
|
) # Reduce MCP noise
|
|
|
|
|
|
# Patch uvicorn's logging config to use our format
|
|
def configure_uvicorn_logging():
|
|
"""Configure uvicorn loggers to match our format after they're created."""
|
|
for logger_name in ['uvicorn', 'uvicorn.error', 'uvicorn.access']:
|
|
uvicorn_logger = logging.getLogger(logger_name)
|
|
# Remove existing handlers and add our own with proper formatting
|
|
uvicorn_logger.handlers.clear()
|
|
handler = logging.StreamHandler(sys.stderr)
|
|
handler.setFormatter(logging.Formatter(LOG_FORMAT, datefmt=DATE_FORMAT))
|
|
uvicorn_logger.addHandler(handler)
|
|
uvicorn_logger.propagate = False
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Create global config instance - will be properly initialized later
|
|
config: GraphitiConfig
|
|
|
|
# MCP server instructions
|
|
GRAPHITI_MCP_INSTRUCTIONS = """
|
|
Graphiti is a memory service for AI agents built on a knowledge graph. Graphiti performs well
|
|
with dynamic data such as user interactions, changing enterprise data, and external information.
|
|
|
|
Graphiti transforms information into a richly connected knowledge network, allowing you to
|
|
capture relationships between concepts, entities, and information. The system organizes data as episodes
|
|
(content snippets), nodes (entities), and facts (relationships between entities), creating a dynamic,
|
|
queryable memory store that evolves with new information. Graphiti supports multiple data formats, including
|
|
structured JSON data, enabling seamless integration with existing data pipelines and systems.
|
|
|
|
Facts contain temporal metadata, allowing you to track the time of creation and whether a fact is invalid
|
|
(superseded by new information).
|
|
|
|
Key capabilities:
|
|
1. Add episodes (text, messages, or JSON) to the knowledge graph with the add_memory tool
|
|
2. Search for nodes (entities) in the graph using natural language queries with search_nodes
|
|
3. Find relevant facts (relationships between entities) with search_facts
|
|
4. Retrieve specific entity edges or episodes by UUID
|
|
5. Manage the knowledge graph with tools like delete_episode, delete_entity_edge, and clear_graph
|
|
|
|
The server connects to a database for persistent storage and uses language models for certain operations.
|
|
Each piece of information is organized by group_id, allowing you to maintain separate knowledge domains.
|
|
|
|
When adding information, provide descriptive names and detailed content to improve search quality.
|
|
When searching, use specific queries and consider filtering by group_id for more relevant results.
|
|
|
|
For optimal performance, ensure the database is properly configured and accessible, and valid
|
|
API keys are provided for any language model operations.
|
|
"""
|
|
|
|
# MCP server instance
|
|
mcp = FastMCP(
|
|
'Graphiti Agent Memory',
|
|
instructions=GRAPHITI_MCP_INSTRUCTIONS,
|
|
)
|
|
|
|
# Global services
|
|
graphiti_service: Optional['GraphitiService'] = None
|
|
queue_service: QueueService | None = None
|
|
|
|
# Global client for backward compatibility
|
|
graphiti_client: Graphiti | None = None
|
|
semaphore: asyncio.Semaphore
|
|
|
|
|
|
class GraphitiService:
|
|
"""Graphiti service using the unified configuration system."""
|
|
|
|
def __init__(self, config: GraphitiConfig, semaphore_limit: int = 10):
|
|
self.config = config
|
|
self.semaphore_limit = semaphore_limit
|
|
self.semaphore = asyncio.Semaphore(semaphore_limit)
|
|
self.client: Graphiti | None = None
|
|
self.entity_types = None
|
|
|
|
async def initialize(self) -> None:
|
|
"""Initialize the Graphiti client with factory-created components."""
|
|
try:
|
|
# Create clients using factories
|
|
llm_client = None
|
|
embedder_client = None
|
|
|
|
# Create LLM client based on configured provider
|
|
try:
|
|
llm_client = LLMClientFactory.create(self.config.llm)
|
|
except Exception as e:
|
|
logger.warning(f'Failed to create LLM client: {e}')
|
|
|
|
# Create embedder client based on configured provider
|
|
try:
|
|
embedder_client = EmbedderFactory.create(self.config.embedder)
|
|
except Exception as e:
|
|
logger.warning(f'Failed to create embedder client: {e}')
|
|
|
|
# Get database configuration
|
|
db_config = DatabaseDriverFactory.create_config(self.config.database)
|
|
|
|
# Build entity types from configuration
|
|
custom_types = None
|
|
if self.config.graphiti.entity_types:
|
|
custom_types = {}
|
|
for entity_type in self.config.graphiti.entity_types:
|
|
# Create a dynamic Pydantic model for each entity type
|
|
# Note: Don't use 'name' as it's a protected Pydantic attribute
|
|
entity_model = type(
|
|
entity_type.name,
|
|
(BaseModel,),
|
|
{
|
|
'__doc__': entity_type.description,
|
|
},
|
|
)
|
|
custom_types[entity_type.name] = entity_model
|
|
|
|
# Store entity types for later use
|
|
self.entity_types = custom_types
|
|
|
|
# Initialize Graphiti client with appropriate driver
|
|
try:
|
|
if self.config.database.provider.lower() == 'falkordb':
|
|
# For FalkorDB, create a FalkorDriver instance directly
|
|
from graphiti_core.driver.falkordb_driver import FalkorDriver
|
|
|
|
falkor_driver = FalkorDriver(
|
|
host=db_config['host'],
|
|
port=db_config['port'],
|
|
password=db_config['password'],
|
|
database=db_config['database'],
|
|
)
|
|
|
|
self.client = Graphiti(
|
|
graph_driver=falkor_driver,
|
|
llm_client=llm_client,
|
|
embedder=embedder_client,
|
|
max_coroutines=self.semaphore_limit,
|
|
)
|
|
else:
|
|
# For Neo4j (default), use the original approach
|
|
self.client = Graphiti(
|
|
uri=db_config['uri'],
|
|
user=db_config['user'],
|
|
password=db_config['password'],
|
|
llm_client=llm_client,
|
|
embedder=embedder_client,
|
|
max_coroutines=self.semaphore_limit,
|
|
)
|
|
except Exception as db_error:
|
|
# Check for connection errors
|
|
error_msg = str(db_error).lower()
|
|
if 'connection refused' in error_msg or 'could not connect' in error_msg:
|
|
db_provider = self.config.database.provider
|
|
if db_provider.lower() == 'falkordb':
|
|
raise RuntimeError(
|
|
f'\n{"=" * 70}\n'
|
|
f'Database Connection Error: FalkorDB is not running\n'
|
|
f'{"=" * 70}\n\n'
|
|
f'FalkorDB at {db_config["host"]}:{db_config["port"]} is not accessible.\n\n'
|
|
f'To start FalkorDB:\n'
|
|
f' - Using Docker Compose: cd mcp_server && docker compose up\n'
|
|
f' - Or run FalkorDB manually: docker run -p 6379:6379 falkordb/falkordb\n\n'
|
|
f'{"=" * 70}\n'
|
|
) from db_error
|
|
elif db_provider.lower() == 'neo4j':
|
|
raise RuntimeError(
|
|
f'\n{"=" * 70}\n'
|
|
f'Database Connection Error: Neo4j is not running\n'
|
|
f'{"=" * 70}\n\n'
|
|
f'Neo4j at {db_config.get("uri", "unknown")} is not accessible.\n\n'
|
|
f'To start Neo4j:\n'
|
|
f' - Using Docker Compose: cd mcp_server && docker compose -f docker/docker-compose-neo4j.yml up\n'
|
|
f' - Or install Neo4j Desktop from: https://neo4j.com/download/\n'
|
|
f' - Or run Neo4j manually: docker run -p 7474:7474 -p 7687:7687 neo4j:latest\n\n'
|
|
f'{"=" * 70}\n'
|
|
) from db_error
|
|
else:
|
|
raise RuntimeError(
|
|
f'\n{"=" * 70}\n'
|
|
f'Database Connection Error: {db_provider} is not running\n'
|
|
f'{"=" * 70}\n\n'
|
|
f'{db_provider} at {db_config.get("uri", "unknown")} is not accessible.\n\n'
|
|
f'Please ensure {db_provider} is running and accessible.\n\n'
|
|
f'{"=" * 70}\n'
|
|
) from db_error
|
|
# Re-raise other errors
|
|
raise
|
|
|
|
# Build indices
|
|
await self.client.build_indices_and_constraints()
|
|
|
|
logger.info('Successfully initialized Graphiti client')
|
|
|
|
# Log configuration details
|
|
if llm_client:
|
|
logger.info(
|
|
f'Using LLM provider: {self.config.llm.provider} / {self.config.llm.model}'
|
|
)
|
|
else:
|
|
logger.info('No LLM client configured - entity extraction will be limited')
|
|
|
|
if embedder_client:
|
|
logger.info(f'Using Embedder provider: {self.config.embedder.provider}')
|
|
else:
|
|
logger.info('No Embedder client configured - search will be limited')
|
|
|
|
if self.entity_types:
|
|
entity_type_names = list(self.entity_types.keys())
|
|
logger.info(f'Using custom entity types: {", ".join(entity_type_names)}')
|
|
else:
|
|
logger.info('Using default entity types')
|
|
|
|
logger.info(f'Using database: {self.config.database.provider}')
|
|
logger.info(f'Using group_id: {self.config.graphiti.group_id}')
|
|
|
|
except Exception as e:
|
|
logger.error(f'Failed to initialize Graphiti client: {e}')
|
|
raise
|
|
|
|
async def get_client(self) -> Graphiti:
|
|
"""Get the Graphiti client, initializing if necessary."""
|
|
if self.client is None:
|
|
await self.initialize()
|
|
if self.client is None:
|
|
raise RuntimeError('Failed to initialize Graphiti client')
|
|
return self.client
|
|
|
|
|
|
@mcp.tool()
|
|
async def add_memory(
|
|
name: str,
|
|
episode_body: str,
|
|
group_id: str | None = None,
|
|
source: str = 'text',
|
|
source_description: str = '',
|
|
uuid: str | None = None,
|
|
) -> SuccessResponse | ErrorResponse:
|
|
"""Add an episode to memory. This is the primary way to add information to the graph.
|
|
|
|
This function returns immediately and processes the episode addition in the background.
|
|
Episodes for the same group_id are processed sequentially to avoid race conditions.
|
|
|
|
Args:
|
|
name (str): Name of the episode
|
|
episode_body (str): The content of the episode to persist to memory. When source='json', this must be a
|
|
properly escaped JSON string, not a raw Python dictionary. The JSON data will be
|
|
automatically processed to extract entities and relationships.
|
|
group_id (str, optional): A unique ID for this graph. If not provided, uses the default group_id from CLI
|
|
or a generated one.
|
|
source (str, optional): Source type, must be one of:
|
|
- 'text': For plain text content (default)
|
|
- 'json': For structured data
|
|
- 'message': For conversation-style content
|
|
source_description (str, optional): Description of the source
|
|
uuid (str, optional): Optional UUID for the episode
|
|
|
|
Examples:
|
|
# Adding plain text content
|
|
add_memory(
|
|
name="Company News",
|
|
episode_body="Acme Corp announced a new product line today.",
|
|
source="text",
|
|
source_description="news article",
|
|
group_id="some_arbitrary_string"
|
|
)
|
|
|
|
# Adding structured JSON data
|
|
# NOTE: episode_body should be a JSON string (standard JSON escaping)
|
|
add_memory(
|
|
name="Customer Profile",
|
|
episode_body='{"company": {"name": "Acme Technologies"}, "products": [{"id": "P001", "name": "CloudSync"}, {"id": "P002", "name": "DataMiner"}]}',
|
|
source="json",
|
|
source_description="CRM data"
|
|
)
|
|
"""
|
|
global graphiti_service, queue_service
|
|
|
|
if graphiti_service is None or queue_service is None:
|
|
return ErrorResponse(error='Services not initialized')
|
|
|
|
try:
|
|
# Use the provided group_id or fall back to the default from config
|
|
effective_group_id = group_id or config.graphiti.group_id
|
|
|
|
# Try to parse the source as an EpisodeType enum, with fallback to text
|
|
episode_type = EpisodeType.text # Default
|
|
if source:
|
|
try:
|
|
episode_type = EpisodeType[source.lower()]
|
|
except (KeyError, AttributeError):
|
|
# If the source doesn't match any enum value, use text as default
|
|
logger.warning(f"Unknown source type '{source}', using 'text' as default")
|
|
episode_type = EpisodeType.text
|
|
|
|
# Submit to queue service for async processing
|
|
await queue_service.add_episode(
|
|
group_id=effective_group_id,
|
|
name=name,
|
|
content=episode_body,
|
|
source_description=source_description,
|
|
episode_type=episode_type,
|
|
entity_types=graphiti_service.entity_types,
|
|
uuid=uuid or None, # Ensure None is passed if uuid is None
|
|
)
|
|
|
|
return SuccessResponse(
|
|
message=f"Episode '{name}' queued for processing in group '{effective_group_id}'"
|
|
)
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
logger.error(f'Error queuing episode: {error_msg}')
|
|
return ErrorResponse(error=f'Error queuing episode: {error_msg}')
|
|
|
|
|
|
@mcp.tool()
|
|
async def search_nodes(
|
|
query: str,
|
|
group_ids: list[str] | None = None,
|
|
max_nodes: int = 10,
|
|
entity_types: list[str] | None = None,
|
|
) -> NodeSearchResponse | ErrorResponse:
|
|
"""Search for nodes in the graph memory.
|
|
|
|
Args:
|
|
query: The search query
|
|
group_ids: Optional list of group IDs to filter results
|
|
max_nodes: Maximum number of nodes to return (default: 10)
|
|
entity_types: Optional list of entity type names to filter by
|
|
"""
|
|
global graphiti_service
|
|
|
|
if graphiti_service is None:
|
|
return ErrorResponse(error='Graphiti service not initialized')
|
|
|
|
try:
|
|
client = await graphiti_service.get_client()
|
|
|
|
# Use the provided group_ids or fall back to the default from config if none provided
|
|
effective_group_ids = (
|
|
group_ids
|
|
if group_ids is not None
|
|
else [config.graphiti.group_id]
|
|
if config.graphiti.group_id
|
|
else []
|
|
)
|
|
|
|
# Create search filters
|
|
search_filters = SearchFilters(
|
|
node_labels=entity_types,
|
|
)
|
|
|
|
# Use the search_ method with node search config
|
|
from graphiti_core.search.search_config_recipes import NODE_HYBRID_SEARCH_RRF
|
|
|
|
results = await client.search_(
|
|
query=query,
|
|
config=NODE_HYBRID_SEARCH_RRF,
|
|
group_ids=effective_group_ids,
|
|
search_filter=search_filters,
|
|
)
|
|
|
|
# Extract nodes from results
|
|
nodes = results.nodes[:max_nodes] if results.nodes else []
|
|
|
|
if not nodes:
|
|
return NodeSearchResponse(message='No relevant nodes found', nodes=[])
|
|
|
|
# Format the results
|
|
node_results = []
|
|
for node in nodes:
|
|
# Get attributes and ensure no embeddings are included
|
|
attrs = node.attributes if hasattr(node, 'attributes') else {}
|
|
# Remove any embedding keys that might be in attributes
|
|
attrs = {k: v for k, v in attrs.items() if 'embedding' not in k.lower()}
|
|
|
|
node_results.append(
|
|
NodeResult(
|
|
uuid=node.uuid,
|
|
name=node.name,
|
|
labels=node.labels if node.labels else [],
|
|
created_at=node.created_at.isoformat() if node.created_at else None,
|
|
summary=node.summary,
|
|
group_id=node.group_id,
|
|
attributes=attrs,
|
|
)
|
|
)
|
|
|
|
return NodeSearchResponse(message='Nodes retrieved successfully', nodes=node_results)
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
logger.error(f'Error searching nodes: {error_msg}')
|
|
return ErrorResponse(error=f'Error searching nodes: {error_msg}')
|
|
|
|
|
|
@mcp.tool()
|
|
async def search_memory_facts(
|
|
query: str,
|
|
group_ids: list[str] | None = None,
|
|
max_facts: int = 10,
|
|
center_node_uuid: str | None = None,
|
|
) -> FactSearchResponse | ErrorResponse:
|
|
"""Search the graph memory for relevant facts.
|
|
|
|
Args:
|
|
query: The search query
|
|
group_ids: Optional list of group IDs to filter results
|
|
max_facts: Maximum number of facts to return (default: 10)
|
|
center_node_uuid: Optional UUID of a node to center the search around
|
|
"""
|
|
global graphiti_service
|
|
|
|
if graphiti_service is None:
|
|
return ErrorResponse(error='Graphiti service not initialized')
|
|
|
|
try:
|
|
# Validate max_facts parameter
|
|
if max_facts <= 0:
|
|
return ErrorResponse(error='max_facts must be a positive integer')
|
|
|
|
client = await graphiti_service.get_client()
|
|
|
|
# Use the provided group_ids or fall back to the default from config if none provided
|
|
effective_group_ids = (
|
|
group_ids
|
|
if group_ids is not None
|
|
else [config.graphiti.group_id]
|
|
if config.graphiti.group_id
|
|
else []
|
|
)
|
|
|
|
relevant_edges = await client.search(
|
|
group_ids=effective_group_ids,
|
|
query=query,
|
|
num_results=max_facts,
|
|
center_node_uuid=center_node_uuid,
|
|
)
|
|
|
|
if not relevant_edges:
|
|
return FactSearchResponse(message='No relevant facts found', facts=[])
|
|
|
|
facts = [format_fact_result(edge) for edge in relevant_edges]
|
|
return FactSearchResponse(message='Facts retrieved successfully', facts=facts)
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
logger.error(f'Error searching facts: {error_msg}')
|
|
return ErrorResponse(error=f'Error searching facts: {error_msg}')
|
|
|
|
|
|
@mcp.tool()
|
|
async def delete_entity_edge(uuid: str) -> SuccessResponse | ErrorResponse:
|
|
"""Delete an entity edge from the graph memory.
|
|
|
|
Args:
|
|
uuid: UUID of the entity edge to delete
|
|
"""
|
|
global graphiti_service
|
|
|
|
if graphiti_service is None:
|
|
return ErrorResponse(error='Graphiti service not initialized')
|
|
|
|
try:
|
|
client = await graphiti_service.get_client()
|
|
|
|
# Get the entity edge by UUID
|
|
entity_edge = await EntityEdge.get_by_uuid(client.driver, uuid)
|
|
# Delete the edge using its delete method
|
|
await entity_edge.delete(client.driver)
|
|
return SuccessResponse(message=f'Entity edge with UUID {uuid} deleted successfully')
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
logger.error(f'Error deleting entity edge: {error_msg}')
|
|
return ErrorResponse(error=f'Error deleting entity edge: {error_msg}')
|
|
|
|
|
|
@mcp.tool()
|
|
async def delete_episode(uuid: str) -> SuccessResponse | ErrorResponse:
|
|
"""Delete an episode from the graph memory.
|
|
|
|
Args:
|
|
uuid: UUID of the episode to delete
|
|
"""
|
|
global graphiti_service
|
|
|
|
if graphiti_service is None:
|
|
return ErrorResponse(error='Graphiti service not initialized')
|
|
|
|
try:
|
|
client = await graphiti_service.get_client()
|
|
|
|
# Get the episodic node by UUID
|
|
episodic_node = await EpisodicNode.get_by_uuid(client.driver, uuid)
|
|
# Delete the node using its delete method
|
|
await episodic_node.delete(client.driver)
|
|
return SuccessResponse(message=f'Episode with UUID {uuid} deleted successfully')
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
logger.error(f'Error deleting episode: {error_msg}')
|
|
return ErrorResponse(error=f'Error deleting episode: {error_msg}')
|
|
|
|
|
|
@mcp.tool()
|
|
async def get_entity_edge(uuid: str) -> dict[str, Any] | ErrorResponse:
|
|
"""Get an entity edge from the graph memory by its UUID.
|
|
|
|
Args:
|
|
uuid: UUID of the entity edge to retrieve
|
|
"""
|
|
global graphiti_service
|
|
|
|
if graphiti_service is None:
|
|
return ErrorResponse(error='Graphiti service not initialized')
|
|
|
|
try:
|
|
client = await graphiti_service.get_client()
|
|
|
|
# Get the entity edge directly using the EntityEdge class method
|
|
entity_edge = await EntityEdge.get_by_uuid(client.driver, uuid)
|
|
|
|
# Use the format_fact_result function to serialize the edge
|
|
# Return the Python dict directly - MCP will handle serialization
|
|
return format_fact_result(entity_edge)
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
logger.error(f'Error getting entity edge: {error_msg}')
|
|
return ErrorResponse(error=f'Error getting entity edge: {error_msg}')
|
|
|
|
|
|
@mcp.tool()
|
|
async def get_episodes(
|
|
group_ids: list[str] | None = None,
|
|
max_episodes: int = 10,
|
|
) -> EpisodeSearchResponse | ErrorResponse:
|
|
"""Get episodes from the graph memory.
|
|
|
|
Args:
|
|
group_ids: Optional list of group IDs to filter results
|
|
max_episodes: Maximum number of episodes to return (default: 10)
|
|
"""
|
|
global graphiti_service
|
|
|
|
if graphiti_service is None:
|
|
return ErrorResponse(error='Graphiti service not initialized')
|
|
|
|
try:
|
|
client = await graphiti_service.get_client()
|
|
|
|
# Use the provided group_ids or fall back to the default from config if none provided
|
|
effective_group_ids = (
|
|
group_ids
|
|
if group_ids is not None
|
|
else [config.graphiti.group_id]
|
|
if config.graphiti.group_id
|
|
else []
|
|
)
|
|
|
|
# Get episodes from the driver directly
|
|
from graphiti_core.nodes import EpisodicNode
|
|
|
|
if effective_group_ids:
|
|
episodes = await EpisodicNode.get_by_group_ids(
|
|
client.driver, effective_group_ids, limit=max_episodes
|
|
)
|
|
else:
|
|
# If no group IDs, we need to use a different approach
|
|
# For now, return empty list when no group IDs specified
|
|
episodes = []
|
|
|
|
if not episodes:
|
|
return EpisodeSearchResponse(message='No episodes found', episodes=[])
|
|
|
|
# Format the results
|
|
episode_results = []
|
|
for episode in episodes:
|
|
episode_dict = {
|
|
'uuid': episode.uuid,
|
|
'name': episode.name,
|
|
'content': episode.content,
|
|
'created_at': episode.created_at.isoformat() if episode.created_at else None,
|
|
'source': episode.source.value
|
|
if hasattr(episode.source, 'value')
|
|
else str(episode.source),
|
|
'source_description': episode.source_description,
|
|
'group_id': episode.group_id,
|
|
}
|
|
episode_results.append(episode_dict)
|
|
|
|
return EpisodeSearchResponse(
|
|
message='Episodes retrieved successfully', episodes=episode_results
|
|
)
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
logger.error(f'Error getting episodes: {error_msg}')
|
|
return ErrorResponse(error=f'Error getting episodes: {error_msg}')
|
|
|
|
|
|
@mcp.tool()
|
|
async def clear_graph(group_ids: list[str] | None = None) -> SuccessResponse | ErrorResponse:
|
|
"""Clear all data from the graph for specified group IDs.
|
|
|
|
Args:
|
|
group_ids: Optional list of group IDs to clear. If not provided, clears the default group.
|
|
"""
|
|
global graphiti_service
|
|
|
|
if graphiti_service is None:
|
|
return ErrorResponse(error='Graphiti service not initialized')
|
|
|
|
try:
|
|
client = await graphiti_service.get_client()
|
|
|
|
# Use the provided group_ids or fall back to the default from config if none provided
|
|
effective_group_ids = (
|
|
group_ids or [config.graphiti.group_id] if config.graphiti.group_id else []
|
|
)
|
|
|
|
if not effective_group_ids:
|
|
return ErrorResponse(error='No group IDs specified for clearing')
|
|
|
|
# Clear data for the specified group IDs
|
|
await clear_data(client.driver, group_ids=effective_group_ids)
|
|
|
|
return SuccessResponse(
|
|
message=f'Graph data cleared successfully for group IDs: {", ".join(effective_group_ids)}'
|
|
)
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
logger.error(f'Error clearing graph: {error_msg}')
|
|
return ErrorResponse(error=f'Error clearing graph: {error_msg}')
|
|
|
|
|
|
@mcp.tool()
|
|
async def get_status() -> StatusResponse:
|
|
"""Get the status of the Graphiti MCP server and database connection."""
|
|
global graphiti_service
|
|
|
|
if graphiti_service is None:
|
|
return StatusResponse(status='error', message='Graphiti service not initialized')
|
|
|
|
try:
|
|
client = await graphiti_service.get_client()
|
|
|
|
# Test database connection with a simple query
|
|
async with client.driver.session() as session:
|
|
result = await session.run('MATCH (n) RETURN count(n) as count')
|
|
# Consume the result to verify query execution
|
|
if result:
|
|
_ = [record async for record in result]
|
|
|
|
# Use the provider from the service's config, not the global
|
|
provider_name = graphiti_service.config.database.provider
|
|
return StatusResponse(
|
|
status='ok',
|
|
message=f'Graphiti MCP server is running and connected to {provider_name} database',
|
|
)
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
logger.error(f'Error checking database connection: {error_msg}')
|
|
return StatusResponse(
|
|
status='error',
|
|
message=f'Graphiti MCP server is running but database connection failed: {error_msg}',
|
|
)
|
|
|
|
|
|
@mcp.custom_route('/health', methods=['GET'])
|
|
async def health_check(request) -> JSONResponse:
|
|
"""Health check endpoint for Docker and load balancers."""
|
|
return JSONResponse({'status': 'healthy', 'service': 'graphiti-mcp'})
|
|
|
|
|
|
async def initialize_server() -> ServerConfig:
|
|
"""Parse CLI arguments and initialize the Graphiti server configuration."""
|
|
global config, graphiti_service, queue_service, graphiti_client, semaphore
|
|
|
|
parser = argparse.ArgumentParser(
|
|
description='Run the Graphiti MCP server with YAML configuration support'
|
|
)
|
|
|
|
# Configuration file argument
|
|
# Default to config/config.yaml relative to the mcp_server directory
|
|
default_config = Path(__file__).parent.parent / 'config' / 'config.yaml'
|
|
parser.add_argument(
|
|
'--config',
|
|
type=Path,
|
|
default=default_config,
|
|
help='Path to YAML configuration file (default: config/config.yaml)',
|
|
)
|
|
|
|
# Transport arguments
|
|
parser.add_argument(
|
|
'--transport',
|
|
choices=['sse', 'stdio', 'http'],
|
|
help='Transport to use: http (recommended, default), stdio (standard I/O), or sse (deprecated)',
|
|
)
|
|
parser.add_argument(
|
|
'--host',
|
|
help='Host to bind the MCP server to',
|
|
)
|
|
parser.add_argument(
|
|
'--port',
|
|
type=int,
|
|
help='Port to bind the MCP server to',
|
|
)
|
|
|
|
# Provider selection arguments
|
|
parser.add_argument(
|
|
'--llm-provider',
|
|
choices=['openai', 'azure_openai', 'anthropic', 'gemini', 'groq'],
|
|
help='LLM provider to use',
|
|
)
|
|
parser.add_argument(
|
|
'--embedder-provider',
|
|
choices=['openai', 'azure_openai', 'gemini', 'voyage'],
|
|
help='Embedder provider to use',
|
|
)
|
|
parser.add_argument(
|
|
'--database-provider',
|
|
choices=['neo4j', 'falkordb'],
|
|
help='Database provider to use',
|
|
)
|
|
|
|
# LLM configuration arguments
|
|
parser.add_argument('--model', help='Model name to use with the LLM client')
|
|
parser.add_argument('--small-model', help='Small model name to use with the LLM client')
|
|
parser.add_argument(
|
|
'--temperature', type=float, help='Temperature setting for the LLM (0.0-2.0)'
|
|
)
|
|
|
|
# Embedder configuration arguments
|
|
parser.add_argument('--embedder-model', help='Model name to use with the embedder')
|
|
|
|
# Graphiti-specific arguments
|
|
parser.add_argument(
|
|
'--group-id',
|
|
help='Namespace for the graph. If not provided, uses config file or generates random UUID.',
|
|
)
|
|
parser.add_argument(
|
|
'--user-id',
|
|
help='User ID for tracking operations',
|
|
)
|
|
parser.add_argument(
|
|
'--destroy-graph',
|
|
action='store_true',
|
|
help='Destroy all Graphiti graphs on startup',
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Set config path in environment for the settings to pick up
|
|
if args.config:
|
|
os.environ['CONFIG_PATH'] = str(args.config)
|
|
|
|
# Load configuration with environment variables and YAML
|
|
config = GraphitiConfig()
|
|
|
|
# Apply CLI overrides
|
|
config.apply_cli_overrides(args)
|
|
|
|
# Also apply legacy CLI args for backward compatibility
|
|
if hasattr(args, 'destroy_graph'):
|
|
config.destroy_graph = args.destroy_graph
|
|
|
|
# Log configuration details
|
|
logger.info('Using configuration:')
|
|
logger.info(f' - LLM: {config.llm.provider} / {config.llm.model}')
|
|
logger.info(f' - Embedder: {config.embedder.provider} / {config.embedder.model}')
|
|
logger.info(f' - Database: {config.database.provider}')
|
|
logger.info(f' - Group ID: {config.graphiti.group_id}')
|
|
logger.info(f' - Transport: {config.server.transport}')
|
|
|
|
# Log graphiti-core version
|
|
try:
|
|
import graphiti_core
|
|
|
|
graphiti_version = getattr(graphiti_core, '__version__', 'unknown')
|
|
logger.info(f' - Graphiti Core: {graphiti_version}')
|
|
except Exception:
|
|
# Check for Docker-stored version file
|
|
version_file = Path('/app/.graphiti-core-version')
|
|
if version_file.exists():
|
|
graphiti_version = version_file.read_text().strip()
|
|
logger.info(f' - Graphiti Core: {graphiti_version}')
|
|
else:
|
|
logger.info(' - Graphiti Core: version unavailable')
|
|
|
|
# Handle graph destruction if requested
|
|
if hasattr(config, 'destroy_graph') and config.destroy_graph:
|
|
logger.warning('Destroying all Graphiti graphs as requested...')
|
|
temp_service = GraphitiService(config, SEMAPHORE_LIMIT)
|
|
await temp_service.initialize()
|
|
client = await temp_service.get_client()
|
|
await clear_data(client.driver)
|
|
logger.info('All graphs destroyed')
|
|
|
|
# Initialize services
|
|
graphiti_service = GraphitiService(config, SEMAPHORE_LIMIT)
|
|
queue_service = QueueService()
|
|
await graphiti_service.initialize()
|
|
|
|
# Set global client for backward compatibility
|
|
graphiti_client = await graphiti_service.get_client()
|
|
semaphore = graphiti_service.semaphore
|
|
|
|
# Initialize queue service with the client
|
|
await queue_service.initialize(graphiti_client)
|
|
|
|
# Set MCP server settings
|
|
if config.server.host:
|
|
mcp.settings.host = config.server.host
|
|
if config.server.port:
|
|
mcp.settings.port = config.server.port
|
|
|
|
# Return MCP configuration for transport
|
|
return config.server
|
|
|
|
|
|
async def run_mcp_server():
|
|
"""Run the MCP server in the current event loop."""
|
|
# Initialize the server
|
|
mcp_config = await initialize_server()
|
|
|
|
# Run the server with configured transport
|
|
logger.info(f'Starting MCP server with transport: {mcp_config.transport}')
|
|
if mcp_config.transport == 'stdio':
|
|
await mcp.run_stdio_async()
|
|
elif mcp_config.transport == 'sse':
|
|
logger.info(
|
|
f'Running MCP server with SSE transport on {mcp.settings.host}:{mcp.settings.port}'
|
|
)
|
|
logger.info(f'Access the server at: http://{mcp.settings.host}:{mcp.settings.port}/sse')
|
|
await mcp.run_sse_async()
|
|
elif mcp_config.transport == 'http':
|
|
# Use localhost for display if binding to 0.0.0.0
|
|
display_host = 'localhost' if mcp.settings.host == '0.0.0.0' else mcp.settings.host
|
|
logger.info(
|
|
f'Running MCP server with streamable HTTP transport on {mcp.settings.host}:{mcp.settings.port}'
|
|
)
|
|
logger.info('=' * 60)
|
|
logger.info('MCP Server Access Information:')
|
|
logger.info(f' Base URL: http://{display_host}:{mcp.settings.port}/')
|
|
logger.info(f' MCP Endpoint: http://{display_host}:{mcp.settings.port}/mcp/')
|
|
logger.info(' Transport: HTTP (streamable)')
|
|
|
|
# Show FalkorDB Browser UI access if enabled
|
|
if os.environ.get('BROWSER', '1') == '1':
|
|
logger.info(f' FalkorDB Browser UI: http://{display_host}:3000/')
|
|
|
|
logger.info('=' * 60)
|
|
logger.info('For MCP clients, connect to the /mcp/ endpoint above')
|
|
|
|
# Configure uvicorn logging to match our format
|
|
configure_uvicorn_logging()
|
|
|
|
await mcp.run_streamable_http_async()
|
|
else:
|
|
raise ValueError(
|
|
f'Unsupported transport: {mcp_config.transport}. Use "sse", "stdio", or "http"'
|
|
)
|
|
|
|
|
|
def main():
|
|
"""Main function to run the Graphiti MCP server."""
|
|
try:
|
|
# Run everything in a single event loop
|
|
asyncio.run(run_mcp_server())
|
|
except KeyboardInterrupt:
|
|
logger.info('Server shutting down...')
|
|
except Exception as e:
|
|
logger.error(f'Error initializing Graphiti MCP server: {str(e)}')
|
|
raise
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|