wip
This commit is contained in:
parent
119a43b8e4
commit
452a45cb4e
18 changed files with 1877 additions and 641 deletions
|
|
@ -34,7 +34,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
|||
uv sync --frozen --no-dev
|
||||
|
||||
# Copy application code
|
||||
COPY graphiti_mcp_server.py ./
|
||||
COPY *.py ./
|
||||
|
||||
# Change ownership to app user
|
||||
RUN chown -Rv app:app /app
|
||||
|
|
|
|||
51
mcp_server/config_manager.py
Normal file
51
mcp_server/config_manager.py
Normal file
|
|
@ -0,0 +1,51 @@
|
|||
"""Unified configuration manager for Graphiti MCP Server."""
|
||||
|
||||
import argparse
|
||||
|
||||
from embedder_config import GraphitiEmbedderConfig
|
||||
from llm_config import GraphitiLLMConfig
|
||||
from neo4j_config import Neo4jConfig
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class GraphitiConfig(BaseModel):
|
||||
"""Configuration for Graphiti client.
|
||||
|
||||
Centralizes all configuration parameters for the Graphiti client.
|
||||
"""
|
||||
|
||||
llm: GraphitiLLMConfig = Field(default_factory=GraphitiLLMConfig)
|
||||
embedder: GraphitiEmbedderConfig = Field(default_factory=GraphitiEmbedderConfig)
|
||||
neo4j: Neo4jConfig = Field(default_factory=Neo4jConfig)
|
||||
group_id: str | None = None
|
||||
use_custom_entities: bool = False
|
||||
destroy_graph: bool = False
|
||||
|
||||
@classmethod
|
||||
def from_env(cls) -> 'GraphitiConfig':
|
||||
"""Create a configuration instance from environment variables."""
|
||||
return cls(
|
||||
llm=GraphitiLLMConfig.from_env(),
|
||||
embedder=GraphitiEmbedderConfig.from_env(),
|
||||
neo4j=Neo4jConfig.from_env(),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_and_env(cls, args: argparse.Namespace) -> 'GraphitiConfig':
|
||||
"""Create configuration from CLI arguments, falling back to environment variables."""
|
||||
# Start with environment configuration
|
||||
config = cls.from_env()
|
||||
|
||||
# Apply CLI overrides
|
||||
if args.group_id:
|
||||
config.group_id = args.group_id
|
||||
else:
|
||||
config.group_id = 'default'
|
||||
|
||||
config.use_custom_entities = args.use_custom_entities
|
||||
config.destroy_graph = args.destroy_graph
|
||||
|
||||
# Update LLM config using CLI args
|
||||
config.llm = GraphitiLLMConfig.from_cli_and_env(args)
|
||||
|
||||
return config
|
||||
124
mcp_server/embedder_config.py
Normal file
124
mcp_server/embedder_config.py
Normal file
|
|
@ -0,0 +1,124 @@
|
|||
"""Embedder configuration for Graphiti MCP Server."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
from openai import AsyncAzureOpenAI
|
||||
from pydantic import BaseModel
|
||||
from utils import create_azure_credential_token_provider
|
||||
|
||||
from graphiti_core.embedder.azure_openai import AzureOpenAIEmbedderClient
|
||||
from graphiti_core.embedder.client import EmbedderClient
|
||||
from graphiti_core.embedder.openai import OpenAIEmbedder, OpenAIEmbedderConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_EMBEDDER_MODEL = 'text-embedding-3-small'
|
||||
|
||||
|
||||
class GraphitiEmbedderConfig(BaseModel):
|
||||
"""Configuration for the embedder client.
|
||||
|
||||
Centralizes all embedding-related configuration parameters.
|
||||
"""
|
||||
|
||||
model: str = DEFAULT_EMBEDDER_MODEL
|
||||
api_key: str | None = None
|
||||
azure_openai_endpoint: str | None = None
|
||||
azure_openai_deployment_name: str | None = None
|
||||
azure_openai_api_version: str | None = None
|
||||
azure_openai_use_managed_identity: bool = False
|
||||
|
||||
@classmethod
|
||||
def from_env(cls) -> 'GraphitiEmbedderConfig':
|
||||
"""Create embedder configuration from environment variables."""
|
||||
# Get model from environment, or use default if not set or empty
|
||||
model_env = os.environ.get('EMBEDDER_MODEL_NAME', '')
|
||||
model = model_env if model_env.strip() else DEFAULT_EMBEDDER_MODEL
|
||||
|
||||
azure_openai_endpoint = os.environ.get('AZURE_OPENAI_EMBEDDING_ENDPOINT', None)
|
||||
azure_openai_api_version = os.environ.get('AZURE_OPENAI_EMBEDDING_API_VERSION', None)
|
||||
azure_openai_deployment_name = os.environ.get(
|
||||
'AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME', None
|
||||
)
|
||||
azure_openai_use_managed_identity = (
|
||||
os.environ.get('AZURE_OPENAI_USE_MANAGED_IDENTITY', 'false').lower() == 'true'
|
||||
)
|
||||
|
||||
if azure_openai_endpoint is not None:
|
||||
# Setup for Azure OpenAI API
|
||||
# Log if empty deployment name was provided
|
||||
azure_openai_deployment_name = os.environ.get(
|
||||
'AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME', None
|
||||
)
|
||||
if azure_openai_deployment_name is None:
|
||||
logger.error('AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME environment variable not set')
|
||||
raise ValueError(
|
||||
'AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME environment variable not set'
|
||||
)
|
||||
|
||||
if not azure_openai_use_managed_identity:
|
||||
# api key
|
||||
api_key = os.environ.get('AZURE_OPENAI_EMBEDDING_API_KEY', None) or os.environ.get(
|
||||
'OPENAI_API_KEY', None
|
||||
)
|
||||
else:
|
||||
# Managed identity
|
||||
api_key = None
|
||||
|
||||
return cls(
|
||||
azure_openai_use_managed_identity=azure_openai_use_managed_identity,
|
||||
azure_openai_endpoint=azure_openai_endpoint,
|
||||
api_key=api_key,
|
||||
azure_openai_api_version=azure_openai_api_version,
|
||||
azure_openai_deployment_name=azure_openai_deployment_name,
|
||||
model=model,
|
||||
)
|
||||
else:
|
||||
return cls(
|
||||
model=model,
|
||||
api_key=os.environ.get('OPENAI_API_KEY'),
|
||||
)
|
||||
|
||||
def create_client(self) -> EmbedderClient | None:
|
||||
"""Create an embedder client based on this configuration.
|
||||
|
||||
Returns:
|
||||
EmbedderClient instance or None if configuration is invalid
|
||||
"""
|
||||
if self.azure_openai_endpoint is not None:
|
||||
# Azure OpenAI API setup
|
||||
if self.azure_openai_use_managed_identity:
|
||||
# Use managed identity for authentication
|
||||
token_provider = create_azure_credential_token_provider()
|
||||
return AzureOpenAIEmbedderClient(
|
||||
azure_client=AsyncAzureOpenAI(
|
||||
azure_endpoint=self.azure_openai_endpoint,
|
||||
azure_deployment=self.azure_openai_deployment_name,
|
||||
api_version=self.azure_openai_api_version,
|
||||
azure_ad_token_provider=token_provider,
|
||||
),
|
||||
model=self.model,
|
||||
)
|
||||
elif self.api_key:
|
||||
# Use API key for authentication
|
||||
return AzureOpenAIEmbedderClient(
|
||||
azure_client=AsyncAzureOpenAI(
|
||||
azure_endpoint=self.azure_openai_endpoint,
|
||||
azure_deployment=self.azure_openai_deployment_name,
|
||||
api_version=self.azure_openai_api_version,
|
||||
api_key=self.api_key,
|
||||
),
|
||||
model=self.model,
|
||||
)
|
||||
else:
|
||||
logger.error('OPENAI_API_KEY must be set when using Azure OpenAI API')
|
||||
return None
|
||||
else:
|
||||
# OpenAI API setup
|
||||
if not self.api_key:
|
||||
return None
|
||||
|
||||
embedder_config = OpenAIEmbedderConfig(api_key=self.api_key, embedding_model=self.model)
|
||||
|
||||
return OpenAIEmbedder(config=embedder_config)
|
||||
83
mcp_server/entity_types.py
Normal file
83
mcp_server/entity_types.py
Normal file
|
|
@ -0,0 +1,83 @@
|
|||
"""Entity type definitions for Graphiti MCP Server."""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Requirement(BaseModel):
|
||||
"""A Requirement represents a specific need, feature, or functionality that a product or service must fulfill.
|
||||
|
||||
Always ensure an edge is created between the requirement and the project it belongs to, and clearly indicate on the
|
||||
edge that the requirement is a requirement.
|
||||
|
||||
Instructions for identifying and extracting requirements:
|
||||
1. Look for explicit statements of needs or necessities ("We need X", "X is required", "X must have Y")
|
||||
2. Identify functional specifications that describe what the system should do
|
||||
3. Pay attention to non-functional requirements like performance, security, or usability criteria
|
||||
4. Extract constraints or limitations that must be adhered to
|
||||
5. Focus on clear, specific, and measurable requirements rather than vague wishes
|
||||
6. Capture the priority or importance if mentioned ("critical", "high priority", etc.)
|
||||
7. Include any dependencies between requirements when explicitly stated
|
||||
8. Preserve the original intent and scope of the requirement
|
||||
9. Categorize requirements appropriately based on their domain or function
|
||||
"""
|
||||
|
||||
project_name: str = Field(
|
||||
...,
|
||||
description='The name of the project to which the requirement belongs.',
|
||||
)
|
||||
description: str = Field(
|
||||
...,
|
||||
description='Description of the requirement. Only use information mentioned in the context to write this description.',
|
||||
)
|
||||
|
||||
|
||||
class Preference(BaseModel):
|
||||
"""A Preference represents a user's expressed like, dislike, or preference for something.
|
||||
|
||||
Instructions for identifying and extracting preferences:
|
||||
1. Look for explicit statements of preference such as "I like/love/enjoy/prefer X" or "I don't like/hate/dislike X"
|
||||
2. Pay attention to comparative statements ("I prefer X over Y")
|
||||
3. Consider the emotional tone when users mention certain topics
|
||||
4. Extract only preferences that are clearly expressed, not assumptions
|
||||
5. Categorize the preference appropriately based on its domain (food, music, brands, etc.)
|
||||
6. Include relevant qualifiers (e.g., "likes spicy food" rather than just "likes food")
|
||||
7. Only extract preferences directly stated by the user, not preferences of others they mention
|
||||
8. Provide a concise but specific description that captures the nature of the preference
|
||||
"""
|
||||
|
||||
category: str = Field(
|
||||
...,
|
||||
description="The category of the preference. (e.g., 'Brands', 'Food', 'Music')",
|
||||
)
|
||||
description: str = Field(
|
||||
...,
|
||||
description='Brief description of the preference. Only use information mentioned in the context to write this description.',
|
||||
)
|
||||
|
||||
|
||||
class Procedure(BaseModel):
|
||||
"""A Procedure informing the agent what actions to take or how to perform in certain scenarios. Procedures are typically composed of several steps.
|
||||
|
||||
Instructions for identifying and extracting procedures:
|
||||
1. Look for sequential instructions or steps ("First do X, then do Y")
|
||||
2. Identify explicit directives or commands ("Always do X when Y happens")
|
||||
3. Pay attention to conditional statements ("If X occurs, then do Y")
|
||||
4. Extract procedures that have clear beginning and end points
|
||||
5. Focus on actionable instructions rather than general information
|
||||
6. Preserve the original sequence and dependencies between steps
|
||||
7. Include any specified conditions or triggers for the procedure
|
||||
8. Capture any stated purpose or goal of the procedure
|
||||
9. Summarize complex procedures while maintaining critical details
|
||||
"""
|
||||
|
||||
description: str = Field(
|
||||
...,
|
||||
description='Brief description of the procedure. Only use information mentioned in the context to write this description.',
|
||||
)
|
||||
|
||||
|
||||
ENTITY_TYPES: dict[str, BaseModel] = {
|
||||
'Requirement': Requirement, # type: ignore
|
||||
'Preference': Preference, # type: ignore
|
||||
'Procedure': Procedure, # type: ignore
|
||||
}
|
||||
26
mcp_server/formatting.py
Normal file
26
mcp_server/formatting.py
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
"""Formatting utilities for Graphiti MCP Server."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from graphiti_core.edges import EntityEdge
|
||||
|
||||
|
||||
def format_fact_result(edge: EntityEdge) -> dict[str, Any]:
|
||||
"""Format an entity edge into a readable result.
|
||||
|
||||
Since EntityEdge is a Pydantic BaseModel, we can use its built-in serialization capabilities.
|
||||
|
||||
Args:
|
||||
edge: The EntityEdge to format
|
||||
|
||||
Returns:
|
||||
A dictionary representation of the edge with serialized dates and excluded embeddings
|
||||
"""
|
||||
result = edge.model_dump(
|
||||
mode='json',
|
||||
exclude={
|
||||
'fact_embedding',
|
||||
},
|
||||
)
|
||||
result.get('attributes', {}).pop('fact_embedding', None)
|
||||
return result
|
||||
|
|
@ -8,25 +8,30 @@ import asyncio
|
|||
import logging
|
||||
import os
|
||||
import sys
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, TypedDict, cast
|
||||
from typing import Any, cast
|
||||
|
||||
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
|
||||
from config_manager import GraphitiConfig
|
||||
from dotenv import load_dotenv
|
||||
from entity_types import ENTITY_TYPES
|
||||
from formatting import format_fact_result
|
||||
from graphiti_service import GraphitiService
|
||||
from llm_config import DEFAULT_LLM_MODEL, SMALL_LLM_MODEL
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
from openai import AsyncAzureOpenAI
|
||||
from pydantic import BaseModel, Field
|
||||
from queue_service import QueueService
|
||||
from response_types import (
|
||||
EpisodeSearchResponse,
|
||||
ErrorResponse,
|
||||
FactSearchResponse,
|
||||
NodeResult,
|
||||
NodeSearchResponse,
|
||||
StatusResponse,
|
||||
SuccessResponse,
|
||||
)
|
||||
from server_config import MCPConfig
|
||||
|
||||
from graphiti_core import Graphiti
|
||||
from graphiti_core.edges import EntityEdge
|
||||
from graphiti_core.embedder.azure_openai import AzureOpenAIEmbedderClient
|
||||
from graphiti_core.embedder.client import EmbedderClient
|
||||
from graphiti_core.embedder.openai import OpenAIEmbedder, OpenAIEmbedderConfig
|
||||
from graphiti_core.llm_client import LLMClient
|
||||
from graphiti_core.llm_client.azure_openai_client import AzureOpenAILLMClient
|
||||
from graphiti_core.llm_client.config import LLMConfig
|
||||
from graphiti_core.llm_client.openai_client import OpenAIClient
|
||||
from graphiti_core.nodes import EpisodeType, EpisodicNode
|
||||
from graphiti_core.search.search_config_recipes import (
|
||||
NODE_HYBRID_SEARCH_NODE_DISTANCE,
|
||||
|
|
@ -38,488 +43,12 @@ from graphiti_core.utils.maintenance.graph_data_operations import clear_data
|
|||
load_dotenv()
|
||||
|
||||
|
||||
DEFAULT_LLM_MODEL = 'gpt-4.1-mini'
|
||||
SMALL_LLM_MODEL = 'gpt-4.1-nano'
|
||||
DEFAULT_EMBEDDER_MODEL = 'text-embedding-3-small'
|
||||
|
||||
# Semaphore limit for concurrent Graphiti operations.
|
||||
# Decrease this if you're experiencing 429 rate limit errors from your LLM provider.
|
||||
# Increase if you have high rate limits.
|
||||
SEMAPHORE_LIMIT = int(os.getenv('SEMAPHORE_LIMIT', 10))
|
||||
|
||||
|
||||
class Requirement(BaseModel):
|
||||
"""A Requirement represents a specific need, feature, or functionality that a product or service must fulfill.
|
||||
|
||||
Always ensure an edge is created between the requirement and the project it belongs to, and clearly indicate on the
|
||||
edge that the requirement is a requirement.
|
||||
|
||||
Instructions for identifying and extracting requirements:
|
||||
1. Look for explicit statements of needs or necessities ("We need X", "X is required", "X must have Y")
|
||||
2. Identify functional specifications that describe what the system should do
|
||||
3. Pay attention to non-functional requirements like performance, security, or usability criteria
|
||||
4. Extract constraints or limitations that must be adhered to
|
||||
5. Focus on clear, specific, and measurable requirements rather than vague wishes
|
||||
6. Capture the priority or importance if mentioned ("critical", "high priority", etc.)
|
||||
7. Include any dependencies between requirements when explicitly stated
|
||||
8. Preserve the original intent and scope of the requirement
|
||||
9. Categorize requirements appropriately based on their domain or function
|
||||
"""
|
||||
|
||||
project_name: str = Field(
|
||||
...,
|
||||
description='The name of the project to which the requirement belongs.',
|
||||
)
|
||||
description: str = Field(
|
||||
...,
|
||||
description='Description of the requirement. Only use information mentioned in the context to write this description.',
|
||||
)
|
||||
|
||||
|
||||
class Preference(BaseModel):
|
||||
"""A Preference represents a user's expressed like, dislike, or preference for something.
|
||||
|
||||
Instructions for identifying and extracting preferences:
|
||||
1. Look for explicit statements of preference such as "I like/love/enjoy/prefer X" or "I don't like/hate/dislike X"
|
||||
2. Pay attention to comparative statements ("I prefer X over Y")
|
||||
3. Consider the emotional tone when users mention certain topics
|
||||
4. Extract only preferences that are clearly expressed, not assumptions
|
||||
5. Categorize the preference appropriately based on its domain (food, music, brands, etc.)
|
||||
6. Include relevant qualifiers (e.g., "likes spicy food" rather than just "likes food")
|
||||
7. Only extract preferences directly stated by the user, not preferences of others they mention
|
||||
8. Provide a concise but specific description that captures the nature of the preference
|
||||
"""
|
||||
|
||||
category: str = Field(
|
||||
...,
|
||||
description="The category of the preference. (e.g., 'Brands', 'Food', 'Music')",
|
||||
)
|
||||
description: str = Field(
|
||||
...,
|
||||
description='Brief description of the preference. Only use information mentioned in the context to write this description.',
|
||||
)
|
||||
|
||||
|
||||
class Procedure(BaseModel):
|
||||
"""A Procedure informing the agent what actions to take or how to perform in certain scenarios. Procedures are typically composed of several steps.
|
||||
|
||||
Instructions for identifying and extracting procedures:
|
||||
1. Look for sequential instructions or steps ("First do X, then do Y")
|
||||
2. Identify explicit directives or commands ("Always do X when Y happens")
|
||||
3. Pay attention to conditional statements ("If X occurs, then do Y")
|
||||
4. Extract procedures that have clear beginning and end points
|
||||
5. Focus on actionable instructions rather than general information
|
||||
6. Preserve the original sequence and dependencies between steps
|
||||
7. Include any specified conditions or triggers for the procedure
|
||||
8. Capture any stated purpose or goal of the procedure
|
||||
9. Summarize complex procedures while maintaining critical details
|
||||
"""
|
||||
|
||||
description: str = Field(
|
||||
...,
|
||||
description='Brief description of the procedure. Only use information mentioned in the context to write this description.',
|
||||
)
|
||||
|
||||
|
||||
ENTITY_TYPES: dict[str, BaseModel] = {
|
||||
'Requirement': Requirement, # type: ignore
|
||||
'Preference': Preference, # type: ignore
|
||||
'Procedure': Procedure, # type: ignore
|
||||
}
|
||||
|
||||
|
||||
# Type definitions for API responses
|
||||
class ErrorResponse(TypedDict):
|
||||
error: str
|
||||
|
||||
|
||||
class SuccessResponse(TypedDict):
|
||||
message: str
|
||||
|
||||
|
||||
class NodeResult(TypedDict):
|
||||
uuid: str
|
||||
name: str
|
||||
summary: str
|
||||
labels: list[str]
|
||||
group_id: str
|
||||
created_at: str
|
||||
attributes: dict[str, Any]
|
||||
|
||||
|
||||
class NodeSearchResponse(TypedDict):
|
||||
message: str
|
||||
nodes: list[NodeResult]
|
||||
|
||||
|
||||
class FactSearchResponse(TypedDict):
|
||||
message: str
|
||||
facts: list[dict[str, Any]]
|
||||
|
||||
|
||||
class EpisodeSearchResponse(TypedDict):
|
||||
message: str
|
||||
episodes: list[dict[str, Any]]
|
||||
|
||||
|
||||
class StatusResponse(TypedDict):
|
||||
status: str
|
||||
message: str
|
||||
|
||||
|
||||
def create_azure_credential_token_provider() -> Callable[[], str]:
|
||||
credential = DefaultAzureCredential()
|
||||
token_provider = get_bearer_token_provider(
|
||||
credential, 'https://cognitiveservices.azure.com/.default'
|
||||
)
|
||||
return token_provider
|
||||
|
||||
|
||||
# Server configuration classes
|
||||
# The configuration system has a hierarchy:
|
||||
# - GraphitiConfig is the top-level configuration
|
||||
# - LLMConfig handles all OpenAI/LLM related settings
|
||||
# - EmbedderConfig manages embedding settings
|
||||
# - Neo4jConfig manages database connection details
|
||||
# - Various other settings like group_id and feature flags
|
||||
# Configuration values are loaded from:
|
||||
# 1. Default values in the class definitions
|
||||
# 2. Environment variables (loaded via load_dotenv())
|
||||
# 3. Command line arguments (which override environment variables)
|
||||
class GraphitiLLMConfig(BaseModel):
|
||||
"""Configuration for the LLM client.
|
||||
|
||||
Centralizes all LLM-specific configuration parameters including API keys and model selection.
|
||||
"""
|
||||
|
||||
api_key: str | None = None
|
||||
model: str = DEFAULT_LLM_MODEL
|
||||
small_model: str = SMALL_LLM_MODEL
|
||||
temperature: float = 0.0
|
||||
azure_openai_endpoint: str | None = None
|
||||
azure_openai_deployment_name: str | None = None
|
||||
azure_openai_api_version: str | None = None
|
||||
azure_openai_use_managed_identity: bool = False
|
||||
|
||||
@classmethod
|
||||
def from_env(cls) -> 'GraphitiLLMConfig':
|
||||
"""Create LLM configuration from environment variables."""
|
||||
# Get model from environment, or use default if not set or empty
|
||||
model_env = os.environ.get('MODEL_NAME', '')
|
||||
model = model_env if model_env.strip() else DEFAULT_LLM_MODEL
|
||||
|
||||
# Get small_model from environment, or use default if not set or empty
|
||||
small_model_env = os.environ.get('SMALL_MODEL_NAME', '')
|
||||
small_model = small_model_env if small_model_env.strip() else SMALL_LLM_MODEL
|
||||
|
||||
azure_openai_endpoint = os.environ.get('AZURE_OPENAI_ENDPOINT', None)
|
||||
azure_openai_api_version = os.environ.get('AZURE_OPENAI_API_VERSION', None)
|
||||
azure_openai_deployment_name = os.environ.get('AZURE_OPENAI_DEPLOYMENT_NAME', None)
|
||||
azure_openai_use_managed_identity = (
|
||||
os.environ.get('AZURE_OPENAI_USE_MANAGED_IDENTITY', 'false').lower() == 'true'
|
||||
)
|
||||
|
||||
if azure_openai_endpoint is None:
|
||||
# Setup for OpenAI API
|
||||
# Log if empty model was provided
|
||||
if model_env == '':
|
||||
logger.debug(
|
||||
f'MODEL_NAME environment variable not set, using default: {DEFAULT_LLM_MODEL}'
|
||||
)
|
||||
elif not model_env.strip():
|
||||
logger.warning(
|
||||
f'Empty MODEL_NAME environment variable, using default: {DEFAULT_LLM_MODEL}'
|
||||
)
|
||||
|
||||
return cls(
|
||||
api_key=os.environ.get('OPENAI_API_KEY'),
|
||||
model=model,
|
||||
small_model=small_model,
|
||||
temperature=float(os.environ.get('LLM_TEMPERATURE', '0.0')),
|
||||
)
|
||||
else:
|
||||
# Setup for Azure OpenAI API
|
||||
# Log if empty deployment name was provided
|
||||
if azure_openai_deployment_name is None:
|
||||
logger.error('AZURE_OPENAI_DEPLOYMENT_NAME environment variable not set')
|
||||
|
||||
raise ValueError('AZURE_OPENAI_DEPLOYMENT_NAME environment variable not set')
|
||||
if not azure_openai_use_managed_identity:
|
||||
# api key
|
||||
api_key = os.environ.get('OPENAI_API_KEY', None)
|
||||
else:
|
||||
# Managed identity
|
||||
api_key = None
|
||||
|
||||
return cls(
|
||||
azure_openai_use_managed_identity=azure_openai_use_managed_identity,
|
||||
azure_openai_endpoint=azure_openai_endpoint,
|
||||
api_key=api_key,
|
||||
azure_openai_api_version=azure_openai_api_version,
|
||||
azure_openai_deployment_name=azure_openai_deployment_name,
|
||||
model=model,
|
||||
small_model=small_model,
|
||||
temperature=float(os.environ.get('LLM_TEMPERATURE', '0.0')),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_and_env(cls, args: argparse.Namespace) -> 'GraphitiLLMConfig':
|
||||
"""Create LLM configuration from CLI arguments, falling back to environment variables."""
|
||||
# Start with environment-based config
|
||||
config = cls.from_env()
|
||||
|
||||
# CLI arguments override environment variables when provided
|
||||
if hasattr(args, 'model') and args.model:
|
||||
# Only use CLI model if it's not empty
|
||||
if args.model.strip():
|
||||
config.model = args.model
|
||||
else:
|
||||
# Log that empty model was provided and default is used
|
||||
logger.warning(f'Empty model name provided, using default: {DEFAULT_LLM_MODEL}')
|
||||
|
||||
if hasattr(args, 'small_model') and args.small_model:
|
||||
if args.small_model.strip():
|
||||
config.small_model = args.small_model
|
||||
else:
|
||||
logger.warning(f'Empty small_model name provided, using default: {SMALL_LLM_MODEL}')
|
||||
|
||||
if hasattr(args, 'temperature') and args.temperature is not None:
|
||||
config.temperature = args.temperature
|
||||
|
||||
return config
|
||||
|
||||
def create_client(self) -> LLMClient:
|
||||
"""Create an LLM client based on this configuration.
|
||||
|
||||
Returns:
|
||||
LLMClient instance
|
||||
"""
|
||||
|
||||
if self.azure_openai_endpoint is not None:
|
||||
# Azure OpenAI API setup
|
||||
if self.azure_openai_use_managed_identity:
|
||||
# Use managed identity for authentication
|
||||
token_provider = create_azure_credential_token_provider()
|
||||
return AzureOpenAILLMClient(
|
||||
azure_client=AsyncAzureOpenAI(
|
||||
azure_endpoint=self.azure_openai_endpoint,
|
||||
azure_deployment=self.azure_openai_deployment_name,
|
||||
api_version=self.azure_openai_api_version,
|
||||
azure_ad_token_provider=token_provider,
|
||||
),
|
||||
config=LLMConfig(
|
||||
api_key=self.api_key,
|
||||
model=self.model,
|
||||
small_model=self.small_model,
|
||||
temperature=self.temperature,
|
||||
),
|
||||
)
|
||||
elif self.api_key:
|
||||
# Use API key for authentication
|
||||
return AzureOpenAILLMClient(
|
||||
azure_client=AsyncAzureOpenAI(
|
||||
azure_endpoint=self.azure_openai_endpoint,
|
||||
azure_deployment=self.azure_openai_deployment_name,
|
||||
api_version=self.azure_openai_api_version,
|
||||
api_key=self.api_key,
|
||||
),
|
||||
config=LLMConfig(
|
||||
api_key=self.api_key,
|
||||
model=self.model,
|
||||
small_model=self.small_model,
|
||||
temperature=self.temperature,
|
||||
),
|
||||
)
|
||||
else:
|
||||
raise ValueError('OPENAI_API_KEY must be set when using Azure OpenAI API')
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError('OPENAI_API_KEY must be set when using OpenAI API')
|
||||
|
||||
llm_client_config = LLMConfig(
|
||||
api_key=self.api_key, model=self.model, small_model=self.small_model
|
||||
)
|
||||
|
||||
# Set temperature
|
||||
llm_client_config.temperature = self.temperature
|
||||
|
||||
return OpenAIClient(config=llm_client_config)
|
||||
|
||||
|
||||
class GraphitiEmbedderConfig(BaseModel):
|
||||
"""Configuration for the embedder client.
|
||||
|
||||
Centralizes all embedding-related configuration parameters.
|
||||
"""
|
||||
|
||||
model: str = DEFAULT_EMBEDDER_MODEL
|
||||
api_key: str | None = None
|
||||
azure_openai_endpoint: str | None = None
|
||||
azure_openai_deployment_name: str | None = None
|
||||
azure_openai_api_version: str | None = None
|
||||
azure_openai_use_managed_identity: bool = False
|
||||
|
||||
@classmethod
|
||||
def from_env(cls) -> 'GraphitiEmbedderConfig':
|
||||
"""Create embedder configuration from environment variables."""
|
||||
|
||||
# Get model from environment, or use default if not set or empty
|
||||
model_env = os.environ.get('EMBEDDER_MODEL_NAME', '')
|
||||
model = model_env if model_env.strip() else DEFAULT_EMBEDDER_MODEL
|
||||
|
||||
azure_openai_endpoint = os.environ.get('AZURE_OPENAI_EMBEDDING_ENDPOINT', None)
|
||||
azure_openai_api_version = os.environ.get('AZURE_OPENAI_EMBEDDING_API_VERSION', None)
|
||||
azure_openai_deployment_name = os.environ.get(
|
||||
'AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME', None
|
||||
)
|
||||
azure_openai_use_managed_identity = (
|
||||
os.environ.get('AZURE_OPENAI_USE_MANAGED_IDENTITY', 'false').lower() == 'true'
|
||||
)
|
||||
if azure_openai_endpoint is not None:
|
||||
# Setup for Azure OpenAI API
|
||||
# Log if empty deployment name was provided
|
||||
azure_openai_deployment_name = os.environ.get(
|
||||
'AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME', None
|
||||
)
|
||||
if azure_openai_deployment_name is None:
|
||||
logger.error('AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME environment variable not set')
|
||||
|
||||
raise ValueError(
|
||||
'AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME environment variable not set'
|
||||
)
|
||||
|
||||
if not azure_openai_use_managed_identity:
|
||||
# api key
|
||||
api_key = os.environ.get('AZURE_OPENAI_EMBEDDING_API_KEY', None) or os.environ.get(
|
||||
'OPENAI_API_KEY', None
|
||||
)
|
||||
else:
|
||||
# Managed identity
|
||||
api_key = None
|
||||
|
||||
return cls(
|
||||
azure_openai_use_managed_identity=azure_openai_use_managed_identity,
|
||||
azure_openai_endpoint=azure_openai_endpoint,
|
||||
api_key=api_key,
|
||||
azure_openai_api_version=azure_openai_api_version,
|
||||
azure_openai_deployment_name=azure_openai_deployment_name,
|
||||
)
|
||||
else:
|
||||
return cls(
|
||||
model=model,
|
||||
api_key=os.environ.get('OPENAI_API_KEY'),
|
||||
)
|
||||
|
||||
def create_client(self) -> EmbedderClient | None:
|
||||
if self.azure_openai_endpoint is not None:
|
||||
# Azure OpenAI API setup
|
||||
if self.azure_openai_use_managed_identity:
|
||||
# Use managed identity for authentication
|
||||
token_provider = create_azure_credential_token_provider()
|
||||
return AzureOpenAIEmbedderClient(
|
||||
azure_client=AsyncAzureOpenAI(
|
||||
azure_endpoint=self.azure_openai_endpoint,
|
||||
azure_deployment=self.azure_openai_deployment_name,
|
||||
api_version=self.azure_openai_api_version,
|
||||
azure_ad_token_provider=token_provider,
|
||||
),
|
||||
model=self.model,
|
||||
)
|
||||
elif self.api_key:
|
||||
# Use API key for authentication
|
||||
return AzureOpenAIEmbedderClient(
|
||||
azure_client=AsyncAzureOpenAI(
|
||||
azure_endpoint=self.azure_openai_endpoint,
|
||||
azure_deployment=self.azure_openai_deployment_name,
|
||||
api_version=self.azure_openai_api_version,
|
||||
api_key=self.api_key,
|
||||
),
|
||||
model=self.model,
|
||||
)
|
||||
else:
|
||||
logger.error('OPENAI_API_KEY must be set when using Azure OpenAI API')
|
||||
return None
|
||||
else:
|
||||
# OpenAI API setup
|
||||
if not self.api_key:
|
||||
return None
|
||||
|
||||
embedder_config = OpenAIEmbedderConfig(api_key=self.api_key, embedding_model=self.model)
|
||||
|
||||
return OpenAIEmbedder(config=embedder_config)
|
||||
|
||||
|
||||
class Neo4jConfig(BaseModel):
|
||||
"""Configuration for Neo4j database connection."""
|
||||
|
||||
uri: str = 'bolt://localhost:7687'
|
||||
user: str = 'neo4j'
|
||||
password: str = 'password'
|
||||
|
||||
@classmethod
|
||||
def from_env(cls) -> 'Neo4jConfig':
|
||||
"""Create Neo4j configuration from environment variables."""
|
||||
return cls(
|
||||
uri=os.environ.get('NEO4J_URI', 'bolt://localhost:7687'),
|
||||
user=os.environ.get('NEO4J_USER', 'neo4j'),
|
||||
password=os.environ.get('NEO4J_PASSWORD', 'password'),
|
||||
)
|
||||
|
||||
|
||||
class GraphitiConfig(BaseModel):
|
||||
"""Configuration for Graphiti client.
|
||||
|
||||
Centralizes all configuration parameters for the Graphiti client.
|
||||
"""
|
||||
|
||||
llm: GraphitiLLMConfig = Field(default_factory=GraphitiLLMConfig)
|
||||
embedder: GraphitiEmbedderConfig = Field(default_factory=GraphitiEmbedderConfig)
|
||||
neo4j: Neo4jConfig = Field(default_factory=Neo4jConfig)
|
||||
group_id: str | None = None
|
||||
use_custom_entities: bool = False
|
||||
destroy_graph: bool = False
|
||||
|
||||
@classmethod
|
||||
def from_env(cls) -> 'GraphitiConfig':
|
||||
"""Create a configuration instance from environment variables."""
|
||||
return cls(
|
||||
llm=GraphitiLLMConfig.from_env(),
|
||||
embedder=GraphitiEmbedderConfig.from_env(),
|
||||
neo4j=Neo4jConfig.from_env(),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_and_env(cls, args: argparse.Namespace) -> 'GraphitiConfig':
|
||||
"""Create configuration from CLI arguments, falling back to environment variables."""
|
||||
# Start with environment configuration
|
||||
config = cls.from_env()
|
||||
|
||||
# Apply CLI overrides
|
||||
if args.group_id:
|
||||
config.group_id = args.group_id
|
||||
else:
|
||||
config.group_id = 'default'
|
||||
|
||||
config.use_custom_entities = args.use_custom_entities
|
||||
config.destroy_graph = args.destroy_graph
|
||||
|
||||
# Update LLM config using CLI args
|
||||
config.llm = GraphitiLLMConfig.from_cli_and_env(args)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
class MCPConfig(BaseModel):
|
||||
"""Configuration for MCP server."""
|
||||
|
||||
transport: str = 'sse' # Default to SSE transport
|
||||
|
||||
@classmethod
|
||||
def from_cli(cls, args: argparse.Namespace) -> 'MCPConfig':
|
||||
"""Create MCP configuration from CLI arguments."""
|
||||
return cls(transport=args.transport)
|
||||
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
|
|
@ -568,124 +97,9 @@ mcp = FastMCP(
|
|||
instructions=GRAPHITI_MCP_INSTRUCTIONS,
|
||||
)
|
||||
|
||||
# Initialize Graphiti client
|
||||
graphiti_client: Graphiti | None = None
|
||||
|
||||
|
||||
async def initialize_graphiti():
|
||||
"""Initialize the Graphiti client with the configured settings."""
|
||||
global graphiti_client, config
|
||||
|
||||
try:
|
||||
# Create LLM client if possible
|
||||
llm_client = config.llm.create_client()
|
||||
if not llm_client and config.use_custom_entities:
|
||||
# If custom entities are enabled, we must have an LLM client
|
||||
raise ValueError('OPENAI_API_KEY must be set when custom entities are enabled')
|
||||
|
||||
# Validate Neo4j configuration
|
||||
if not config.neo4j.uri or not config.neo4j.user or not config.neo4j.password:
|
||||
raise ValueError('NEO4J_URI, NEO4J_USER, and NEO4J_PASSWORD must be set')
|
||||
|
||||
embedder_client = config.embedder.create_client()
|
||||
|
||||
# Initialize Graphiti client
|
||||
graphiti_client = Graphiti(
|
||||
uri=config.neo4j.uri,
|
||||
user=config.neo4j.user,
|
||||
password=config.neo4j.password,
|
||||
llm_client=llm_client,
|
||||
embedder=embedder_client,
|
||||
max_coroutines=SEMAPHORE_LIMIT,
|
||||
)
|
||||
|
||||
# Destroy graph if requested
|
||||
if config.destroy_graph:
|
||||
logger.info('Destroying graph...')
|
||||
await clear_data(graphiti_client.driver)
|
||||
|
||||
# Initialize the graph database with Graphiti's indices
|
||||
await graphiti_client.build_indices_and_constraints()
|
||||
logger.info('Graphiti client initialized successfully')
|
||||
|
||||
# Log configuration details for transparency
|
||||
if llm_client:
|
||||
logger.info(f'Using OpenAI model: {config.llm.model}')
|
||||
logger.info(f'Using temperature: {config.llm.temperature}')
|
||||
else:
|
||||
logger.info('No LLM client configured - entity extraction will be limited')
|
||||
|
||||
logger.info(f'Using group_id: {config.group_id}')
|
||||
logger.info(
|
||||
f'Custom entity extraction: {"enabled" if config.use_custom_entities else "disabled"}'
|
||||
)
|
||||
logger.info(f'Using concurrency limit: {SEMAPHORE_LIMIT}')
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f'Failed to initialize Graphiti: {str(e)}')
|
||||
raise
|
||||
|
||||
|
||||
def format_fact_result(edge: EntityEdge) -> dict[str, Any]:
|
||||
"""Format an entity edge into a readable result.
|
||||
|
||||
Since EntityEdge is a Pydantic BaseModel, we can use its built-in serialization capabilities.
|
||||
|
||||
Args:
|
||||
edge: The EntityEdge to format
|
||||
|
||||
Returns:
|
||||
A dictionary representation of the edge with serialized dates and excluded embeddings
|
||||
"""
|
||||
result = edge.model_dump(
|
||||
mode='json',
|
||||
exclude={
|
||||
'fact_embedding',
|
||||
},
|
||||
)
|
||||
result.get('attributes', {}).pop('fact_embedding', None)
|
||||
return result
|
||||
|
||||
|
||||
# Dictionary to store queues for each group_id
|
||||
# Each queue is a list of tasks to be processed sequentially
|
||||
episode_queues: dict[str, asyncio.Queue] = {}
|
||||
# Dictionary to track if a worker is running for each group_id
|
||||
queue_workers: dict[str, bool] = {}
|
||||
|
||||
|
||||
async def process_episode_queue(group_id: str):
|
||||
"""Process episodes for a specific group_id sequentially.
|
||||
|
||||
This function runs as a long-lived task that processes episodes
|
||||
from the queue one at a time.
|
||||
"""
|
||||
global queue_workers
|
||||
|
||||
logger.info(f'Starting episode queue worker for group_id: {group_id}')
|
||||
queue_workers[group_id] = True
|
||||
|
||||
try:
|
||||
while True:
|
||||
# Get the next episode processing function from the queue
|
||||
# This will wait if the queue is empty
|
||||
process_func = await episode_queues[group_id].get()
|
||||
|
||||
try:
|
||||
# Process the episode
|
||||
await process_func()
|
||||
except Exception as e:
|
||||
logger.error(f'Error processing queued episode for group_id {group_id}: {str(e)}')
|
||||
finally:
|
||||
# Mark the task as done regardless of success/failure
|
||||
episode_queues[group_id].task_done()
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f'Episode queue worker for group_id {group_id} was cancelled')
|
||||
except Exception as e:
|
||||
logger.error(f'Unexpected error in queue worker for group_id {group_id}: {str(e)}')
|
||||
finally:
|
||||
queue_workers[group_id] = False
|
||||
logger.info(f'Stopped episode queue worker for group_id: {group_id}')
|
||||
# Global services
|
||||
graphiti_service: GraphitiService | None = None
|
||||
queue_service: QueueService | None = None
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
|
|
@ -752,10 +166,13 @@ async def add_memory(
|
|||
- Entities will be created from appropriate JSON properties
|
||||
- Relationships between entities will be established based on the JSON structure
|
||||
"""
|
||||
global graphiti_client, episode_queues, queue_workers
|
||||
global graphiti_service, queue_service, config
|
||||
|
||||
if graphiti_client is None:
|
||||
return ErrorResponse(error='Graphiti client not initialized')
|
||||
if not graphiti_service or not graphiti_service.is_initialized():
|
||||
return ErrorResponse(error='Graphiti service not initialized')
|
||||
|
||||
if not queue_service:
|
||||
return ErrorResponse(error='Queue service not initialized')
|
||||
|
||||
try:
|
||||
# Map string source to EpisodeType enum
|
||||
|
|
@ -772,13 +189,6 @@ async def add_memory(
|
|||
# The Graphiti client expects a str for group_id, not Optional[str]
|
||||
group_id_str = str(effective_group_id) if effective_group_id is not None else ''
|
||||
|
||||
# We've already checked that graphiti_client is not None above
|
||||
# This assert statement helps type checkers understand that graphiti_client is defined
|
||||
assert graphiti_client is not None, 'graphiti_client should not be None here'
|
||||
|
||||
# Use cast to help the type checker understand that graphiti_client is not None
|
||||
client = cast(Graphiti, graphiti_client)
|
||||
|
||||
# Define the episode processing function
|
||||
async def process_episode():
|
||||
try:
|
||||
|
|
@ -786,7 +196,7 @@ async def add_memory(
|
|||
# Use all entity types if use_custom_entities is enabled, otherwise use empty dict
|
||||
entity_types = ENTITY_TYPES if config.use_custom_entities else {}
|
||||
|
||||
await client.add_episode(
|
||||
await graphiti_service.client.add_episode(
|
||||
name=name,
|
||||
episode_body=episode_body,
|
||||
source=source_type,
|
||||
|
|
@ -796,8 +206,6 @@ async def add_memory(
|
|||
reference_time=datetime.now(timezone.utc),
|
||||
entity_types=entity_types,
|
||||
)
|
||||
logger.info(f"Episode '{name}' added successfully")
|
||||
|
||||
logger.info(f"Episode '{name}' processed successfully")
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
|
|
@ -805,20 +213,12 @@ async def add_memory(
|
|||
f"Error processing episode '{name}' for group_id {group_id_str}: {error_msg}"
|
||||
)
|
||||
|
||||
# Initialize queue for this group_id if it doesn't exist
|
||||
if group_id_str not in episode_queues:
|
||||
episode_queues[group_id_str] = asyncio.Queue()
|
||||
|
||||
# Add the episode processing function to the queue
|
||||
await episode_queues[group_id_str].put(process_episode)
|
||||
|
||||
# Start a worker for this queue if one isn't already running
|
||||
if not queue_workers.get(group_id_str, False):
|
||||
asyncio.create_task(process_episode_queue(group_id_str))
|
||||
queue_position = await queue_service.add_episode_task(group_id_str, process_episode)
|
||||
|
||||
# Return immediately with a success message
|
||||
return SuccessResponse(
|
||||
message=f"Episode '{name}' queued for processing (position: {episode_queues[group_id_str].qsize()})"
|
||||
message=f"Episode '{name}' queued for processing (position: {queue_position})"
|
||||
)
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
|
|
@ -846,10 +246,10 @@ async def search_memory_nodes(
|
|||
center_node_uuid: Optional UUID of a node to center the search around
|
||||
entity: Optional single entity type to filter results (permitted: "Preference", "Procedure")
|
||||
"""
|
||||
global graphiti_client
|
||||
global graphiti_service, config
|
||||
|
||||
if graphiti_client is None:
|
||||
return ErrorResponse(error='Graphiti client not initialized')
|
||||
if not graphiti_service or not graphiti_service.is_initialized():
|
||||
return ErrorResponse(error='Graphiti service not initialized')
|
||||
|
||||
try:
|
||||
# Use the provided group_ids or fall back to the default from config if none provided
|
||||
|
|
@ -868,11 +268,7 @@ async def search_memory_nodes(
|
|||
if entity != '':
|
||||
filters.node_labels = [entity]
|
||||
|
||||
# We've already checked that graphiti_client is not None above
|
||||
assert graphiti_client is not None
|
||||
|
||||
# Use cast to help the type checker understand that graphiti_client is not None
|
||||
client = cast(Graphiti, graphiti_client)
|
||||
client = graphiti_service.client
|
||||
|
||||
# Perform the search using the _search method
|
||||
search_results = await client._search(
|
||||
|
|
@ -1218,8 +614,11 @@ async def initialize_server() -> MCPConfig:
|
|||
else:
|
||||
logger.info('Entity extraction disabled (no custom entities will be used)')
|
||||
|
||||
# Initialize Graphiti
|
||||
await initialize_graphiti()
|
||||
# Initialize services
|
||||
global graphiti_service, queue_service
|
||||
graphiti_service = GraphitiService(config, SEMAPHORE_LIMIT)
|
||||
queue_service = QueueService()
|
||||
await graphiti_service.initialize()
|
||||
|
||||
if args.host:
|
||||
logger.info(f'Setting MCP server host to: {args.host}')
|
||||
|
|
|
|||
110
mcp_server/graphiti_service.py
Normal file
110
mcp_server/graphiti_service.py
Normal file
|
|
@ -0,0 +1,110 @@
|
|||
"""Graphiti service for managing client lifecycle and operations."""
|
||||
|
||||
import logging
|
||||
|
||||
from config_manager import GraphitiConfig
|
||||
|
||||
from graphiti_core import Graphiti
|
||||
from graphiti_core.utils.maintenance.graph_data_operations import clear_data
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GraphitiService:
|
||||
"""Service for managing Graphiti client operations."""
|
||||
|
||||
def __init__(self, config: GraphitiConfig, semaphore_limit: int = 10):
|
||||
"""Initialize the Graphiti service with configuration.
|
||||
|
||||
Args:
|
||||
config: The Graphiti configuration
|
||||
semaphore_limit: Maximum concurrent operations
|
||||
"""
|
||||
self.config = config
|
||||
self.semaphore_limit = semaphore_limit
|
||||
self._client: Graphiti | None = None
|
||||
|
||||
@property
|
||||
def client(self) -> Graphiti:
|
||||
"""Get the Graphiti client instance.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the client hasn't been initialized
|
||||
"""
|
||||
if self._client is None:
|
||||
raise RuntimeError('Graphiti client not initialized. Call initialize() first.')
|
||||
return self._client
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize the Graphiti client with the configured settings."""
|
||||
try:
|
||||
# Create LLM client if possible
|
||||
llm_client = self.config.llm.create_client()
|
||||
if not llm_client and self.config.use_custom_entities:
|
||||
# If custom entities are enabled, we must have an LLM client
|
||||
raise ValueError('OPENAI_API_KEY must be set when custom entities are enabled')
|
||||
|
||||
# Validate Neo4j configuration
|
||||
if (
|
||||
not self.config.neo4j.uri
|
||||
or not self.config.neo4j.user
|
||||
or not self.config.neo4j.password
|
||||
):
|
||||
raise ValueError('NEO4J_URI, NEO4J_USER, and NEO4J_PASSWORD must be set')
|
||||
|
||||
embedder_client = self.config.embedder.create_client()
|
||||
|
||||
# Initialize Graphiti client
|
||||
self._client = Graphiti(
|
||||
uri=self.config.neo4j.uri,
|
||||
user=self.config.neo4j.user,
|
||||
password=self.config.neo4j.password,
|
||||
llm_client=llm_client,
|
||||
embedder=embedder_client,
|
||||
max_coroutines=self.semaphore_limit,
|
||||
)
|
||||
|
||||
# Destroy graph if requested
|
||||
if self.config.destroy_graph:
|
||||
logger.info('Destroying graph...')
|
||||
await clear_data(self._client.driver)
|
||||
|
||||
# Initialize the graph database with Graphiti's indices
|
||||
await self._client.build_indices_and_constraints()
|
||||
logger.info('Graphiti client initialized successfully')
|
||||
|
||||
# Log configuration details for transparency
|
||||
if llm_client:
|
||||
logger.info(f'Using OpenAI model: {self.config.llm.model}')
|
||||
logger.info(f'Using temperature: {self.config.llm.temperature}')
|
||||
else:
|
||||
logger.info('No LLM client configured - entity extraction will be limited')
|
||||
|
||||
logger.info(f'Using group_id: {self.config.group_id}')
|
||||
logger.info(
|
||||
f'Custom entity extraction: {"enabled" if self.config.use_custom_entities else "disabled"}'
|
||||
)
|
||||
logger.info(f'Using concurrency limit: {self.semaphore_limit}')
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f'Failed to initialize Graphiti: {str(e)}')
|
||||
raise
|
||||
|
||||
async def clear_graph(self) -> None:
|
||||
"""Clear all data from the graph and rebuild indices."""
|
||||
if self._client is None:
|
||||
raise RuntimeError('Graphiti client not initialized')
|
||||
|
||||
await clear_data(self._client.driver)
|
||||
await self._client.build_indices_and_constraints()
|
||||
|
||||
async def verify_connection(self) -> None:
|
||||
"""Verify the database connection."""
|
||||
if self._client is None:
|
||||
raise RuntimeError('Graphiti client not initialized')
|
||||
|
||||
await self._client.driver.client.verify_connectivity() # type: ignore
|
||||
|
||||
def is_initialized(self) -> bool:
|
||||
"""Check if the client is initialized."""
|
||||
return self._client is not None
|
||||
182
mcp_server/llm_config.py
Normal file
182
mcp_server/llm_config.py
Normal file
|
|
@ -0,0 +1,182 @@
|
|||
"""LLM configuration for Graphiti MCP Server."""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from openai import AsyncAzureOpenAI
|
||||
from pydantic import BaseModel
|
||||
from utils import create_azure_credential_token_provider
|
||||
|
||||
from graphiti_core.llm_client import LLMClient
|
||||
from graphiti_core.llm_client.azure_openai_client import AzureOpenAILLMClient
|
||||
from graphiti_core.llm_client.config import LLMConfig
|
||||
from graphiti_core.llm_client.openai_client import OpenAIClient
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_LLM_MODEL = 'gpt-4.1-mini'
|
||||
SMALL_LLM_MODEL = 'gpt-4.1-nano'
|
||||
|
||||
|
||||
class GraphitiLLMConfig(BaseModel):
|
||||
"""Configuration for the LLM client.
|
||||
|
||||
Centralizes all LLM-specific configuration parameters including API keys and model selection.
|
||||
"""
|
||||
|
||||
api_key: str | None = None
|
||||
model: str = DEFAULT_LLM_MODEL
|
||||
small_model: str = SMALL_LLM_MODEL
|
||||
temperature: float = 0.0
|
||||
azure_openai_endpoint: str | None = None
|
||||
azure_openai_deployment_name: str | None = None
|
||||
azure_openai_api_version: str | None = None
|
||||
azure_openai_use_managed_identity: bool = False
|
||||
|
||||
@classmethod
|
||||
def from_env(cls) -> 'GraphitiLLMConfig':
|
||||
"""Create LLM configuration from environment variables."""
|
||||
# Get model from environment, or use default if not set or empty
|
||||
model_env = os.environ.get('MODEL_NAME', '')
|
||||
model = model_env if model_env.strip() else DEFAULT_LLM_MODEL
|
||||
|
||||
# Get small_model from environment, or use default if not set or empty
|
||||
small_model_env = os.environ.get('SMALL_MODEL_NAME', '')
|
||||
small_model = small_model_env if small_model_env.strip() else SMALL_LLM_MODEL
|
||||
|
||||
azure_openai_endpoint = os.environ.get('AZURE_OPENAI_ENDPOINT', None)
|
||||
azure_openai_api_version = os.environ.get('AZURE_OPENAI_API_VERSION', None)
|
||||
azure_openai_deployment_name = os.environ.get('AZURE_OPENAI_DEPLOYMENT_NAME', None)
|
||||
azure_openai_use_managed_identity = (
|
||||
os.environ.get('AZURE_OPENAI_USE_MANAGED_IDENTITY', 'false').lower() == 'true'
|
||||
)
|
||||
|
||||
if azure_openai_endpoint is None:
|
||||
# Setup for OpenAI API
|
||||
# Log if empty model was provided
|
||||
if model_env == '':
|
||||
logger.debug(
|
||||
f'MODEL_NAME environment variable not set, using default: {DEFAULT_LLM_MODEL}'
|
||||
)
|
||||
elif not model_env.strip():
|
||||
logger.warning(
|
||||
f'Empty MODEL_NAME environment variable, using default: {DEFAULT_LLM_MODEL}'
|
||||
)
|
||||
|
||||
return cls(
|
||||
api_key=os.environ.get('OPENAI_API_KEY'),
|
||||
model=model,
|
||||
small_model=small_model,
|
||||
temperature=float(os.environ.get('LLM_TEMPERATURE', '0.0')),
|
||||
)
|
||||
else:
|
||||
# Setup for Azure OpenAI API
|
||||
# Log if empty deployment name was provided
|
||||
if azure_openai_deployment_name is None:
|
||||
logger.error('AZURE_OPENAI_DEPLOYMENT_NAME environment variable not set')
|
||||
raise ValueError('AZURE_OPENAI_DEPLOYMENT_NAME environment variable not set')
|
||||
|
||||
if not azure_openai_use_managed_identity:
|
||||
# api key
|
||||
api_key = os.environ.get('OPENAI_API_KEY', None)
|
||||
else:
|
||||
# Managed identity
|
||||
api_key = None
|
||||
|
||||
return cls(
|
||||
azure_openai_use_managed_identity=azure_openai_use_managed_identity,
|
||||
azure_openai_endpoint=azure_openai_endpoint,
|
||||
api_key=api_key,
|
||||
azure_openai_api_version=azure_openai_api_version,
|
||||
azure_openai_deployment_name=azure_openai_deployment_name,
|
||||
model=model,
|
||||
small_model=small_model,
|
||||
temperature=float(os.environ.get('LLM_TEMPERATURE', '0.0')),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_and_env(cls, args: argparse.Namespace) -> 'GraphitiLLMConfig':
|
||||
"""Create LLM configuration from CLI arguments, falling back to environment variables."""
|
||||
# Start with environment-based config
|
||||
config = cls.from_env()
|
||||
|
||||
# CLI arguments override environment variables when provided
|
||||
if hasattr(args, 'model') and args.model:
|
||||
# Only use CLI model if it's not empty
|
||||
if args.model.strip():
|
||||
config.model = args.model
|
||||
else:
|
||||
# Log that empty model was provided and default is used
|
||||
logger.warning(f'Empty model name provided, using default: {DEFAULT_LLM_MODEL}')
|
||||
|
||||
if hasattr(args, 'small_model') and args.small_model:
|
||||
if args.small_model.strip():
|
||||
config.small_model = args.small_model
|
||||
else:
|
||||
logger.warning(f'Empty small_model name provided, using default: {SMALL_LLM_MODEL}')
|
||||
|
||||
if hasattr(args, 'temperature') and args.temperature is not None:
|
||||
config.temperature = args.temperature
|
||||
|
||||
return config
|
||||
|
||||
def create_client(self) -> LLMClient:
|
||||
"""Create an LLM client based on this configuration.
|
||||
|
||||
Returns:
|
||||
LLMClient instance
|
||||
"""
|
||||
if self.azure_openai_endpoint is not None:
|
||||
# Azure OpenAI API setup
|
||||
if self.azure_openai_use_managed_identity:
|
||||
# Use managed identity for authentication
|
||||
token_provider = create_azure_credential_token_provider()
|
||||
return AzureOpenAILLMClient(
|
||||
azure_client=AsyncAzureOpenAI(
|
||||
azure_endpoint=self.azure_openai_endpoint,
|
||||
azure_deployment=self.azure_openai_deployment_name,
|
||||
api_version=self.azure_openai_api_version,
|
||||
azure_ad_token_provider=token_provider,
|
||||
),
|
||||
config=LLMConfig(
|
||||
api_key=self.api_key,
|
||||
model=self.model,
|
||||
small_model=self.small_model,
|
||||
temperature=self.temperature,
|
||||
),
|
||||
)
|
||||
elif self.api_key:
|
||||
# Use API key for authentication
|
||||
return AzureOpenAILLMClient(
|
||||
azure_client=AsyncAzureOpenAI(
|
||||
azure_endpoint=self.azure_openai_endpoint,
|
||||
azure_deployment=self.azure_openai_deployment_name,
|
||||
api_version=self.azure_openai_api_version,
|
||||
api_key=self.api_key,
|
||||
),
|
||||
config=LLMConfig(
|
||||
api_key=self.api_key,
|
||||
model=self.model,
|
||||
small_model=self.small_model,
|
||||
temperature=self.temperature,
|
||||
),
|
||||
)
|
||||
else:
|
||||
raise ValueError('OPENAI_API_KEY must be set when using Azure OpenAI API')
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError('OPENAI_API_KEY must be set when using OpenAI API')
|
||||
|
||||
llm_client_config = LLMConfig(
|
||||
api_key=self.api_key, model=self.model, small_model=self.small_model
|
||||
)
|
||||
|
||||
# Set temperature
|
||||
llm_client_config.temperature = self.temperature
|
||||
|
||||
return OpenAIClient(config=llm_client_config)
|
||||
22
mcp_server/neo4j_config.py
Normal file
22
mcp_server/neo4j_config.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
"""Neo4j database configuration for Graphiti MCP Server."""
|
||||
|
||||
import os
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class Neo4jConfig(BaseModel):
|
||||
"""Configuration for Neo4j database connection."""
|
||||
|
||||
uri: str = 'bolt://localhost:7687'
|
||||
user: str = 'neo4j'
|
||||
password: str = 'password'
|
||||
|
||||
@classmethod
|
||||
def from_env(cls) -> 'Neo4jConfig':
|
||||
"""Create Neo4j configuration from environment variables."""
|
||||
return cls(
|
||||
uri=os.environ.get('NEO4J_URI', 'bolt://localhost:7687'),
|
||||
user=os.environ.get('NEO4J_USER', 'neo4j'),
|
||||
password=os.environ.get('NEO4J_PASSWORD', 'password'),
|
||||
)
|
||||
|
|
@ -11,3 +11,9 @@ dependencies = [
|
|||
"azure-identity>=1.21.0",
|
||||
"graphiti-core",
|
||||
]
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"httpx>=0.28.1",
|
||||
"mcp>=1.9.4",
|
||||
]
|
||||
|
|
|
|||
86
mcp_server/queue_service.py
Normal file
86
mcp_server/queue_service.py
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
"""Queue service for managing episode processing."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class QueueService:
|
||||
"""Service for managing sequential episode processing queues by group_id."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the queue service."""
|
||||
# Dictionary to store queues for each group_id
|
||||
self._episode_queues: dict[str, asyncio.Queue] = {}
|
||||
# Dictionary to track if a worker is running for each group_id
|
||||
self._queue_workers: dict[str, bool] = {}
|
||||
|
||||
async def add_episode_task(
|
||||
self, group_id: str, process_func: Callable[[], Awaitable[None]]
|
||||
) -> int:
|
||||
"""Add an episode processing task to the queue.
|
||||
|
||||
Args:
|
||||
group_id: The group ID for the episode
|
||||
process_func: The async function to process the episode
|
||||
|
||||
Returns:
|
||||
The position in the queue
|
||||
"""
|
||||
# Initialize queue for this group_id if it doesn't exist
|
||||
if group_id not in self._episode_queues:
|
||||
self._episode_queues[group_id] = asyncio.Queue()
|
||||
|
||||
# Add the episode processing function to the queue
|
||||
await self._episode_queues[group_id].put(process_func)
|
||||
|
||||
# Start a worker for this queue if one isn't already running
|
||||
if not self._queue_workers.get(group_id, False):
|
||||
asyncio.create_task(self._process_episode_queue(group_id))
|
||||
|
||||
return self._episode_queues[group_id].qsize()
|
||||
|
||||
async def _process_episode_queue(self, group_id: str) -> None:
|
||||
"""Process episodes for a specific group_id sequentially.
|
||||
|
||||
This function runs as a long-lived task that processes episodes
|
||||
from the queue one at a time.
|
||||
"""
|
||||
logger.info(f'Starting episode queue worker for group_id: {group_id}')
|
||||
self._queue_workers[group_id] = True
|
||||
|
||||
try:
|
||||
while True:
|
||||
# Get the next episode processing function from the queue
|
||||
# This will wait if the queue is empty
|
||||
process_func = await self._episode_queues[group_id].get()
|
||||
|
||||
try:
|
||||
# Process the episode
|
||||
await process_func()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f'Error processing queued episode for group_id {group_id}: {str(e)}'
|
||||
)
|
||||
finally:
|
||||
# Mark the task as done regardless of success/failure
|
||||
self._episode_queues[group_id].task_done()
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f'Episode queue worker for group_id {group_id} was cancelled')
|
||||
except Exception as e:
|
||||
logger.error(f'Unexpected error in queue worker for group_id {group_id}: {str(e)}')
|
||||
finally:
|
||||
self._queue_workers[group_id] = False
|
||||
logger.info(f'Stopped episode queue worker for group_id: {group_id}')
|
||||
|
||||
def get_queue_size(self, group_id: str) -> int:
|
||||
"""Get the current queue size for a group_id."""
|
||||
if group_id not in self._episode_queues:
|
||||
return 0
|
||||
return self._episode_queues[group_id].qsize()
|
||||
|
||||
def is_worker_running(self, group_id: str) -> bool:
|
||||
"""Check if a worker is running for a group_id."""
|
||||
return self._queue_workers.get(group_id, False)
|
||||
41
mcp_server/response_types.py
Normal file
41
mcp_server/response_types.py
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
"""Response type definitions for Graphiti MCP Server."""
|
||||
|
||||
from typing import Any, TypedDict
|
||||
|
||||
|
||||
class ErrorResponse(TypedDict):
|
||||
error: str
|
||||
|
||||
|
||||
class SuccessResponse(TypedDict):
|
||||
message: str
|
||||
|
||||
|
||||
class NodeResult(TypedDict):
|
||||
uuid: str
|
||||
name: str
|
||||
summary: str
|
||||
labels: list[str]
|
||||
group_id: str
|
||||
created_at: str
|
||||
attributes: dict[str, Any]
|
||||
|
||||
|
||||
class NodeSearchResponse(TypedDict):
|
||||
message: str
|
||||
nodes: list[NodeResult]
|
||||
|
||||
|
||||
class FactSearchResponse(TypedDict):
|
||||
message: str
|
||||
facts: list[dict[str, Any]]
|
||||
|
||||
|
||||
class EpisodeSearchResponse(TypedDict):
|
||||
message: str
|
||||
episodes: list[dict[str, Any]]
|
||||
|
||||
|
||||
class StatusResponse(TypedDict):
|
||||
status: str
|
||||
message: str
|
||||
16
mcp_server/server_config.py
Normal file
16
mcp_server/server_config.py
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
"""Server configuration for Graphiti MCP Server."""
|
||||
|
||||
import argparse
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class MCPConfig(BaseModel):
|
||||
"""Configuration for MCP server."""
|
||||
|
||||
transport: str = 'sse' # Default to SSE transport
|
||||
|
||||
@classmethod
|
||||
def from_cli(cls, args: argparse.Namespace) -> 'MCPConfig':
|
||||
"""Create MCP configuration from CLI arguments."""
|
||||
return cls(transport=args.transport)
|
||||
363
mcp_server/test_integration.py
Normal file
363
mcp_server/test_integration.py
Normal file
|
|
@ -0,0 +1,363 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Integration test for the refactored Graphiti MCP Server.
|
||||
Tests all major MCP tools and handles episode processing latency.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
|
||||
class MCPIntegrationTest:
|
||||
"""Integration test client for Graphiti MCP Server."""
|
||||
|
||||
def __init__(self, base_url: str = 'http://localhost:8000'):
|
||||
self.base_url = base_url
|
||||
self.client = httpx.AsyncClient(timeout=30.0)
|
||||
self.test_group_id = f'test_group_{int(time.time())}'
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self.client.aclose()
|
||||
|
||||
async def call_mcp_tool(self, tool_name: str, arguments: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Call an MCP tool via the SSE endpoint."""
|
||||
# MCP protocol message structure
|
||||
message = {
|
||||
'jsonrpc': '2.0',
|
||||
'id': int(time.time() * 1000),
|
||||
'method': 'tools/call',
|
||||
'params': {'name': tool_name, 'arguments': arguments},
|
||||
}
|
||||
|
||||
try:
|
||||
response = await self.client.post(
|
||||
f'{self.base_url}/message',
|
||||
json=message,
|
||||
headers={'Content-Type': 'application/json'},
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
return {'error': f'HTTP {response.status_code}: {response.text}'}
|
||||
|
||||
result = response.json()
|
||||
return result.get('result', result)
|
||||
|
||||
except Exception as e:
|
||||
return {'error': str(e)}
|
||||
|
||||
async def test_server_status(self) -> bool:
|
||||
"""Test the get_status resource."""
|
||||
print('🔍 Testing server status...')
|
||||
|
||||
try:
|
||||
response = await self.client.get(f'{self.base_url}/resources/http://graphiti/status')
|
||||
if response.status_code == 200:
|
||||
status = response.json()
|
||||
print(f' ✅ Server status: {status.get("status", "unknown")}')
|
||||
return status.get('status') == 'ok'
|
||||
else:
|
||||
print(f' ❌ Status check failed: HTTP {response.status_code}')
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f' ❌ Status check failed: {e}')
|
||||
return False
|
||||
|
||||
async def test_add_memory(self) -> dict[str, str]:
|
||||
"""Test adding various types of memory episodes."""
|
||||
print('📝 Testing add_memory functionality...')
|
||||
|
||||
episode_results = {}
|
||||
|
||||
# Test 1: Add text episode
|
||||
print(' Testing text episode...')
|
||||
result = await self.call_mcp_tool(
|
||||
'add_memory',
|
||||
{
|
||||
'name': 'Test Company News',
|
||||
'episode_body': 'Acme Corp announced a revolutionary new AI product that will transform the industry. The CEO mentioned this is their biggest launch since 2020.',
|
||||
'source': 'text',
|
||||
'source_description': 'news article',
|
||||
'group_id': self.test_group_id,
|
||||
},
|
||||
)
|
||||
|
||||
if 'error' in result:
|
||||
print(f' ❌ Text episode failed: {result["error"]}')
|
||||
else:
|
||||
print(f' ✅ Text episode queued: {result.get("message", "Success")}')
|
||||
episode_results['text'] = 'success'
|
||||
|
||||
# Test 2: Add JSON episode
|
||||
print(' Testing JSON episode...')
|
||||
json_data = {
|
||||
'company': {'name': 'TechCorp', 'founded': 2010},
|
||||
'products': [
|
||||
{'id': 'P001', 'name': 'CloudSync', 'category': 'software'},
|
||||
{'id': 'P002', 'name': 'DataMiner', 'category': 'analytics'},
|
||||
],
|
||||
'employees': 150,
|
||||
}
|
||||
|
||||
result = await self.call_mcp_tool(
|
||||
'add_memory',
|
||||
{
|
||||
'name': 'Company Profile',
|
||||
'episode_body': json.dumps(json_data),
|
||||
'source': 'json',
|
||||
'source_description': 'CRM data',
|
||||
'group_id': self.test_group_id,
|
||||
},
|
||||
)
|
||||
|
||||
if 'error' in result:
|
||||
print(f' ❌ JSON episode failed: {result["error"]}')
|
||||
else:
|
||||
print(f' ✅ JSON episode queued: {result.get("message", "Success")}')
|
||||
episode_results['json'] = 'success'
|
||||
|
||||
# Test 3: Add message episode
|
||||
print(' Testing message episode...')
|
||||
result = await self.call_mcp_tool(
|
||||
'add_memory',
|
||||
{
|
||||
'name': 'Customer Support Chat',
|
||||
'episode_body': "user: What's your return policy?\nassistant: You can return items within 30 days of purchase with receipt.\nuser: Thanks!",
|
||||
'source': 'message',
|
||||
'source_description': 'support chat log',
|
||||
'group_id': self.test_group_id,
|
||||
},
|
||||
)
|
||||
|
||||
if 'error' in result:
|
||||
print(f' ❌ Message episode failed: {result["error"]}')
|
||||
else:
|
||||
print(f' ✅ Message episode queued: {result.get("message", "Success")}')
|
||||
episode_results['message'] = 'success'
|
||||
|
||||
return episode_results
|
||||
|
||||
async def wait_for_processing(self, max_wait: int = 30) -> None:
|
||||
"""Wait for episode processing to complete."""
|
||||
print(f'⏳ Waiting up to {max_wait} seconds for episode processing...')
|
||||
|
||||
for i in range(max_wait):
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Check if we have any episodes
|
||||
result = await self.call_mcp_tool(
|
||||
'get_episodes', {'group_id': self.test_group_id, 'last_n': 10}
|
||||
)
|
||||
|
||||
if not isinstance(result, dict) or 'error' in result:
|
||||
continue
|
||||
|
||||
if isinstance(result, list) and len(result) > 0:
|
||||
print(f' ✅ Found {len(result)} processed episodes after {i + 1} seconds')
|
||||
return
|
||||
|
||||
print(f' ⚠️ Still waiting after {max_wait} seconds...')
|
||||
|
||||
async def test_search_functions(self) -> dict[str, bool]:
|
||||
"""Test search functionality."""
|
||||
print('🔍 Testing search functions...')
|
||||
|
||||
results = {}
|
||||
|
||||
# Test search_memory_nodes
|
||||
print(' Testing search_memory_nodes...')
|
||||
result = await self.call_mcp_tool(
|
||||
'search_memory_nodes',
|
||||
{
|
||||
'query': 'Acme Corp product launch',
|
||||
'group_ids': [self.test_group_id],
|
||||
'max_nodes': 5,
|
||||
},
|
||||
)
|
||||
|
||||
if 'error' in result:
|
||||
print(f' ❌ Node search failed: {result["error"]}')
|
||||
results['nodes'] = False
|
||||
else:
|
||||
nodes = result.get('nodes', [])
|
||||
print(f' ✅ Node search returned {len(nodes)} nodes')
|
||||
results['nodes'] = True
|
||||
|
||||
# Test search_memory_facts
|
||||
print(' Testing search_memory_facts...')
|
||||
result = await self.call_mcp_tool(
|
||||
'search_memory_facts',
|
||||
{
|
||||
'query': 'company products software',
|
||||
'group_ids': [self.test_group_id],
|
||||
'max_facts': 5,
|
||||
},
|
||||
)
|
||||
|
||||
if 'error' in result:
|
||||
print(f' ❌ Fact search failed: {result["error"]}')
|
||||
results['facts'] = False
|
||||
else:
|
||||
facts = result.get('facts', [])
|
||||
print(f' ✅ Fact search returned {len(facts)} facts')
|
||||
results['facts'] = True
|
||||
|
||||
return results
|
||||
|
||||
async def test_episode_retrieval(self) -> bool:
|
||||
"""Test episode retrieval."""
|
||||
print('📚 Testing episode retrieval...')
|
||||
|
||||
result = await self.call_mcp_tool(
|
||||
'get_episodes', {'group_id': self.test_group_id, 'last_n': 10}
|
||||
)
|
||||
|
||||
if 'error' in result:
|
||||
print(f' ❌ Episode retrieval failed: {result["error"]}')
|
||||
return False
|
||||
|
||||
if isinstance(result, list):
|
||||
print(f' ✅ Retrieved {len(result)} episodes')
|
||||
|
||||
# Print episode details
|
||||
for i, episode in enumerate(result[:3]): # Show first 3
|
||||
name = episode.get('name', 'Unknown')
|
||||
source = episode.get('source', 'unknown')
|
||||
print(f' Episode {i + 1}: {name} (source: {source})')
|
||||
|
||||
return len(result) > 0
|
||||
else:
|
||||
print(f' ❌ Unexpected result format: {type(result)}')
|
||||
return False
|
||||
|
||||
async def test_edge_cases(self) -> dict[str, bool]:
|
||||
"""Test edge cases and error handling."""
|
||||
print('🧪 Testing edge cases...')
|
||||
|
||||
results = {}
|
||||
|
||||
# Test with invalid group_id
|
||||
print(' Testing invalid group_id...')
|
||||
result = await self.call_mcp_tool(
|
||||
'search_memory_nodes',
|
||||
{'query': 'nonexistent data', 'group_ids': ['nonexistent_group'], 'max_nodes': 5},
|
||||
)
|
||||
|
||||
# Should not error, just return empty results
|
||||
if 'error' not in result:
|
||||
nodes = result.get('nodes', [])
|
||||
print(f' ✅ Invalid group_id handled gracefully (returned {len(nodes)} nodes)')
|
||||
results['invalid_group'] = True
|
||||
else:
|
||||
print(f' ❌ Invalid group_id caused error: {result["error"]}')
|
||||
results['invalid_group'] = False
|
||||
|
||||
# Test empty query
|
||||
print(' Testing empty query...')
|
||||
result = await self.call_mcp_tool(
|
||||
'search_memory_nodes', {'query': '', 'group_ids': [self.test_group_id], 'max_nodes': 5}
|
||||
)
|
||||
|
||||
if 'error' not in result:
|
||||
print(' ✅ Empty query handled gracefully')
|
||||
results['empty_query'] = True
|
||||
else:
|
||||
print(f' ❌ Empty query caused error: {result["error"]}')
|
||||
results['empty_query'] = False
|
||||
|
||||
return results
|
||||
|
||||
async def run_full_test_suite(self) -> dict[str, Any]:
|
||||
"""Run the complete integration test suite."""
|
||||
print('🚀 Starting Graphiti MCP Server Integration Test')
|
||||
print(f' Test group ID: {self.test_group_id}')
|
||||
print('=' * 60)
|
||||
|
||||
results = {
|
||||
'server_status': False,
|
||||
'add_memory': {},
|
||||
'search': {},
|
||||
'episodes': False,
|
||||
'edge_cases': {},
|
||||
'overall_success': False,
|
||||
}
|
||||
|
||||
# Test 1: Server Status
|
||||
results['server_status'] = await self.test_server_status()
|
||||
if not results['server_status']:
|
||||
print('❌ Server not responding, aborting tests')
|
||||
return results
|
||||
|
||||
print()
|
||||
|
||||
# Test 2: Add Memory
|
||||
results['add_memory'] = await self.test_add_memory()
|
||||
print()
|
||||
|
||||
# Test 3: Wait for processing
|
||||
await self.wait_for_processing()
|
||||
print()
|
||||
|
||||
# Test 4: Search Functions
|
||||
results['search'] = await self.test_search_functions()
|
||||
print()
|
||||
|
||||
# Test 5: Episode Retrieval
|
||||
results['episodes'] = await self.test_episode_retrieval()
|
||||
print()
|
||||
|
||||
# Test 6: Edge Cases
|
||||
results['edge_cases'] = await self.test_edge_cases()
|
||||
print()
|
||||
|
||||
# Calculate overall success
|
||||
memory_success = len(results['add_memory']) > 0
|
||||
search_success = any(results['search'].values())
|
||||
edge_case_success = any(results['edge_cases'].values())
|
||||
|
||||
results['overall_success'] = (
|
||||
results['server_status']
|
||||
and memory_success
|
||||
and results['episodes']
|
||||
and (search_success or edge_case_success) # At least some functionality working
|
||||
)
|
||||
|
||||
# Print summary
|
||||
print('=' * 60)
|
||||
print('📊 TEST SUMMARY')
|
||||
print(f' Server Status: {"✅" if results["server_status"] else "❌"}')
|
||||
print(
|
||||
f' Memory Operations: {"✅" if memory_success else "❌"} ({len(results["add_memory"])} types)'
|
||||
)
|
||||
print(f' Search Functions: {"✅" if search_success else "❌"}')
|
||||
print(f' Episode Retrieval: {"✅" if results["episodes"] else "❌"}')
|
||||
print(f' Edge Cases: {"✅" if edge_case_success else "❌"}')
|
||||
print()
|
||||
print(f'🎯 OVERALL: {"✅ SUCCESS" if results["overall_success"] else "❌ FAILED"}')
|
||||
|
||||
if results['overall_success']:
|
||||
print(' The refactored MCP server is working correctly!')
|
||||
else:
|
||||
print(' Some issues detected. Check individual test results above.')
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def main():
|
||||
"""Run the integration test."""
|
||||
async with MCPIntegrationTest() as test:
|
||||
results = await test.run_full_test_suite()
|
||||
|
||||
# Exit with appropriate code
|
||||
exit_code = 0 if results['overall_success'] else 1
|
||||
exit(exit_code)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
asyncio.run(main())
|
||||
502
mcp_server/test_mcp_integration.py
Normal file
502
mcp_server/test_mcp_integration.py
Normal file
|
|
@ -0,0 +1,502 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Integration test for the refactored Graphiti MCP Server using the official MCP Python SDK.
|
||||
Tests all major MCP tools and handles episode processing latency.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.stdio import stdio_client
|
||||
|
||||
|
||||
class GraphitiMCPIntegrationTest:
|
||||
"""Integration test client for Graphiti MCP Server using official MCP SDK."""
|
||||
|
||||
def __init__(self):
|
||||
self.test_group_id = f'test_group_{int(time.time())}'
|
||||
self.session = None
|
||||
|
||||
async def __aenter__(self):
|
||||
"""Start the MCP client session."""
|
||||
# Configure server parameters to run our refactored server
|
||||
server_params = StdioServerParameters(
|
||||
command='uv',
|
||||
args=['run', 'graphiti_mcp_server.py', '--transport', 'stdio'],
|
||||
env={
|
||||
'NEO4J_URI': 'bolt://localhost:7687',
|
||||
'NEO4J_USER': 'neo4j',
|
||||
'NEO4J_PASSWORD': 'demodemo',
|
||||
'OPENAI_API_KEY': 'dummy_key_for_testing', # Will use existing .env
|
||||
},
|
||||
)
|
||||
|
||||
print(f'🚀 Starting MCP client session with test group: {self.test_group_id}')
|
||||
|
||||
# Use the async context manager properly
|
||||
self.client_context = stdio_client(server_params)
|
||||
read, write = await self.client_context.__aenter__()
|
||||
self.session = ClientSession(read, write)
|
||||
await self.session.initialize()
|
||||
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Close the MCP client session."""
|
||||
if self.session:
|
||||
await self.session.close()
|
||||
if hasattr(self, 'client_context'):
|
||||
await self.client_context.__aexit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any:
|
||||
"""Call an MCP tool and return the result."""
|
||||
try:
|
||||
result = await self.session.call_tool(tool_name, arguments)
|
||||
return result.content[0].text if result.content else {'error': 'No content returned'}
|
||||
except Exception as e:
|
||||
return {'error': str(e)}
|
||||
|
||||
async def test_server_initialization(self) -> bool:
|
||||
"""Test that the server initializes properly."""
|
||||
print('🔍 Testing server initialization...')
|
||||
|
||||
try:
|
||||
# List available tools to verify server is responding
|
||||
tools_result = await self.session.list_tools()
|
||||
tools = [tool.name for tool in tools_result.tools]
|
||||
|
||||
expected_tools = [
|
||||
'add_memory',
|
||||
'search_memory_nodes',
|
||||
'search_memory_facts',
|
||||
'get_episodes',
|
||||
'delete_episode',
|
||||
'delete_entity_edge',
|
||||
'get_entity_edge',
|
||||
'clear_graph',
|
||||
]
|
||||
|
||||
available_tools = len([tool for tool in expected_tools if tool in tools])
|
||||
print(
|
||||
f' ✅ Server responding with {len(tools)} tools ({available_tools}/{len(expected_tools)} expected)'
|
||||
)
|
||||
print(f' Available tools: {", ".join(sorted(tools))}')
|
||||
|
||||
return available_tools >= len(expected_tools) * 0.8 # 80% of expected tools
|
||||
|
||||
except Exception as e:
|
||||
print(f' ❌ Server initialization failed: {e}')
|
||||
return False
|
||||
|
||||
async def test_add_memory_operations(self) -> dict[str, bool]:
|
||||
"""Test adding various types of memory episodes."""
|
||||
print('📝 Testing add_memory operations...')
|
||||
|
||||
results = {}
|
||||
|
||||
# Test 1: Add text episode
|
||||
print(' Testing text episode...')
|
||||
try:
|
||||
result = await self.call_tool(
|
||||
'add_memory',
|
||||
{
|
||||
'name': 'Test Company News',
|
||||
'episode_body': 'Acme Corp announced a revolutionary new AI product that will transform the industry. The CEO mentioned this is their biggest launch since 2020.',
|
||||
'source': 'text',
|
||||
'source_description': 'news article',
|
||||
'group_id': self.test_group_id,
|
||||
},
|
||||
)
|
||||
|
||||
if isinstance(result, str) and 'queued' in result.lower():
|
||||
print(f' ✅ Text episode: {result}')
|
||||
results['text'] = True
|
||||
else:
|
||||
print(f' ❌ Text episode failed: {result}')
|
||||
results['text'] = False
|
||||
except Exception as e:
|
||||
print(f' ❌ Text episode error: {e}')
|
||||
results['text'] = False
|
||||
|
||||
# Test 2: Add JSON episode
|
||||
print(' Testing JSON episode...')
|
||||
try:
|
||||
json_data = {
|
||||
'company': {'name': 'TechCorp', 'founded': 2010},
|
||||
'products': [
|
||||
{'id': 'P001', 'name': 'CloudSync', 'category': 'software'},
|
||||
{'id': 'P002', 'name': 'DataMiner', 'category': 'analytics'},
|
||||
],
|
||||
'employees': 150,
|
||||
}
|
||||
|
||||
result = await self.call_tool(
|
||||
'add_memory',
|
||||
{
|
||||
'name': 'Company Profile',
|
||||
'episode_body': json.dumps(json_data),
|
||||
'source': 'json',
|
||||
'source_description': 'CRM data',
|
||||
'group_id': self.test_group_id,
|
||||
},
|
||||
)
|
||||
|
||||
if isinstance(result, str) and 'queued' in result.lower():
|
||||
print(f' ✅ JSON episode: {result}')
|
||||
results['json'] = True
|
||||
else:
|
||||
print(f' ❌ JSON episode failed: {result}')
|
||||
results['json'] = False
|
||||
except Exception as e:
|
||||
print(f' ❌ JSON episode error: {e}')
|
||||
results['json'] = False
|
||||
|
||||
# Test 3: Add message episode
|
||||
print(' Testing message episode...')
|
||||
try:
|
||||
result = await self.call_tool(
|
||||
'add_memory',
|
||||
{
|
||||
'name': 'Customer Support Chat',
|
||||
'episode_body': "user: What's your return policy?\nassistant: You can return items within 30 days of purchase with receipt.\nuser: Thanks!",
|
||||
'source': 'message',
|
||||
'source_description': 'support chat log',
|
||||
'group_id': self.test_group_id,
|
||||
},
|
||||
)
|
||||
|
||||
if isinstance(result, str) and 'queued' in result.lower():
|
||||
print(f' ✅ Message episode: {result}')
|
||||
results['message'] = True
|
||||
else:
|
||||
print(f' ❌ Message episode failed: {result}')
|
||||
results['message'] = False
|
||||
except Exception as e:
|
||||
print(f' ❌ Message episode error: {e}')
|
||||
results['message'] = False
|
||||
|
||||
return results
|
||||
|
||||
async def wait_for_processing(self, max_wait: int = 45) -> bool:
|
||||
"""Wait for episode processing to complete."""
|
||||
print(f'⏳ Waiting up to {max_wait} seconds for episode processing...')
|
||||
|
||||
for i in range(max_wait):
|
||||
await asyncio.sleep(1)
|
||||
|
||||
try:
|
||||
# Check if we have any episodes
|
||||
result = await self.call_tool(
|
||||
'get_episodes', {'group_id': self.test_group_id, 'last_n': 10}
|
||||
)
|
||||
|
||||
# Parse the JSON result if it's a string
|
||||
if isinstance(result, str):
|
||||
try:
|
||||
parsed_result = json.loads(result)
|
||||
if isinstance(parsed_result, list) and len(parsed_result) > 0:
|
||||
print(
|
||||
f' ✅ Found {len(parsed_result)} processed episodes after {i + 1} seconds'
|
||||
)
|
||||
return True
|
||||
except json.JSONDecodeError:
|
||||
if 'episodes' in result.lower():
|
||||
print(f' ✅ Episodes detected after {i + 1} seconds')
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
if i == 0: # Only log first error to avoid spam
|
||||
print(f' ⚠️ Waiting for processing... ({e})')
|
||||
continue
|
||||
|
||||
print(f' ⚠️ Still waiting after {max_wait} seconds...')
|
||||
return False
|
||||
|
||||
async def test_search_operations(self) -> dict[str, bool]:
|
||||
"""Test search functionality."""
|
||||
print('🔍 Testing search operations...')
|
||||
|
||||
results = {}
|
||||
|
||||
# Test search_memory_nodes
|
||||
print(' Testing search_memory_nodes...')
|
||||
try:
|
||||
result = await self.call_tool(
|
||||
'search_memory_nodes',
|
||||
{
|
||||
'query': 'Acme Corp product launch AI',
|
||||
'group_ids': [self.test_group_id],
|
||||
'max_nodes': 5,
|
||||
},
|
||||
)
|
||||
|
||||
success = False
|
||||
if isinstance(result, str):
|
||||
try:
|
||||
parsed = json.loads(result)
|
||||
nodes = parsed.get('nodes', [])
|
||||
success = isinstance(nodes, list)
|
||||
print(f' ✅ Node search returned {len(nodes)} nodes')
|
||||
except json.JSONDecodeError:
|
||||
success = 'nodes' in result.lower() and 'successfully' in result.lower()
|
||||
if success:
|
||||
print(' ✅ Node search completed successfully')
|
||||
|
||||
results['nodes'] = success
|
||||
if not success:
|
||||
print(f' ❌ Node search failed: {result}')
|
||||
|
||||
except Exception as e:
|
||||
print(f' ❌ Node search error: {e}')
|
||||
results['nodes'] = False
|
||||
|
||||
# Test search_memory_facts
|
||||
print(' Testing search_memory_facts...')
|
||||
try:
|
||||
result = await self.call_tool(
|
||||
'search_memory_facts',
|
||||
{
|
||||
'query': 'company products software TechCorp',
|
||||
'group_ids': [self.test_group_id],
|
||||
'max_facts': 5,
|
||||
},
|
||||
)
|
||||
|
||||
success = False
|
||||
if isinstance(result, str):
|
||||
try:
|
||||
parsed = json.loads(result)
|
||||
facts = parsed.get('facts', [])
|
||||
success = isinstance(facts, list)
|
||||
print(f' ✅ Fact search returned {len(facts)} facts')
|
||||
except json.JSONDecodeError:
|
||||
success = 'facts' in result.lower() and 'successfully' in result.lower()
|
||||
if success:
|
||||
print(' ✅ Fact search completed successfully')
|
||||
|
||||
results['facts'] = success
|
||||
if not success:
|
||||
print(f' ❌ Fact search failed: {result}')
|
||||
|
||||
except Exception as e:
|
||||
print(f' ❌ Fact search error: {e}')
|
||||
results['facts'] = False
|
||||
|
||||
return results
|
||||
|
||||
async def test_episode_retrieval(self) -> bool:
|
||||
"""Test episode retrieval."""
|
||||
print('📚 Testing episode retrieval...')
|
||||
|
||||
try:
|
||||
result = await self.call_tool(
|
||||
'get_episodes', {'group_id': self.test_group_id, 'last_n': 10}
|
||||
)
|
||||
|
||||
if isinstance(result, str):
|
||||
try:
|
||||
parsed = json.loads(result)
|
||||
if isinstance(parsed, list):
|
||||
print(f' ✅ Retrieved {len(parsed)} episodes')
|
||||
|
||||
# Show episode details
|
||||
for i, episode in enumerate(parsed[:3]):
|
||||
name = episode.get('name', 'Unknown')
|
||||
source = episode.get('source', 'unknown')
|
||||
print(f' Episode {i + 1}: {name} (source: {source})')
|
||||
|
||||
return len(parsed) > 0
|
||||
except json.JSONDecodeError:
|
||||
# Check if response indicates success
|
||||
if 'episode' in result.lower():
|
||||
print(' ✅ Episode retrieval completed')
|
||||
return True
|
||||
|
||||
print(f' ❌ Unexpected result format: {result}')
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f' ❌ Episode retrieval failed: {e}')
|
||||
return False
|
||||
|
||||
async def test_error_handling(self) -> dict[str, bool]:
|
||||
"""Test error handling and edge cases."""
|
||||
print('🧪 Testing error handling...')
|
||||
|
||||
results = {}
|
||||
|
||||
# Test with nonexistent group
|
||||
print(' Testing nonexistent group handling...')
|
||||
try:
|
||||
result = await self.call_tool(
|
||||
'search_memory_nodes',
|
||||
{
|
||||
'query': 'nonexistent data',
|
||||
'group_ids': ['nonexistent_group_12345'],
|
||||
'max_nodes': 5,
|
||||
},
|
||||
)
|
||||
|
||||
# Should handle gracefully, not crash
|
||||
success = (
|
||||
'error' not in str(result).lower() or 'not initialized' not in str(result).lower()
|
||||
)
|
||||
if success:
|
||||
print(' ✅ Nonexistent group handled gracefully')
|
||||
else:
|
||||
print(f' ❌ Nonexistent group caused issues: {result}')
|
||||
|
||||
results['nonexistent_group'] = success
|
||||
|
||||
except Exception as e:
|
||||
print(f' ❌ Nonexistent group test failed: {e}')
|
||||
results['nonexistent_group'] = False
|
||||
|
||||
# Test empty query
|
||||
print(' Testing empty query handling...')
|
||||
try:
|
||||
result = await self.call_tool(
|
||||
'search_memory_nodes',
|
||||
{'query': '', 'group_ids': [self.test_group_id], 'max_nodes': 5},
|
||||
)
|
||||
|
||||
# Should handle gracefully
|
||||
success = (
|
||||
'error' not in str(result).lower() or 'not initialized' not in str(result).lower()
|
||||
)
|
||||
if success:
|
||||
print(' ✅ Empty query handled gracefully')
|
||||
else:
|
||||
print(f' ❌ Empty query caused issues: {result}')
|
||||
|
||||
results['empty_query'] = success
|
||||
|
||||
except Exception as e:
|
||||
print(f' ❌ Empty query test failed: {e}')
|
||||
results['empty_query'] = False
|
||||
|
||||
return results
|
||||
|
||||
async def run_comprehensive_test(self) -> dict[str, Any]:
|
||||
"""Run the complete integration test suite."""
|
||||
print('🚀 Starting Comprehensive Graphiti MCP Server Integration Test')
|
||||
print(f' Test group ID: {self.test_group_id}')
|
||||
print('=' * 70)
|
||||
|
||||
results = {
|
||||
'server_init': False,
|
||||
'add_memory': {},
|
||||
'processing_wait': False,
|
||||
'search': {},
|
||||
'episodes': False,
|
||||
'error_handling': {},
|
||||
'overall_success': False,
|
||||
}
|
||||
|
||||
# Test 1: Server Initialization
|
||||
results['server_init'] = await self.test_server_initialization()
|
||||
if not results['server_init']:
|
||||
print('❌ Server initialization failed, aborting remaining tests')
|
||||
return results
|
||||
|
||||
print()
|
||||
|
||||
# Test 2: Add Memory Operations
|
||||
results['add_memory'] = await self.test_add_memory_operations()
|
||||
print()
|
||||
|
||||
# Test 3: Wait for Processing
|
||||
results['processing_wait'] = await self.wait_for_processing()
|
||||
print()
|
||||
|
||||
# Test 4: Search Operations
|
||||
results['search'] = await self.test_search_operations()
|
||||
print()
|
||||
|
||||
# Test 5: Episode Retrieval
|
||||
results['episodes'] = await self.test_episode_retrieval()
|
||||
print()
|
||||
|
||||
# Test 6: Error Handling
|
||||
results['error_handling'] = await self.test_error_handling()
|
||||
print()
|
||||
|
||||
# Calculate overall success
|
||||
memory_success = any(results['add_memory'].values())
|
||||
search_success = any(results['search'].values()) if results['search'] else False
|
||||
error_success = (
|
||||
any(results['error_handling'].values()) if results['error_handling'] else True
|
||||
)
|
||||
|
||||
results['overall_success'] = (
|
||||
results['server_init']
|
||||
and memory_success
|
||||
and (results['episodes'] or results['processing_wait'])
|
||||
and error_success
|
||||
)
|
||||
|
||||
# Print comprehensive summary
|
||||
print('=' * 70)
|
||||
print('📊 COMPREHENSIVE TEST SUMMARY')
|
||||
print('-' * 35)
|
||||
print(f'Server Initialization: {"✅ PASS" if results["server_init"] else "❌ FAIL"}')
|
||||
|
||||
memory_stats = f'({sum(results["add_memory"].values())}/{len(results["add_memory"])} types)'
|
||||
print(
|
||||
f'Memory Operations: {"✅ PASS" if memory_success else "❌ FAIL"} {memory_stats}'
|
||||
)
|
||||
|
||||
print(f'Processing Pipeline: {"✅ PASS" if results["processing_wait"] else "❌ FAIL"}')
|
||||
|
||||
search_stats = (
|
||||
f'({sum(results["search"].values())}/{len(results["search"])} types)'
|
||||
if results['search']
|
||||
else '(0/0 types)'
|
||||
)
|
||||
print(
|
||||
f'Search Operations: {"✅ PASS" if search_success else "❌ FAIL"} {search_stats}'
|
||||
)
|
||||
|
||||
print(f'Episode Retrieval: {"✅ PASS" if results["episodes"] else "❌ FAIL"}')
|
||||
|
||||
error_stats = (
|
||||
f'({sum(results["error_handling"].values())}/{len(results["error_handling"])} cases)'
|
||||
if results['error_handling']
|
||||
else '(0/0 cases)'
|
||||
)
|
||||
print(
|
||||
f'Error Handling: {"✅ PASS" if error_success else "❌ FAIL"} {error_stats}'
|
||||
)
|
||||
|
||||
print('-' * 35)
|
||||
print(f'🎯 OVERALL RESULT: {"✅ SUCCESS" if results["overall_success"] else "❌ FAILED"}')
|
||||
|
||||
if results['overall_success']:
|
||||
print('\n🎉 The refactored Graphiti MCP server is working correctly!')
|
||||
print(' All core functionality has been successfully tested.')
|
||||
else:
|
||||
print('\n⚠️ Some issues were detected. Review the test results above.')
|
||||
print(' The refactoring may need additional attention.')
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def main():
|
||||
"""Run the integration test."""
|
||||
try:
|
||||
async with GraphitiMCPIntegrationTest() as test:
|
||||
results = await test.run_comprehensive_test()
|
||||
|
||||
# Exit with appropriate code
|
||||
exit_code = 0 if results['overall_success'] else 1
|
||||
exit(exit_code)
|
||||
except Exception as e:
|
||||
print(f'❌ Test setup failed: {e}')
|
||||
exit(1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
asyncio.run(main())
|
||||
199
mcp_server/test_simple_validation.py
Normal file
199
mcp_server/test_simple_validation.py
Normal file
|
|
@ -0,0 +1,199 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Simple validation test for the refactored Graphiti MCP Server.
|
||||
Tests basic functionality quickly without timeouts.
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
|
||||
|
||||
def test_server_startup():
|
||||
"""Test that the refactored server starts up successfully."""
|
||||
print('🚀 Testing Graphiti MCP Server Startup...')
|
||||
|
||||
try:
|
||||
# Start the server and capture output
|
||||
process = subprocess.Popen(
|
||||
['uv', 'run', 'graphiti_mcp_server.py', '--transport', 'stdio'],
|
||||
env={
|
||||
'NEO4J_URI': 'bolt://localhost:7687',
|
||||
'NEO4J_USER': 'neo4j',
|
||||
'NEO4J_PASSWORD': 'demodemo',
|
||||
},
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
)
|
||||
|
||||
# Wait for startup logs
|
||||
startup_output = ''
|
||||
for _ in range(50): # Wait up to 5 seconds
|
||||
if process.poll() is not None:
|
||||
break
|
||||
time.sleep(0.1)
|
||||
|
||||
# Check if we have output
|
||||
try:
|
||||
line = process.stderr.readline()
|
||||
if line:
|
||||
startup_output += line
|
||||
print(f' 📋 {line.strip()}')
|
||||
|
||||
# Check for success indicators
|
||||
if 'Graphiti client initialized successfully' in line:
|
||||
print(' ✅ Graphiti service initialization: SUCCESS')
|
||||
success = True
|
||||
break
|
||||
|
||||
except Exception:
|
||||
continue
|
||||
else:
|
||||
print(' ⚠️ Timeout waiting for initialization')
|
||||
success = False
|
||||
|
||||
# Clean shutdown
|
||||
process.terminate()
|
||||
try:
|
||||
process.wait(timeout=5)
|
||||
except subprocess.TimeoutExpired:
|
||||
process.kill()
|
||||
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
print(f' ❌ Server startup failed: {e}')
|
||||
return False
|
||||
|
||||
|
||||
def test_import_validation():
|
||||
"""Test that all refactored modules import correctly."""
|
||||
print('\n🔍 Testing Module Import Validation...')
|
||||
|
||||
modules_to_test = [
|
||||
'config_manager',
|
||||
'llm_config',
|
||||
'embedder_config',
|
||||
'neo4j_config',
|
||||
'server_config',
|
||||
'graphiti_service',
|
||||
'queue_service',
|
||||
'entity_types',
|
||||
'response_types',
|
||||
'formatting',
|
||||
'utils',
|
||||
]
|
||||
|
||||
success_count = 0
|
||||
|
||||
for module in modules_to_test:
|
||||
try:
|
||||
result = subprocess.run(
|
||||
['python', '-c', f"import {module}; print('✅ {module}')"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
if result.returncode == 0:
|
||||
print(f' ✅ {module}: Import successful')
|
||||
success_count += 1
|
||||
else:
|
||||
print(f' ❌ {module}: Import failed - {result.stderr.strip()}')
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
print(f' ❌ {module}: Import timeout')
|
||||
except Exception as e:
|
||||
print(f' ❌ {module}: Import error - {e}')
|
||||
|
||||
print(f' 📊 Import Results: {success_count}/{len(modules_to_test)} modules successful')
|
||||
return success_count == len(modules_to_test)
|
||||
|
||||
|
||||
def test_syntax_validation():
|
||||
"""Test that all Python files have valid syntax."""
|
||||
print('\n🔧 Testing Syntax Validation...')
|
||||
|
||||
files_to_test = [
|
||||
'graphiti_mcp_server.py',
|
||||
'config_manager.py',
|
||||
'llm_config.py',
|
||||
'embedder_config.py',
|
||||
'neo4j_config.py',
|
||||
'server_config.py',
|
||||
'graphiti_service.py',
|
||||
'queue_service.py',
|
||||
'entity_types.py',
|
||||
'response_types.py',
|
||||
'formatting.py',
|
||||
'utils.py',
|
||||
]
|
||||
|
||||
success_count = 0
|
||||
|
||||
for file in files_to_test:
|
||||
try:
|
||||
result = subprocess.run(
|
||||
['python', '-m', 'py_compile', file], capture_output=True, text=True, timeout=10
|
||||
)
|
||||
|
||||
if result.returncode == 0:
|
||||
print(f' ✅ {file}: Syntax valid')
|
||||
success_count += 1
|
||||
else:
|
||||
print(f' ❌ {file}: Syntax error - {result.stderr.strip()}')
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
print(f' ❌ {file}: Syntax check timeout')
|
||||
except Exception as e:
|
||||
print(f' ❌ {file}: Syntax check error - {e}')
|
||||
|
||||
print(f' 📊 Syntax Results: {success_count}/{len(files_to_test)} files valid')
|
||||
return success_count == len(files_to_test)
|
||||
|
||||
|
||||
def main():
|
||||
"""Run the validation tests."""
|
||||
print('🧪 Graphiti MCP Server Refactoring Validation')
|
||||
print('=' * 55)
|
||||
|
||||
results = {}
|
||||
|
||||
# Test 1: Syntax validation
|
||||
results['syntax'] = test_syntax_validation()
|
||||
|
||||
# Test 2: Import validation
|
||||
results['imports'] = test_import_validation()
|
||||
|
||||
# Test 3: Server startup
|
||||
results['startup'] = test_server_startup()
|
||||
|
||||
# Summary
|
||||
print('\n' + '=' * 55)
|
||||
print('📊 VALIDATION SUMMARY')
|
||||
print('-' * 25)
|
||||
print(f'Syntax Validation: {"✅ PASS" if results["syntax"] else "❌ FAIL"}')
|
||||
print(f'Import Validation: {"✅ PASS" if results["imports"] else "❌ FAIL"}')
|
||||
print(f'Startup Validation: {"✅ PASS" if results["startup"] else "❌ FAIL"}')
|
||||
|
||||
overall_success = all(results.values())
|
||||
print('-' * 25)
|
||||
print(f'🎯 OVERALL: {"✅ SUCCESS" if overall_success else "❌ FAILED"}')
|
||||
|
||||
if overall_success:
|
||||
print('\n🎉 Refactoring validation successful!')
|
||||
print(' ✅ All modules have valid syntax')
|
||||
print(' ✅ All imports work correctly')
|
||||
print(' ✅ Server initializes successfully')
|
||||
print(' ✅ The refactored MCP server is ready for use!')
|
||||
else:
|
||||
print('\n⚠️ Some validation issues detected.')
|
||||
print(' Please review the failed tests above.')
|
||||
|
||||
return 0 if overall_success else 1
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
exit_code = main()
|
||||
sys.exit(exit_code)
|
||||
14
mcp_server/utils.py
Normal file
14
mcp_server/utils.py
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
"""Utility functions for Graphiti MCP Server."""
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
|
||||
|
||||
|
||||
def create_azure_credential_token_provider() -> Callable[[], str]:
|
||||
"""Create Azure credential token provider for managed identity authentication."""
|
||||
credential = DefaultAzureCredential()
|
||||
token_provider = get_bearer_token_provider(
|
||||
credential, 'https://cognitiveservices.azure.com/.default'
|
||||
)
|
||||
return token_provider
|
||||
14
mcp_server/uv.lock
generated
14
mcp_server/uv.lock
generated
|
|
@ -457,7 +457,7 @@ wheels = [
|
|||
|
||||
[[package]]
|
||||
name = "mcp-server"
|
||||
version = "0.1.0"
|
||||
version = "0.2.1"
|
||||
source = { virtual = "." }
|
||||
dependencies = [
|
||||
{ name = "azure-identity" },
|
||||
|
|
@ -466,6 +466,12 @@ dependencies = [
|
|||
{ name = "openai" },
|
||||
]
|
||||
|
||||
[package.dev-dependencies]
|
||||
dev = [
|
||||
{ name = "httpx" },
|
||||
{ name = "mcp" },
|
||||
]
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "azure-identity", specifier = ">=1.21.0" },
|
||||
|
|
@ -475,6 +481,12 @@ requires-dist = [
|
|||
{ name = "openai", specifier = ">=1.68.2" },
|
||||
]
|
||||
|
||||
[package.metadata.requires-dev]
|
||||
dev = [
|
||||
{ name = "httpx", specifier = ">=0.28.1" },
|
||||
{ name = "mcp", specifier = ">=1.9.4" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "msal"
|
||||
version = "1.32.3"
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue