From 452a45cb4e2db6108fb2c1d3995ab7aab4a5eabd Mon Sep 17 00:00:00 2001 From: Daniel Chalef <131175+danielchalef@users.noreply.github.com> Date: Tue, 8 Jul 2025 22:20:26 -0700 Subject: [PATCH] wip --- mcp_server/Dockerfile | 2 +- mcp_server/config_manager.py | 51 ++ mcp_server/embedder_config.py | 124 +++++ mcp_server/entity_types.py | 83 ++++ mcp_server/formatting.py | 26 + mcp_server/graphiti_mcp_server.py | 677 ++------------------------- mcp_server/graphiti_service.py | 110 +++++ mcp_server/llm_config.py | 182 +++++++ mcp_server/neo4j_config.py | 22 + mcp_server/pyproject.toml | 6 + mcp_server/queue_service.py | 86 ++++ mcp_server/response_types.py | 41 ++ mcp_server/server_config.py | 16 + mcp_server/test_integration.py | 363 ++++++++++++++ mcp_server/test_mcp_integration.py | 502 ++++++++++++++++++++ mcp_server/test_simple_validation.py | 199 ++++++++ mcp_server/utils.py | 14 + mcp_server/uv.lock | 14 +- 18 files changed, 1877 insertions(+), 641 deletions(-) create mode 100644 mcp_server/config_manager.py create mode 100644 mcp_server/embedder_config.py create mode 100644 mcp_server/entity_types.py create mode 100644 mcp_server/formatting.py create mode 100644 mcp_server/graphiti_service.py create mode 100644 mcp_server/llm_config.py create mode 100644 mcp_server/neo4j_config.py create mode 100644 mcp_server/queue_service.py create mode 100644 mcp_server/response_types.py create mode 100644 mcp_server/server_config.py create mode 100644 mcp_server/test_integration.py create mode 100644 mcp_server/test_mcp_integration.py create mode 100644 mcp_server/test_simple_validation.py create mode 100644 mcp_server/utils.py diff --git a/mcp_server/Dockerfile b/mcp_server/Dockerfile index e6b84c87..2de8eab7 100644 --- a/mcp_server/Dockerfile +++ b/mcp_server/Dockerfile @@ -34,7 +34,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \ uv sync --frozen --no-dev # Copy application code -COPY graphiti_mcp_server.py ./ +COPY *.py ./ # Change ownership to app user RUN chown -Rv app:app /app diff --git a/mcp_server/config_manager.py b/mcp_server/config_manager.py new file mode 100644 index 00000000..d836ae52 --- /dev/null +++ b/mcp_server/config_manager.py @@ -0,0 +1,51 @@ +"""Unified configuration manager for Graphiti MCP Server.""" + +import argparse + +from embedder_config import GraphitiEmbedderConfig +from llm_config import GraphitiLLMConfig +from neo4j_config import Neo4jConfig +from pydantic import BaseModel, Field + + +class GraphitiConfig(BaseModel): + """Configuration for Graphiti client. + + Centralizes all configuration parameters for the Graphiti client. + """ + + llm: GraphitiLLMConfig = Field(default_factory=GraphitiLLMConfig) + embedder: GraphitiEmbedderConfig = Field(default_factory=GraphitiEmbedderConfig) + neo4j: Neo4jConfig = Field(default_factory=Neo4jConfig) + group_id: str | None = None + use_custom_entities: bool = False + destroy_graph: bool = False + + @classmethod + def from_env(cls) -> 'GraphitiConfig': + """Create a configuration instance from environment variables.""" + return cls( + llm=GraphitiLLMConfig.from_env(), + embedder=GraphitiEmbedderConfig.from_env(), + neo4j=Neo4jConfig.from_env(), + ) + + @classmethod + def from_cli_and_env(cls, args: argparse.Namespace) -> 'GraphitiConfig': + """Create configuration from CLI arguments, falling back to environment variables.""" + # Start with environment configuration + config = cls.from_env() + + # Apply CLI overrides + if args.group_id: + config.group_id = args.group_id + else: + config.group_id = 'default' + + config.use_custom_entities = args.use_custom_entities + config.destroy_graph = args.destroy_graph + + # Update LLM config using CLI args + config.llm = GraphitiLLMConfig.from_cli_and_env(args) + + return config diff --git a/mcp_server/embedder_config.py b/mcp_server/embedder_config.py new file mode 100644 index 00000000..cec8922e --- /dev/null +++ b/mcp_server/embedder_config.py @@ -0,0 +1,124 @@ +"""Embedder configuration for Graphiti MCP Server.""" + +import logging +import os + +from openai import AsyncAzureOpenAI +from pydantic import BaseModel +from utils import create_azure_credential_token_provider + +from graphiti_core.embedder.azure_openai import AzureOpenAIEmbedderClient +from graphiti_core.embedder.client import EmbedderClient +from graphiti_core.embedder.openai import OpenAIEmbedder, OpenAIEmbedderConfig + +logger = logging.getLogger(__name__) + +DEFAULT_EMBEDDER_MODEL = 'text-embedding-3-small' + + +class GraphitiEmbedderConfig(BaseModel): + """Configuration for the embedder client. + + Centralizes all embedding-related configuration parameters. + """ + + model: str = DEFAULT_EMBEDDER_MODEL + api_key: str | None = None + azure_openai_endpoint: str | None = None + azure_openai_deployment_name: str | None = None + azure_openai_api_version: str | None = None + azure_openai_use_managed_identity: bool = False + + @classmethod + def from_env(cls) -> 'GraphitiEmbedderConfig': + """Create embedder configuration from environment variables.""" + # Get model from environment, or use default if not set or empty + model_env = os.environ.get('EMBEDDER_MODEL_NAME', '') + model = model_env if model_env.strip() else DEFAULT_EMBEDDER_MODEL + + azure_openai_endpoint = os.environ.get('AZURE_OPENAI_EMBEDDING_ENDPOINT', None) + azure_openai_api_version = os.environ.get('AZURE_OPENAI_EMBEDDING_API_VERSION', None) + azure_openai_deployment_name = os.environ.get( + 'AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME', None + ) + azure_openai_use_managed_identity = ( + os.environ.get('AZURE_OPENAI_USE_MANAGED_IDENTITY', 'false').lower() == 'true' + ) + + if azure_openai_endpoint is not None: + # Setup for Azure OpenAI API + # Log if empty deployment name was provided + azure_openai_deployment_name = os.environ.get( + 'AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME', None + ) + if azure_openai_deployment_name is None: + logger.error('AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME environment variable not set') + raise ValueError( + 'AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME environment variable not set' + ) + + if not azure_openai_use_managed_identity: + # api key + api_key = os.environ.get('AZURE_OPENAI_EMBEDDING_API_KEY', None) or os.environ.get( + 'OPENAI_API_KEY', None + ) + else: + # Managed identity + api_key = None + + return cls( + azure_openai_use_managed_identity=azure_openai_use_managed_identity, + azure_openai_endpoint=azure_openai_endpoint, + api_key=api_key, + azure_openai_api_version=azure_openai_api_version, + azure_openai_deployment_name=azure_openai_deployment_name, + model=model, + ) + else: + return cls( + model=model, + api_key=os.environ.get('OPENAI_API_KEY'), + ) + + def create_client(self) -> EmbedderClient | None: + """Create an embedder client based on this configuration. + + Returns: + EmbedderClient instance or None if configuration is invalid + """ + if self.azure_openai_endpoint is not None: + # Azure OpenAI API setup + if self.azure_openai_use_managed_identity: + # Use managed identity for authentication + token_provider = create_azure_credential_token_provider() + return AzureOpenAIEmbedderClient( + azure_client=AsyncAzureOpenAI( + azure_endpoint=self.azure_openai_endpoint, + azure_deployment=self.azure_openai_deployment_name, + api_version=self.azure_openai_api_version, + azure_ad_token_provider=token_provider, + ), + model=self.model, + ) + elif self.api_key: + # Use API key for authentication + return AzureOpenAIEmbedderClient( + azure_client=AsyncAzureOpenAI( + azure_endpoint=self.azure_openai_endpoint, + azure_deployment=self.azure_openai_deployment_name, + api_version=self.azure_openai_api_version, + api_key=self.api_key, + ), + model=self.model, + ) + else: + logger.error('OPENAI_API_KEY must be set when using Azure OpenAI API') + return None + else: + # OpenAI API setup + if not self.api_key: + return None + + embedder_config = OpenAIEmbedderConfig(api_key=self.api_key, embedding_model=self.model) + + return OpenAIEmbedder(config=embedder_config) diff --git a/mcp_server/entity_types.py b/mcp_server/entity_types.py new file mode 100644 index 00000000..2b6513db --- /dev/null +++ b/mcp_server/entity_types.py @@ -0,0 +1,83 @@ +"""Entity type definitions for Graphiti MCP Server.""" + +from pydantic import BaseModel, Field + + +class Requirement(BaseModel): + """A Requirement represents a specific need, feature, or functionality that a product or service must fulfill. + + Always ensure an edge is created between the requirement and the project it belongs to, and clearly indicate on the + edge that the requirement is a requirement. + + Instructions for identifying and extracting requirements: + 1. Look for explicit statements of needs or necessities ("We need X", "X is required", "X must have Y") + 2. Identify functional specifications that describe what the system should do + 3. Pay attention to non-functional requirements like performance, security, or usability criteria + 4. Extract constraints or limitations that must be adhered to + 5. Focus on clear, specific, and measurable requirements rather than vague wishes + 6. Capture the priority or importance if mentioned ("critical", "high priority", etc.) + 7. Include any dependencies between requirements when explicitly stated + 8. Preserve the original intent and scope of the requirement + 9. Categorize requirements appropriately based on their domain or function + """ + + project_name: str = Field( + ..., + description='The name of the project to which the requirement belongs.', + ) + description: str = Field( + ..., + description='Description of the requirement. Only use information mentioned in the context to write this description.', + ) + + +class Preference(BaseModel): + """A Preference represents a user's expressed like, dislike, or preference for something. + + Instructions for identifying and extracting preferences: + 1. Look for explicit statements of preference such as "I like/love/enjoy/prefer X" or "I don't like/hate/dislike X" + 2. Pay attention to comparative statements ("I prefer X over Y") + 3. Consider the emotional tone when users mention certain topics + 4. Extract only preferences that are clearly expressed, not assumptions + 5. Categorize the preference appropriately based on its domain (food, music, brands, etc.) + 6. Include relevant qualifiers (e.g., "likes spicy food" rather than just "likes food") + 7. Only extract preferences directly stated by the user, not preferences of others they mention + 8. Provide a concise but specific description that captures the nature of the preference + """ + + category: str = Field( + ..., + description="The category of the preference. (e.g., 'Brands', 'Food', 'Music')", + ) + description: str = Field( + ..., + description='Brief description of the preference. Only use information mentioned in the context to write this description.', + ) + + +class Procedure(BaseModel): + """A Procedure informing the agent what actions to take or how to perform in certain scenarios. Procedures are typically composed of several steps. + + Instructions for identifying and extracting procedures: + 1. Look for sequential instructions or steps ("First do X, then do Y") + 2. Identify explicit directives or commands ("Always do X when Y happens") + 3. Pay attention to conditional statements ("If X occurs, then do Y") + 4. Extract procedures that have clear beginning and end points + 5. Focus on actionable instructions rather than general information + 6. Preserve the original sequence and dependencies between steps + 7. Include any specified conditions or triggers for the procedure + 8. Capture any stated purpose or goal of the procedure + 9. Summarize complex procedures while maintaining critical details + """ + + description: str = Field( + ..., + description='Brief description of the procedure. Only use information mentioned in the context to write this description.', + ) + + +ENTITY_TYPES: dict[str, BaseModel] = { + 'Requirement': Requirement, # type: ignore + 'Preference': Preference, # type: ignore + 'Procedure': Procedure, # type: ignore +} diff --git a/mcp_server/formatting.py b/mcp_server/formatting.py new file mode 100644 index 00000000..53ff937e --- /dev/null +++ b/mcp_server/formatting.py @@ -0,0 +1,26 @@ +"""Formatting utilities for Graphiti MCP Server.""" + +from typing import Any + +from graphiti_core.edges import EntityEdge + + +def format_fact_result(edge: EntityEdge) -> dict[str, Any]: + """Format an entity edge into a readable result. + + Since EntityEdge is a Pydantic BaseModel, we can use its built-in serialization capabilities. + + Args: + edge: The EntityEdge to format + + Returns: + A dictionary representation of the edge with serialized dates and excluded embeddings + """ + result = edge.model_dump( + mode='json', + exclude={ + 'fact_embedding', + }, + ) + result.get('attributes', {}).pop('fact_embedding', None) + return result diff --git a/mcp_server/graphiti_mcp_server.py b/mcp_server/graphiti_mcp_server.py index 9b382074..f25f2bc7 100644 --- a/mcp_server/graphiti_mcp_server.py +++ b/mcp_server/graphiti_mcp_server.py @@ -8,25 +8,30 @@ import asyncio import logging import os import sys -from collections.abc import Callable from datetime import datetime, timezone -from typing import Any, TypedDict, cast +from typing import Any, cast -from azure.identity import DefaultAzureCredential, get_bearer_token_provider +from config_manager import GraphitiConfig from dotenv import load_dotenv +from entity_types import ENTITY_TYPES +from formatting import format_fact_result +from graphiti_service import GraphitiService +from llm_config import DEFAULT_LLM_MODEL, SMALL_LLM_MODEL from mcp.server.fastmcp import FastMCP -from openai import AsyncAzureOpenAI -from pydantic import BaseModel, Field +from queue_service import QueueService +from response_types import ( + EpisodeSearchResponse, + ErrorResponse, + FactSearchResponse, + NodeResult, + NodeSearchResponse, + StatusResponse, + SuccessResponse, +) +from server_config import MCPConfig from graphiti_core import Graphiti from graphiti_core.edges import EntityEdge -from graphiti_core.embedder.azure_openai import AzureOpenAIEmbedderClient -from graphiti_core.embedder.client import EmbedderClient -from graphiti_core.embedder.openai import OpenAIEmbedder, OpenAIEmbedderConfig -from graphiti_core.llm_client import LLMClient -from graphiti_core.llm_client.azure_openai_client import AzureOpenAILLMClient -from graphiti_core.llm_client.config import LLMConfig -from graphiti_core.llm_client.openai_client import OpenAIClient from graphiti_core.nodes import EpisodeType, EpisodicNode from graphiti_core.search.search_config_recipes import ( NODE_HYBRID_SEARCH_NODE_DISTANCE, @@ -38,488 +43,12 @@ from graphiti_core.utils.maintenance.graph_data_operations import clear_data load_dotenv() -DEFAULT_LLM_MODEL = 'gpt-4.1-mini' -SMALL_LLM_MODEL = 'gpt-4.1-nano' -DEFAULT_EMBEDDER_MODEL = 'text-embedding-3-small' - # Semaphore limit for concurrent Graphiti operations. # Decrease this if you're experiencing 429 rate limit errors from your LLM provider. # Increase if you have high rate limits. SEMAPHORE_LIMIT = int(os.getenv('SEMAPHORE_LIMIT', 10)) -class Requirement(BaseModel): - """A Requirement represents a specific need, feature, or functionality that a product or service must fulfill. - - Always ensure an edge is created between the requirement and the project it belongs to, and clearly indicate on the - edge that the requirement is a requirement. - - Instructions for identifying and extracting requirements: - 1. Look for explicit statements of needs or necessities ("We need X", "X is required", "X must have Y") - 2. Identify functional specifications that describe what the system should do - 3. Pay attention to non-functional requirements like performance, security, or usability criteria - 4. Extract constraints or limitations that must be adhered to - 5. Focus on clear, specific, and measurable requirements rather than vague wishes - 6. Capture the priority or importance if mentioned ("critical", "high priority", etc.) - 7. Include any dependencies between requirements when explicitly stated - 8. Preserve the original intent and scope of the requirement - 9. Categorize requirements appropriately based on their domain or function - """ - - project_name: str = Field( - ..., - description='The name of the project to which the requirement belongs.', - ) - description: str = Field( - ..., - description='Description of the requirement. Only use information mentioned in the context to write this description.', - ) - - -class Preference(BaseModel): - """A Preference represents a user's expressed like, dislike, or preference for something. - - Instructions for identifying and extracting preferences: - 1. Look for explicit statements of preference such as "I like/love/enjoy/prefer X" or "I don't like/hate/dislike X" - 2. Pay attention to comparative statements ("I prefer X over Y") - 3. Consider the emotional tone when users mention certain topics - 4. Extract only preferences that are clearly expressed, not assumptions - 5. Categorize the preference appropriately based on its domain (food, music, brands, etc.) - 6. Include relevant qualifiers (e.g., "likes spicy food" rather than just "likes food") - 7. Only extract preferences directly stated by the user, not preferences of others they mention - 8. Provide a concise but specific description that captures the nature of the preference - """ - - category: str = Field( - ..., - description="The category of the preference. (e.g., 'Brands', 'Food', 'Music')", - ) - description: str = Field( - ..., - description='Brief description of the preference. Only use information mentioned in the context to write this description.', - ) - - -class Procedure(BaseModel): - """A Procedure informing the agent what actions to take or how to perform in certain scenarios. Procedures are typically composed of several steps. - - Instructions for identifying and extracting procedures: - 1. Look for sequential instructions or steps ("First do X, then do Y") - 2. Identify explicit directives or commands ("Always do X when Y happens") - 3. Pay attention to conditional statements ("If X occurs, then do Y") - 4. Extract procedures that have clear beginning and end points - 5. Focus on actionable instructions rather than general information - 6. Preserve the original sequence and dependencies between steps - 7. Include any specified conditions or triggers for the procedure - 8. Capture any stated purpose or goal of the procedure - 9. Summarize complex procedures while maintaining critical details - """ - - description: str = Field( - ..., - description='Brief description of the procedure. Only use information mentioned in the context to write this description.', - ) - - -ENTITY_TYPES: dict[str, BaseModel] = { - 'Requirement': Requirement, # type: ignore - 'Preference': Preference, # type: ignore - 'Procedure': Procedure, # type: ignore -} - - -# Type definitions for API responses -class ErrorResponse(TypedDict): - error: str - - -class SuccessResponse(TypedDict): - message: str - - -class NodeResult(TypedDict): - uuid: str - name: str - summary: str - labels: list[str] - group_id: str - created_at: str - attributes: dict[str, Any] - - -class NodeSearchResponse(TypedDict): - message: str - nodes: list[NodeResult] - - -class FactSearchResponse(TypedDict): - message: str - facts: list[dict[str, Any]] - - -class EpisodeSearchResponse(TypedDict): - message: str - episodes: list[dict[str, Any]] - - -class StatusResponse(TypedDict): - status: str - message: str - - -def create_azure_credential_token_provider() -> Callable[[], str]: - credential = DefaultAzureCredential() - token_provider = get_bearer_token_provider( - credential, 'https://cognitiveservices.azure.com/.default' - ) - return token_provider - - -# Server configuration classes -# The configuration system has a hierarchy: -# - GraphitiConfig is the top-level configuration -# - LLMConfig handles all OpenAI/LLM related settings -# - EmbedderConfig manages embedding settings -# - Neo4jConfig manages database connection details -# - Various other settings like group_id and feature flags -# Configuration values are loaded from: -# 1. Default values in the class definitions -# 2. Environment variables (loaded via load_dotenv()) -# 3. Command line arguments (which override environment variables) -class GraphitiLLMConfig(BaseModel): - """Configuration for the LLM client. - - Centralizes all LLM-specific configuration parameters including API keys and model selection. - """ - - api_key: str | None = None - model: str = DEFAULT_LLM_MODEL - small_model: str = SMALL_LLM_MODEL - temperature: float = 0.0 - azure_openai_endpoint: str | None = None - azure_openai_deployment_name: str | None = None - azure_openai_api_version: str | None = None - azure_openai_use_managed_identity: bool = False - - @classmethod - def from_env(cls) -> 'GraphitiLLMConfig': - """Create LLM configuration from environment variables.""" - # Get model from environment, or use default if not set or empty - model_env = os.environ.get('MODEL_NAME', '') - model = model_env if model_env.strip() else DEFAULT_LLM_MODEL - - # Get small_model from environment, or use default if not set or empty - small_model_env = os.environ.get('SMALL_MODEL_NAME', '') - small_model = small_model_env if small_model_env.strip() else SMALL_LLM_MODEL - - azure_openai_endpoint = os.environ.get('AZURE_OPENAI_ENDPOINT', None) - azure_openai_api_version = os.environ.get('AZURE_OPENAI_API_VERSION', None) - azure_openai_deployment_name = os.environ.get('AZURE_OPENAI_DEPLOYMENT_NAME', None) - azure_openai_use_managed_identity = ( - os.environ.get('AZURE_OPENAI_USE_MANAGED_IDENTITY', 'false').lower() == 'true' - ) - - if azure_openai_endpoint is None: - # Setup for OpenAI API - # Log if empty model was provided - if model_env == '': - logger.debug( - f'MODEL_NAME environment variable not set, using default: {DEFAULT_LLM_MODEL}' - ) - elif not model_env.strip(): - logger.warning( - f'Empty MODEL_NAME environment variable, using default: {DEFAULT_LLM_MODEL}' - ) - - return cls( - api_key=os.environ.get('OPENAI_API_KEY'), - model=model, - small_model=small_model, - temperature=float(os.environ.get('LLM_TEMPERATURE', '0.0')), - ) - else: - # Setup for Azure OpenAI API - # Log if empty deployment name was provided - if azure_openai_deployment_name is None: - logger.error('AZURE_OPENAI_DEPLOYMENT_NAME environment variable not set') - - raise ValueError('AZURE_OPENAI_DEPLOYMENT_NAME environment variable not set') - if not azure_openai_use_managed_identity: - # api key - api_key = os.environ.get('OPENAI_API_KEY', None) - else: - # Managed identity - api_key = None - - return cls( - azure_openai_use_managed_identity=azure_openai_use_managed_identity, - azure_openai_endpoint=azure_openai_endpoint, - api_key=api_key, - azure_openai_api_version=azure_openai_api_version, - azure_openai_deployment_name=azure_openai_deployment_name, - model=model, - small_model=small_model, - temperature=float(os.environ.get('LLM_TEMPERATURE', '0.0')), - ) - - @classmethod - def from_cli_and_env(cls, args: argparse.Namespace) -> 'GraphitiLLMConfig': - """Create LLM configuration from CLI arguments, falling back to environment variables.""" - # Start with environment-based config - config = cls.from_env() - - # CLI arguments override environment variables when provided - if hasattr(args, 'model') and args.model: - # Only use CLI model if it's not empty - if args.model.strip(): - config.model = args.model - else: - # Log that empty model was provided and default is used - logger.warning(f'Empty model name provided, using default: {DEFAULT_LLM_MODEL}') - - if hasattr(args, 'small_model') and args.small_model: - if args.small_model.strip(): - config.small_model = args.small_model - else: - logger.warning(f'Empty small_model name provided, using default: {SMALL_LLM_MODEL}') - - if hasattr(args, 'temperature') and args.temperature is not None: - config.temperature = args.temperature - - return config - - def create_client(self) -> LLMClient: - """Create an LLM client based on this configuration. - - Returns: - LLMClient instance - """ - - if self.azure_openai_endpoint is not None: - # Azure OpenAI API setup - if self.azure_openai_use_managed_identity: - # Use managed identity for authentication - token_provider = create_azure_credential_token_provider() - return AzureOpenAILLMClient( - azure_client=AsyncAzureOpenAI( - azure_endpoint=self.azure_openai_endpoint, - azure_deployment=self.azure_openai_deployment_name, - api_version=self.azure_openai_api_version, - azure_ad_token_provider=token_provider, - ), - config=LLMConfig( - api_key=self.api_key, - model=self.model, - small_model=self.small_model, - temperature=self.temperature, - ), - ) - elif self.api_key: - # Use API key for authentication - return AzureOpenAILLMClient( - azure_client=AsyncAzureOpenAI( - azure_endpoint=self.azure_openai_endpoint, - azure_deployment=self.azure_openai_deployment_name, - api_version=self.azure_openai_api_version, - api_key=self.api_key, - ), - config=LLMConfig( - api_key=self.api_key, - model=self.model, - small_model=self.small_model, - temperature=self.temperature, - ), - ) - else: - raise ValueError('OPENAI_API_KEY must be set when using Azure OpenAI API') - - if not self.api_key: - raise ValueError('OPENAI_API_KEY must be set when using OpenAI API') - - llm_client_config = LLMConfig( - api_key=self.api_key, model=self.model, small_model=self.small_model - ) - - # Set temperature - llm_client_config.temperature = self.temperature - - return OpenAIClient(config=llm_client_config) - - -class GraphitiEmbedderConfig(BaseModel): - """Configuration for the embedder client. - - Centralizes all embedding-related configuration parameters. - """ - - model: str = DEFAULT_EMBEDDER_MODEL - api_key: str | None = None - azure_openai_endpoint: str | None = None - azure_openai_deployment_name: str | None = None - azure_openai_api_version: str | None = None - azure_openai_use_managed_identity: bool = False - - @classmethod - def from_env(cls) -> 'GraphitiEmbedderConfig': - """Create embedder configuration from environment variables.""" - - # Get model from environment, or use default if not set or empty - model_env = os.environ.get('EMBEDDER_MODEL_NAME', '') - model = model_env if model_env.strip() else DEFAULT_EMBEDDER_MODEL - - azure_openai_endpoint = os.environ.get('AZURE_OPENAI_EMBEDDING_ENDPOINT', None) - azure_openai_api_version = os.environ.get('AZURE_OPENAI_EMBEDDING_API_VERSION', None) - azure_openai_deployment_name = os.environ.get( - 'AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME', None - ) - azure_openai_use_managed_identity = ( - os.environ.get('AZURE_OPENAI_USE_MANAGED_IDENTITY', 'false').lower() == 'true' - ) - if azure_openai_endpoint is not None: - # Setup for Azure OpenAI API - # Log if empty deployment name was provided - azure_openai_deployment_name = os.environ.get( - 'AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME', None - ) - if azure_openai_deployment_name is None: - logger.error('AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME environment variable not set') - - raise ValueError( - 'AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME environment variable not set' - ) - - if not azure_openai_use_managed_identity: - # api key - api_key = os.environ.get('AZURE_OPENAI_EMBEDDING_API_KEY', None) or os.environ.get( - 'OPENAI_API_KEY', None - ) - else: - # Managed identity - api_key = None - - return cls( - azure_openai_use_managed_identity=azure_openai_use_managed_identity, - azure_openai_endpoint=azure_openai_endpoint, - api_key=api_key, - azure_openai_api_version=azure_openai_api_version, - azure_openai_deployment_name=azure_openai_deployment_name, - ) - else: - return cls( - model=model, - api_key=os.environ.get('OPENAI_API_KEY'), - ) - - def create_client(self) -> EmbedderClient | None: - if self.azure_openai_endpoint is not None: - # Azure OpenAI API setup - if self.azure_openai_use_managed_identity: - # Use managed identity for authentication - token_provider = create_azure_credential_token_provider() - return AzureOpenAIEmbedderClient( - azure_client=AsyncAzureOpenAI( - azure_endpoint=self.azure_openai_endpoint, - azure_deployment=self.azure_openai_deployment_name, - api_version=self.azure_openai_api_version, - azure_ad_token_provider=token_provider, - ), - model=self.model, - ) - elif self.api_key: - # Use API key for authentication - return AzureOpenAIEmbedderClient( - azure_client=AsyncAzureOpenAI( - azure_endpoint=self.azure_openai_endpoint, - azure_deployment=self.azure_openai_deployment_name, - api_version=self.azure_openai_api_version, - api_key=self.api_key, - ), - model=self.model, - ) - else: - logger.error('OPENAI_API_KEY must be set when using Azure OpenAI API') - return None - else: - # OpenAI API setup - if not self.api_key: - return None - - embedder_config = OpenAIEmbedderConfig(api_key=self.api_key, embedding_model=self.model) - - return OpenAIEmbedder(config=embedder_config) - - -class Neo4jConfig(BaseModel): - """Configuration for Neo4j database connection.""" - - uri: str = 'bolt://localhost:7687' - user: str = 'neo4j' - password: str = 'password' - - @classmethod - def from_env(cls) -> 'Neo4jConfig': - """Create Neo4j configuration from environment variables.""" - return cls( - uri=os.environ.get('NEO4J_URI', 'bolt://localhost:7687'), - user=os.environ.get('NEO4J_USER', 'neo4j'), - password=os.environ.get('NEO4J_PASSWORD', 'password'), - ) - - -class GraphitiConfig(BaseModel): - """Configuration for Graphiti client. - - Centralizes all configuration parameters for the Graphiti client. - """ - - llm: GraphitiLLMConfig = Field(default_factory=GraphitiLLMConfig) - embedder: GraphitiEmbedderConfig = Field(default_factory=GraphitiEmbedderConfig) - neo4j: Neo4jConfig = Field(default_factory=Neo4jConfig) - group_id: str | None = None - use_custom_entities: bool = False - destroy_graph: bool = False - - @classmethod - def from_env(cls) -> 'GraphitiConfig': - """Create a configuration instance from environment variables.""" - return cls( - llm=GraphitiLLMConfig.from_env(), - embedder=GraphitiEmbedderConfig.from_env(), - neo4j=Neo4jConfig.from_env(), - ) - - @classmethod - def from_cli_and_env(cls, args: argparse.Namespace) -> 'GraphitiConfig': - """Create configuration from CLI arguments, falling back to environment variables.""" - # Start with environment configuration - config = cls.from_env() - - # Apply CLI overrides - if args.group_id: - config.group_id = args.group_id - else: - config.group_id = 'default' - - config.use_custom_entities = args.use_custom_entities - config.destroy_graph = args.destroy_graph - - # Update LLM config using CLI args - config.llm = GraphitiLLMConfig.from_cli_and_env(args) - - return config - - -class MCPConfig(BaseModel): - """Configuration for MCP server.""" - - transport: str = 'sse' # Default to SSE transport - - @classmethod - def from_cli(cls, args: argparse.Namespace) -> 'MCPConfig': - """Create MCP configuration from CLI arguments.""" - return cls(transport=args.transport) - - # Configure logging logging.basicConfig( level=logging.INFO, @@ -568,124 +97,9 @@ mcp = FastMCP( instructions=GRAPHITI_MCP_INSTRUCTIONS, ) -# Initialize Graphiti client -graphiti_client: Graphiti | None = None - - -async def initialize_graphiti(): - """Initialize the Graphiti client with the configured settings.""" - global graphiti_client, config - - try: - # Create LLM client if possible - llm_client = config.llm.create_client() - if not llm_client and config.use_custom_entities: - # If custom entities are enabled, we must have an LLM client - raise ValueError('OPENAI_API_KEY must be set when custom entities are enabled') - - # Validate Neo4j configuration - if not config.neo4j.uri or not config.neo4j.user or not config.neo4j.password: - raise ValueError('NEO4J_URI, NEO4J_USER, and NEO4J_PASSWORD must be set') - - embedder_client = config.embedder.create_client() - - # Initialize Graphiti client - graphiti_client = Graphiti( - uri=config.neo4j.uri, - user=config.neo4j.user, - password=config.neo4j.password, - llm_client=llm_client, - embedder=embedder_client, - max_coroutines=SEMAPHORE_LIMIT, - ) - - # Destroy graph if requested - if config.destroy_graph: - logger.info('Destroying graph...') - await clear_data(graphiti_client.driver) - - # Initialize the graph database with Graphiti's indices - await graphiti_client.build_indices_and_constraints() - logger.info('Graphiti client initialized successfully') - - # Log configuration details for transparency - if llm_client: - logger.info(f'Using OpenAI model: {config.llm.model}') - logger.info(f'Using temperature: {config.llm.temperature}') - else: - logger.info('No LLM client configured - entity extraction will be limited') - - logger.info(f'Using group_id: {config.group_id}') - logger.info( - f'Custom entity extraction: {"enabled" if config.use_custom_entities else "disabled"}' - ) - logger.info(f'Using concurrency limit: {SEMAPHORE_LIMIT}') - - except Exception as e: - logger.error(f'Failed to initialize Graphiti: {str(e)}') - raise - - -def format_fact_result(edge: EntityEdge) -> dict[str, Any]: - """Format an entity edge into a readable result. - - Since EntityEdge is a Pydantic BaseModel, we can use its built-in serialization capabilities. - - Args: - edge: The EntityEdge to format - - Returns: - A dictionary representation of the edge with serialized dates and excluded embeddings - """ - result = edge.model_dump( - mode='json', - exclude={ - 'fact_embedding', - }, - ) - result.get('attributes', {}).pop('fact_embedding', None) - return result - - -# Dictionary to store queues for each group_id -# Each queue is a list of tasks to be processed sequentially -episode_queues: dict[str, asyncio.Queue] = {} -# Dictionary to track if a worker is running for each group_id -queue_workers: dict[str, bool] = {} - - -async def process_episode_queue(group_id: str): - """Process episodes for a specific group_id sequentially. - - This function runs as a long-lived task that processes episodes - from the queue one at a time. - """ - global queue_workers - - logger.info(f'Starting episode queue worker for group_id: {group_id}') - queue_workers[group_id] = True - - try: - while True: - # Get the next episode processing function from the queue - # This will wait if the queue is empty - process_func = await episode_queues[group_id].get() - - try: - # Process the episode - await process_func() - except Exception as e: - logger.error(f'Error processing queued episode for group_id {group_id}: {str(e)}') - finally: - # Mark the task as done regardless of success/failure - episode_queues[group_id].task_done() - except asyncio.CancelledError: - logger.info(f'Episode queue worker for group_id {group_id} was cancelled') - except Exception as e: - logger.error(f'Unexpected error in queue worker for group_id {group_id}: {str(e)}') - finally: - queue_workers[group_id] = False - logger.info(f'Stopped episode queue worker for group_id: {group_id}') +# Global services +graphiti_service: GraphitiService | None = None +queue_service: QueueService | None = None @mcp.tool() @@ -752,10 +166,13 @@ async def add_memory( - Entities will be created from appropriate JSON properties - Relationships between entities will be established based on the JSON structure """ - global graphiti_client, episode_queues, queue_workers + global graphiti_service, queue_service, config - if graphiti_client is None: - return ErrorResponse(error='Graphiti client not initialized') + if not graphiti_service or not graphiti_service.is_initialized(): + return ErrorResponse(error='Graphiti service not initialized') + + if not queue_service: + return ErrorResponse(error='Queue service not initialized') try: # Map string source to EpisodeType enum @@ -772,13 +189,6 @@ async def add_memory( # The Graphiti client expects a str for group_id, not Optional[str] group_id_str = str(effective_group_id) if effective_group_id is not None else '' - # We've already checked that graphiti_client is not None above - # This assert statement helps type checkers understand that graphiti_client is defined - assert graphiti_client is not None, 'graphiti_client should not be None here' - - # Use cast to help the type checker understand that graphiti_client is not None - client = cast(Graphiti, graphiti_client) - # Define the episode processing function async def process_episode(): try: @@ -786,7 +196,7 @@ async def add_memory( # Use all entity types if use_custom_entities is enabled, otherwise use empty dict entity_types = ENTITY_TYPES if config.use_custom_entities else {} - await client.add_episode( + await graphiti_service.client.add_episode( name=name, episode_body=episode_body, source=source_type, @@ -796,8 +206,6 @@ async def add_memory( reference_time=datetime.now(timezone.utc), entity_types=entity_types, ) - logger.info(f"Episode '{name}' added successfully") - logger.info(f"Episode '{name}' processed successfully") except Exception as e: error_msg = str(e) @@ -805,20 +213,12 @@ async def add_memory( f"Error processing episode '{name}' for group_id {group_id_str}: {error_msg}" ) - # Initialize queue for this group_id if it doesn't exist - if group_id_str not in episode_queues: - episode_queues[group_id_str] = asyncio.Queue() - # Add the episode processing function to the queue - await episode_queues[group_id_str].put(process_episode) - - # Start a worker for this queue if one isn't already running - if not queue_workers.get(group_id_str, False): - asyncio.create_task(process_episode_queue(group_id_str)) + queue_position = await queue_service.add_episode_task(group_id_str, process_episode) # Return immediately with a success message return SuccessResponse( - message=f"Episode '{name}' queued for processing (position: {episode_queues[group_id_str].qsize()})" + message=f"Episode '{name}' queued for processing (position: {queue_position})" ) except Exception as e: error_msg = str(e) @@ -846,10 +246,10 @@ async def search_memory_nodes( center_node_uuid: Optional UUID of a node to center the search around entity: Optional single entity type to filter results (permitted: "Preference", "Procedure") """ - global graphiti_client + global graphiti_service, config - if graphiti_client is None: - return ErrorResponse(error='Graphiti client not initialized') + if not graphiti_service or not graphiti_service.is_initialized(): + return ErrorResponse(error='Graphiti service not initialized') try: # Use the provided group_ids or fall back to the default from config if none provided @@ -868,11 +268,7 @@ async def search_memory_nodes( if entity != '': filters.node_labels = [entity] - # We've already checked that graphiti_client is not None above - assert graphiti_client is not None - - # Use cast to help the type checker understand that graphiti_client is not None - client = cast(Graphiti, graphiti_client) + client = graphiti_service.client # Perform the search using the _search method search_results = await client._search( @@ -1218,8 +614,11 @@ async def initialize_server() -> MCPConfig: else: logger.info('Entity extraction disabled (no custom entities will be used)') - # Initialize Graphiti - await initialize_graphiti() + # Initialize services + global graphiti_service, queue_service + graphiti_service = GraphitiService(config, SEMAPHORE_LIMIT) + queue_service = QueueService() + await graphiti_service.initialize() if args.host: logger.info(f'Setting MCP server host to: {args.host}') diff --git a/mcp_server/graphiti_service.py b/mcp_server/graphiti_service.py new file mode 100644 index 00000000..2ed34b60 --- /dev/null +++ b/mcp_server/graphiti_service.py @@ -0,0 +1,110 @@ +"""Graphiti service for managing client lifecycle and operations.""" + +import logging + +from config_manager import GraphitiConfig + +from graphiti_core import Graphiti +from graphiti_core.utils.maintenance.graph_data_operations import clear_data + +logger = logging.getLogger(__name__) + + +class GraphitiService: + """Service for managing Graphiti client operations.""" + + def __init__(self, config: GraphitiConfig, semaphore_limit: int = 10): + """Initialize the Graphiti service with configuration. + + Args: + config: The Graphiti configuration + semaphore_limit: Maximum concurrent operations + """ + self.config = config + self.semaphore_limit = semaphore_limit + self._client: Graphiti | None = None + + @property + def client(self) -> Graphiti: + """Get the Graphiti client instance. + + Raises: + RuntimeError: If the client hasn't been initialized + """ + if self._client is None: + raise RuntimeError('Graphiti client not initialized. Call initialize() first.') + return self._client + + async def initialize(self) -> None: + """Initialize the Graphiti client with the configured settings.""" + try: + # Create LLM client if possible + llm_client = self.config.llm.create_client() + if not llm_client and self.config.use_custom_entities: + # If custom entities are enabled, we must have an LLM client + raise ValueError('OPENAI_API_KEY must be set when custom entities are enabled') + + # Validate Neo4j configuration + if ( + not self.config.neo4j.uri + or not self.config.neo4j.user + or not self.config.neo4j.password + ): + raise ValueError('NEO4J_URI, NEO4J_USER, and NEO4J_PASSWORD must be set') + + embedder_client = self.config.embedder.create_client() + + # Initialize Graphiti client + self._client = Graphiti( + uri=self.config.neo4j.uri, + user=self.config.neo4j.user, + password=self.config.neo4j.password, + llm_client=llm_client, + embedder=embedder_client, + max_coroutines=self.semaphore_limit, + ) + + # Destroy graph if requested + if self.config.destroy_graph: + logger.info('Destroying graph...') + await clear_data(self._client.driver) + + # Initialize the graph database with Graphiti's indices + await self._client.build_indices_and_constraints() + logger.info('Graphiti client initialized successfully') + + # Log configuration details for transparency + if llm_client: + logger.info(f'Using OpenAI model: {self.config.llm.model}') + logger.info(f'Using temperature: {self.config.llm.temperature}') + else: + logger.info('No LLM client configured - entity extraction will be limited') + + logger.info(f'Using group_id: {self.config.group_id}') + logger.info( + f'Custom entity extraction: {"enabled" if self.config.use_custom_entities else "disabled"}' + ) + logger.info(f'Using concurrency limit: {self.semaphore_limit}') + + except Exception as e: + logger.error(f'Failed to initialize Graphiti: {str(e)}') + raise + + async def clear_graph(self) -> None: + """Clear all data from the graph and rebuild indices.""" + if self._client is None: + raise RuntimeError('Graphiti client not initialized') + + await clear_data(self._client.driver) + await self._client.build_indices_and_constraints() + + async def verify_connection(self) -> None: + """Verify the database connection.""" + if self._client is None: + raise RuntimeError('Graphiti client not initialized') + + await self._client.driver.client.verify_connectivity() # type: ignore + + def is_initialized(self) -> bool: + """Check if the client is initialized.""" + return self._client is not None diff --git a/mcp_server/llm_config.py b/mcp_server/llm_config.py new file mode 100644 index 00000000..687bed8f --- /dev/null +++ b/mcp_server/llm_config.py @@ -0,0 +1,182 @@ +"""LLM configuration for Graphiti MCP Server.""" + +import argparse +import logging +import os +from typing import TYPE_CHECKING + +from openai import AsyncAzureOpenAI +from pydantic import BaseModel +from utils import create_azure_credential_token_provider + +from graphiti_core.llm_client import LLMClient +from graphiti_core.llm_client.azure_openai_client import AzureOpenAILLMClient +from graphiti_core.llm_client.config import LLMConfig +from graphiti_core.llm_client.openai_client import OpenAIClient + +if TYPE_CHECKING: + pass + +logger = logging.getLogger(__name__) + +DEFAULT_LLM_MODEL = 'gpt-4.1-mini' +SMALL_LLM_MODEL = 'gpt-4.1-nano' + + +class GraphitiLLMConfig(BaseModel): + """Configuration for the LLM client. + + Centralizes all LLM-specific configuration parameters including API keys and model selection. + """ + + api_key: str | None = None + model: str = DEFAULT_LLM_MODEL + small_model: str = SMALL_LLM_MODEL + temperature: float = 0.0 + azure_openai_endpoint: str | None = None + azure_openai_deployment_name: str | None = None + azure_openai_api_version: str | None = None + azure_openai_use_managed_identity: bool = False + + @classmethod + def from_env(cls) -> 'GraphitiLLMConfig': + """Create LLM configuration from environment variables.""" + # Get model from environment, or use default if not set or empty + model_env = os.environ.get('MODEL_NAME', '') + model = model_env if model_env.strip() else DEFAULT_LLM_MODEL + + # Get small_model from environment, or use default if not set or empty + small_model_env = os.environ.get('SMALL_MODEL_NAME', '') + small_model = small_model_env if small_model_env.strip() else SMALL_LLM_MODEL + + azure_openai_endpoint = os.environ.get('AZURE_OPENAI_ENDPOINT', None) + azure_openai_api_version = os.environ.get('AZURE_OPENAI_API_VERSION', None) + azure_openai_deployment_name = os.environ.get('AZURE_OPENAI_DEPLOYMENT_NAME', None) + azure_openai_use_managed_identity = ( + os.environ.get('AZURE_OPENAI_USE_MANAGED_IDENTITY', 'false').lower() == 'true' + ) + + if azure_openai_endpoint is None: + # Setup for OpenAI API + # Log if empty model was provided + if model_env == '': + logger.debug( + f'MODEL_NAME environment variable not set, using default: {DEFAULT_LLM_MODEL}' + ) + elif not model_env.strip(): + logger.warning( + f'Empty MODEL_NAME environment variable, using default: {DEFAULT_LLM_MODEL}' + ) + + return cls( + api_key=os.environ.get('OPENAI_API_KEY'), + model=model, + small_model=small_model, + temperature=float(os.environ.get('LLM_TEMPERATURE', '0.0')), + ) + else: + # Setup for Azure OpenAI API + # Log if empty deployment name was provided + if azure_openai_deployment_name is None: + logger.error('AZURE_OPENAI_DEPLOYMENT_NAME environment variable not set') + raise ValueError('AZURE_OPENAI_DEPLOYMENT_NAME environment variable not set') + + if not azure_openai_use_managed_identity: + # api key + api_key = os.environ.get('OPENAI_API_KEY', None) + else: + # Managed identity + api_key = None + + return cls( + azure_openai_use_managed_identity=azure_openai_use_managed_identity, + azure_openai_endpoint=azure_openai_endpoint, + api_key=api_key, + azure_openai_api_version=azure_openai_api_version, + azure_openai_deployment_name=azure_openai_deployment_name, + model=model, + small_model=small_model, + temperature=float(os.environ.get('LLM_TEMPERATURE', '0.0')), + ) + + @classmethod + def from_cli_and_env(cls, args: argparse.Namespace) -> 'GraphitiLLMConfig': + """Create LLM configuration from CLI arguments, falling back to environment variables.""" + # Start with environment-based config + config = cls.from_env() + + # CLI arguments override environment variables when provided + if hasattr(args, 'model') and args.model: + # Only use CLI model if it's not empty + if args.model.strip(): + config.model = args.model + else: + # Log that empty model was provided and default is used + logger.warning(f'Empty model name provided, using default: {DEFAULT_LLM_MODEL}') + + if hasattr(args, 'small_model') and args.small_model: + if args.small_model.strip(): + config.small_model = args.small_model + else: + logger.warning(f'Empty small_model name provided, using default: {SMALL_LLM_MODEL}') + + if hasattr(args, 'temperature') and args.temperature is not None: + config.temperature = args.temperature + + return config + + def create_client(self) -> LLMClient: + """Create an LLM client based on this configuration. + + Returns: + LLMClient instance + """ + if self.azure_openai_endpoint is not None: + # Azure OpenAI API setup + if self.azure_openai_use_managed_identity: + # Use managed identity for authentication + token_provider = create_azure_credential_token_provider() + return AzureOpenAILLMClient( + azure_client=AsyncAzureOpenAI( + azure_endpoint=self.azure_openai_endpoint, + azure_deployment=self.azure_openai_deployment_name, + api_version=self.azure_openai_api_version, + azure_ad_token_provider=token_provider, + ), + config=LLMConfig( + api_key=self.api_key, + model=self.model, + small_model=self.small_model, + temperature=self.temperature, + ), + ) + elif self.api_key: + # Use API key for authentication + return AzureOpenAILLMClient( + azure_client=AsyncAzureOpenAI( + azure_endpoint=self.azure_openai_endpoint, + azure_deployment=self.azure_openai_deployment_name, + api_version=self.azure_openai_api_version, + api_key=self.api_key, + ), + config=LLMConfig( + api_key=self.api_key, + model=self.model, + small_model=self.small_model, + temperature=self.temperature, + ), + ) + else: + raise ValueError('OPENAI_API_KEY must be set when using Azure OpenAI API') + + if not self.api_key: + raise ValueError('OPENAI_API_KEY must be set when using OpenAI API') + + llm_client_config = LLMConfig( + api_key=self.api_key, model=self.model, small_model=self.small_model + ) + + # Set temperature + llm_client_config.temperature = self.temperature + + return OpenAIClient(config=llm_client_config) diff --git a/mcp_server/neo4j_config.py b/mcp_server/neo4j_config.py new file mode 100644 index 00000000..8d365f81 --- /dev/null +++ b/mcp_server/neo4j_config.py @@ -0,0 +1,22 @@ +"""Neo4j database configuration for Graphiti MCP Server.""" + +import os + +from pydantic import BaseModel + + +class Neo4jConfig(BaseModel): + """Configuration for Neo4j database connection.""" + + uri: str = 'bolt://localhost:7687' + user: str = 'neo4j' + password: str = 'password' + + @classmethod + def from_env(cls) -> 'Neo4jConfig': + """Create Neo4j configuration from environment variables.""" + return cls( + uri=os.environ.get('NEO4J_URI', 'bolt://localhost:7687'), + user=os.environ.get('NEO4J_USER', 'neo4j'), + password=os.environ.get('NEO4J_PASSWORD', 'password'), + ) diff --git a/mcp_server/pyproject.toml b/mcp_server/pyproject.toml index 8460c68a..5d3c255c 100644 --- a/mcp_server/pyproject.toml +++ b/mcp_server/pyproject.toml @@ -11,3 +11,9 @@ dependencies = [ "azure-identity>=1.21.0", "graphiti-core", ] + +[dependency-groups] +dev = [ + "httpx>=0.28.1", + "mcp>=1.9.4", +] diff --git a/mcp_server/queue_service.py b/mcp_server/queue_service.py new file mode 100644 index 00000000..c0d608f4 --- /dev/null +++ b/mcp_server/queue_service.py @@ -0,0 +1,86 @@ +"""Queue service for managing episode processing.""" + +import asyncio +import logging +from collections.abc import Awaitable, Callable + +logger = logging.getLogger(__name__) + + +class QueueService: + """Service for managing sequential episode processing queues by group_id.""" + + def __init__(self): + """Initialize the queue service.""" + # Dictionary to store queues for each group_id + self._episode_queues: dict[str, asyncio.Queue] = {} + # Dictionary to track if a worker is running for each group_id + self._queue_workers: dict[str, bool] = {} + + async def add_episode_task( + self, group_id: str, process_func: Callable[[], Awaitable[None]] + ) -> int: + """Add an episode processing task to the queue. + + Args: + group_id: The group ID for the episode + process_func: The async function to process the episode + + Returns: + The position in the queue + """ + # Initialize queue for this group_id if it doesn't exist + if group_id not in self._episode_queues: + self._episode_queues[group_id] = asyncio.Queue() + + # Add the episode processing function to the queue + await self._episode_queues[group_id].put(process_func) + + # Start a worker for this queue if one isn't already running + if not self._queue_workers.get(group_id, False): + asyncio.create_task(self._process_episode_queue(group_id)) + + return self._episode_queues[group_id].qsize() + + async def _process_episode_queue(self, group_id: str) -> None: + """Process episodes for a specific group_id sequentially. + + This function runs as a long-lived task that processes episodes + from the queue one at a time. + """ + logger.info(f'Starting episode queue worker for group_id: {group_id}') + self._queue_workers[group_id] = True + + try: + while True: + # Get the next episode processing function from the queue + # This will wait if the queue is empty + process_func = await self._episode_queues[group_id].get() + + try: + # Process the episode + await process_func() + except Exception as e: + logger.error( + f'Error processing queued episode for group_id {group_id}: {str(e)}' + ) + finally: + # Mark the task as done regardless of success/failure + self._episode_queues[group_id].task_done() + except asyncio.CancelledError: + logger.info(f'Episode queue worker for group_id {group_id} was cancelled') + except Exception as e: + logger.error(f'Unexpected error in queue worker for group_id {group_id}: {str(e)}') + finally: + self._queue_workers[group_id] = False + logger.info(f'Stopped episode queue worker for group_id: {group_id}') + + def get_queue_size(self, group_id: str) -> int: + """Get the current queue size for a group_id.""" + if group_id not in self._episode_queues: + return 0 + return self._episode_queues[group_id].qsize() + + def is_worker_running(self, group_id: str) -> bool: + """Check if a worker is running for a group_id.""" + return self._queue_workers.get(group_id, False) diff --git a/mcp_server/response_types.py b/mcp_server/response_types.py new file mode 100644 index 00000000..ac9a9844 --- /dev/null +++ b/mcp_server/response_types.py @@ -0,0 +1,41 @@ +"""Response type definitions for Graphiti MCP Server.""" + +from typing import Any, TypedDict + + +class ErrorResponse(TypedDict): + error: str + + +class SuccessResponse(TypedDict): + message: str + + +class NodeResult(TypedDict): + uuid: str + name: str + summary: str + labels: list[str] + group_id: str + created_at: str + attributes: dict[str, Any] + + +class NodeSearchResponse(TypedDict): + message: str + nodes: list[NodeResult] + + +class FactSearchResponse(TypedDict): + message: str + facts: list[dict[str, Any]] + + +class EpisodeSearchResponse(TypedDict): + message: str + episodes: list[dict[str, Any]] + + +class StatusResponse(TypedDict): + status: str + message: str diff --git a/mcp_server/server_config.py b/mcp_server/server_config.py new file mode 100644 index 00000000..e80673e2 --- /dev/null +++ b/mcp_server/server_config.py @@ -0,0 +1,16 @@ +"""Server configuration for Graphiti MCP Server.""" + +import argparse + +from pydantic import BaseModel + + +class MCPConfig(BaseModel): + """Configuration for MCP server.""" + + transport: str = 'sse' # Default to SSE transport + + @classmethod + def from_cli(cls, args: argparse.Namespace) -> 'MCPConfig': + """Create MCP configuration from CLI arguments.""" + return cls(transport=args.transport) diff --git a/mcp_server/test_integration.py b/mcp_server/test_integration.py new file mode 100644 index 00000000..adb3c835 --- /dev/null +++ b/mcp_server/test_integration.py @@ -0,0 +1,363 @@ +#!/usr/bin/env python3 +""" +Integration test for the refactored Graphiti MCP Server. +Tests all major MCP tools and handles episode processing latency. +""" + +import asyncio +import json +import time +from typing import Any + +import httpx + + +class MCPIntegrationTest: + """Integration test client for Graphiti MCP Server.""" + + def __init__(self, base_url: str = 'http://localhost:8000'): + self.base_url = base_url + self.client = httpx.AsyncClient(timeout=30.0) + self.test_group_id = f'test_group_{int(time.time())}' + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.client.aclose() + + async def call_mcp_tool(self, tool_name: str, arguments: dict[str, Any]) -> dict[str, Any]: + """Call an MCP tool via the SSE endpoint.""" + # MCP protocol message structure + message = { + 'jsonrpc': '2.0', + 'id': int(time.time() * 1000), + 'method': 'tools/call', + 'params': {'name': tool_name, 'arguments': arguments}, + } + + try: + response = await self.client.post( + f'{self.base_url}/message', + json=message, + headers={'Content-Type': 'application/json'}, + ) + + if response.status_code != 200: + return {'error': f'HTTP {response.status_code}: {response.text}'} + + result = response.json() + return result.get('result', result) + + except Exception as e: + return {'error': str(e)} + + async def test_server_status(self) -> bool: + """Test the get_status resource.""" + print('๐Ÿ” Testing server status...') + + try: + response = await self.client.get(f'{self.base_url}/resources/http://graphiti/status') + if response.status_code == 200: + status = response.json() + print(f' โœ… Server status: {status.get("status", "unknown")}') + return status.get('status') == 'ok' + else: + print(f' โŒ Status check failed: HTTP {response.status_code}') + return False + except Exception as e: + print(f' โŒ Status check failed: {e}') + return False + + async def test_add_memory(self) -> dict[str, str]: + """Test adding various types of memory episodes.""" + print('๐Ÿ“ Testing add_memory functionality...') + + episode_results = {} + + # Test 1: Add text episode + print(' Testing text episode...') + result = await self.call_mcp_tool( + 'add_memory', + { + 'name': 'Test Company News', + 'episode_body': 'Acme Corp announced a revolutionary new AI product that will transform the industry. The CEO mentioned this is their biggest launch since 2020.', + 'source': 'text', + 'source_description': 'news article', + 'group_id': self.test_group_id, + }, + ) + + if 'error' in result: + print(f' โŒ Text episode failed: {result["error"]}') + else: + print(f' โœ… Text episode queued: {result.get("message", "Success")}') + episode_results['text'] = 'success' + + # Test 2: Add JSON episode + print(' Testing JSON episode...') + json_data = { + 'company': {'name': 'TechCorp', 'founded': 2010}, + 'products': [ + {'id': 'P001', 'name': 'CloudSync', 'category': 'software'}, + {'id': 'P002', 'name': 'DataMiner', 'category': 'analytics'}, + ], + 'employees': 150, + } + + result = await self.call_mcp_tool( + 'add_memory', + { + 'name': 'Company Profile', + 'episode_body': json.dumps(json_data), + 'source': 'json', + 'source_description': 'CRM data', + 'group_id': self.test_group_id, + }, + ) + + if 'error' in result: + print(f' โŒ JSON episode failed: {result["error"]}') + else: + print(f' โœ… JSON episode queued: {result.get("message", "Success")}') + episode_results['json'] = 'success' + + # Test 3: Add message episode + print(' Testing message episode...') + result = await self.call_mcp_tool( + 'add_memory', + { + 'name': 'Customer Support Chat', + 'episode_body': "user: What's your return policy?\nassistant: You can return items within 30 days of purchase with receipt.\nuser: Thanks!", + 'source': 'message', + 'source_description': 'support chat log', + 'group_id': self.test_group_id, + }, + ) + + if 'error' in result: + print(f' โŒ Message episode failed: {result["error"]}') + else: + print(f' โœ… Message episode queued: {result.get("message", "Success")}') + episode_results['message'] = 'success' + + return episode_results + + async def wait_for_processing(self, max_wait: int = 30) -> None: + """Wait for episode processing to complete.""" + print(f'โณ Waiting up to {max_wait} seconds for episode processing...') + + for i in range(max_wait): + await asyncio.sleep(1) + + # Check if we have any episodes + result = await self.call_mcp_tool( + 'get_episodes', {'group_id': self.test_group_id, 'last_n': 10} + ) + + if not isinstance(result, dict) or 'error' in result: + continue + + if isinstance(result, list) and len(result) > 0: + print(f' โœ… Found {len(result)} processed episodes after {i + 1} seconds') + return + + print(f' โš ๏ธ Still waiting after {max_wait} seconds...') + + async def test_search_functions(self) -> dict[str, bool]: + """Test search functionality.""" + print('๐Ÿ” Testing search functions...') + + results = {} + + # Test search_memory_nodes + print(' Testing search_memory_nodes...') + result = await self.call_mcp_tool( + 'search_memory_nodes', + { + 'query': 'Acme Corp product launch', + 'group_ids': [self.test_group_id], + 'max_nodes': 5, + }, + ) + + if 'error' in result: + print(f' โŒ Node search failed: {result["error"]}') + results['nodes'] = False + else: + nodes = result.get('nodes', []) + print(f' โœ… Node search returned {len(nodes)} nodes') + results['nodes'] = True + + # Test search_memory_facts + print(' Testing search_memory_facts...') + result = await self.call_mcp_tool( + 'search_memory_facts', + { + 'query': 'company products software', + 'group_ids': [self.test_group_id], + 'max_facts': 5, + }, + ) + + if 'error' in result: + print(f' โŒ Fact search failed: {result["error"]}') + results['facts'] = False + else: + facts = result.get('facts', []) + print(f' โœ… Fact search returned {len(facts)} facts') + results['facts'] = True + + return results + + async def test_episode_retrieval(self) -> bool: + """Test episode retrieval.""" + print('๐Ÿ“š Testing episode retrieval...') + + result = await self.call_mcp_tool( + 'get_episodes', {'group_id': self.test_group_id, 'last_n': 10} + ) + + if 'error' in result: + print(f' โŒ Episode retrieval failed: {result["error"]}') + return False + + if isinstance(result, list): + print(f' โœ… Retrieved {len(result)} episodes') + + # Print episode details + for i, episode in enumerate(result[:3]): # Show first 3 + name = episode.get('name', 'Unknown') + source = episode.get('source', 'unknown') + print(f' Episode {i + 1}: {name} (source: {source})') + + return len(result) > 0 + else: + print(f' โŒ Unexpected result format: {type(result)}') + return False + + async def test_edge_cases(self) -> dict[str, bool]: + """Test edge cases and error handling.""" + print('๐Ÿงช Testing edge cases...') + + results = {} + + # Test with invalid group_id + print(' Testing invalid group_id...') + result = await self.call_mcp_tool( + 'search_memory_nodes', + {'query': 'nonexistent data', 'group_ids': ['nonexistent_group'], 'max_nodes': 5}, + ) + + # Should not error, just return empty results + if 'error' not in result: + nodes = result.get('nodes', []) + print(f' โœ… Invalid group_id handled gracefully (returned {len(nodes)} nodes)') + results['invalid_group'] = True + else: + print(f' โŒ Invalid group_id caused error: {result["error"]}') + results['invalid_group'] = False + + # Test empty query + print(' Testing empty query...') + result = await self.call_mcp_tool( + 'search_memory_nodes', {'query': '', 'group_ids': [self.test_group_id], 'max_nodes': 5} + ) + + if 'error' not in result: + print(' โœ… Empty query handled gracefully') + results['empty_query'] = True + else: + print(f' โŒ Empty query caused error: {result["error"]}') + results['empty_query'] = False + + return results + + async def run_full_test_suite(self) -> dict[str, Any]: + """Run the complete integration test suite.""" + print('๐Ÿš€ Starting Graphiti MCP Server Integration Test') + print(f' Test group ID: {self.test_group_id}') + print('=' * 60) + + results = { + 'server_status': False, + 'add_memory': {}, + 'search': {}, + 'episodes': False, + 'edge_cases': {}, + 'overall_success': False, + } + + # Test 1: Server Status + results['server_status'] = await self.test_server_status() + if not results['server_status']: + print('โŒ Server not responding, aborting tests') + return results + + print() + + # Test 2: Add Memory + results['add_memory'] = await self.test_add_memory() + print() + + # Test 3: Wait for processing + await self.wait_for_processing() + print() + + # Test 4: Search Functions + results['search'] = await self.test_search_functions() + print() + + # Test 5: Episode Retrieval + results['episodes'] = await self.test_episode_retrieval() + print() + + # Test 6: Edge Cases + results['edge_cases'] = await self.test_edge_cases() + print() + + # Calculate overall success + memory_success = len(results['add_memory']) > 0 + search_success = any(results['search'].values()) + edge_case_success = any(results['edge_cases'].values()) + + results['overall_success'] = ( + results['server_status'] + and memory_success + and results['episodes'] + and (search_success or edge_case_success) # At least some functionality working + ) + + # Print summary + print('=' * 60) + print('๐Ÿ“Š TEST SUMMARY') + print(f' Server Status: {"โœ…" if results["server_status"] else "โŒ"}') + print( + f' Memory Operations: {"โœ…" if memory_success else "โŒ"} ({len(results["add_memory"])} types)' + ) + print(f' Search Functions: {"โœ…" if search_success else "โŒ"}') + print(f' Episode Retrieval: {"โœ…" if results["episodes"] else "โŒ"}') + print(f' Edge Cases: {"โœ…" if edge_case_success else "โŒ"}') + print() + print(f'๐ŸŽฏ OVERALL: {"โœ… SUCCESS" if results["overall_success"] else "โŒ FAILED"}') + + if results['overall_success']: + print(' The refactored MCP server is working correctly!') + else: + print(' Some issues detected. Check individual test results above.') + + return results + + +async def main(): + """Run the integration test.""" + async with MCPIntegrationTest() as test: + results = await test.run_full_test_suite() + + # Exit with appropriate code + exit_code = 0 if results['overall_success'] else 1 + exit(exit_code) + + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/mcp_server/test_mcp_integration.py b/mcp_server/test_mcp_integration.py new file mode 100644 index 00000000..6ce446f3 --- /dev/null +++ b/mcp_server/test_mcp_integration.py @@ -0,0 +1,502 @@ +#!/usr/bin/env python3 +""" +Integration test for the refactored Graphiti MCP Server using the official MCP Python SDK. +Tests all major MCP tools and handles episode processing latency. +""" + +import asyncio +import json +import time +from typing import Any + +from mcp import ClientSession, StdioServerParameters +from mcp.client.stdio import stdio_client + + +class GraphitiMCPIntegrationTest: + """Integration test client for Graphiti MCP Server using official MCP SDK.""" + + def __init__(self): + self.test_group_id = f'test_group_{int(time.time())}' + self.session = None + + async def __aenter__(self): + """Start the MCP client session.""" + # Configure server parameters to run our refactored server + server_params = StdioServerParameters( + command='uv', + args=['run', 'graphiti_mcp_server.py', '--transport', 'stdio'], + env={ + 'NEO4J_URI': 'bolt://localhost:7687', + 'NEO4J_USER': 'neo4j', + 'NEO4J_PASSWORD': 'demodemo', + 'OPENAI_API_KEY': 'dummy_key_for_testing', # Will use existing .env + }, + ) + + print(f'๐Ÿš€ Starting MCP client session with test group: {self.test_group_id}') + + # Use the async context manager properly + self.client_context = stdio_client(server_params) + read, write = await self.client_context.__aenter__() + self.session = ClientSession(read, write) + await self.session.initialize() + + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Close the MCP client session.""" + if self.session: + await self.session.close() + if hasattr(self, 'client_context'): + await self.client_context.__aexit__(exc_type, exc_val, exc_tb) + + async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any: + """Call an MCP tool and return the result.""" + try: + result = await self.session.call_tool(tool_name, arguments) + return result.content[0].text if result.content else {'error': 'No content returned'} + except Exception as e: + return {'error': str(e)} + + async def test_server_initialization(self) -> bool: + """Test that the server initializes properly.""" + print('๐Ÿ” Testing server initialization...') + + try: + # List available tools to verify server is responding + tools_result = await self.session.list_tools() + tools = [tool.name for tool in tools_result.tools] + + expected_tools = [ + 'add_memory', + 'search_memory_nodes', + 'search_memory_facts', + 'get_episodes', + 'delete_episode', + 'delete_entity_edge', + 'get_entity_edge', + 'clear_graph', + ] + + available_tools = len([tool for tool in expected_tools if tool in tools]) + print( + f' โœ… Server responding with {len(tools)} tools ({available_tools}/{len(expected_tools)} expected)' + ) + print(f' Available tools: {", ".join(sorted(tools))}') + + return available_tools >= len(expected_tools) * 0.8 # 80% of expected tools + + except Exception as e: + print(f' โŒ Server initialization failed: {e}') + return False + + async def test_add_memory_operations(self) -> dict[str, bool]: + """Test adding various types of memory episodes.""" + print('๐Ÿ“ Testing add_memory operations...') + + results = {} + + # Test 1: Add text episode + print(' Testing text episode...') + try: + result = await self.call_tool( + 'add_memory', + { + 'name': 'Test Company News', + 'episode_body': 'Acme Corp announced a revolutionary new AI product that will transform the industry. The CEO mentioned this is their biggest launch since 2020.', + 'source': 'text', + 'source_description': 'news article', + 'group_id': self.test_group_id, + }, + ) + + if isinstance(result, str) and 'queued' in result.lower(): + print(f' โœ… Text episode: {result}') + results['text'] = True + else: + print(f' โŒ Text episode failed: {result}') + results['text'] = False + except Exception as e: + print(f' โŒ Text episode error: {e}') + results['text'] = False + + # Test 2: Add JSON episode + print(' Testing JSON episode...') + try: + json_data = { + 'company': {'name': 'TechCorp', 'founded': 2010}, + 'products': [ + {'id': 'P001', 'name': 'CloudSync', 'category': 'software'}, + {'id': 'P002', 'name': 'DataMiner', 'category': 'analytics'}, + ], + 'employees': 150, + } + + result = await self.call_tool( + 'add_memory', + { + 'name': 'Company Profile', + 'episode_body': json.dumps(json_data), + 'source': 'json', + 'source_description': 'CRM data', + 'group_id': self.test_group_id, + }, + ) + + if isinstance(result, str) and 'queued' in result.lower(): + print(f' โœ… JSON episode: {result}') + results['json'] = True + else: + print(f' โŒ JSON episode failed: {result}') + results['json'] = False + except Exception as e: + print(f' โŒ JSON episode error: {e}') + results['json'] = False + + # Test 3: Add message episode + print(' Testing message episode...') + try: + result = await self.call_tool( + 'add_memory', + { + 'name': 'Customer Support Chat', + 'episode_body': "user: What's your return policy?\nassistant: You can return items within 30 days of purchase with receipt.\nuser: Thanks!", + 'source': 'message', + 'source_description': 'support chat log', + 'group_id': self.test_group_id, + }, + ) + + if isinstance(result, str) and 'queued' in result.lower(): + print(f' โœ… Message episode: {result}') + results['message'] = True + else: + print(f' โŒ Message episode failed: {result}') + results['message'] = False + except Exception as e: + print(f' โŒ Message episode error: {e}') + results['message'] = False + + return results + + async def wait_for_processing(self, max_wait: int = 45) -> bool: + """Wait for episode processing to complete.""" + print(f'โณ Waiting up to {max_wait} seconds for episode processing...') + + for i in range(max_wait): + await asyncio.sleep(1) + + try: + # Check if we have any episodes + result = await self.call_tool( + 'get_episodes', {'group_id': self.test_group_id, 'last_n': 10} + ) + + # Parse the JSON result if it's a string + if isinstance(result, str): + try: + parsed_result = json.loads(result) + if isinstance(parsed_result, list) and len(parsed_result) > 0: + print( + f' โœ… Found {len(parsed_result)} processed episodes after {i + 1} seconds' + ) + return True + except json.JSONDecodeError: + if 'episodes' in result.lower(): + print(f' โœ… Episodes detected after {i + 1} seconds') + return True + + except Exception as e: + if i == 0: # Only log first error to avoid spam + print(f' โš ๏ธ Waiting for processing... ({e})') + continue + + print(f' โš ๏ธ Still waiting after {max_wait} seconds...') + return False + + async def test_search_operations(self) -> dict[str, bool]: + """Test search functionality.""" + print('๐Ÿ” Testing search operations...') + + results = {} + + # Test search_memory_nodes + print(' Testing search_memory_nodes...') + try: + result = await self.call_tool( + 'search_memory_nodes', + { + 'query': 'Acme Corp product launch AI', + 'group_ids': [self.test_group_id], + 'max_nodes': 5, + }, + ) + + success = False + if isinstance(result, str): + try: + parsed = json.loads(result) + nodes = parsed.get('nodes', []) + success = isinstance(nodes, list) + print(f' โœ… Node search returned {len(nodes)} nodes') + except json.JSONDecodeError: + success = 'nodes' in result.lower() and 'successfully' in result.lower() + if success: + print(' โœ… Node search completed successfully') + + results['nodes'] = success + if not success: + print(f' โŒ Node search failed: {result}') + + except Exception as e: + print(f' โŒ Node search error: {e}') + results['nodes'] = False + + # Test search_memory_facts + print(' Testing search_memory_facts...') + try: + result = await self.call_tool( + 'search_memory_facts', + { + 'query': 'company products software TechCorp', + 'group_ids': [self.test_group_id], + 'max_facts': 5, + }, + ) + + success = False + if isinstance(result, str): + try: + parsed = json.loads(result) + facts = parsed.get('facts', []) + success = isinstance(facts, list) + print(f' โœ… Fact search returned {len(facts)} facts') + except json.JSONDecodeError: + success = 'facts' in result.lower() and 'successfully' in result.lower() + if success: + print(' โœ… Fact search completed successfully') + + results['facts'] = success + if not success: + print(f' โŒ Fact search failed: {result}') + + except Exception as e: + print(f' โŒ Fact search error: {e}') + results['facts'] = False + + return results + + async def test_episode_retrieval(self) -> bool: + """Test episode retrieval.""" + print('๐Ÿ“š Testing episode retrieval...') + + try: + result = await self.call_tool( + 'get_episodes', {'group_id': self.test_group_id, 'last_n': 10} + ) + + if isinstance(result, str): + try: + parsed = json.loads(result) + if isinstance(parsed, list): + print(f' โœ… Retrieved {len(parsed)} episodes') + + # Show episode details + for i, episode in enumerate(parsed[:3]): + name = episode.get('name', 'Unknown') + source = episode.get('source', 'unknown') + print(f' Episode {i + 1}: {name} (source: {source})') + + return len(parsed) > 0 + except json.JSONDecodeError: + # Check if response indicates success + if 'episode' in result.lower(): + print(' โœ… Episode retrieval completed') + return True + + print(f' โŒ Unexpected result format: {result}') + return False + + except Exception as e: + print(f' โŒ Episode retrieval failed: {e}') + return False + + async def test_error_handling(self) -> dict[str, bool]: + """Test error handling and edge cases.""" + print('๐Ÿงช Testing error handling...') + + results = {} + + # Test with nonexistent group + print(' Testing nonexistent group handling...') + try: + result = await self.call_tool( + 'search_memory_nodes', + { + 'query': 'nonexistent data', + 'group_ids': ['nonexistent_group_12345'], + 'max_nodes': 5, + }, + ) + + # Should handle gracefully, not crash + success = ( + 'error' not in str(result).lower() or 'not initialized' not in str(result).lower() + ) + if success: + print(' โœ… Nonexistent group handled gracefully') + else: + print(f' โŒ Nonexistent group caused issues: {result}') + + results['nonexistent_group'] = success + + except Exception as e: + print(f' โŒ Nonexistent group test failed: {e}') + results['nonexistent_group'] = False + + # Test empty query + print(' Testing empty query handling...') + try: + result = await self.call_tool( + 'search_memory_nodes', + {'query': '', 'group_ids': [self.test_group_id], 'max_nodes': 5}, + ) + + # Should handle gracefully + success = ( + 'error' not in str(result).lower() or 'not initialized' not in str(result).lower() + ) + if success: + print(' โœ… Empty query handled gracefully') + else: + print(f' โŒ Empty query caused issues: {result}') + + results['empty_query'] = success + + except Exception as e: + print(f' โŒ Empty query test failed: {e}') + results['empty_query'] = False + + return results + + async def run_comprehensive_test(self) -> dict[str, Any]: + """Run the complete integration test suite.""" + print('๐Ÿš€ Starting Comprehensive Graphiti MCP Server Integration Test') + print(f' Test group ID: {self.test_group_id}') + print('=' * 70) + + results = { + 'server_init': False, + 'add_memory': {}, + 'processing_wait': False, + 'search': {}, + 'episodes': False, + 'error_handling': {}, + 'overall_success': False, + } + + # Test 1: Server Initialization + results['server_init'] = await self.test_server_initialization() + if not results['server_init']: + print('โŒ Server initialization failed, aborting remaining tests') + return results + + print() + + # Test 2: Add Memory Operations + results['add_memory'] = await self.test_add_memory_operations() + print() + + # Test 3: Wait for Processing + results['processing_wait'] = await self.wait_for_processing() + print() + + # Test 4: Search Operations + results['search'] = await self.test_search_operations() + print() + + # Test 5: Episode Retrieval + results['episodes'] = await self.test_episode_retrieval() + print() + + # Test 6: Error Handling + results['error_handling'] = await self.test_error_handling() + print() + + # Calculate overall success + memory_success = any(results['add_memory'].values()) + search_success = any(results['search'].values()) if results['search'] else False + error_success = ( + any(results['error_handling'].values()) if results['error_handling'] else True + ) + + results['overall_success'] = ( + results['server_init'] + and memory_success + and (results['episodes'] or results['processing_wait']) + and error_success + ) + + # Print comprehensive summary + print('=' * 70) + print('๐Ÿ“Š COMPREHENSIVE TEST SUMMARY') + print('-' * 35) + print(f'Server Initialization: {"โœ… PASS" if results["server_init"] else "โŒ FAIL"}') + + memory_stats = f'({sum(results["add_memory"].values())}/{len(results["add_memory"])} types)' + print( + f'Memory Operations: {"โœ… PASS" if memory_success else "โŒ FAIL"} {memory_stats}' + ) + + print(f'Processing Pipeline: {"โœ… PASS" if results["processing_wait"] else "โŒ FAIL"}') + + search_stats = ( + f'({sum(results["search"].values())}/{len(results["search"])} types)' + if results['search'] + else '(0/0 types)' + ) + print( + f'Search Operations: {"โœ… PASS" if search_success else "โŒ FAIL"} {search_stats}' + ) + + print(f'Episode Retrieval: {"โœ… PASS" if results["episodes"] else "โŒ FAIL"}') + + error_stats = ( + f'({sum(results["error_handling"].values())}/{len(results["error_handling"])} cases)' + if results['error_handling'] + else '(0/0 cases)' + ) + print( + f'Error Handling: {"โœ… PASS" if error_success else "โŒ FAIL"} {error_stats}' + ) + + print('-' * 35) + print(f'๐ŸŽฏ OVERALL RESULT: {"โœ… SUCCESS" if results["overall_success"] else "โŒ FAILED"}') + + if results['overall_success']: + print('\n๐ŸŽ‰ The refactored Graphiti MCP server is working correctly!') + print(' All core functionality has been successfully tested.') + else: + print('\nโš ๏ธ Some issues were detected. Review the test results above.') + print(' The refactoring may need additional attention.') + + return results + + +async def main(): + """Run the integration test.""" + try: + async with GraphitiMCPIntegrationTest() as test: + results = await test.run_comprehensive_test() + + # Exit with appropriate code + exit_code = 0 if results['overall_success'] else 1 + exit(exit_code) + except Exception as e: + print(f'โŒ Test setup failed: {e}') + exit(1) + + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/mcp_server/test_simple_validation.py b/mcp_server/test_simple_validation.py new file mode 100644 index 00000000..7cdac4b1 --- /dev/null +++ b/mcp_server/test_simple_validation.py @@ -0,0 +1,199 @@ +#!/usr/bin/env python3 +""" +Simple validation test for the refactored Graphiti MCP Server. +Tests basic functionality quickly without timeouts. +""" + +import subprocess +import sys +import time + + +def test_server_startup(): + """Test that the refactored server starts up successfully.""" + print('๐Ÿš€ Testing Graphiti MCP Server Startup...') + + try: + # Start the server and capture output + process = subprocess.Popen( + ['uv', 'run', 'graphiti_mcp_server.py', '--transport', 'stdio'], + env={ + 'NEO4J_URI': 'bolt://localhost:7687', + 'NEO4J_USER': 'neo4j', + 'NEO4J_PASSWORD': 'demodemo', + }, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + + # Wait for startup logs + startup_output = '' + for _ in range(50): # Wait up to 5 seconds + if process.poll() is not None: + break + time.sleep(0.1) + + # Check if we have output + try: + line = process.stderr.readline() + if line: + startup_output += line + print(f' ๐Ÿ“‹ {line.strip()}') + + # Check for success indicators + if 'Graphiti client initialized successfully' in line: + print(' โœ… Graphiti service initialization: SUCCESS') + success = True + break + + except Exception: + continue + else: + print(' โš ๏ธ Timeout waiting for initialization') + success = False + + # Clean shutdown + process.terminate() + try: + process.wait(timeout=5) + except subprocess.TimeoutExpired: + process.kill() + + return success + + except Exception as e: + print(f' โŒ Server startup failed: {e}') + return False + + +def test_import_validation(): + """Test that all refactored modules import correctly.""" + print('\n๐Ÿ” Testing Module Import Validation...') + + modules_to_test = [ + 'config_manager', + 'llm_config', + 'embedder_config', + 'neo4j_config', + 'server_config', + 'graphiti_service', + 'queue_service', + 'entity_types', + 'response_types', + 'formatting', + 'utils', + ] + + success_count = 0 + + for module in modules_to_test: + try: + result = subprocess.run( + ['python', '-c', f"import {module}; print('โœ… {module}')"], + capture_output=True, + text=True, + timeout=10, + ) + + if result.returncode == 0: + print(f' โœ… {module}: Import successful') + success_count += 1 + else: + print(f' โŒ {module}: Import failed - {result.stderr.strip()}') + + except subprocess.TimeoutExpired: + print(f' โŒ {module}: Import timeout') + except Exception as e: + print(f' โŒ {module}: Import error - {e}') + + print(f' ๐Ÿ“Š Import Results: {success_count}/{len(modules_to_test)} modules successful') + return success_count == len(modules_to_test) + + +def test_syntax_validation(): + """Test that all Python files have valid syntax.""" + print('\n๐Ÿ”ง Testing Syntax Validation...') + + files_to_test = [ + 'graphiti_mcp_server.py', + 'config_manager.py', + 'llm_config.py', + 'embedder_config.py', + 'neo4j_config.py', + 'server_config.py', + 'graphiti_service.py', + 'queue_service.py', + 'entity_types.py', + 'response_types.py', + 'formatting.py', + 'utils.py', + ] + + success_count = 0 + + for file in files_to_test: + try: + result = subprocess.run( + ['python', '-m', 'py_compile', file], capture_output=True, text=True, timeout=10 + ) + + if result.returncode == 0: + print(f' โœ… {file}: Syntax valid') + success_count += 1 + else: + print(f' โŒ {file}: Syntax error - {result.stderr.strip()}') + + except subprocess.TimeoutExpired: + print(f' โŒ {file}: Syntax check timeout') + except Exception as e: + print(f' โŒ {file}: Syntax check error - {e}') + + print(f' ๐Ÿ“Š Syntax Results: {success_count}/{len(files_to_test)} files valid') + return success_count == len(files_to_test) + + +def main(): + """Run the validation tests.""" + print('๐Ÿงช Graphiti MCP Server Refactoring Validation') + print('=' * 55) + + results = {} + + # Test 1: Syntax validation + results['syntax'] = test_syntax_validation() + + # Test 2: Import validation + results['imports'] = test_import_validation() + + # Test 3: Server startup + results['startup'] = test_server_startup() + + # Summary + print('\n' + '=' * 55) + print('๐Ÿ“Š VALIDATION SUMMARY') + print('-' * 25) + print(f'Syntax Validation: {"โœ… PASS" if results["syntax"] else "โŒ FAIL"}') + print(f'Import Validation: {"โœ… PASS" if results["imports"] else "โŒ FAIL"}') + print(f'Startup Validation: {"โœ… PASS" if results["startup"] else "โŒ FAIL"}') + + overall_success = all(results.values()) + print('-' * 25) + print(f'๐ŸŽฏ OVERALL: {"โœ… SUCCESS" if overall_success else "โŒ FAILED"}') + + if overall_success: + print('\n๐ŸŽ‰ Refactoring validation successful!') + print(' โœ… All modules have valid syntax') + print(' โœ… All imports work correctly') + print(' โœ… Server initializes successfully') + print(' โœ… The refactored MCP server is ready for use!') + else: + print('\nโš ๏ธ Some validation issues detected.') + print(' Please review the failed tests above.') + + return 0 if overall_success else 1 + + +if __name__ == '__main__': + exit_code = main() + sys.exit(exit_code) diff --git a/mcp_server/utils.py b/mcp_server/utils.py new file mode 100644 index 00000000..31c3afb3 --- /dev/null +++ b/mcp_server/utils.py @@ -0,0 +1,14 @@ +"""Utility functions for Graphiti MCP Server.""" + +from collections.abc import Callable + +from azure.identity import DefaultAzureCredential, get_bearer_token_provider + + +def create_azure_credential_token_provider() -> Callable[[], str]: + """Create Azure credential token provider for managed identity authentication.""" + credential = DefaultAzureCredential() + token_provider = get_bearer_token_provider( + credential, 'https://cognitiveservices.azure.com/.default' + ) + return token_provider diff --git a/mcp_server/uv.lock b/mcp_server/uv.lock index d4f4d4c7..8d3c2295 100644 --- a/mcp_server/uv.lock +++ b/mcp_server/uv.lock @@ -457,7 +457,7 @@ wheels = [ [[package]] name = "mcp-server" -version = "0.1.0" +version = "0.2.1" source = { virtual = "." } dependencies = [ { name = "azure-identity" }, @@ -466,6 +466,12 @@ dependencies = [ { name = "openai" }, ] +[package.dev-dependencies] +dev = [ + { name = "httpx" }, + { name = "mcp" }, +] + [package.metadata] requires-dist = [ { name = "azure-identity", specifier = ">=1.21.0" }, @@ -475,6 +481,12 @@ requires-dist = [ { name = "openai", specifier = ">=1.68.2" }, ] +[package.metadata.requires-dev] +dev = [ + { name = "httpx", specifier = ">=0.28.1" }, + { name = "mcp", specifier = ">=1.9.4" }, +] + [[package]] name = "msal" version = "1.32.3"