This commit is contained in:
Daniel Chalef 2025-07-08 22:20:26 -07:00
parent 119a43b8e4
commit 452a45cb4e
18 changed files with 1877 additions and 641 deletions

View file

@ -34,7 +34,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
uv sync --frozen --no-dev
# Copy application code
COPY graphiti_mcp_server.py ./
COPY *.py ./
# Change ownership to app user
RUN chown -Rv app:app /app

View file

@ -0,0 +1,51 @@
"""Unified configuration manager for Graphiti MCP Server."""
import argparse
from embedder_config import GraphitiEmbedderConfig
from llm_config import GraphitiLLMConfig
from neo4j_config import Neo4jConfig
from pydantic import BaseModel, Field
class GraphitiConfig(BaseModel):
"""Configuration for Graphiti client.
Centralizes all configuration parameters for the Graphiti client.
"""
llm: GraphitiLLMConfig = Field(default_factory=GraphitiLLMConfig)
embedder: GraphitiEmbedderConfig = Field(default_factory=GraphitiEmbedderConfig)
neo4j: Neo4jConfig = Field(default_factory=Neo4jConfig)
group_id: str | None = None
use_custom_entities: bool = False
destroy_graph: bool = False
@classmethod
def from_env(cls) -> 'GraphitiConfig':
"""Create a configuration instance from environment variables."""
return cls(
llm=GraphitiLLMConfig.from_env(),
embedder=GraphitiEmbedderConfig.from_env(),
neo4j=Neo4jConfig.from_env(),
)
@classmethod
def from_cli_and_env(cls, args: argparse.Namespace) -> 'GraphitiConfig':
"""Create configuration from CLI arguments, falling back to environment variables."""
# Start with environment configuration
config = cls.from_env()
# Apply CLI overrides
if args.group_id:
config.group_id = args.group_id
else:
config.group_id = 'default'
config.use_custom_entities = args.use_custom_entities
config.destroy_graph = args.destroy_graph
# Update LLM config using CLI args
config.llm = GraphitiLLMConfig.from_cli_and_env(args)
return config

View file

@ -0,0 +1,124 @@
"""Embedder configuration for Graphiti MCP Server."""
import logging
import os
from openai import AsyncAzureOpenAI
from pydantic import BaseModel
from utils import create_azure_credential_token_provider
from graphiti_core.embedder.azure_openai import AzureOpenAIEmbedderClient
from graphiti_core.embedder.client import EmbedderClient
from graphiti_core.embedder.openai import OpenAIEmbedder, OpenAIEmbedderConfig
logger = logging.getLogger(__name__)
DEFAULT_EMBEDDER_MODEL = 'text-embedding-3-small'
class GraphitiEmbedderConfig(BaseModel):
"""Configuration for the embedder client.
Centralizes all embedding-related configuration parameters.
"""
model: str = DEFAULT_EMBEDDER_MODEL
api_key: str | None = None
azure_openai_endpoint: str | None = None
azure_openai_deployment_name: str | None = None
azure_openai_api_version: str | None = None
azure_openai_use_managed_identity: bool = False
@classmethod
def from_env(cls) -> 'GraphitiEmbedderConfig':
"""Create embedder configuration from environment variables."""
# Get model from environment, or use default if not set or empty
model_env = os.environ.get('EMBEDDER_MODEL_NAME', '')
model = model_env if model_env.strip() else DEFAULT_EMBEDDER_MODEL
azure_openai_endpoint = os.environ.get('AZURE_OPENAI_EMBEDDING_ENDPOINT', None)
azure_openai_api_version = os.environ.get('AZURE_OPENAI_EMBEDDING_API_VERSION', None)
azure_openai_deployment_name = os.environ.get(
'AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME', None
)
azure_openai_use_managed_identity = (
os.environ.get('AZURE_OPENAI_USE_MANAGED_IDENTITY', 'false').lower() == 'true'
)
if azure_openai_endpoint is not None:
# Setup for Azure OpenAI API
# Log if empty deployment name was provided
azure_openai_deployment_name = os.environ.get(
'AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME', None
)
if azure_openai_deployment_name is None:
logger.error('AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME environment variable not set')
raise ValueError(
'AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME environment variable not set'
)
if not azure_openai_use_managed_identity:
# api key
api_key = os.environ.get('AZURE_OPENAI_EMBEDDING_API_KEY', None) or os.environ.get(
'OPENAI_API_KEY', None
)
else:
# Managed identity
api_key = None
return cls(
azure_openai_use_managed_identity=azure_openai_use_managed_identity,
azure_openai_endpoint=azure_openai_endpoint,
api_key=api_key,
azure_openai_api_version=azure_openai_api_version,
azure_openai_deployment_name=azure_openai_deployment_name,
model=model,
)
else:
return cls(
model=model,
api_key=os.environ.get('OPENAI_API_KEY'),
)
def create_client(self) -> EmbedderClient | None:
"""Create an embedder client based on this configuration.
Returns:
EmbedderClient instance or None if configuration is invalid
"""
if self.azure_openai_endpoint is not None:
# Azure OpenAI API setup
if self.azure_openai_use_managed_identity:
# Use managed identity for authentication
token_provider = create_azure_credential_token_provider()
return AzureOpenAIEmbedderClient(
azure_client=AsyncAzureOpenAI(
azure_endpoint=self.azure_openai_endpoint,
azure_deployment=self.azure_openai_deployment_name,
api_version=self.azure_openai_api_version,
azure_ad_token_provider=token_provider,
),
model=self.model,
)
elif self.api_key:
# Use API key for authentication
return AzureOpenAIEmbedderClient(
azure_client=AsyncAzureOpenAI(
azure_endpoint=self.azure_openai_endpoint,
azure_deployment=self.azure_openai_deployment_name,
api_version=self.azure_openai_api_version,
api_key=self.api_key,
),
model=self.model,
)
else:
logger.error('OPENAI_API_KEY must be set when using Azure OpenAI API')
return None
else:
# OpenAI API setup
if not self.api_key:
return None
embedder_config = OpenAIEmbedderConfig(api_key=self.api_key, embedding_model=self.model)
return OpenAIEmbedder(config=embedder_config)

View file

@ -0,0 +1,83 @@
"""Entity type definitions for Graphiti MCP Server."""
from pydantic import BaseModel, Field
class Requirement(BaseModel):
"""A Requirement represents a specific need, feature, or functionality that a product or service must fulfill.
Always ensure an edge is created between the requirement and the project it belongs to, and clearly indicate on the
edge that the requirement is a requirement.
Instructions for identifying and extracting requirements:
1. Look for explicit statements of needs or necessities ("We need X", "X is required", "X must have Y")
2. Identify functional specifications that describe what the system should do
3. Pay attention to non-functional requirements like performance, security, or usability criteria
4. Extract constraints or limitations that must be adhered to
5. Focus on clear, specific, and measurable requirements rather than vague wishes
6. Capture the priority or importance if mentioned ("critical", "high priority", etc.)
7. Include any dependencies between requirements when explicitly stated
8. Preserve the original intent and scope of the requirement
9. Categorize requirements appropriately based on their domain or function
"""
project_name: str = Field(
...,
description='The name of the project to which the requirement belongs.',
)
description: str = Field(
...,
description='Description of the requirement. Only use information mentioned in the context to write this description.',
)
class Preference(BaseModel):
"""A Preference represents a user's expressed like, dislike, or preference for something.
Instructions for identifying and extracting preferences:
1. Look for explicit statements of preference such as "I like/love/enjoy/prefer X" or "I don't like/hate/dislike X"
2. Pay attention to comparative statements ("I prefer X over Y")
3. Consider the emotional tone when users mention certain topics
4. Extract only preferences that are clearly expressed, not assumptions
5. Categorize the preference appropriately based on its domain (food, music, brands, etc.)
6. Include relevant qualifiers (e.g., "likes spicy food" rather than just "likes food")
7. Only extract preferences directly stated by the user, not preferences of others they mention
8. Provide a concise but specific description that captures the nature of the preference
"""
category: str = Field(
...,
description="The category of the preference. (e.g., 'Brands', 'Food', 'Music')",
)
description: str = Field(
...,
description='Brief description of the preference. Only use information mentioned in the context to write this description.',
)
class Procedure(BaseModel):
"""A Procedure informing the agent what actions to take or how to perform in certain scenarios. Procedures are typically composed of several steps.
Instructions for identifying and extracting procedures:
1. Look for sequential instructions or steps ("First do X, then do Y")
2. Identify explicit directives or commands ("Always do X when Y happens")
3. Pay attention to conditional statements ("If X occurs, then do Y")
4. Extract procedures that have clear beginning and end points
5. Focus on actionable instructions rather than general information
6. Preserve the original sequence and dependencies between steps
7. Include any specified conditions or triggers for the procedure
8. Capture any stated purpose or goal of the procedure
9. Summarize complex procedures while maintaining critical details
"""
description: str = Field(
...,
description='Brief description of the procedure. Only use information mentioned in the context to write this description.',
)
ENTITY_TYPES: dict[str, BaseModel] = {
'Requirement': Requirement, # type: ignore
'Preference': Preference, # type: ignore
'Procedure': Procedure, # type: ignore
}

26
mcp_server/formatting.py Normal file
View file

@ -0,0 +1,26 @@
"""Formatting utilities for Graphiti MCP Server."""
from typing import Any
from graphiti_core.edges import EntityEdge
def format_fact_result(edge: EntityEdge) -> dict[str, Any]:
"""Format an entity edge into a readable result.
Since EntityEdge is a Pydantic BaseModel, we can use its built-in serialization capabilities.
Args:
edge: The EntityEdge to format
Returns:
A dictionary representation of the edge with serialized dates and excluded embeddings
"""
result = edge.model_dump(
mode='json',
exclude={
'fact_embedding',
},
)
result.get('attributes', {}).pop('fact_embedding', None)
return result

View file

@ -8,25 +8,30 @@ import asyncio
import logging
import os
import sys
from collections.abc import Callable
from datetime import datetime, timezone
from typing import Any, TypedDict, cast
from typing import Any, cast
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
from config_manager import GraphitiConfig
from dotenv import load_dotenv
from entity_types import ENTITY_TYPES
from formatting import format_fact_result
from graphiti_service import GraphitiService
from llm_config import DEFAULT_LLM_MODEL, SMALL_LLM_MODEL
from mcp.server.fastmcp import FastMCP
from openai import AsyncAzureOpenAI
from pydantic import BaseModel, Field
from queue_service import QueueService
from response_types import (
EpisodeSearchResponse,
ErrorResponse,
FactSearchResponse,
NodeResult,
NodeSearchResponse,
StatusResponse,
SuccessResponse,
)
from server_config import MCPConfig
from graphiti_core import Graphiti
from graphiti_core.edges import EntityEdge
from graphiti_core.embedder.azure_openai import AzureOpenAIEmbedderClient
from graphiti_core.embedder.client import EmbedderClient
from graphiti_core.embedder.openai import OpenAIEmbedder, OpenAIEmbedderConfig
from graphiti_core.llm_client import LLMClient
from graphiti_core.llm_client.azure_openai_client import AzureOpenAILLMClient
from graphiti_core.llm_client.config import LLMConfig
from graphiti_core.llm_client.openai_client import OpenAIClient
from graphiti_core.nodes import EpisodeType, EpisodicNode
from graphiti_core.search.search_config_recipes import (
NODE_HYBRID_SEARCH_NODE_DISTANCE,
@ -38,488 +43,12 @@ from graphiti_core.utils.maintenance.graph_data_operations import clear_data
load_dotenv()
DEFAULT_LLM_MODEL = 'gpt-4.1-mini'
SMALL_LLM_MODEL = 'gpt-4.1-nano'
DEFAULT_EMBEDDER_MODEL = 'text-embedding-3-small'
# Semaphore limit for concurrent Graphiti operations.
# Decrease this if you're experiencing 429 rate limit errors from your LLM provider.
# Increase if you have high rate limits.
SEMAPHORE_LIMIT = int(os.getenv('SEMAPHORE_LIMIT', 10))
class Requirement(BaseModel):
"""A Requirement represents a specific need, feature, or functionality that a product or service must fulfill.
Always ensure an edge is created between the requirement and the project it belongs to, and clearly indicate on the
edge that the requirement is a requirement.
Instructions for identifying and extracting requirements:
1. Look for explicit statements of needs or necessities ("We need X", "X is required", "X must have Y")
2. Identify functional specifications that describe what the system should do
3. Pay attention to non-functional requirements like performance, security, or usability criteria
4. Extract constraints or limitations that must be adhered to
5. Focus on clear, specific, and measurable requirements rather than vague wishes
6. Capture the priority or importance if mentioned ("critical", "high priority", etc.)
7. Include any dependencies between requirements when explicitly stated
8. Preserve the original intent and scope of the requirement
9. Categorize requirements appropriately based on their domain or function
"""
project_name: str = Field(
...,
description='The name of the project to which the requirement belongs.',
)
description: str = Field(
...,
description='Description of the requirement. Only use information mentioned in the context to write this description.',
)
class Preference(BaseModel):
"""A Preference represents a user's expressed like, dislike, or preference for something.
Instructions for identifying and extracting preferences:
1. Look for explicit statements of preference such as "I like/love/enjoy/prefer X" or "I don't like/hate/dislike X"
2. Pay attention to comparative statements ("I prefer X over Y")
3. Consider the emotional tone when users mention certain topics
4. Extract only preferences that are clearly expressed, not assumptions
5. Categorize the preference appropriately based on its domain (food, music, brands, etc.)
6. Include relevant qualifiers (e.g., "likes spicy food" rather than just "likes food")
7. Only extract preferences directly stated by the user, not preferences of others they mention
8. Provide a concise but specific description that captures the nature of the preference
"""
category: str = Field(
...,
description="The category of the preference. (e.g., 'Brands', 'Food', 'Music')",
)
description: str = Field(
...,
description='Brief description of the preference. Only use information mentioned in the context to write this description.',
)
class Procedure(BaseModel):
"""A Procedure informing the agent what actions to take or how to perform in certain scenarios. Procedures are typically composed of several steps.
Instructions for identifying and extracting procedures:
1. Look for sequential instructions or steps ("First do X, then do Y")
2. Identify explicit directives or commands ("Always do X when Y happens")
3. Pay attention to conditional statements ("If X occurs, then do Y")
4. Extract procedures that have clear beginning and end points
5. Focus on actionable instructions rather than general information
6. Preserve the original sequence and dependencies between steps
7. Include any specified conditions or triggers for the procedure
8. Capture any stated purpose or goal of the procedure
9. Summarize complex procedures while maintaining critical details
"""
description: str = Field(
...,
description='Brief description of the procedure. Only use information mentioned in the context to write this description.',
)
ENTITY_TYPES: dict[str, BaseModel] = {
'Requirement': Requirement, # type: ignore
'Preference': Preference, # type: ignore
'Procedure': Procedure, # type: ignore
}
# Type definitions for API responses
class ErrorResponse(TypedDict):
error: str
class SuccessResponse(TypedDict):
message: str
class NodeResult(TypedDict):
uuid: str
name: str
summary: str
labels: list[str]
group_id: str
created_at: str
attributes: dict[str, Any]
class NodeSearchResponse(TypedDict):
message: str
nodes: list[NodeResult]
class FactSearchResponse(TypedDict):
message: str
facts: list[dict[str, Any]]
class EpisodeSearchResponse(TypedDict):
message: str
episodes: list[dict[str, Any]]
class StatusResponse(TypedDict):
status: str
message: str
def create_azure_credential_token_provider() -> Callable[[], str]:
credential = DefaultAzureCredential()
token_provider = get_bearer_token_provider(
credential, 'https://cognitiveservices.azure.com/.default'
)
return token_provider
# Server configuration classes
# The configuration system has a hierarchy:
# - GraphitiConfig is the top-level configuration
# - LLMConfig handles all OpenAI/LLM related settings
# - EmbedderConfig manages embedding settings
# - Neo4jConfig manages database connection details
# - Various other settings like group_id and feature flags
# Configuration values are loaded from:
# 1. Default values in the class definitions
# 2. Environment variables (loaded via load_dotenv())
# 3. Command line arguments (which override environment variables)
class GraphitiLLMConfig(BaseModel):
"""Configuration for the LLM client.
Centralizes all LLM-specific configuration parameters including API keys and model selection.
"""
api_key: str | None = None
model: str = DEFAULT_LLM_MODEL
small_model: str = SMALL_LLM_MODEL
temperature: float = 0.0
azure_openai_endpoint: str | None = None
azure_openai_deployment_name: str | None = None
azure_openai_api_version: str | None = None
azure_openai_use_managed_identity: bool = False
@classmethod
def from_env(cls) -> 'GraphitiLLMConfig':
"""Create LLM configuration from environment variables."""
# Get model from environment, or use default if not set or empty
model_env = os.environ.get('MODEL_NAME', '')
model = model_env if model_env.strip() else DEFAULT_LLM_MODEL
# Get small_model from environment, or use default if not set or empty
small_model_env = os.environ.get('SMALL_MODEL_NAME', '')
small_model = small_model_env if small_model_env.strip() else SMALL_LLM_MODEL
azure_openai_endpoint = os.environ.get('AZURE_OPENAI_ENDPOINT', None)
azure_openai_api_version = os.environ.get('AZURE_OPENAI_API_VERSION', None)
azure_openai_deployment_name = os.environ.get('AZURE_OPENAI_DEPLOYMENT_NAME', None)
azure_openai_use_managed_identity = (
os.environ.get('AZURE_OPENAI_USE_MANAGED_IDENTITY', 'false').lower() == 'true'
)
if azure_openai_endpoint is None:
# Setup for OpenAI API
# Log if empty model was provided
if model_env == '':
logger.debug(
f'MODEL_NAME environment variable not set, using default: {DEFAULT_LLM_MODEL}'
)
elif not model_env.strip():
logger.warning(
f'Empty MODEL_NAME environment variable, using default: {DEFAULT_LLM_MODEL}'
)
return cls(
api_key=os.environ.get('OPENAI_API_KEY'),
model=model,
small_model=small_model,
temperature=float(os.environ.get('LLM_TEMPERATURE', '0.0')),
)
else:
# Setup for Azure OpenAI API
# Log if empty deployment name was provided
if azure_openai_deployment_name is None:
logger.error('AZURE_OPENAI_DEPLOYMENT_NAME environment variable not set')
raise ValueError('AZURE_OPENAI_DEPLOYMENT_NAME environment variable not set')
if not azure_openai_use_managed_identity:
# api key
api_key = os.environ.get('OPENAI_API_KEY', None)
else:
# Managed identity
api_key = None
return cls(
azure_openai_use_managed_identity=azure_openai_use_managed_identity,
azure_openai_endpoint=azure_openai_endpoint,
api_key=api_key,
azure_openai_api_version=azure_openai_api_version,
azure_openai_deployment_name=azure_openai_deployment_name,
model=model,
small_model=small_model,
temperature=float(os.environ.get('LLM_TEMPERATURE', '0.0')),
)
@classmethod
def from_cli_and_env(cls, args: argparse.Namespace) -> 'GraphitiLLMConfig':
"""Create LLM configuration from CLI arguments, falling back to environment variables."""
# Start with environment-based config
config = cls.from_env()
# CLI arguments override environment variables when provided
if hasattr(args, 'model') and args.model:
# Only use CLI model if it's not empty
if args.model.strip():
config.model = args.model
else:
# Log that empty model was provided and default is used
logger.warning(f'Empty model name provided, using default: {DEFAULT_LLM_MODEL}')
if hasattr(args, 'small_model') and args.small_model:
if args.small_model.strip():
config.small_model = args.small_model
else:
logger.warning(f'Empty small_model name provided, using default: {SMALL_LLM_MODEL}')
if hasattr(args, 'temperature') and args.temperature is not None:
config.temperature = args.temperature
return config
def create_client(self) -> LLMClient:
"""Create an LLM client based on this configuration.
Returns:
LLMClient instance
"""
if self.azure_openai_endpoint is not None:
# Azure OpenAI API setup
if self.azure_openai_use_managed_identity:
# Use managed identity for authentication
token_provider = create_azure_credential_token_provider()
return AzureOpenAILLMClient(
azure_client=AsyncAzureOpenAI(
azure_endpoint=self.azure_openai_endpoint,
azure_deployment=self.azure_openai_deployment_name,
api_version=self.azure_openai_api_version,
azure_ad_token_provider=token_provider,
),
config=LLMConfig(
api_key=self.api_key,
model=self.model,
small_model=self.small_model,
temperature=self.temperature,
),
)
elif self.api_key:
# Use API key for authentication
return AzureOpenAILLMClient(
azure_client=AsyncAzureOpenAI(
azure_endpoint=self.azure_openai_endpoint,
azure_deployment=self.azure_openai_deployment_name,
api_version=self.azure_openai_api_version,
api_key=self.api_key,
),
config=LLMConfig(
api_key=self.api_key,
model=self.model,
small_model=self.small_model,
temperature=self.temperature,
),
)
else:
raise ValueError('OPENAI_API_KEY must be set when using Azure OpenAI API')
if not self.api_key:
raise ValueError('OPENAI_API_KEY must be set when using OpenAI API')
llm_client_config = LLMConfig(
api_key=self.api_key, model=self.model, small_model=self.small_model
)
# Set temperature
llm_client_config.temperature = self.temperature
return OpenAIClient(config=llm_client_config)
class GraphitiEmbedderConfig(BaseModel):
"""Configuration for the embedder client.
Centralizes all embedding-related configuration parameters.
"""
model: str = DEFAULT_EMBEDDER_MODEL
api_key: str | None = None
azure_openai_endpoint: str | None = None
azure_openai_deployment_name: str | None = None
azure_openai_api_version: str | None = None
azure_openai_use_managed_identity: bool = False
@classmethod
def from_env(cls) -> 'GraphitiEmbedderConfig':
"""Create embedder configuration from environment variables."""
# Get model from environment, or use default if not set or empty
model_env = os.environ.get('EMBEDDER_MODEL_NAME', '')
model = model_env if model_env.strip() else DEFAULT_EMBEDDER_MODEL
azure_openai_endpoint = os.environ.get('AZURE_OPENAI_EMBEDDING_ENDPOINT', None)
azure_openai_api_version = os.environ.get('AZURE_OPENAI_EMBEDDING_API_VERSION', None)
azure_openai_deployment_name = os.environ.get(
'AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME', None
)
azure_openai_use_managed_identity = (
os.environ.get('AZURE_OPENAI_USE_MANAGED_IDENTITY', 'false').lower() == 'true'
)
if azure_openai_endpoint is not None:
# Setup for Azure OpenAI API
# Log if empty deployment name was provided
azure_openai_deployment_name = os.environ.get(
'AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME', None
)
if azure_openai_deployment_name is None:
logger.error('AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME environment variable not set')
raise ValueError(
'AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME environment variable not set'
)
if not azure_openai_use_managed_identity:
# api key
api_key = os.environ.get('AZURE_OPENAI_EMBEDDING_API_KEY', None) or os.environ.get(
'OPENAI_API_KEY', None
)
else:
# Managed identity
api_key = None
return cls(
azure_openai_use_managed_identity=azure_openai_use_managed_identity,
azure_openai_endpoint=azure_openai_endpoint,
api_key=api_key,
azure_openai_api_version=azure_openai_api_version,
azure_openai_deployment_name=azure_openai_deployment_name,
)
else:
return cls(
model=model,
api_key=os.environ.get('OPENAI_API_KEY'),
)
def create_client(self) -> EmbedderClient | None:
if self.azure_openai_endpoint is not None:
# Azure OpenAI API setup
if self.azure_openai_use_managed_identity:
# Use managed identity for authentication
token_provider = create_azure_credential_token_provider()
return AzureOpenAIEmbedderClient(
azure_client=AsyncAzureOpenAI(
azure_endpoint=self.azure_openai_endpoint,
azure_deployment=self.azure_openai_deployment_name,
api_version=self.azure_openai_api_version,
azure_ad_token_provider=token_provider,
),
model=self.model,
)
elif self.api_key:
# Use API key for authentication
return AzureOpenAIEmbedderClient(
azure_client=AsyncAzureOpenAI(
azure_endpoint=self.azure_openai_endpoint,
azure_deployment=self.azure_openai_deployment_name,
api_version=self.azure_openai_api_version,
api_key=self.api_key,
),
model=self.model,
)
else:
logger.error('OPENAI_API_KEY must be set when using Azure OpenAI API')
return None
else:
# OpenAI API setup
if not self.api_key:
return None
embedder_config = OpenAIEmbedderConfig(api_key=self.api_key, embedding_model=self.model)
return OpenAIEmbedder(config=embedder_config)
class Neo4jConfig(BaseModel):
"""Configuration for Neo4j database connection."""
uri: str = 'bolt://localhost:7687'
user: str = 'neo4j'
password: str = 'password'
@classmethod
def from_env(cls) -> 'Neo4jConfig':
"""Create Neo4j configuration from environment variables."""
return cls(
uri=os.environ.get('NEO4J_URI', 'bolt://localhost:7687'),
user=os.environ.get('NEO4J_USER', 'neo4j'),
password=os.environ.get('NEO4J_PASSWORD', 'password'),
)
class GraphitiConfig(BaseModel):
"""Configuration for Graphiti client.
Centralizes all configuration parameters for the Graphiti client.
"""
llm: GraphitiLLMConfig = Field(default_factory=GraphitiLLMConfig)
embedder: GraphitiEmbedderConfig = Field(default_factory=GraphitiEmbedderConfig)
neo4j: Neo4jConfig = Field(default_factory=Neo4jConfig)
group_id: str | None = None
use_custom_entities: bool = False
destroy_graph: bool = False
@classmethod
def from_env(cls) -> 'GraphitiConfig':
"""Create a configuration instance from environment variables."""
return cls(
llm=GraphitiLLMConfig.from_env(),
embedder=GraphitiEmbedderConfig.from_env(),
neo4j=Neo4jConfig.from_env(),
)
@classmethod
def from_cli_and_env(cls, args: argparse.Namespace) -> 'GraphitiConfig':
"""Create configuration from CLI arguments, falling back to environment variables."""
# Start with environment configuration
config = cls.from_env()
# Apply CLI overrides
if args.group_id:
config.group_id = args.group_id
else:
config.group_id = 'default'
config.use_custom_entities = args.use_custom_entities
config.destroy_graph = args.destroy_graph
# Update LLM config using CLI args
config.llm = GraphitiLLMConfig.from_cli_and_env(args)
return config
class MCPConfig(BaseModel):
"""Configuration for MCP server."""
transport: str = 'sse' # Default to SSE transport
@classmethod
def from_cli(cls, args: argparse.Namespace) -> 'MCPConfig':
"""Create MCP configuration from CLI arguments."""
return cls(transport=args.transport)
# Configure logging
logging.basicConfig(
level=logging.INFO,
@ -568,124 +97,9 @@ mcp = FastMCP(
instructions=GRAPHITI_MCP_INSTRUCTIONS,
)
# Initialize Graphiti client
graphiti_client: Graphiti | None = None
async def initialize_graphiti():
"""Initialize the Graphiti client with the configured settings."""
global graphiti_client, config
try:
# Create LLM client if possible
llm_client = config.llm.create_client()
if not llm_client and config.use_custom_entities:
# If custom entities are enabled, we must have an LLM client
raise ValueError('OPENAI_API_KEY must be set when custom entities are enabled')
# Validate Neo4j configuration
if not config.neo4j.uri or not config.neo4j.user or not config.neo4j.password:
raise ValueError('NEO4J_URI, NEO4J_USER, and NEO4J_PASSWORD must be set')
embedder_client = config.embedder.create_client()
# Initialize Graphiti client
graphiti_client = Graphiti(
uri=config.neo4j.uri,
user=config.neo4j.user,
password=config.neo4j.password,
llm_client=llm_client,
embedder=embedder_client,
max_coroutines=SEMAPHORE_LIMIT,
)
# Destroy graph if requested
if config.destroy_graph:
logger.info('Destroying graph...')
await clear_data(graphiti_client.driver)
# Initialize the graph database with Graphiti's indices
await graphiti_client.build_indices_and_constraints()
logger.info('Graphiti client initialized successfully')
# Log configuration details for transparency
if llm_client:
logger.info(f'Using OpenAI model: {config.llm.model}')
logger.info(f'Using temperature: {config.llm.temperature}')
else:
logger.info('No LLM client configured - entity extraction will be limited')
logger.info(f'Using group_id: {config.group_id}')
logger.info(
f'Custom entity extraction: {"enabled" if config.use_custom_entities else "disabled"}'
)
logger.info(f'Using concurrency limit: {SEMAPHORE_LIMIT}')
except Exception as e:
logger.error(f'Failed to initialize Graphiti: {str(e)}')
raise
def format_fact_result(edge: EntityEdge) -> dict[str, Any]:
"""Format an entity edge into a readable result.
Since EntityEdge is a Pydantic BaseModel, we can use its built-in serialization capabilities.
Args:
edge: The EntityEdge to format
Returns:
A dictionary representation of the edge with serialized dates and excluded embeddings
"""
result = edge.model_dump(
mode='json',
exclude={
'fact_embedding',
},
)
result.get('attributes', {}).pop('fact_embedding', None)
return result
# Dictionary to store queues for each group_id
# Each queue is a list of tasks to be processed sequentially
episode_queues: dict[str, asyncio.Queue] = {}
# Dictionary to track if a worker is running for each group_id
queue_workers: dict[str, bool] = {}
async def process_episode_queue(group_id: str):
"""Process episodes for a specific group_id sequentially.
This function runs as a long-lived task that processes episodes
from the queue one at a time.
"""
global queue_workers
logger.info(f'Starting episode queue worker for group_id: {group_id}')
queue_workers[group_id] = True
try:
while True:
# Get the next episode processing function from the queue
# This will wait if the queue is empty
process_func = await episode_queues[group_id].get()
try:
# Process the episode
await process_func()
except Exception as e:
logger.error(f'Error processing queued episode for group_id {group_id}: {str(e)}')
finally:
# Mark the task as done regardless of success/failure
episode_queues[group_id].task_done()
except asyncio.CancelledError:
logger.info(f'Episode queue worker for group_id {group_id} was cancelled')
except Exception as e:
logger.error(f'Unexpected error in queue worker for group_id {group_id}: {str(e)}')
finally:
queue_workers[group_id] = False
logger.info(f'Stopped episode queue worker for group_id: {group_id}')
# Global services
graphiti_service: GraphitiService | None = None
queue_service: QueueService | None = None
@mcp.tool()
@ -752,10 +166,13 @@ async def add_memory(
- Entities will be created from appropriate JSON properties
- Relationships between entities will be established based on the JSON structure
"""
global graphiti_client, episode_queues, queue_workers
global graphiti_service, queue_service, config
if graphiti_client is None:
return ErrorResponse(error='Graphiti client not initialized')
if not graphiti_service or not graphiti_service.is_initialized():
return ErrorResponse(error='Graphiti service not initialized')
if not queue_service:
return ErrorResponse(error='Queue service not initialized')
try:
# Map string source to EpisodeType enum
@ -772,13 +189,6 @@ async def add_memory(
# The Graphiti client expects a str for group_id, not Optional[str]
group_id_str = str(effective_group_id) if effective_group_id is not None else ''
# We've already checked that graphiti_client is not None above
# This assert statement helps type checkers understand that graphiti_client is defined
assert graphiti_client is not None, 'graphiti_client should not be None here'
# Use cast to help the type checker understand that graphiti_client is not None
client = cast(Graphiti, graphiti_client)
# Define the episode processing function
async def process_episode():
try:
@ -786,7 +196,7 @@ async def add_memory(
# Use all entity types if use_custom_entities is enabled, otherwise use empty dict
entity_types = ENTITY_TYPES if config.use_custom_entities else {}
await client.add_episode(
await graphiti_service.client.add_episode(
name=name,
episode_body=episode_body,
source=source_type,
@ -796,8 +206,6 @@ async def add_memory(
reference_time=datetime.now(timezone.utc),
entity_types=entity_types,
)
logger.info(f"Episode '{name}' added successfully")
logger.info(f"Episode '{name}' processed successfully")
except Exception as e:
error_msg = str(e)
@ -805,20 +213,12 @@ async def add_memory(
f"Error processing episode '{name}' for group_id {group_id_str}: {error_msg}"
)
# Initialize queue for this group_id if it doesn't exist
if group_id_str not in episode_queues:
episode_queues[group_id_str] = asyncio.Queue()
# Add the episode processing function to the queue
await episode_queues[group_id_str].put(process_episode)
# Start a worker for this queue if one isn't already running
if not queue_workers.get(group_id_str, False):
asyncio.create_task(process_episode_queue(group_id_str))
queue_position = await queue_service.add_episode_task(group_id_str, process_episode)
# Return immediately with a success message
return SuccessResponse(
message=f"Episode '{name}' queued for processing (position: {episode_queues[group_id_str].qsize()})"
message=f"Episode '{name}' queued for processing (position: {queue_position})"
)
except Exception as e:
error_msg = str(e)
@ -846,10 +246,10 @@ async def search_memory_nodes(
center_node_uuid: Optional UUID of a node to center the search around
entity: Optional single entity type to filter results (permitted: "Preference", "Procedure")
"""
global graphiti_client
global graphiti_service, config
if graphiti_client is None:
return ErrorResponse(error='Graphiti client not initialized')
if not graphiti_service or not graphiti_service.is_initialized():
return ErrorResponse(error='Graphiti service not initialized')
try:
# Use the provided group_ids or fall back to the default from config if none provided
@ -868,11 +268,7 @@ async def search_memory_nodes(
if entity != '':
filters.node_labels = [entity]
# We've already checked that graphiti_client is not None above
assert graphiti_client is not None
# Use cast to help the type checker understand that graphiti_client is not None
client = cast(Graphiti, graphiti_client)
client = graphiti_service.client
# Perform the search using the _search method
search_results = await client._search(
@ -1218,8 +614,11 @@ async def initialize_server() -> MCPConfig:
else:
logger.info('Entity extraction disabled (no custom entities will be used)')
# Initialize Graphiti
await initialize_graphiti()
# Initialize services
global graphiti_service, queue_service
graphiti_service = GraphitiService(config, SEMAPHORE_LIMIT)
queue_service = QueueService()
await graphiti_service.initialize()
if args.host:
logger.info(f'Setting MCP server host to: {args.host}')

View file

@ -0,0 +1,110 @@
"""Graphiti service for managing client lifecycle and operations."""
import logging
from config_manager import GraphitiConfig
from graphiti_core import Graphiti
from graphiti_core.utils.maintenance.graph_data_operations import clear_data
logger = logging.getLogger(__name__)
class GraphitiService:
"""Service for managing Graphiti client operations."""
def __init__(self, config: GraphitiConfig, semaphore_limit: int = 10):
"""Initialize the Graphiti service with configuration.
Args:
config: The Graphiti configuration
semaphore_limit: Maximum concurrent operations
"""
self.config = config
self.semaphore_limit = semaphore_limit
self._client: Graphiti | None = None
@property
def client(self) -> Graphiti:
"""Get the Graphiti client instance.
Raises:
RuntimeError: If the client hasn't been initialized
"""
if self._client is None:
raise RuntimeError('Graphiti client not initialized. Call initialize() first.')
return self._client
async def initialize(self) -> None:
"""Initialize the Graphiti client with the configured settings."""
try:
# Create LLM client if possible
llm_client = self.config.llm.create_client()
if not llm_client and self.config.use_custom_entities:
# If custom entities are enabled, we must have an LLM client
raise ValueError('OPENAI_API_KEY must be set when custom entities are enabled')
# Validate Neo4j configuration
if (
not self.config.neo4j.uri
or not self.config.neo4j.user
or not self.config.neo4j.password
):
raise ValueError('NEO4J_URI, NEO4J_USER, and NEO4J_PASSWORD must be set')
embedder_client = self.config.embedder.create_client()
# Initialize Graphiti client
self._client = Graphiti(
uri=self.config.neo4j.uri,
user=self.config.neo4j.user,
password=self.config.neo4j.password,
llm_client=llm_client,
embedder=embedder_client,
max_coroutines=self.semaphore_limit,
)
# Destroy graph if requested
if self.config.destroy_graph:
logger.info('Destroying graph...')
await clear_data(self._client.driver)
# Initialize the graph database with Graphiti's indices
await self._client.build_indices_and_constraints()
logger.info('Graphiti client initialized successfully')
# Log configuration details for transparency
if llm_client:
logger.info(f'Using OpenAI model: {self.config.llm.model}')
logger.info(f'Using temperature: {self.config.llm.temperature}')
else:
logger.info('No LLM client configured - entity extraction will be limited')
logger.info(f'Using group_id: {self.config.group_id}')
logger.info(
f'Custom entity extraction: {"enabled" if self.config.use_custom_entities else "disabled"}'
)
logger.info(f'Using concurrency limit: {self.semaphore_limit}')
except Exception as e:
logger.error(f'Failed to initialize Graphiti: {str(e)}')
raise
async def clear_graph(self) -> None:
"""Clear all data from the graph and rebuild indices."""
if self._client is None:
raise RuntimeError('Graphiti client not initialized')
await clear_data(self._client.driver)
await self._client.build_indices_and_constraints()
async def verify_connection(self) -> None:
"""Verify the database connection."""
if self._client is None:
raise RuntimeError('Graphiti client not initialized')
await self._client.driver.client.verify_connectivity() # type: ignore
def is_initialized(self) -> bool:
"""Check if the client is initialized."""
return self._client is not None

182
mcp_server/llm_config.py Normal file
View file

@ -0,0 +1,182 @@
"""LLM configuration for Graphiti MCP Server."""
import argparse
import logging
import os
from typing import TYPE_CHECKING
from openai import AsyncAzureOpenAI
from pydantic import BaseModel
from utils import create_azure_credential_token_provider
from graphiti_core.llm_client import LLMClient
from graphiti_core.llm_client.azure_openai_client import AzureOpenAILLMClient
from graphiti_core.llm_client.config import LLMConfig
from graphiti_core.llm_client.openai_client import OpenAIClient
if TYPE_CHECKING:
pass
logger = logging.getLogger(__name__)
DEFAULT_LLM_MODEL = 'gpt-4.1-mini'
SMALL_LLM_MODEL = 'gpt-4.1-nano'
class GraphitiLLMConfig(BaseModel):
"""Configuration for the LLM client.
Centralizes all LLM-specific configuration parameters including API keys and model selection.
"""
api_key: str | None = None
model: str = DEFAULT_LLM_MODEL
small_model: str = SMALL_LLM_MODEL
temperature: float = 0.0
azure_openai_endpoint: str | None = None
azure_openai_deployment_name: str | None = None
azure_openai_api_version: str | None = None
azure_openai_use_managed_identity: bool = False
@classmethod
def from_env(cls) -> 'GraphitiLLMConfig':
"""Create LLM configuration from environment variables."""
# Get model from environment, or use default if not set or empty
model_env = os.environ.get('MODEL_NAME', '')
model = model_env if model_env.strip() else DEFAULT_LLM_MODEL
# Get small_model from environment, or use default if not set or empty
small_model_env = os.environ.get('SMALL_MODEL_NAME', '')
small_model = small_model_env if small_model_env.strip() else SMALL_LLM_MODEL
azure_openai_endpoint = os.environ.get('AZURE_OPENAI_ENDPOINT', None)
azure_openai_api_version = os.environ.get('AZURE_OPENAI_API_VERSION', None)
azure_openai_deployment_name = os.environ.get('AZURE_OPENAI_DEPLOYMENT_NAME', None)
azure_openai_use_managed_identity = (
os.environ.get('AZURE_OPENAI_USE_MANAGED_IDENTITY', 'false').lower() == 'true'
)
if azure_openai_endpoint is None:
# Setup for OpenAI API
# Log if empty model was provided
if model_env == '':
logger.debug(
f'MODEL_NAME environment variable not set, using default: {DEFAULT_LLM_MODEL}'
)
elif not model_env.strip():
logger.warning(
f'Empty MODEL_NAME environment variable, using default: {DEFAULT_LLM_MODEL}'
)
return cls(
api_key=os.environ.get('OPENAI_API_KEY'),
model=model,
small_model=small_model,
temperature=float(os.environ.get('LLM_TEMPERATURE', '0.0')),
)
else:
# Setup for Azure OpenAI API
# Log if empty deployment name was provided
if azure_openai_deployment_name is None:
logger.error('AZURE_OPENAI_DEPLOYMENT_NAME environment variable not set')
raise ValueError('AZURE_OPENAI_DEPLOYMENT_NAME environment variable not set')
if not azure_openai_use_managed_identity:
# api key
api_key = os.environ.get('OPENAI_API_KEY', None)
else:
# Managed identity
api_key = None
return cls(
azure_openai_use_managed_identity=azure_openai_use_managed_identity,
azure_openai_endpoint=azure_openai_endpoint,
api_key=api_key,
azure_openai_api_version=azure_openai_api_version,
azure_openai_deployment_name=azure_openai_deployment_name,
model=model,
small_model=small_model,
temperature=float(os.environ.get('LLM_TEMPERATURE', '0.0')),
)
@classmethod
def from_cli_and_env(cls, args: argparse.Namespace) -> 'GraphitiLLMConfig':
"""Create LLM configuration from CLI arguments, falling back to environment variables."""
# Start with environment-based config
config = cls.from_env()
# CLI arguments override environment variables when provided
if hasattr(args, 'model') and args.model:
# Only use CLI model if it's not empty
if args.model.strip():
config.model = args.model
else:
# Log that empty model was provided and default is used
logger.warning(f'Empty model name provided, using default: {DEFAULT_LLM_MODEL}')
if hasattr(args, 'small_model') and args.small_model:
if args.small_model.strip():
config.small_model = args.small_model
else:
logger.warning(f'Empty small_model name provided, using default: {SMALL_LLM_MODEL}')
if hasattr(args, 'temperature') and args.temperature is not None:
config.temperature = args.temperature
return config
def create_client(self) -> LLMClient:
"""Create an LLM client based on this configuration.
Returns:
LLMClient instance
"""
if self.azure_openai_endpoint is not None:
# Azure OpenAI API setup
if self.azure_openai_use_managed_identity:
# Use managed identity for authentication
token_provider = create_azure_credential_token_provider()
return AzureOpenAILLMClient(
azure_client=AsyncAzureOpenAI(
azure_endpoint=self.azure_openai_endpoint,
azure_deployment=self.azure_openai_deployment_name,
api_version=self.azure_openai_api_version,
azure_ad_token_provider=token_provider,
),
config=LLMConfig(
api_key=self.api_key,
model=self.model,
small_model=self.small_model,
temperature=self.temperature,
),
)
elif self.api_key:
# Use API key for authentication
return AzureOpenAILLMClient(
azure_client=AsyncAzureOpenAI(
azure_endpoint=self.azure_openai_endpoint,
azure_deployment=self.azure_openai_deployment_name,
api_version=self.azure_openai_api_version,
api_key=self.api_key,
),
config=LLMConfig(
api_key=self.api_key,
model=self.model,
small_model=self.small_model,
temperature=self.temperature,
),
)
else:
raise ValueError('OPENAI_API_KEY must be set when using Azure OpenAI API')
if not self.api_key:
raise ValueError('OPENAI_API_KEY must be set when using OpenAI API')
llm_client_config = LLMConfig(
api_key=self.api_key, model=self.model, small_model=self.small_model
)
# Set temperature
llm_client_config.temperature = self.temperature
return OpenAIClient(config=llm_client_config)

View file

@ -0,0 +1,22 @@
"""Neo4j database configuration for Graphiti MCP Server."""
import os
from pydantic import BaseModel
class Neo4jConfig(BaseModel):
"""Configuration for Neo4j database connection."""
uri: str = 'bolt://localhost:7687'
user: str = 'neo4j'
password: str = 'password'
@classmethod
def from_env(cls) -> 'Neo4jConfig':
"""Create Neo4j configuration from environment variables."""
return cls(
uri=os.environ.get('NEO4J_URI', 'bolt://localhost:7687'),
user=os.environ.get('NEO4J_USER', 'neo4j'),
password=os.environ.get('NEO4J_PASSWORD', 'password'),
)

View file

@ -11,3 +11,9 @@ dependencies = [
"azure-identity>=1.21.0",
"graphiti-core",
]
[dependency-groups]
dev = [
"httpx>=0.28.1",
"mcp>=1.9.4",
]

View file

@ -0,0 +1,86 @@
"""Queue service for managing episode processing."""
import asyncio
import logging
from collections.abc import Awaitable, Callable
logger = logging.getLogger(__name__)
class QueueService:
"""Service for managing sequential episode processing queues by group_id."""
def __init__(self):
"""Initialize the queue service."""
# Dictionary to store queues for each group_id
self._episode_queues: dict[str, asyncio.Queue] = {}
# Dictionary to track if a worker is running for each group_id
self._queue_workers: dict[str, bool] = {}
async def add_episode_task(
self, group_id: str, process_func: Callable[[], Awaitable[None]]
) -> int:
"""Add an episode processing task to the queue.
Args:
group_id: The group ID for the episode
process_func: The async function to process the episode
Returns:
The position in the queue
"""
# Initialize queue for this group_id if it doesn't exist
if group_id not in self._episode_queues:
self._episode_queues[group_id] = asyncio.Queue()
# Add the episode processing function to the queue
await self._episode_queues[group_id].put(process_func)
# Start a worker for this queue if one isn't already running
if not self._queue_workers.get(group_id, False):
asyncio.create_task(self._process_episode_queue(group_id))
return self._episode_queues[group_id].qsize()
async def _process_episode_queue(self, group_id: str) -> None:
"""Process episodes for a specific group_id sequentially.
This function runs as a long-lived task that processes episodes
from the queue one at a time.
"""
logger.info(f'Starting episode queue worker for group_id: {group_id}')
self._queue_workers[group_id] = True
try:
while True:
# Get the next episode processing function from the queue
# This will wait if the queue is empty
process_func = await self._episode_queues[group_id].get()
try:
# Process the episode
await process_func()
except Exception as e:
logger.error(
f'Error processing queued episode for group_id {group_id}: {str(e)}'
)
finally:
# Mark the task as done regardless of success/failure
self._episode_queues[group_id].task_done()
except asyncio.CancelledError:
logger.info(f'Episode queue worker for group_id {group_id} was cancelled')
except Exception as e:
logger.error(f'Unexpected error in queue worker for group_id {group_id}: {str(e)}')
finally:
self._queue_workers[group_id] = False
logger.info(f'Stopped episode queue worker for group_id: {group_id}')
def get_queue_size(self, group_id: str) -> int:
"""Get the current queue size for a group_id."""
if group_id not in self._episode_queues:
return 0
return self._episode_queues[group_id].qsize()
def is_worker_running(self, group_id: str) -> bool:
"""Check if a worker is running for a group_id."""
return self._queue_workers.get(group_id, False)

View file

@ -0,0 +1,41 @@
"""Response type definitions for Graphiti MCP Server."""
from typing import Any, TypedDict
class ErrorResponse(TypedDict):
error: str
class SuccessResponse(TypedDict):
message: str
class NodeResult(TypedDict):
uuid: str
name: str
summary: str
labels: list[str]
group_id: str
created_at: str
attributes: dict[str, Any]
class NodeSearchResponse(TypedDict):
message: str
nodes: list[NodeResult]
class FactSearchResponse(TypedDict):
message: str
facts: list[dict[str, Any]]
class EpisodeSearchResponse(TypedDict):
message: str
episodes: list[dict[str, Any]]
class StatusResponse(TypedDict):
status: str
message: str

View file

@ -0,0 +1,16 @@
"""Server configuration for Graphiti MCP Server."""
import argparse
from pydantic import BaseModel
class MCPConfig(BaseModel):
"""Configuration for MCP server."""
transport: str = 'sse' # Default to SSE transport
@classmethod
def from_cli(cls, args: argparse.Namespace) -> 'MCPConfig':
"""Create MCP configuration from CLI arguments."""
return cls(transport=args.transport)

View file

@ -0,0 +1,363 @@
#!/usr/bin/env python3
"""
Integration test for the refactored Graphiti MCP Server.
Tests all major MCP tools and handles episode processing latency.
"""
import asyncio
import json
import time
from typing import Any
import httpx
class MCPIntegrationTest:
"""Integration test client for Graphiti MCP Server."""
def __init__(self, base_url: str = 'http://localhost:8000'):
self.base_url = base_url
self.client = httpx.AsyncClient(timeout=30.0)
self.test_group_id = f'test_group_{int(time.time())}'
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.client.aclose()
async def call_mcp_tool(self, tool_name: str, arguments: dict[str, Any]) -> dict[str, Any]:
"""Call an MCP tool via the SSE endpoint."""
# MCP protocol message structure
message = {
'jsonrpc': '2.0',
'id': int(time.time() * 1000),
'method': 'tools/call',
'params': {'name': tool_name, 'arguments': arguments},
}
try:
response = await self.client.post(
f'{self.base_url}/message',
json=message,
headers={'Content-Type': 'application/json'},
)
if response.status_code != 200:
return {'error': f'HTTP {response.status_code}: {response.text}'}
result = response.json()
return result.get('result', result)
except Exception as e:
return {'error': str(e)}
async def test_server_status(self) -> bool:
"""Test the get_status resource."""
print('🔍 Testing server status...')
try:
response = await self.client.get(f'{self.base_url}/resources/http://graphiti/status')
if response.status_code == 200:
status = response.json()
print(f' ✅ Server status: {status.get("status", "unknown")}')
return status.get('status') == 'ok'
else:
print(f' ❌ Status check failed: HTTP {response.status_code}')
return False
except Exception as e:
print(f' ❌ Status check failed: {e}')
return False
async def test_add_memory(self) -> dict[str, str]:
"""Test adding various types of memory episodes."""
print('📝 Testing add_memory functionality...')
episode_results = {}
# Test 1: Add text episode
print(' Testing text episode...')
result = await self.call_mcp_tool(
'add_memory',
{
'name': 'Test Company News',
'episode_body': 'Acme Corp announced a revolutionary new AI product that will transform the industry. The CEO mentioned this is their biggest launch since 2020.',
'source': 'text',
'source_description': 'news article',
'group_id': self.test_group_id,
},
)
if 'error' in result:
print(f' ❌ Text episode failed: {result["error"]}')
else:
print(f' ✅ Text episode queued: {result.get("message", "Success")}')
episode_results['text'] = 'success'
# Test 2: Add JSON episode
print(' Testing JSON episode...')
json_data = {
'company': {'name': 'TechCorp', 'founded': 2010},
'products': [
{'id': 'P001', 'name': 'CloudSync', 'category': 'software'},
{'id': 'P002', 'name': 'DataMiner', 'category': 'analytics'},
],
'employees': 150,
}
result = await self.call_mcp_tool(
'add_memory',
{
'name': 'Company Profile',
'episode_body': json.dumps(json_data),
'source': 'json',
'source_description': 'CRM data',
'group_id': self.test_group_id,
},
)
if 'error' in result:
print(f' ❌ JSON episode failed: {result["error"]}')
else:
print(f' ✅ JSON episode queued: {result.get("message", "Success")}')
episode_results['json'] = 'success'
# Test 3: Add message episode
print(' Testing message episode...')
result = await self.call_mcp_tool(
'add_memory',
{
'name': 'Customer Support Chat',
'episode_body': "user: What's your return policy?\nassistant: You can return items within 30 days of purchase with receipt.\nuser: Thanks!",
'source': 'message',
'source_description': 'support chat log',
'group_id': self.test_group_id,
},
)
if 'error' in result:
print(f' ❌ Message episode failed: {result["error"]}')
else:
print(f' ✅ Message episode queued: {result.get("message", "Success")}')
episode_results['message'] = 'success'
return episode_results
async def wait_for_processing(self, max_wait: int = 30) -> None:
"""Wait for episode processing to complete."""
print(f'⏳ Waiting up to {max_wait} seconds for episode processing...')
for i in range(max_wait):
await asyncio.sleep(1)
# Check if we have any episodes
result = await self.call_mcp_tool(
'get_episodes', {'group_id': self.test_group_id, 'last_n': 10}
)
if not isinstance(result, dict) or 'error' in result:
continue
if isinstance(result, list) and len(result) > 0:
print(f' ✅ Found {len(result)} processed episodes after {i + 1} seconds')
return
print(f' ⚠️ Still waiting after {max_wait} seconds...')
async def test_search_functions(self) -> dict[str, bool]:
"""Test search functionality."""
print('🔍 Testing search functions...')
results = {}
# Test search_memory_nodes
print(' Testing search_memory_nodes...')
result = await self.call_mcp_tool(
'search_memory_nodes',
{
'query': 'Acme Corp product launch',
'group_ids': [self.test_group_id],
'max_nodes': 5,
},
)
if 'error' in result:
print(f' ❌ Node search failed: {result["error"]}')
results['nodes'] = False
else:
nodes = result.get('nodes', [])
print(f' ✅ Node search returned {len(nodes)} nodes')
results['nodes'] = True
# Test search_memory_facts
print(' Testing search_memory_facts...')
result = await self.call_mcp_tool(
'search_memory_facts',
{
'query': 'company products software',
'group_ids': [self.test_group_id],
'max_facts': 5,
},
)
if 'error' in result:
print(f' ❌ Fact search failed: {result["error"]}')
results['facts'] = False
else:
facts = result.get('facts', [])
print(f' ✅ Fact search returned {len(facts)} facts')
results['facts'] = True
return results
async def test_episode_retrieval(self) -> bool:
"""Test episode retrieval."""
print('📚 Testing episode retrieval...')
result = await self.call_mcp_tool(
'get_episodes', {'group_id': self.test_group_id, 'last_n': 10}
)
if 'error' in result:
print(f' ❌ Episode retrieval failed: {result["error"]}')
return False
if isinstance(result, list):
print(f' ✅ Retrieved {len(result)} episodes')
# Print episode details
for i, episode in enumerate(result[:3]): # Show first 3
name = episode.get('name', 'Unknown')
source = episode.get('source', 'unknown')
print(f' Episode {i + 1}: {name} (source: {source})')
return len(result) > 0
else:
print(f' ❌ Unexpected result format: {type(result)}')
return False
async def test_edge_cases(self) -> dict[str, bool]:
"""Test edge cases and error handling."""
print('🧪 Testing edge cases...')
results = {}
# Test with invalid group_id
print(' Testing invalid group_id...')
result = await self.call_mcp_tool(
'search_memory_nodes',
{'query': 'nonexistent data', 'group_ids': ['nonexistent_group'], 'max_nodes': 5},
)
# Should not error, just return empty results
if 'error' not in result:
nodes = result.get('nodes', [])
print(f' ✅ Invalid group_id handled gracefully (returned {len(nodes)} nodes)')
results['invalid_group'] = True
else:
print(f' ❌ Invalid group_id caused error: {result["error"]}')
results['invalid_group'] = False
# Test empty query
print(' Testing empty query...')
result = await self.call_mcp_tool(
'search_memory_nodes', {'query': '', 'group_ids': [self.test_group_id], 'max_nodes': 5}
)
if 'error' not in result:
print(' ✅ Empty query handled gracefully')
results['empty_query'] = True
else:
print(f' ❌ Empty query caused error: {result["error"]}')
results['empty_query'] = False
return results
async def run_full_test_suite(self) -> dict[str, Any]:
"""Run the complete integration test suite."""
print('🚀 Starting Graphiti MCP Server Integration Test')
print(f' Test group ID: {self.test_group_id}')
print('=' * 60)
results = {
'server_status': False,
'add_memory': {},
'search': {},
'episodes': False,
'edge_cases': {},
'overall_success': False,
}
# Test 1: Server Status
results['server_status'] = await self.test_server_status()
if not results['server_status']:
print('❌ Server not responding, aborting tests')
return results
print()
# Test 2: Add Memory
results['add_memory'] = await self.test_add_memory()
print()
# Test 3: Wait for processing
await self.wait_for_processing()
print()
# Test 4: Search Functions
results['search'] = await self.test_search_functions()
print()
# Test 5: Episode Retrieval
results['episodes'] = await self.test_episode_retrieval()
print()
# Test 6: Edge Cases
results['edge_cases'] = await self.test_edge_cases()
print()
# Calculate overall success
memory_success = len(results['add_memory']) > 0
search_success = any(results['search'].values())
edge_case_success = any(results['edge_cases'].values())
results['overall_success'] = (
results['server_status']
and memory_success
and results['episodes']
and (search_success or edge_case_success) # At least some functionality working
)
# Print summary
print('=' * 60)
print('📊 TEST SUMMARY')
print(f' Server Status: {"" if results["server_status"] else ""}')
print(
f' Memory Operations: {"" if memory_success else ""} ({len(results["add_memory"])} types)'
)
print(f' Search Functions: {"" if search_success else ""}')
print(f' Episode Retrieval: {"" if results["episodes"] else ""}')
print(f' Edge Cases: {"" if edge_case_success else ""}')
print()
print(f'🎯 OVERALL: {"✅ SUCCESS" if results["overall_success"] else "❌ FAILED"}')
if results['overall_success']:
print(' The refactored MCP server is working correctly!')
else:
print(' Some issues detected. Check individual test results above.')
return results
async def main():
"""Run the integration test."""
async with MCPIntegrationTest() as test:
results = await test.run_full_test_suite()
# Exit with appropriate code
exit_code = 0 if results['overall_success'] else 1
exit(exit_code)
if __name__ == '__main__':
asyncio.run(main())

View file

@ -0,0 +1,502 @@
#!/usr/bin/env python3
"""
Integration test for the refactored Graphiti MCP Server using the official MCP Python SDK.
Tests all major MCP tools and handles episode processing latency.
"""
import asyncio
import json
import time
from typing import Any
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
class GraphitiMCPIntegrationTest:
"""Integration test client for Graphiti MCP Server using official MCP SDK."""
def __init__(self):
self.test_group_id = f'test_group_{int(time.time())}'
self.session = None
async def __aenter__(self):
"""Start the MCP client session."""
# Configure server parameters to run our refactored server
server_params = StdioServerParameters(
command='uv',
args=['run', 'graphiti_mcp_server.py', '--transport', 'stdio'],
env={
'NEO4J_URI': 'bolt://localhost:7687',
'NEO4J_USER': 'neo4j',
'NEO4J_PASSWORD': 'demodemo',
'OPENAI_API_KEY': 'dummy_key_for_testing', # Will use existing .env
},
)
print(f'🚀 Starting MCP client session with test group: {self.test_group_id}')
# Use the async context manager properly
self.client_context = stdio_client(server_params)
read, write = await self.client_context.__aenter__()
self.session = ClientSession(read, write)
await self.session.initialize()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Close the MCP client session."""
if self.session:
await self.session.close()
if hasattr(self, 'client_context'):
await self.client_context.__aexit__(exc_type, exc_val, exc_tb)
async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any:
"""Call an MCP tool and return the result."""
try:
result = await self.session.call_tool(tool_name, arguments)
return result.content[0].text if result.content else {'error': 'No content returned'}
except Exception as e:
return {'error': str(e)}
async def test_server_initialization(self) -> bool:
"""Test that the server initializes properly."""
print('🔍 Testing server initialization...')
try:
# List available tools to verify server is responding
tools_result = await self.session.list_tools()
tools = [tool.name for tool in tools_result.tools]
expected_tools = [
'add_memory',
'search_memory_nodes',
'search_memory_facts',
'get_episodes',
'delete_episode',
'delete_entity_edge',
'get_entity_edge',
'clear_graph',
]
available_tools = len([tool for tool in expected_tools if tool in tools])
print(
f' ✅ Server responding with {len(tools)} tools ({available_tools}/{len(expected_tools)} expected)'
)
print(f' Available tools: {", ".join(sorted(tools))}')
return available_tools >= len(expected_tools) * 0.8 # 80% of expected tools
except Exception as e:
print(f' ❌ Server initialization failed: {e}')
return False
async def test_add_memory_operations(self) -> dict[str, bool]:
"""Test adding various types of memory episodes."""
print('📝 Testing add_memory operations...')
results = {}
# Test 1: Add text episode
print(' Testing text episode...')
try:
result = await self.call_tool(
'add_memory',
{
'name': 'Test Company News',
'episode_body': 'Acme Corp announced a revolutionary new AI product that will transform the industry. The CEO mentioned this is their biggest launch since 2020.',
'source': 'text',
'source_description': 'news article',
'group_id': self.test_group_id,
},
)
if isinstance(result, str) and 'queued' in result.lower():
print(f' ✅ Text episode: {result}')
results['text'] = True
else:
print(f' ❌ Text episode failed: {result}')
results['text'] = False
except Exception as e:
print(f' ❌ Text episode error: {e}')
results['text'] = False
# Test 2: Add JSON episode
print(' Testing JSON episode...')
try:
json_data = {
'company': {'name': 'TechCorp', 'founded': 2010},
'products': [
{'id': 'P001', 'name': 'CloudSync', 'category': 'software'},
{'id': 'P002', 'name': 'DataMiner', 'category': 'analytics'},
],
'employees': 150,
}
result = await self.call_tool(
'add_memory',
{
'name': 'Company Profile',
'episode_body': json.dumps(json_data),
'source': 'json',
'source_description': 'CRM data',
'group_id': self.test_group_id,
},
)
if isinstance(result, str) and 'queued' in result.lower():
print(f' ✅ JSON episode: {result}')
results['json'] = True
else:
print(f' ❌ JSON episode failed: {result}')
results['json'] = False
except Exception as e:
print(f' ❌ JSON episode error: {e}')
results['json'] = False
# Test 3: Add message episode
print(' Testing message episode...')
try:
result = await self.call_tool(
'add_memory',
{
'name': 'Customer Support Chat',
'episode_body': "user: What's your return policy?\nassistant: You can return items within 30 days of purchase with receipt.\nuser: Thanks!",
'source': 'message',
'source_description': 'support chat log',
'group_id': self.test_group_id,
},
)
if isinstance(result, str) and 'queued' in result.lower():
print(f' ✅ Message episode: {result}')
results['message'] = True
else:
print(f' ❌ Message episode failed: {result}')
results['message'] = False
except Exception as e:
print(f' ❌ Message episode error: {e}')
results['message'] = False
return results
async def wait_for_processing(self, max_wait: int = 45) -> bool:
"""Wait for episode processing to complete."""
print(f'⏳ Waiting up to {max_wait} seconds for episode processing...')
for i in range(max_wait):
await asyncio.sleep(1)
try:
# Check if we have any episodes
result = await self.call_tool(
'get_episodes', {'group_id': self.test_group_id, 'last_n': 10}
)
# Parse the JSON result if it's a string
if isinstance(result, str):
try:
parsed_result = json.loads(result)
if isinstance(parsed_result, list) and len(parsed_result) > 0:
print(
f' ✅ Found {len(parsed_result)} processed episodes after {i + 1} seconds'
)
return True
except json.JSONDecodeError:
if 'episodes' in result.lower():
print(f' ✅ Episodes detected after {i + 1} seconds')
return True
except Exception as e:
if i == 0: # Only log first error to avoid spam
print(f' ⚠️ Waiting for processing... ({e})')
continue
print(f' ⚠️ Still waiting after {max_wait} seconds...')
return False
async def test_search_operations(self) -> dict[str, bool]:
"""Test search functionality."""
print('🔍 Testing search operations...')
results = {}
# Test search_memory_nodes
print(' Testing search_memory_nodes...')
try:
result = await self.call_tool(
'search_memory_nodes',
{
'query': 'Acme Corp product launch AI',
'group_ids': [self.test_group_id],
'max_nodes': 5,
},
)
success = False
if isinstance(result, str):
try:
parsed = json.loads(result)
nodes = parsed.get('nodes', [])
success = isinstance(nodes, list)
print(f' ✅ Node search returned {len(nodes)} nodes')
except json.JSONDecodeError:
success = 'nodes' in result.lower() and 'successfully' in result.lower()
if success:
print(' ✅ Node search completed successfully')
results['nodes'] = success
if not success:
print(f' ❌ Node search failed: {result}')
except Exception as e:
print(f' ❌ Node search error: {e}')
results['nodes'] = False
# Test search_memory_facts
print(' Testing search_memory_facts...')
try:
result = await self.call_tool(
'search_memory_facts',
{
'query': 'company products software TechCorp',
'group_ids': [self.test_group_id],
'max_facts': 5,
},
)
success = False
if isinstance(result, str):
try:
parsed = json.loads(result)
facts = parsed.get('facts', [])
success = isinstance(facts, list)
print(f' ✅ Fact search returned {len(facts)} facts')
except json.JSONDecodeError:
success = 'facts' in result.lower() and 'successfully' in result.lower()
if success:
print(' ✅ Fact search completed successfully')
results['facts'] = success
if not success:
print(f' ❌ Fact search failed: {result}')
except Exception as e:
print(f' ❌ Fact search error: {e}')
results['facts'] = False
return results
async def test_episode_retrieval(self) -> bool:
"""Test episode retrieval."""
print('📚 Testing episode retrieval...')
try:
result = await self.call_tool(
'get_episodes', {'group_id': self.test_group_id, 'last_n': 10}
)
if isinstance(result, str):
try:
parsed = json.loads(result)
if isinstance(parsed, list):
print(f' ✅ Retrieved {len(parsed)} episodes')
# Show episode details
for i, episode in enumerate(parsed[:3]):
name = episode.get('name', 'Unknown')
source = episode.get('source', 'unknown')
print(f' Episode {i + 1}: {name} (source: {source})')
return len(parsed) > 0
except json.JSONDecodeError:
# Check if response indicates success
if 'episode' in result.lower():
print(' ✅ Episode retrieval completed')
return True
print(f' ❌ Unexpected result format: {result}')
return False
except Exception as e:
print(f' ❌ Episode retrieval failed: {e}')
return False
async def test_error_handling(self) -> dict[str, bool]:
"""Test error handling and edge cases."""
print('🧪 Testing error handling...')
results = {}
# Test with nonexistent group
print(' Testing nonexistent group handling...')
try:
result = await self.call_tool(
'search_memory_nodes',
{
'query': 'nonexistent data',
'group_ids': ['nonexistent_group_12345'],
'max_nodes': 5,
},
)
# Should handle gracefully, not crash
success = (
'error' not in str(result).lower() or 'not initialized' not in str(result).lower()
)
if success:
print(' ✅ Nonexistent group handled gracefully')
else:
print(f' ❌ Nonexistent group caused issues: {result}')
results['nonexistent_group'] = success
except Exception as e:
print(f' ❌ Nonexistent group test failed: {e}')
results['nonexistent_group'] = False
# Test empty query
print(' Testing empty query handling...')
try:
result = await self.call_tool(
'search_memory_nodes',
{'query': '', 'group_ids': [self.test_group_id], 'max_nodes': 5},
)
# Should handle gracefully
success = (
'error' not in str(result).lower() or 'not initialized' not in str(result).lower()
)
if success:
print(' ✅ Empty query handled gracefully')
else:
print(f' ❌ Empty query caused issues: {result}')
results['empty_query'] = success
except Exception as e:
print(f' ❌ Empty query test failed: {e}')
results['empty_query'] = False
return results
async def run_comprehensive_test(self) -> dict[str, Any]:
"""Run the complete integration test suite."""
print('🚀 Starting Comprehensive Graphiti MCP Server Integration Test')
print(f' Test group ID: {self.test_group_id}')
print('=' * 70)
results = {
'server_init': False,
'add_memory': {},
'processing_wait': False,
'search': {},
'episodes': False,
'error_handling': {},
'overall_success': False,
}
# Test 1: Server Initialization
results['server_init'] = await self.test_server_initialization()
if not results['server_init']:
print('❌ Server initialization failed, aborting remaining tests')
return results
print()
# Test 2: Add Memory Operations
results['add_memory'] = await self.test_add_memory_operations()
print()
# Test 3: Wait for Processing
results['processing_wait'] = await self.wait_for_processing()
print()
# Test 4: Search Operations
results['search'] = await self.test_search_operations()
print()
# Test 5: Episode Retrieval
results['episodes'] = await self.test_episode_retrieval()
print()
# Test 6: Error Handling
results['error_handling'] = await self.test_error_handling()
print()
# Calculate overall success
memory_success = any(results['add_memory'].values())
search_success = any(results['search'].values()) if results['search'] else False
error_success = (
any(results['error_handling'].values()) if results['error_handling'] else True
)
results['overall_success'] = (
results['server_init']
and memory_success
and (results['episodes'] or results['processing_wait'])
and error_success
)
# Print comprehensive summary
print('=' * 70)
print('📊 COMPREHENSIVE TEST SUMMARY')
print('-' * 35)
print(f'Server Initialization: {"✅ PASS" if results["server_init"] else "❌ FAIL"}')
memory_stats = f'({sum(results["add_memory"].values())}/{len(results["add_memory"])} types)'
print(
f'Memory Operations: {"✅ PASS" if memory_success else "❌ FAIL"} {memory_stats}'
)
print(f'Processing Pipeline: {"✅ PASS" if results["processing_wait"] else "❌ FAIL"}')
search_stats = (
f'({sum(results["search"].values())}/{len(results["search"])} types)'
if results['search']
else '(0/0 types)'
)
print(
f'Search Operations: {"✅ PASS" if search_success else "❌ FAIL"} {search_stats}'
)
print(f'Episode Retrieval: {"✅ PASS" if results["episodes"] else "❌ FAIL"}')
error_stats = (
f'({sum(results["error_handling"].values())}/{len(results["error_handling"])} cases)'
if results['error_handling']
else '(0/0 cases)'
)
print(
f'Error Handling: {"✅ PASS" if error_success else "❌ FAIL"} {error_stats}'
)
print('-' * 35)
print(f'🎯 OVERALL RESULT: {"✅ SUCCESS" if results["overall_success"] else "❌ FAILED"}')
if results['overall_success']:
print('\n🎉 The refactored Graphiti MCP server is working correctly!')
print(' All core functionality has been successfully tested.')
else:
print('\n⚠️ Some issues were detected. Review the test results above.')
print(' The refactoring may need additional attention.')
return results
async def main():
"""Run the integration test."""
try:
async with GraphitiMCPIntegrationTest() as test:
results = await test.run_comprehensive_test()
# Exit with appropriate code
exit_code = 0 if results['overall_success'] else 1
exit(exit_code)
except Exception as e:
print(f'❌ Test setup failed: {e}')
exit(1)
if __name__ == '__main__':
asyncio.run(main())

View file

@ -0,0 +1,199 @@
#!/usr/bin/env python3
"""
Simple validation test for the refactored Graphiti MCP Server.
Tests basic functionality quickly without timeouts.
"""
import subprocess
import sys
import time
def test_server_startup():
"""Test that the refactored server starts up successfully."""
print('🚀 Testing Graphiti MCP Server Startup...')
try:
# Start the server and capture output
process = subprocess.Popen(
['uv', 'run', 'graphiti_mcp_server.py', '--transport', 'stdio'],
env={
'NEO4J_URI': 'bolt://localhost:7687',
'NEO4J_USER': 'neo4j',
'NEO4J_PASSWORD': 'demodemo',
},
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
)
# Wait for startup logs
startup_output = ''
for _ in range(50): # Wait up to 5 seconds
if process.poll() is not None:
break
time.sleep(0.1)
# Check if we have output
try:
line = process.stderr.readline()
if line:
startup_output += line
print(f' 📋 {line.strip()}')
# Check for success indicators
if 'Graphiti client initialized successfully' in line:
print(' ✅ Graphiti service initialization: SUCCESS')
success = True
break
except Exception:
continue
else:
print(' ⚠️ Timeout waiting for initialization')
success = False
# Clean shutdown
process.terminate()
try:
process.wait(timeout=5)
except subprocess.TimeoutExpired:
process.kill()
return success
except Exception as e:
print(f' ❌ Server startup failed: {e}')
return False
def test_import_validation():
"""Test that all refactored modules import correctly."""
print('\n🔍 Testing Module Import Validation...')
modules_to_test = [
'config_manager',
'llm_config',
'embedder_config',
'neo4j_config',
'server_config',
'graphiti_service',
'queue_service',
'entity_types',
'response_types',
'formatting',
'utils',
]
success_count = 0
for module in modules_to_test:
try:
result = subprocess.run(
['python', '-c', f"import {module}; print('{module}')"],
capture_output=True,
text=True,
timeout=10,
)
if result.returncode == 0:
print(f'{module}: Import successful')
success_count += 1
else:
print(f'{module}: Import failed - {result.stderr.strip()}')
except subprocess.TimeoutExpired:
print(f'{module}: Import timeout')
except Exception as e:
print(f'{module}: Import error - {e}')
print(f' 📊 Import Results: {success_count}/{len(modules_to_test)} modules successful')
return success_count == len(modules_to_test)
def test_syntax_validation():
"""Test that all Python files have valid syntax."""
print('\n🔧 Testing Syntax Validation...')
files_to_test = [
'graphiti_mcp_server.py',
'config_manager.py',
'llm_config.py',
'embedder_config.py',
'neo4j_config.py',
'server_config.py',
'graphiti_service.py',
'queue_service.py',
'entity_types.py',
'response_types.py',
'formatting.py',
'utils.py',
]
success_count = 0
for file in files_to_test:
try:
result = subprocess.run(
['python', '-m', 'py_compile', file], capture_output=True, text=True, timeout=10
)
if result.returncode == 0:
print(f'{file}: Syntax valid')
success_count += 1
else:
print(f'{file}: Syntax error - {result.stderr.strip()}')
except subprocess.TimeoutExpired:
print(f'{file}: Syntax check timeout')
except Exception as e:
print(f'{file}: Syntax check error - {e}')
print(f' 📊 Syntax Results: {success_count}/{len(files_to_test)} files valid')
return success_count == len(files_to_test)
def main():
"""Run the validation tests."""
print('🧪 Graphiti MCP Server Refactoring Validation')
print('=' * 55)
results = {}
# Test 1: Syntax validation
results['syntax'] = test_syntax_validation()
# Test 2: Import validation
results['imports'] = test_import_validation()
# Test 3: Server startup
results['startup'] = test_server_startup()
# Summary
print('\n' + '=' * 55)
print('📊 VALIDATION SUMMARY')
print('-' * 25)
print(f'Syntax Validation: {"✅ PASS" if results["syntax"] else "❌ FAIL"}')
print(f'Import Validation: {"✅ PASS" if results["imports"] else "❌ FAIL"}')
print(f'Startup Validation: {"✅ PASS" if results["startup"] else "❌ FAIL"}')
overall_success = all(results.values())
print('-' * 25)
print(f'🎯 OVERALL: {"✅ SUCCESS" if overall_success else "❌ FAILED"}')
if overall_success:
print('\n🎉 Refactoring validation successful!')
print(' ✅ All modules have valid syntax')
print(' ✅ All imports work correctly')
print(' ✅ Server initializes successfully')
print(' ✅ The refactored MCP server is ready for use!')
else:
print('\n⚠️ Some validation issues detected.')
print(' Please review the failed tests above.')
return 0 if overall_success else 1
if __name__ == '__main__':
exit_code = main()
sys.exit(exit_code)

14
mcp_server/utils.py Normal file
View file

@ -0,0 +1,14 @@
"""Utility functions for Graphiti MCP Server."""
from collections.abc import Callable
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
def create_azure_credential_token_provider() -> Callable[[], str]:
"""Create Azure credential token provider for managed identity authentication."""
credential = DefaultAzureCredential()
token_provider = get_bearer_token_provider(
credential, 'https://cognitiveservices.azure.com/.default'
)
return token_provider

14
mcp_server/uv.lock generated
View file

@ -457,7 +457,7 @@ wheels = [
[[package]]
name = "mcp-server"
version = "0.1.0"
version = "0.2.1"
source = { virtual = "." }
dependencies = [
{ name = "azure-identity" },
@ -466,6 +466,12 @@ dependencies = [
{ name = "openai" },
]
[package.dev-dependencies]
dev = [
{ name = "httpx" },
{ name = "mcp" },
]
[package.metadata]
requires-dist = [
{ name = "azure-identity", specifier = ">=1.21.0" },
@ -475,6 +481,12 @@ requires-dist = [
{ name = "openai", specifier = ">=1.68.2" },
]
[package.metadata.requires-dev]
dev = [
{ name = "httpx", specifier = ">=0.28.1" },
{ name = "mcp", specifier = ">=1.9.4" },
]
[[package]]
name = "msal"
version = "1.32.3"