1397 lines
56 KiB
Python
1397 lines
56 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Graphiti MCP Server - Exposes Graphiti functionality through the Model Context Protocol (MCP)
|
|
"""
|
|
|
|
import argparse
|
|
import asyncio
|
|
import logging
|
|
import os
|
|
import sys
|
|
import uuid
|
|
from collections.abc import Callable
|
|
from datetime import datetime, timezone
|
|
from typing import Any, TypedDict, cast
|
|
|
|
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
|
|
from dotenv import load_dotenv
|
|
from mcp.server.fastmcp import FastMCP
|
|
from openai import AsyncAzureOpenAI
|
|
from pydantic import BaseModel, Field
|
|
|
|
from graphiti_core import Graphiti
|
|
from graphiti_core.cross_encoder.client import CrossEncoderClient
|
|
from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
|
|
from graphiti_core.edges import EntityEdge
|
|
from graphiti_core.embedder.client import EmbedderClient
|
|
from graphiti_core.embedder.openai import OpenAIEmbedder, OpenAIEmbedderConfig
|
|
from graphiti_core.llm_client import LLMClient
|
|
from graphiti_core.llm_client.anthropic_client import AnthropicClient
|
|
from graphiti_core.llm_client.config import LLMConfig
|
|
from graphiti_core.llm_client.openai_client import OpenAIClient
|
|
from graphiti_core.nodes import EpisodeType, EpisodicNode
|
|
from graphiti_core.search.search_config_recipes import (
|
|
NODE_HYBRID_SEARCH_NODE_DISTANCE,
|
|
NODE_HYBRID_SEARCH_RRF,
|
|
)
|
|
from graphiti_core.search.search_filters import SearchFilters
|
|
from graphiti_core.utils.maintenance.graph_data_operations import clear_data
|
|
|
|
load_dotenv()
|
|
|
|
DEFAULT_LLM_MODEL = 'gpt-4.1-mini'
|
|
DEFAULT_EMBEDDER_MODEL = 'text-embedding-3-small'
|
|
|
|
|
|
class Requirement(BaseModel):
|
|
"""A Requirement represents a specific need, feature, or functionality that a product or service must fulfill.
|
|
|
|
Always ensure an edge is created between the requirement and the project it belongs to, and clearly indicate on the
|
|
edge that the requirement is a requirement.
|
|
|
|
Instructions for identifying and extracting requirements:
|
|
1. Look for explicit statements of needs or necessities ("We need X", "X is required", "X must have Y")
|
|
2. Identify functional specifications that describe what the system should do
|
|
3. Pay attention to non-functional requirements like performance, security, or usability criteria
|
|
4. Extract constraints or limitations that must be adhered to
|
|
5. Focus on clear, specific, and measurable requirements rather than vague wishes
|
|
6. Capture the priority or importance if mentioned ("critical", "high priority", etc.)
|
|
7. Include any dependencies between requirements when explicitly stated
|
|
8. Preserve the original intent and scope of the requirement
|
|
9. Categorize requirements appropriately based on their domain or function
|
|
"""
|
|
|
|
project_name: str = Field(
|
|
...,
|
|
description='The name of the project to which the requirement belongs.',
|
|
)
|
|
description: str = Field(
|
|
...,
|
|
description='Description of the requirement. Only use information mentioned in the context to write this description.',
|
|
)
|
|
|
|
|
|
class Preference(BaseModel):
|
|
"""A Preference represents a user's expressed like, dislike, or preference for something.
|
|
|
|
Instructions for identifying and extracting preferences:
|
|
1. Look for explicit statements of preference such as "I like/love/enjoy/prefer X" or "I don't like/hate/dislike X"
|
|
2. Pay attention to comparative statements ("I prefer X over Y")
|
|
3. Consider the emotional tone when users mention certain topics
|
|
4. Extract only preferences that are clearly expressed, not assumptions
|
|
5. Categorize the preference appropriately based on its domain (food, music, brands, etc.)
|
|
6. Include relevant qualifiers (e.g., "likes spicy food" rather than just "likes food")
|
|
7. Only extract preferences directly stated by the user, not preferences of others they mention
|
|
8. Provide a concise but specific description that captures the nature of the preference
|
|
"""
|
|
|
|
category: str = Field(
|
|
...,
|
|
description="The category of the preference. (e.g., 'Brands', 'Food', 'Music')",
|
|
)
|
|
description: str = Field(
|
|
...,
|
|
description='Brief description of the preference. Only use information mentioned in the context to write this description.',
|
|
)
|
|
|
|
|
|
class Procedure(BaseModel):
|
|
"""A Procedure informing the agent what actions to take or how to perform in certain scenarios. Procedures are typically composed of several steps.
|
|
|
|
Instructions for identifying and extracting procedures:
|
|
1. Look for sequential instructions or steps ("First do X, then do Y")
|
|
2. Identify explicit directives or commands ("Always do X when Y happens")
|
|
3. Pay attention to conditional statements ("If X occurs, then do Y")
|
|
4. Extract procedures that have clear beginning and end points
|
|
5. Focus on actionable instructions rather than general information
|
|
6. Preserve the original sequence and dependencies between steps
|
|
7. Include any specified conditions or triggers for the procedure
|
|
8. Capture any stated purpose or goal of the procedure
|
|
9. Summarize complex procedures while maintaining critical details
|
|
"""
|
|
|
|
description: str = Field(
|
|
...,
|
|
description='Brief description of the procedure. Only use information mentioned in the context to write this description.',
|
|
)
|
|
|
|
|
|
ENTITY_TYPES: dict[str, BaseModel] = {
|
|
'Requirement': Requirement, # type: ignore
|
|
'Preference': Preference, # type: ignore
|
|
'Procedure': Procedure, # type: ignore
|
|
}
|
|
|
|
|
|
# Type definitions for API responses
|
|
class ErrorResponse(TypedDict):
|
|
error: str
|
|
|
|
|
|
class SuccessResponse(TypedDict):
|
|
message: str
|
|
|
|
|
|
class NodeResult(TypedDict):
|
|
uuid: str
|
|
name: str
|
|
summary: str
|
|
labels: list[str]
|
|
group_id: str
|
|
created_at: str
|
|
attributes: dict[str, Any]
|
|
|
|
|
|
class NodeSearchResponse(TypedDict):
|
|
message: str
|
|
nodes: list[NodeResult]
|
|
|
|
|
|
class FactSearchResponse(TypedDict):
|
|
message: str
|
|
facts: list[dict[str, Any]]
|
|
|
|
|
|
class EpisodeSearchResponse(TypedDict):
|
|
message: str
|
|
episodes: list[dict[str, Any]]
|
|
|
|
|
|
class StatusResponse(TypedDict):
|
|
status: str
|
|
message: str
|
|
|
|
|
|
def create_azure_credential_token_provider() -> Callable[[], str]:
|
|
credential = DefaultAzureCredential()
|
|
token_provider = get_bearer_token_provider(
|
|
credential, 'https://cognitiveservices.azure.com/.default'
|
|
)
|
|
return token_provider
|
|
|
|
|
|
# Server configuration classes
|
|
# The configuration system has a hierarchy:
|
|
# - GraphitiConfig is the top-level configuration
|
|
# - LLMConfig handles all OpenAI/LLM related settings
|
|
# - EmbedderConfig manages embedding settings
|
|
# - Neo4jConfig manages database connection details
|
|
# - Various other settings like group_id and feature flags
|
|
# Configuration values are loaded from:
|
|
# 1. Default values in the class definitions
|
|
# 2. Environment variables (loaded via load_dotenv())
|
|
# 3. Command line arguments (which override environment variables)
|
|
class GraphitiLLMConfig(BaseModel):
|
|
"""Configuration for the LLM client.
|
|
|
|
Centralizes all LLM-specific configuration parameters including API keys and model selection.
|
|
"""
|
|
|
|
model_provider: str = 'openai'
|
|
api_key: str | None = None
|
|
anthropic_api_key: str | None = None
|
|
model: str = DEFAULT_LLM_MODEL
|
|
temperature: float = 0.0
|
|
azure_openai_endpoint: str | None = None
|
|
azure_openai_deployment_name: str | None = None
|
|
azure_openai_api_version: str | None = None
|
|
azure_openai_use_managed_identity: bool = False
|
|
|
|
@classmethod
|
|
def from_env(cls) -> 'GraphitiLLMConfig':
|
|
"""Create LLM configuration from environment variables.
|
|
Loads potential keys and default settings. Provider-specific logic is handled in from_cli_and_env.
|
|
"""
|
|
# Load all potentially relevant keys and settings from environment
|
|
model_provider = os.environ.get('MODEL_PROVIDER', 'openai').lower()
|
|
model_env = os.environ.get('MODEL_NAME', '')
|
|
model = model_env if model_env.strip() else DEFAULT_LLM_MODEL
|
|
openai_api_key = os.environ.get('OPENAI_API_KEY')
|
|
anthropic_api_key = os.environ.get('ANTHROPIC_API_KEY')
|
|
azure_endpoint = os.environ.get('AZURE_OPENAI_ENDPOINT')
|
|
azure_deployment = os.environ.get('AZURE_OPENAI_DEPLOYMENT_NAME')
|
|
azure_api_version = os.environ.get('AZURE_OPENAI_API_VERSION')
|
|
azure_use_managed_identity = (
|
|
os.environ.get('AZURE_OPENAI_USE_MANAGED_IDENTITY', 'false').lower() == 'true'
|
|
)
|
|
temperature = float(os.environ.get('LLM_TEMPERATURE', '0.0'))
|
|
|
|
# Return a config with all loaded env vars, provider logic deferred
|
|
return cls(
|
|
model_provider=model_provider, # Initial provider from env
|
|
api_key=openai_api_key,
|
|
anthropic_api_key=anthropic_api_key,
|
|
model=model,
|
|
temperature=temperature,
|
|
azure_openai_endpoint=azure_endpoint,
|
|
azure_openai_deployment_name=azure_deployment,
|
|
azure_openai_api_version=azure_api_version,
|
|
azure_openai_use_managed_identity=azure_use_managed_identity,
|
|
)
|
|
|
|
@classmethod
|
|
def from_cli_and_env(cls, args: argparse.Namespace) -> 'GraphitiLLMConfig':
|
|
"""Create LLM configuration from CLI arguments, falling back to environment variables.
|
|
Determines final provider and loads settings accordingly.
|
|
"""
|
|
# 1. Determine the final model provider (CLI > Env Var > Default)
|
|
final_model_provider = 'openai' # Default
|
|
if 'MODEL_PROVIDER' in os.environ:
|
|
final_model_provider = os.environ['MODEL_PROVIDER'].lower()
|
|
if hasattr(args, 'model_provider') and args.model_provider:
|
|
final_model_provider = args.model_provider.lower()
|
|
|
|
# 2. Determine final model name (CLI > Env Var > Default)
|
|
final_model = DEFAULT_LLM_MODEL # Default
|
|
if 'MODEL_NAME' in os.environ and os.environ['MODEL_NAME'].strip():
|
|
final_model = os.environ['MODEL_NAME']
|
|
if hasattr(args, 'model') and args.model and args.model.strip():
|
|
final_model = args.model
|
|
elif final_model_provider != 'openai' and final_model == DEFAULT_LLM_MODEL:
|
|
# Warn if using non-openai provider but haven't specified a model
|
|
logger.warning(
|
|
f"Model provider is '{final_model_provider}' but no model specified. "
|
|
f"Using default '{DEFAULT_LLM_MODEL}', which might be incorrect."
|
|
)
|
|
|
|
# 3. Determine final temperature (CLI > Env Var > Default)
|
|
final_temperature = 0.0 # Default
|
|
if 'LLM_TEMPERATURE' in os.environ:
|
|
try:
|
|
final_temperature = float(os.environ['LLM_TEMPERATURE'])
|
|
except ValueError:
|
|
logger.warning('Invalid LLM_TEMPERATURE in env, using default 0.0')
|
|
if hasattr(args, 'temperature') and args.temperature is not None:
|
|
final_temperature = args.temperature
|
|
|
|
# 4. Load relevant API keys and Azure settings from environment
|
|
openai_api_key = os.environ.get('OPENAI_API_KEY')
|
|
anthropic_api_key = os.environ.get('ANTHROPIC_API_KEY')
|
|
azure_endpoint = os.environ.get('AZURE_OPENAI_ENDPOINT')
|
|
azure_deployment = os.environ.get('AZURE_OPENAI_DEPLOYMENT_NAME')
|
|
azure_api_version = os.environ.get('AZURE_OPENAI_API_VERSION')
|
|
azure_use_managed_identity = (
|
|
os.environ.get('AZURE_OPENAI_USE_MANAGED_IDENTITY', 'false').lower() == 'true'
|
|
)
|
|
|
|
# 5. Construct the final config object based on the provider
|
|
config_params = {
|
|
'model_provider': final_model_provider,
|
|
'model': final_model,
|
|
'temperature': final_temperature,
|
|
}
|
|
|
|
if final_model_provider == 'anthropic':
|
|
config_params['anthropic_api_key'] = anthropic_api_key
|
|
elif final_model_provider == 'azure_openai':
|
|
config_params['azure_openai_endpoint'] = azure_endpoint
|
|
config_params['azure_openai_deployment_name'] = azure_deployment
|
|
config_params['azure_openai_api_version'] = azure_api_version
|
|
config_params['azure_openai_use_managed_identity'] = azure_use_managed_identity
|
|
if not azure_use_managed_identity:
|
|
config_params['api_key'] = (
|
|
openai_api_key # Azure uses OPENAI_API_KEY var for key auth
|
|
)
|
|
else: # Default to openai
|
|
config_params['api_key'] = openai_api_key
|
|
|
|
# Log the Anthropic key if relevant, for debugging
|
|
if final_model_provider == 'anthropic':
|
|
key_snippet = (
|
|
f'{anthropic_api_key[:5]}...{anthropic_api_key[-4:]}'
|
|
if anthropic_api_key
|
|
else 'None'
|
|
)
|
|
logger.info(
|
|
f'Anthropic provider selected. Read ANTHROPIC_API_KEY from env: {key_snippet}'
|
|
)
|
|
|
|
config = cls(**config_params)
|
|
|
|
return config
|
|
|
|
def create_client(self) -> LLMClient | None:
|
|
"""Create an LLM client based on this configuration.
|
|
|
|
Returns:
|
|
LLMClient instance if able, None otherwise
|
|
"""
|
|
|
|
# --- Anthropic Client ---
|
|
if self.model_provider == 'anthropic':
|
|
logger.info(f'Using Anthropic model: {self.model}')
|
|
if not self.anthropic_api_key:
|
|
logger.error('ANTHROPIC_API_KEY must be set when model_provider is anthropic')
|
|
return None
|
|
anthropic_config = LLMConfig(
|
|
api_key=self.anthropic_api_key, model=self.model, temperature=self.temperature
|
|
)
|
|
return AnthropicClient(config=anthropic_config)
|
|
|
|
# --- Azure OpenAI Client ---
|
|
elif self.model_provider == 'azure':
|
|
logger.info(f'Attempting to use Azure OpenAI model: {self.model}')
|
|
|
|
# Check required Azure configuration
|
|
if not self.azure_openai_endpoint:
|
|
logger.error('AZURE_OPENAI_ENDPOINT must be set when model_provider is azure')
|
|
return None
|
|
if not self.azure_openai_deployment_name:
|
|
logger.error(
|
|
'AZURE_OPENAI_DEPLOYMENT_NAME must be set when model_provider is azure'
|
|
)
|
|
return None
|
|
if not self.azure_openai_api_version:
|
|
logger.error('AZURE_OPENAI_API_VERSION must be set when model_provider is azure')
|
|
return None
|
|
if not self.azure_openai_use_managed_identity and not self.api_key:
|
|
logger.error(
|
|
'OPENAI_API_KEY must be set for Azure OpenAI when not using managed identity'
|
|
)
|
|
return None
|
|
|
|
llm_client: AsyncAzureOpenAI | None = None
|
|
if self.azure_openai_use_managed_identity:
|
|
# Use managed identity for authentication
|
|
token_provider = create_azure_credential_token_provider()
|
|
llm_client = AsyncAzureOpenAI(
|
|
azure_endpoint=self.azure_openai_endpoint,
|
|
azure_deployment=self.azure_openai_deployment_name,
|
|
api_version=self.azure_openai_api_version,
|
|
azure_ad_token_provider=token_provider,
|
|
)
|
|
elif self.api_key: # We already checked api_key exists if needed
|
|
# Use API key for authentication
|
|
llm_client = AsyncAzureOpenAI(
|
|
azure_endpoint=self.azure_openai_endpoint,
|
|
azure_deployment=self.azure_openai_deployment_name,
|
|
api_version=self.azure_openai_api_version,
|
|
api_key=self.api_key,
|
|
)
|
|
|
|
# Wrap the Azure client in OpenAIClient
|
|
if llm_client:
|
|
# Assuming OpenAIClient can accept an initialized client
|
|
return OpenAIClient(client=llm_client) # type: ignore
|
|
return None
|
|
|
|
# --- Standard OpenAI Client ---
|
|
elif self.model_provider == 'openai':
|
|
logger.info(f'Using standard OpenAI model: {self.model}')
|
|
if not self.api_key:
|
|
logger.error('OPENAI_API_KEY must be set when model_provider is openai')
|
|
return None
|
|
|
|
llm_client_config = LLMConfig(api_key=self.api_key, model=self.model)
|
|
|
|
# Set temperature
|
|
llm_client_config.temperature = self.temperature
|
|
|
|
return OpenAIClient(config=llm_client_config)
|
|
else:
|
|
logger.error(f'Unknown model provider configured: {self.model_provider}')
|
|
return None
|
|
|
|
def create_cross_encoder_client(self) -> CrossEncoderClient | None:
|
|
"""Create a cross-encoder client based on this configuration."""
|
|
# Note: OpenAIRerankerClient currently only supports OpenAI models.
|
|
# If Anthropic is selected as the main LLM, reranking might not function correctly
|
|
# or may require a separate OpenAI configuration specifically for reranking.
|
|
if self.model_provider == 'anthropic':
|
|
logger.warning(
|
|
'Anthropic selected as LLM provider. Cross-Encoder (reranking) might not work as expected '
|
|
'as it currently relies on OpenAI models. Using Anthropic client for reranking attempt.'
|
|
)
|
|
|
|
# Determine the client needed based on LLM provider config
|
|
reranker_azure_client: AsyncAzureOpenAI | None = None
|
|
|
|
if self.model_provider == 'azure_openai':
|
|
# Check required Azure configuration for reranker
|
|
if not all(
|
|
[
|
|
self.azure_openai_endpoint,
|
|
self.azure_openai_deployment_name,
|
|
self.azure_openai_api_version,
|
|
]
|
|
):
|
|
logger.error('Azure config (endpoint, deployment, version) missing for reranker')
|
|
return None
|
|
if not self.azure_openai_use_managed_identity and not self.api_key:
|
|
logger.error('OPENAI_API_KEY needed for Azure reranker without managed identity')
|
|
return None
|
|
|
|
# Create the raw Azure client again for OpenAIRerankerClient
|
|
if self.azure_openai_use_managed_identity:
|
|
token_provider = create_azure_credential_token_provider()
|
|
reranker_azure_client = AsyncAzureOpenAI(
|
|
azure_endpoint=self.azure_openai_endpoint, # type: ignore[arg-type]
|
|
azure_deployment=self.azure_openai_deployment_name, # Use LLM deployment name
|
|
api_version=self.azure_openai_api_version, # Use LLM API version
|
|
azure_ad_token_provider=token_provider,
|
|
)
|
|
elif self.api_key:
|
|
reranker_azure_client = AsyncAzureOpenAI(
|
|
azure_endpoint=self.azure_openai_endpoint, # type: ignore[arg-type]
|
|
azure_deployment=self.azure_openai_deployment_name, # Use LLM deployment name
|
|
api_version=self.azure_openai_api_version, # Use LLM API version
|
|
api_key=self.api_key,
|
|
)
|
|
elif self.model_provider == 'openai':
|
|
if not self.api_key:
|
|
logger.error('OPENAI_API_KEY needed for OpenAI reranker')
|
|
return None
|
|
# Standard OpenAI config for reranker
|
|
llm_client_config = LLMConfig(api_key=self.api_key, model=self.model)
|
|
# OpenAIRerankerClient expects a config or a raw client.
|
|
# Passing config is consistent with how OpenAIClient itself can be initialized.
|
|
return OpenAIRerankerClient(config=llm_client_config)
|
|
elif self.model_provider == 'anthropic':
|
|
# Attempt to use Anthropic client for reranking, though likely incompatible
|
|
if not self.anthropic_api_key:
|
|
logger.error('ANTHROPIC_API_KEY needed for Anthropic reranker attempt')
|
|
return None
|
|
# Pass config, OpenAIRerankerClient might handle it or raise error
|
|
# This might require changes in graphiti-core's OpenAIRerankerClient
|
|
# return OpenAIRerankerClient(config=anthropic_config) # Assuming it might work?
|
|
# Raise an error as OpenAIRerankerClient does not support Anthropic
|
|
raise ValueError(
|
|
'Cannot create CrossEncoderClient (reranker): OpenAIRerankerClient currently only supports '
|
|
'OpenAI or Azure OpenAI models. Anthropic was selected as the model_provider.'
|
|
)
|
|
else:
|
|
logger.error(f"Unsupported model provider '{self.model_provider}' for CrossEncoder")
|
|
return None
|
|
|
|
# Pass the raw Azure client instance if created
|
|
if reranker_azure_client:
|
|
return OpenAIRerankerClient(client=reranker_azure_client)
|
|
else:
|
|
# This case should ideally not be reached if provider is Azure due to checks above
|
|
logger.error(
|
|
'Failed to create Azure client for CrossEncoder despite provider being Azure'
|
|
)
|
|
return None
|
|
|
|
|
|
class GraphitiEmbedderConfig(BaseModel):
|
|
"""Configuration for the embedder client.
|
|
|
|
Centralizes all embedding-related configuration parameters.
|
|
"""
|
|
|
|
model: str = DEFAULT_EMBEDDER_MODEL
|
|
api_key: str | None = None
|
|
azure_openai_endpoint: str | None = None
|
|
azure_openai_deployment_name: str | None = None
|
|
azure_openai_api_version: str | None = None
|
|
azure_openai_use_managed_identity: bool = False
|
|
|
|
@classmethod
|
|
def from_env(cls) -> 'GraphitiEmbedderConfig':
|
|
"""Create embedder configuration from environment variables."""
|
|
|
|
# Get model from environment, or use default if not set or empty
|
|
model_env = os.environ.get('EMBEDDER_MODEL_NAME', '')
|
|
model = model_env if model_env.strip() else DEFAULT_EMBEDDER_MODEL
|
|
|
|
azure_openai_endpoint = os.environ.get('AZURE_OPENAI_ENDPOINT', None)
|
|
azure_openai_api_version = os.environ.get('AZURE_OPENAI_EMBEDDING_API_VERSION', None)
|
|
azure_openai_deployment_name = os.environ.get(
|
|
'AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME', None
|
|
)
|
|
azure_openai_use_managed_identity = (
|
|
os.environ.get('AZURE_OPENAI_USE_MANAGED_IDENTITY', 'false').lower() == 'true'
|
|
)
|
|
if azure_openai_endpoint is not None:
|
|
# Setup for Azure OpenAI API
|
|
# Log if empty deployment name was provided
|
|
azure_openai_deployment_name = os.environ.get(
|
|
'AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME', None
|
|
)
|
|
if azure_openai_deployment_name is None:
|
|
logger.error('AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME environment variable not set')
|
|
|
|
raise ValueError(
|
|
'AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME environment variable not set'
|
|
)
|
|
|
|
if not azure_openai_use_managed_identity:
|
|
# api key
|
|
api_key = os.environ.get('OPENAI_API_KEY', None)
|
|
else:
|
|
# Managed identity
|
|
api_key = None
|
|
|
|
return cls(
|
|
azure_openai_use_managed_identity=azure_openai_use_managed_identity,
|
|
azure_openai_endpoint=azure_openai_endpoint,
|
|
api_key=api_key,
|
|
azure_openai_api_version=azure_openai_api_version,
|
|
azure_openai_deployment_name=azure_openai_deployment_name,
|
|
)
|
|
else:
|
|
return cls(
|
|
model=model,
|
|
api_key=os.environ.get('OPENAI_API_KEY'),
|
|
)
|
|
|
|
def create_client(self) -> EmbedderClient | None:
|
|
embedder_client: AsyncAzureOpenAI | None = None
|
|
if self.azure_openai_endpoint is not None:
|
|
# Azure OpenAI API setup
|
|
if self.azure_openai_use_managed_identity:
|
|
# Use managed identity for authentication
|
|
token_provider = create_azure_credential_token_provider()
|
|
embedder_client = AsyncAzureOpenAI(
|
|
azure_endpoint=self.azure_openai_endpoint,
|
|
azure_deployment=self.azure_openai_deployment_name,
|
|
api_version=self.azure_openai_api_version,
|
|
azure_ad_token_provider=token_provider,
|
|
)
|
|
elif self.api_key:
|
|
# Use API key for authentication
|
|
embedder_client = AsyncAzureOpenAI(
|
|
azure_endpoint=self.azure_openai_endpoint,
|
|
azure_deployment=self.azure_openai_deployment_name,
|
|
api_version=self.azure_openai_api_version,
|
|
api_key=self.api_key,
|
|
)
|
|
else:
|
|
logger.error('OPENAI_API_KEY must be set when using Azure OpenAI API')
|
|
return None
|
|
# Wrap the Azure client in OpenAIEmbedder
|
|
if embedder_client:
|
|
# Assuming OpenAIEmbedder can accept an initialized client
|
|
return OpenAIEmbedder(client=embedder_client) # type: ignore
|
|
return None
|
|
else:
|
|
# OpenAI API setup
|
|
if not self.api_key:
|
|
return None
|
|
|
|
embedder_config = OpenAIEmbedderConfig(api_key=self.api_key, embedding_model=self.model)
|
|
|
|
return OpenAIEmbedder(config=embedder_config)
|
|
|
|
|
|
class Neo4jConfig(BaseModel):
|
|
"""Configuration for Neo4j database connection."""
|
|
|
|
uri: str = 'bolt://localhost:7687'
|
|
user: str = 'neo4j'
|
|
password: str = 'password'
|
|
|
|
@classmethod
|
|
def from_env(cls) -> 'Neo4jConfig':
|
|
"""Create Neo4j configuration from environment variables."""
|
|
return cls(
|
|
uri=os.environ.get('NEO4J_URI', 'bolt://localhost:7687'),
|
|
user=os.environ.get('NEO4J_USER', 'neo4j'),
|
|
password=os.environ.get('NEO4J_PASSWORD', 'password'),
|
|
)
|
|
|
|
|
|
class GraphitiConfig(BaseModel):
|
|
"""Configuration for Graphiti client.
|
|
|
|
Centralizes all configuration parameters for the Graphiti client.
|
|
"""
|
|
|
|
llm: GraphitiLLMConfig = Field(default_factory=GraphitiLLMConfig)
|
|
embedder: GraphitiEmbedderConfig = Field(default_factory=GraphitiEmbedderConfig)
|
|
neo4j: Neo4jConfig = Field(default_factory=Neo4jConfig)
|
|
group_id: str | None = None
|
|
use_custom_entities: bool = False
|
|
destroy_graph: bool = False
|
|
|
|
@classmethod
|
|
def from_env(cls) -> 'GraphitiConfig':
|
|
"""Create a configuration instance from environment variables."""
|
|
return cls(
|
|
llm=GraphitiLLMConfig.from_env(),
|
|
embedder=GraphitiEmbedderConfig.from_env(),
|
|
neo4j=Neo4jConfig.from_env(),
|
|
)
|
|
|
|
@classmethod
|
|
def from_cli_and_env(cls, args: argparse.Namespace) -> 'GraphitiConfig':
|
|
"""Create configuration from CLI arguments, falling back to environment variables."""
|
|
# Start with environment configuration
|
|
config = cls.from_env()
|
|
|
|
# Apply CLI overrides
|
|
if args.group_id:
|
|
config.group_id = args.group_id
|
|
else:
|
|
config.group_id = f'graph_{uuid.uuid4().hex[:8]}'
|
|
|
|
config.use_custom_entities = args.use_custom_entities
|
|
config.destroy_graph = args.destroy_graph
|
|
|
|
# Update LLM config using CLI args
|
|
config.llm = GraphitiLLMConfig.from_cli_and_env(args)
|
|
|
|
return config
|
|
|
|
|
|
class MCPConfig(BaseModel):
|
|
"""Configuration for MCP server."""
|
|
|
|
transport: str = 'sse' # Default to SSE transport
|
|
|
|
@classmethod
|
|
def from_cli(cls, args: argparse.Namespace) -> 'MCPConfig':
|
|
"""Create MCP configuration from CLI arguments."""
|
|
return cls(transport=args.transport)
|
|
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
|
stream=sys.stderr,
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Create global config instance - will be properly initialized later
|
|
config = GraphitiConfig()
|
|
|
|
# MCP server instructions
|
|
GRAPHITI_MCP_INSTRUCTIONS = """
|
|
Welcome to Graphiti MCP - a memory service for AI agents built on a knowledge graph. Graphiti performs well
|
|
with dynamic data such as user interactions, changing enterprise data, and external information.
|
|
|
|
Graphiti transforms information into a richly connected knowledge network, allowing you to
|
|
capture relationships between concepts, entities, and information. The system organizes data as episodes
|
|
(content snippets), nodes (entities), and facts (relationships between entities), creating a dynamic,
|
|
queryable memory store that evolves with new information. Graphiti supports multiple data formats, including
|
|
structured JSON data, enabling seamless integration with existing data pipelines and systems.
|
|
|
|
Facts contain temporal metadata, allowing you to track the time of creation and whether a fact is invalid
|
|
(superseded by new information).
|
|
|
|
Key capabilities:
|
|
1. Add episodes (text, messages, or JSON) to the knowledge graph with the add_episode tool
|
|
2. Search for nodes (entities) in the graph using natural language queries with search_nodes
|
|
3. Find relevant facts (relationships between entities) with search_facts
|
|
4. Retrieve specific entity edges or episodes by UUID
|
|
5. Manage the knowledge graph with tools like delete_episode, delete_entity_edge, and clear_graph
|
|
|
|
The server connects to a database for persistent storage and uses language models for certain operations.
|
|
Each piece of information is organized by group_id, allowing you to maintain separate knowledge domains.
|
|
|
|
When adding information, provide descriptive names and detailed content to improve search quality.
|
|
When searching, use specific queries and consider filtering by group_id for more relevant results.
|
|
|
|
For optimal performance, ensure the database is properly configured and accessible, and valid
|
|
API keys are provided for any language model operations.
|
|
"""
|
|
|
|
# MCP server instance
|
|
mcp = FastMCP(
|
|
'graphiti',
|
|
instructions=GRAPHITI_MCP_INSTRUCTIONS,
|
|
)
|
|
|
|
# Initialize Graphiti client
|
|
graphiti_client: Graphiti | None = None
|
|
|
|
|
|
async def initialize_graphiti():
|
|
"""Initialize the Graphiti client with the configured settings."""
|
|
global graphiti_client, config
|
|
|
|
try:
|
|
# --- Critical Check: Embedder API Key ---
|
|
# Embeddings always use OpenAI/Azure, so we need the key regardless of LLM provider
|
|
# Check if using standard OpenAI embeddings and API key is missing
|
|
if not config.embedder.azure_openai_endpoint and not config.embedder.api_key:
|
|
raise ValueError(
|
|
'OPENAI_API_KEY must be set for embeddings (via OPENAI_API_KEY env var), '
|
|
'even if using a different LLM provider like Anthropic.'
|
|
)
|
|
# Check if using Azure embeddings and relevant Azure config/key is missing
|
|
elif (
|
|
config.embedder.azure_openai_endpoint
|
|
and not config.embedder.azure_openai_use_managed_identity
|
|
and not config.embedder.api_key
|
|
):
|
|
raise ValueError(
|
|
'OPENAI_API_KEY must be set for Azure embeddings (via OPENAI_API_KEY env var) '
|
|
'if not using managed identity.'
|
|
)
|
|
|
|
# Create LLM client if possible
|
|
llm_client = config.llm.create_client()
|
|
# Check if custom entities require an LLM and if the selected LLM provider is configured
|
|
if config.use_custom_entities and not llm_client:
|
|
missing_key = (
|
|
'ANTHROPIC_API_KEY'
|
|
if config.llm.model_provider == 'anthropic'
|
|
else 'OPENAI_API_KEY or Azure equivalent'
|
|
)
|
|
raise ValueError(
|
|
f"{missing_key} must be set when custom entities are enabled and model provider is '{config.llm.model_provider}'."
|
|
)
|
|
|
|
# Validate Neo4j configuration
|
|
if not config.neo4j.uri or not config.neo4j.user or not config.neo4j.password:
|
|
raise ValueError('NEO4J_URI, NEO4J_USER, and NEO4J_PASSWORD must be set')
|
|
|
|
embedder_client = config.embedder.create_client()
|
|
cross_encoder_client = config.llm.create_cross_encoder_client()
|
|
|
|
# Initialize Graphiti client
|
|
graphiti_client = Graphiti(
|
|
uri=config.neo4j.uri,
|
|
user=config.neo4j.user,
|
|
password=config.neo4j.password,
|
|
llm_client=llm_client,
|
|
embedder=embedder_client,
|
|
cross_encoder=cross_encoder_client,
|
|
)
|
|
|
|
# Destroy graph if requested
|
|
if config.destroy_graph:
|
|
logger.info('Destroying graph...')
|
|
await clear_data(graphiti_client.driver)
|
|
|
|
# Initialize the graph database with Graphiti's indices
|
|
await graphiti_client.build_indices_and_constraints()
|
|
logger.info('Graphiti client initialized successfully')
|
|
|
|
# Log configuration details for transparency
|
|
if llm_client:
|
|
logger.info(f'Using OpenAI model: {config.llm.model}')
|
|
logger.info(f'Using temperature: {config.llm.temperature}')
|
|
else:
|
|
logger.info('No LLM client configured - entity extraction will be limited')
|
|
|
|
logger.info(f'Using group_id: {config.group_id}')
|
|
logger.info(
|
|
f'Custom entity extraction: {"enabled" if config.use_custom_entities else "disabled"}'
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f'Failed to initialize Graphiti: {str(e)}')
|
|
raise
|
|
|
|
|
|
def format_fact_result(edge: EntityEdge) -> dict[str, Any]:
|
|
"""Format an entity edge into a readable result.
|
|
|
|
Since EntityEdge is a Pydantic BaseModel, we can use its built-in serialization capabilities.
|
|
|
|
Args:
|
|
edge: The EntityEdge to format
|
|
|
|
Returns:
|
|
A dictionary representation of the edge with serialized dates and excluded embeddings
|
|
"""
|
|
return edge.model_dump(
|
|
mode='json',
|
|
exclude={
|
|
'fact_embedding',
|
|
},
|
|
)
|
|
|
|
|
|
# Dictionary to store queues for each group_id
|
|
# Each queue is a list of tasks to be processed sequentially
|
|
episode_queues: dict[str, asyncio.Queue] = {}
|
|
# Dictionary to track if a worker is running for each group_id
|
|
queue_workers: dict[str, bool] = {}
|
|
|
|
|
|
async def process_episode_queue(group_id: str):
|
|
"""Process episodes for a specific group_id sequentially.
|
|
|
|
This function runs as a long-lived task that processes episodes
|
|
from the queue one at a time.
|
|
"""
|
|
global queue_workers
|
|
|
|
logger.info(f'Starting episode queue worker for group_id: {group_id}')
|
|
queue_workers[group_id] = True
|
|
|
|
try:
|
|
while True:
|
|
# Get the next episode processing function from the queue
|
|
# This will wait if the queue is empty
|
|
process_func = await episode_queues[group_id].get()
|
|
|
|
try:
|
|
# Process the episode
|
|
await process_func()
|
|
except Exception as e:
|
|
logger.error(f'Error processing queued episode for group_id {group_id}: {str(e)}')
|
|
finally:
|
|
# Mark the task as done regardless of success/failure
|
|
episode_queues[group_id].task_done()
|
|
except asyncio.CancelledError:
|
|
logger.info(f'Episode queue worker for group_id {group_id} was cancelled')
|
|
except Exception as e:
|
|
logger.error(f'Unexpected error in queue worker for group_id {group_id}: {str(e)}')
|
|
finally:
|
|
queue_workers[group_id] = False
|
|
logger.info(f'Stopped episode queue worker for group_id: {group_id}')
|
|
|
|
|
|
@mcp.tool()
|
|
async def add_episode(
|
|
name: str,
|
|
episode_body: str,
|
|
group_id: str | None = None,
|
|
source: str = 'text',
|
|
source_description: str = '',
|
|
uuid: str | None = None,
|
|
) -> SuccessResponse | ErrorResponse:
|
|
"""Add an episode to the Graphiti knowledge graph. This is the primary way to add information to the graph.
|
|
|
|
This function returns immediately and processes the episode addition in the background.
|
|
Episodes for the same group_id are processed sequentially to avoid race conditions.
|
|
|
|
Args:
|
|
name (str): Name of the episode
|
|
episode_body (str): The content of the episode. When source='json', this must be a properly escaped JSON string,
|
|
not a raw Python dictionary. The JSON data will be automatically processed
|
|
to extract entities and relationships.
|
|
group_id (str, optional): A unique ID for this graph. If not provided, uses the default group_id from CLI
|
|
or a generated one.
|
|
source (str, optional): Source type, must be one of:
|
|
- 'text': For plain text content (default)
|
|
- 'json': For structured data
|
|
- 'message': For conversation-style content
|
|
source_description (str, optional): Description of the source
|
|
uuid (str, optional): Optional UUID for the episode
|
|
|
|
Examples:
|
|
# Adding plain text content
|
|
add_episode(
|
|
name="Company News",
|
|
episode_body="Acme Corp announced a new product line today.",
|
|
source="text",
|
|
source_description="news article",
|
|
group_id="some_arbitrary_string"
|
|
)
|
|
|
|
# Adding structured JSON data
|
|
# NOTE: episode_body must be a properly escaped JSON string. Note the triple backslashes
|
|
add_episode(
|
|
name="Customer Profile",
|
|
episode_body="{\\\"company\\\": {\\\"name\\\": \\\"Acme Technologies\\\"}, \\\"products\\\": [{\\\"id\\\": \\\"P001\\\", \\\"name\\\": \\\"CloudSync\\\"}, {\\\"id\\\": \\\"P002\\\", \\\"name\\\": \\\"DataMiner\\\"}]}",
|
|
source="json",
|
|
source_description="CRM data"
|
|
)
|
|
|
|
# Adding message-style content
|
|
add_episode(
|
|
name="Customer Conversation",
|
|
episode_body="user: What's your return policy?\nassistant: You can return items within 30 days.",
|
|
source="message",
|
|
source_description="chat transcript",
|
|
group_id="some_arbitrary_string"
|
|
)
|
|
|
|
Notes:
|
|
When using source='json':
|
|
- The JSON must be a properly escaped string, not a raw Python dictionary
|
|
- The JSON will be automatically processed to extract entities and relationships
|
|
- Complex nested structures are supported (arrays, nested objects, mixed data types), but keep nesting to a minimum
|
|
- Entities will be created from appropriate JSON properties
|
|
- Relationships between entities will be established based on the JSON structure
|
|
"""
|
|
global graphiti_client, episode_queues, queue_workers
|
|
|
|
if graphiti_client is None:
|
|
return {'error': 'Graphiti client not initialized'}
|
|
|
|
try:
|
|
# Map string source to EpisodeType enum
|
|
source_type = EpisodeType.text
|
|
if source.lower() == 'message':
|
|
source_type = EpisodeType.message
|
|
elif source.lower() == 'json':
|
|
source_type = EpisodeType.json
|
|
|
|
# Use the provided group_id or fall back to the default from config
|
|
effective_group_id = group_id if group_id is not None else config.group_id
|
|
|
|
# Cast group_id to str to satisfy type checker
|
|
# The Graphiti client expects a str for group_id, not Optional[str]
|
|
group_id_str = str(effective_group_id) if effective_group_id is not None else ''
|
|
|
|
# We've already checked that graphiti_client is not None above
|
|
# This assert statement helps type checkers understand that graphiti_client is defined
|
|
assert graphiti_client is not None, 'graphiti_client should not be None here'
|
|
|
|
# Use cast to help the type checker understand that graphiti_client is not None
|
|
client = cast(Graphiti, graphiti_client)
|
|
|
|
# Define the episode processing function
|
|
async def process_episode():
|
|
try:
|
|
logger.info(f"Processing queued episode '{name}' for group_id: {group_id_str}")
|
|
# Use all entity types if use_custom_entities is enabled, otherwise use empty dict
|
|
entity_types = ENTITY_TYPES if config.use_custom_entities else {}
|
|
|
|
await client.add_episode(
|
|
name=name,
|
|
episode_body=episode_body,
|
|
source=source_type,
|
|
source_description=source_description,
|
|
group_id=group_id_str, # Using the string version of group_id
|
|
uuid=uuid,
|
|
reference_time=datetime.now(timezone.utc),
|
|
entity_types=entity_types,
|
|
)
|
|
logger.info(f"Episode '{name}' added successfully")
|
|
|
|
logger.info(f"Building communities after episode '{name}'")
|
|
await client.build_communities()
|
|
|
|
logger.info(f"Episode '{name}' processed successfully")
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
logger.error(
|
|
f"Error processing episode '{name}' for group_id {group_id_str}: {error_msg}"
|
|
)
|
|
|
|
# Initialize queue for this group_id if it doesn't exist
|
|
if group_id_str not in episode_queues:
|
|
episode_queues[group_id_str] = asyncio.Queue()
|
|
|
|
# Add the episode processing function to the queue
|
|
await episode_queues[group_id_str].put(process_episode)
|
|
|
|
# Start a worker for this queue if one isn't already running
|
|
if not queue_workers.get(group_id_str, False):
|
|
asyncio.create_task(process_episode_queue(group_id_str))
|
|
|
|
# Return immediately with a success message
|
|
return {
|
|
'message': f"Episode '{name}' queued for processing (position: {episode_queues[group_id_str].qsize()})"
|
|
}
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
logger.error(f'Error queuing episode task: {error_msg}')
|
|
return {'error': f'Error queuing episode task: {error_msg}'}
|
|
|
|
|
|
@mcp.tool()
|
|
async def search_nodes(
|
|
query: str,
|
|
group_ids: list[str] | None = None,
|
|
max_nodes: int = 10,
|
|
center_node_uuid: str | None = None,
|
|
entity: str = '', # cursor seems to break with None
|
|
) -> NodeSearchResponse | ErrorResponse:
|
|
"""Search the Graphiti knowledge graph for relevant node summaries.
|
|
These contain a summary of all of a node's relationships with other nodes.
|
|
|
|
Note: entity is a single entity type to filter results (permitted: "Preference", "Procedure").
|
|
|
|
Args:
|
|
query: The search query
|
|
group_ids: Optional list of group IDs to filter results
|
|
max_nodes: Maximum number of nodes to return (default: 10)
|
|
center_node_uuid: Optional UUID of a node to center the search around
|
|
entity: Optional single entity type to filter results (permitted: "Preference", "Procedure")
|
|
"""
|
|
global graphiti_client
|
|
|
|
if graphiti_client is None:
|
|
return ErrorResponse(error='Graphiti client not initialized')
|
|
|
|
try:
|
|
# Use the provided group_ids or fall back to the default from config if none provided
|
|
effective_group_ids = (
|
|
group_ids if group_ids is not None else [config.group_id] if config.group_id else []
|
|
)
|
|
|
|
# Configure the search
|
|
if center_node_uuid is not None:
|
|
search_config = NODE_HYBRID_SEARCH_NODE_DISTANCE.model_copy(deep=True)
|
|
else:
|
|
search_config = NODE_HYBRID_SEARCH_RRF.model_copy(deep=True)
|
|
search_config.limit = max_nodes
|
|
|
|
filters = SearchFilters()
|
|
if entity != '':
|
|
filters.node_labels = [entity]
|
|
|
|
# We've already checked that graphiti_client is not None above
|
|
assert graphiti_client is not None
|
|
|
|
# Use cast to help the type checker understand that graphiti_client is not None
|
|
client = cast(Graphiti, graphiti_client)
|
|
|
|
# Perform the search using the _search method
|
|
search_results = await client._search(
|
|
query=query,
|
|
config=search_config,
|
|
group_ids=effective_group_ids,
|
|
center_node_uuid=center_node_uuid,
|
|
search_filter=filters,
|
|
)
|
|
|
|
if not search_results.nodes:
|
|
return NodeSearchResponse(message='No relevant nodes found', nodes=[])
|
|
|
|
# Format the node results
|
|
formatted_nodes: list[NodeResult] = [
|
|
{
|
|
'uuid': node.uuid,
|
|
'name': node.name,
|
|
'summary': node.summary if hasattr(node, 'summary') else '',
|
|
'labels': node.labels if hasattr(node, 'labels') else [],
|
|
'group_id': node.group_id,
|
|
'created_at': node.created_at.isoformat(),
|
|
'attributes': node.attributes if hasattr(node, 'attributes') else {},
|
|
}
|
|
for node in search_results.nodes
|
|
]
|
|
|
|
return NodeSearchResponse(message='Nodes retrieved successfully', nodes=formatted_nodes)
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
logger.error(f'Error searching nodes: {error_msg}')
|
|
return ErrorResponse(error=f'Error searching nodes: {error_msg}')
|
|
|
|
|
|
@mcp.tool()
|
|
async def search_facts(
|
|
query: str,
|
|
group_ids: list[str] | None = None,
|
|
max_facts: int = 10,
|
|
center_node_uuid: str | None = None,
|
|
) -> FactSearchResponse | ErrorResponse:
|
|
"""Search the Graphiti knowledge graph for relevant facts.
|
|
|
|
Args:
|
|
query: The search query
|
|
group_ids: Optional list of group IDs to filter results
|
|
max_facts: Maximum number of facts to return (default: 10)
|
|
center_node_uuid: Optional UUID of a node to center the search around
|
|
"""
|
|
global graphiti_client
|
|
|
|
if graphiti_client is None:
|
|
return {'error': 'Graphiti client not initialized'}
|
|
|
|
try:
|
|
# Use the provided group_ids or fall back to the default from config if none provided
|
|
effective_group_ids = (
|
|
group_ids if group_ids is not None else [config.group_id] if config.group_id else []
|
|
)
|
|
|
|
# We've already checked that graphiti_client is not None above
|
|
assert graphiti_client is not None
|
|
|
|
# Use cast to help the type checker understand that graphiti_client is not None
|
|
client = cast(Graphiti, graphiti_client)
|
|
|
|
relevant_edges = await client.search(
|
|
group_ids=effective_group_ids,
|
|
query=query,
|
|
num_results=max_facts,
|
|
center_node_uuid=center_node_uuid,
|
|
)
|
|
|
|
if not relevant_edges:
|
|
return {'message': 'No relevant facts found', 'facts': []}
|
|
|
|
facts = [format_fact_result(edge) for edge in relevant_edges]
|
|
return {'message': 'Facts retrieved successfully', 'facts': facts}
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
logger.error(f'Error searching facts: {error_msg}')
|
|
return {'error': f'Error searching facts: {error_msg}'}
|
|
|
|
|
|
@mcp.tool()
|
|
async def delete_entity_edge(uuid: str) -> SuccessResponse | ErrorResponse:
|
|
"""Delete an entity edge from the Graphiti knowledge graph.
|
|
|
|
Args:
|
|
uuid: UUID of the entity edge to delete
|
|
"""
|
|
global graphiti_client
|
|
|
|
if graphiti_client is None:
|
|
return {'error': 'Graphiti client not initialized'}
|
|
|
|
try:
|
|
# We've already checked that graphiti_client is not None above
|
|
assert graphiti_client is not None
|
|
|
|
# Use cast to help the type checker understand that graphiti_client is not None
|
|
client = cast(Graphiti, graphiti_client)
|
|
|
|
# Get the entity edge by UUID
|
|
entity_edge = await EntityEdge.get_by_uuid(client.driver, uuid)
|
|
# Delete the edge using its delete method
|
|
await entity_edge.delete(client.driver)
|
|
return {'message': f'Entity edge with UUID {uuid} deleted successfully'}
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
logger.error(f'Error deleting entity edge: {error_msg}')
|
|
return {'error': f'Error deleting entity edge: {error_msg}'}
|
|
|
|
|
|
@mcp.tool()
|
|
async def delete_episode(uuid: str) -> SuccessResponse | ErrorResponse:
|
|
"""Delete an episode from the Graphiti knowledge graph.
|
|
|
|
Args:
|
|
uuid: UUID of the episode to delete
|
|
"""
|
|
global graphiti_client
|
|
|
|
if graphiti_client is None:
|
|
return {'error': 'Graphiti client not initialized'}
|
|
|
|
try:
|
|
# We've already checked that graphiti_client is not None above
|
|
assert graphiti_client is not None
|
|
|
|
# Use cast to help the type checker understand that graphiti_client is not None
|
|
client = cast(Graphiti, graphiti_client)
|
|
|
|
# Get the episodic node by UUID - EpisodicNode is already imported at the top
|
|
episodic_node = await EpisodicNode.get_by_uuid(client.driver, uuid)
|
|
# Delete the node using its delete method
|
|
await episodic_node.delete(client.driver)
|
|
return {'message': f'Episode with UUID {uuid} deleted successfully'}
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
logger.error(f'Error deleting episode: {error_msg}')
|
|
return {'error': f'Error deleting episode: {error_msg}'}
|
|
|
|
|
|
@mcp.tool()
|
|
async def get_entity_edge(uuid: str) -> dict[str, Any] | ErrorResponse:
|
|
"""Get an entity edge from the Graphiti knowledge graph by its UUID.
|
|
|
|
Args:
|
|
uuid: UUID of the entity edge to retrieve
|
|
"""
|
|
global graphiti_client
|
|
|
|
if graphiti_client is None:
|
|
return {'error': 'Graphiti client not initialized'}
|
|
|
|
try:
|
|
# We've already checked that graphiti_client is not None above
|
|
assert graphiti_client is not None
|
|
|
|
# Use cast to help the type checker understand that graphiti_client is not None
|
|
client = cast(Graphiti, graphiti_client)
|
|
|
|
# Get the entity edge directly using the EntityEdge class method
|
|
entity_edge = await EntityEdge.get_by_uuid(client.driver, uuid)
|
|
|
|
# Use the format_fact_result function to serialize the edge
|
|
# Return the Python dict directly - MCP will handle serialization
|
|
return format_fact_result(entity_edge)
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
logger.error(f'Error getting entity edge: {error_msg}')
|
|
return {'error': f'Error getting entity edge: {error_msg}'}
|
|
|
|
|
|
@mcp.tool()
|
|
async def get_episodes(
|
|
group_id: str | None = None, last_n: int = 10
|
|
) -> list[dict[str, Any]] | EpisodeSearchResponse | ErrorResponse:
|
|
"""Get the most recent episodes for a specific group.
|
|
|
|
Args:
|
|
group_id: ID of the group to retrieve episodes from. If not provided, uses the default group_id.
|
|
last_n: Number of most recent episodes to retrieve (default: 10)
|
|
"""
|
|
global graphiti_client
|
|
|
|
if graphiti_client is None:
|
|
return {'error': 'Graphiti client not initialized'}
|
|
|
|
try:
|
|
# Use the provided group_id or fall back to the default from config
|
|
effective_group_id = group_id if group_id is not None else config.group_id
|
|
|
|
if not isinstance(effective_group_id, str):
|
|
return {'error': 'Group ID must be a string'}
|
|
|
|
# We've already checked that graphiti_client is not None above
|
|
assert graphiti_client is not None
|
|
|
|
# Use cast to help the type checker understand that graphiti_client is not None
|
|
client = cast(Graphiti, graphiti_client)
|
|
|
|
episodes = await client.retrieve_episodes(
|
|
group_ids=[effective_group_id], last_n=last_n, reference_time=datetime.now(timezone.utc)
|
|
)
|
|
|
|
if not episodes:
|
|
return {'message': f'No episodes found for group {effective_group_id}', 'episodes': []}
|
|
|
|
# Use Pydantic's model_dump method for EpisodicNode serialization
|
|
formatted_episodes = [
|
|
# Use mode='json' to handle datetime serialization
|
|
episode.model_dump(mode='json')
|
|
for episode in episodes
|
|
]
|
|
|
|
# Return the Python list directly - MCP will handle serialization
|
|
return formatted_episodes
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
logger.error(f'Error getting episodes: {error_msg}')
|
|
return {'error': f'Error getting episodes: {error_msg}'}
|
|
|
|
|
|
@mcp.tool()
|
|
async def clear_graph() -> SuccessResponse | ErrorResponse:
|
|
"""Clear all data from the Graphiti knowledge graph and rebuild indices."""
|
|
global graphiti_client
|
|
|
|
if graphiti_client is None:
|
|
return {'error': 'Graphiti client not initialized'}
|
|
|
|
try:
|
|
# We've already checked that graphiti_client is not None above
|
|
assert graphiti_client is not None
|
|
|
|
# Use cast to help the type checker understand that graphiti_client is not None
|
|
client = cast(Graphiti, graphiti_client)
|
|
|
|
# clear_data is already imported at the top
|
|
await clear_data(client.driver)
|
|
await client.build_indices_and_constraints()
|
|
return {'message': 'Graph cleared successfully and indices rebuilt'}
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
logger.error(f'Error clearing graph: {error_msg}')
|
|
return {'error': f'Error clearing graph: {error_msg}'}
|
|
|
|
|
|
@mcp.resource('http://graphiti/status')
|
|
async def get_status() -> StatusResponse:
|
|
"""Get the status of the Graphiti MCP server and Neo4j connection."""
|
|
global graphiti_client
|
|
|
|
if graphiti_client is None:
|
|
return {'status': 'error', 'message': 'Graphiti client not initialized'}
|
|
|
|
try:
|
|
# We've already checked that graphiti_client is not None above
|
|
assert graphiti_client is not None
|
|
|
|
# Use cast to help the type checker understand that graphiti_client is not None
|
|
client = cast(Graphiti, graphiti_client)
|
|
|
|
# Test Neo4j connection
|
|
await client.driver.verify_connectivity()
|
|
return {'status': 'ok', 'message': 'Graphiti MCP server is running and connected to Neo4j'}
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
logger.error(f'Error checking Neo4j connection: {error_msg}')
|
|
return {
|
|
'status': 'error',
|
|
'message': f'Graphiti MCP server is running but Neo4j connection failed: {error_msg}',
|
|
}
|
|
|
|
|
|
async def initialize_server() -> MCPConfig:
|
|
"""Parse CLI arguments and initialize the Graphiti server configuration."""
|
|
global config
|
|
|
|
parser = argparse.ArgumentParser(
|
|
description='Run the Graphiti MCP server with optional LLM client'
|
|
)
|
|
parser.add_argument(
|
|
'--group-id',
|
|
help='Namespace for the graph. This is an arbitrary string used to organize related data. '
|
|
'If not provided, a random UUID will be generated.',
|
|
)
|
|
parser.add_argument(
|
|
'--transport',
|
|
choices=['sse', 'stdio'],
|
|
default='sse',
|
|
help='Transport to use for communication with the client. (default: sse)',
|
|
)
|
|
parser.add_argument(
|
|
'--model-provider',
|
|
choices=['openai', 'anthropic', 'azure_openai'],
|
|
default='openai',
|
|
help='Which LLM provider to use (openai, azure_openai, anthropic). Affects --model and API key usage. (default: openai)',
|
|
)
|
|
parser.add_argument(
|
|
'--model', help=f'Model name to use with the LLM client. (default: {DEFAULT_LLM_MODEL})'
|
|
)
|
|
parser.add_argument(
|
|
'--temperature',
|
|
type=float,
|
|
help='Temperature setting for the LLM (0.0-2.0). Lower values make output more deterministic. (default: 0.7)',
|
|
)
|
|
parser.add_argument('--destroy-graph', action='store_true', help='Destroy all Graphiti graphs')
|
|
parser.add_argument(
|
|
'--use-custom-entities',
|
|
action='store_true',
|
|
help='Enable entity extraction using the predefined ENTITY_TYPES',
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Build configuration from CLI arguments and environment variables
|
|
config = GraphitiConfig.from_cli_and_env(args)
|
|
|
|
# Log the group ID configuration
|
|
if args.group_id:
|
|
logger.info(f'Using provided group_id: {config.group_id}')
|
|
else:
|
|
logger.info(f'Generated random group_id: {config.group_id}')
|
|
|
|
# Log entity extraction configuration
|
|
logger.info(f'Using LLM Provider: {config.llm.model_provider}')
|
|
if config.use_custom_entities:
|
|
logger.info('Entity extraction enabled using predefined ENTITY_TYPES')
|
|
else:
|
|
logger.info('Entity extraction disabled (no custom entities will be used)')
|
|
|
|
# Initialize Graphiti
|
|
await initialize_graphiti()
|
|
|
|
# Return MCP configuration
|
|
return MCPConfig.from_cli(args)
|
|
|
|
|
|
async def run_mcp_server():
|
|
"""Run the MCP server in the current event loop."""
|
|
# Initialize the server
|
|
mcp_config = await initialize_server()
|
|
|
|
# Run the server with stdio transport for MCP in the same event loop
|
|
logger.info(f'Starting MCP server with transport: {mcp_config.transport}')
|
|
if mcp_config.transport == 'stdio':
|
|
await mcp.run_stdio_async()
|
|
elif mcp_config.transport == 'sse':
|
|
logger.info(
|
|
f'Running MCP server with SSE transport on {mcp.settings.host}:{mcp.settings.port}'
|
|
)
|
|
await mcp.run_sse_async()
|
|
|
|
|
|
def main():
|
|
"""Main function to run the Graphiti MCP server."""
|
|
try:
|
|
# Run everything in a single event loop
|
|
asyncio.run(run_mcp_server())
|
|
except Exception as e:
|
|
logger.error(f'Error initializing Graphiti MCP server: {str(e)}')
|
|
raise
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|