wip
This commit is contained in:
parent
119a43b8e4
commit
452a45cb4e
18 changed files with 1877 additions and 641 deletions
|
|
@ -34,7 +34,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
||||||
uv sync --frozen --no-dev
|
uv sync --frozen --no-dev
|
||||||
|
|
||||||
# Copy application code
|
# Copy application code
|
||||||
COPY graphiti_mcp_server.py ./
|
COPY *.py ./
|
||||||
|
|
||||||
# Change ownership to app user
|
# Change ownership to app user
|
||||||
RUN chown -Rv app:app /app
|
RUN chown -Rv app:app /app
|
||||||
|
|
|
||||||
51
mcp_server/config_manager.py
Normal file
51
mcp_server/config_manager.py
Normal file
|
|
@ -0,0 +1,51 @@
|
||||||
|
"""Unified configuration manager for Graphiti MCP Server."""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
from embedder_config import GraphitiEmbedderConfig
|
||||||
|
from llm_config import GraphitiLLMConfig
|
||||||
|
from neo4j_config import Neo4jConfig
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class GraphitiConfig(BaseModel):
|
||||||
|
"""Configuration for Graphiti client.
|
||||||
|
|
||||||
|
Centralizes all configuration parameters for the Graphiti client.
|
||||||
|
"""
|
||||||
|
|
||||||
|
llm: GraphitiLLMConfig = Field(default_factory=GraphitiLLMConfig)
|
||||||
|
embedder: GraphitiEmbedderConfig = Field(default_factory=GraphitiEmbedderConfig)
|
||||||
|
neo4j: Neo4jConfig = Field(default_factory=Neo4jConfig)
|
||||||
|
group_id: str | None = None
|
||||||
|
use_custom_entities: bool = False
|
||||||
|
destroy_graph: bool = False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_env(cls) -> 'GraphitiConfig':
|
||||||
|
"""Create a configuration instance from environment variables."""
|
||||||
|
return cls(
|
||||||
|
llm=GraphitiLLMConfig.from_env(),
|
||||||
|
embedder=GraphitiEmbedderConfig.from_env(),
|
||||||
|
neo4j=Neo4jConfig.from_env(),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_cli_and_env(cls, args: argparse.Namespace) -> 'GraphitiConfig':
|
||||||
|
"""Create configuration from CLI arguments, falling back to environment variables."""
|
||||||
|
# Start with environment configuration
|
||||||
|
config = cls.from_env()
|
||||||
|
|
||||||
|
# Apply CLI overrides
|
||||||
|
if args.group_id:
|
||||||
|
config.group_id = args.group_id
|
||||||
|
else:
|
||||||
|
config.group_id = 'default'
|
||||||
|
|
||||||
|
config.use_custom_entities = args.use_custom_entities
|
||||||
|
config.destroy_graph = args.destroy_graph
|
||||||
|
|
||||||
|
# Update LLM config using CLI args
|
||||||
|
config.llm = GraphitiLLMConfig.from_cli_and_env(args)
|
||||||
|
|
||||||
|
return config
|
||||||
124
mcp_server/embedder_config.py
Normal file
124
mcp_server/embedder_config.py
Normal file
|
|
@ -0,0 +1,124 @@
|
||||||
|
"""Embedder configuration for Graphiti MCP Server."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
from openai import AsyncAzureOpenAI
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from utils import create_azure_credential_token_provider
|
||||||
|
|
||||||
|
from graphiti_core.embedder.azure_openai import AzureOpenAIEmbedderClient
|
||||||
|
from graphiti_core.embedder.client import EmbedderClient
|
||||||
|
from graphiti_core.embedder.openai import OpenAIEmbedder, OpenAIEmbedderConfig
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
DEFAULT_EMBEDDER_MODEL = 'text-embedding-3-small'
|
||||||
|
|
||||||
|
|
||||||
|
class GraphitiEmbedderConfig(BaseModel):
|
||||||
|
"""Configuration for the embedder client.
|
||||||
|
|
||||||
|
Centralizes all embedding-related configuration parameters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model: str = DEFAULT_EMBEDDER_MODEL
|
||||||
|
api_key: str | None = None
|
||||||
|
azure_openai_endpoint: str | None = None
|
||||||
|
azure_openai_deployment_name: str | None = None
|
||||||
|
azure_openai_api_version: str | None = None
|
||||||
|
azure_openai_use_managed_identity: bool = False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_env(cls) -> 'GraphitiEmbedderConfig':
|
||||||
|
"""Create embedder configuration from environment variables."""
|
||||||
|
# Get model from environment, or use default if not set or empty
|
||||||
|
model_env = os.environ.get('EMBEDDER_MODEL_NAME', '')
|
||||||
|
model = model_env if model_env.strip() else DEFAULT_EMBEDDER_MODEL
|
||||||
|
|
||||||
|
azure_openai_endpoint = os.environ.get('AZURE_OPENAI_EMBEDDING_ENDPOINT', None)
|
||||||
|
azure_openai_api_version = os.environ.get('AZURE_OPENAI_EMBEDDING_API_VERSION', None)
|
||||||
|
azure_openai_deployment_name = os.environ.get(
|
||||||
|
'AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME', None
|
||||||
|
)
|
||||||
|
azure_openai_use_managed_identity = (
|
||||||
|
os.environ.get('AZURE_OPENAI_USE_MANAGED_IDENTITY', 'false').lower() == 'true'
|
||||||
|
)
|
||||||
|
|
||||||
|
if azure_openai_endpoint is not None:
|
||||||
|
# Setup for Azure OpenAI API
|
||||||
|
# Log if empty deployment name was provided
|
||||||
|
azure_openai_deployment_name = os.environ.get(
|
||||||
|
'AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME', None
|
||||||
|
)
|
||||||
|
if azure_openai_deployment_name is None:
|
||||||
|
logger.error('AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME environment variable not set')
|
||||||
|
raise ValueError(
|
||||||
|
'AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME environment variable not set'
|
||||||
|
)
|
||||||
|
|
||||||
|
if not azure_openai_use_managed_identity:
|
||||||
|
# api key
|
||||||
|
api_key = os.environ.get('AZURE_OPENAI_EMBEDDING_API_KEY', None) or os.environ.get(
|
||||||
|
'OPENAI_API_KEY', None
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Managed identity
|
||||||
|
api_key = None
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
azure_openai_use_managed_identity=azure_openai_use_managed_identity,
|
||||||
|
azure_openai_endpoint=azure_openai_endpoint,
|
||||||
|
api_key=api_key,
|
||||||
|
azure_openai_api_version=azure_openai_api_version,
|
||||||
|
azure_openai_deployment_name=azure_openai_deployment_name,
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return cls(
|
||||||
|
model=model,
|
||||||
|
api_key=os.environ.get('OPENAI_API_KEY'),
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_client(self) -> EmbedderClient | None:
|
||||||
|
"""Create an embedder client based on this configuration.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
EmbedderClient instance or None if configuration is invalid
|
||||||
|
"""
|
||||||
|
if self.azure_openai_endpoint is not None:
|
||||||
|
# Azure OpenAI API setup
|
||||||
|
if self.azure_openai_use_managed_identity:
|
||||||
|
# Use managed identity for authentication
|
||||||
|
token_provider = create_azure_credential_token_provider()
|
||||||
|
return AzureOpenAIEmbedderClient(
|
||||||
|
azure_client=AsyncAzureOpenAI(
|
||||||
|
azure_endpoint=self.azure_openai_endpoint,
|
||||||
|
azure_deployment=self.azure_openai_deployment_name,
|
||||||
|
api_version=self.azure_openai_api_version,
|
||||||
|
azure_ad_token_provider=token_provider,
|
||||||
|
),
|
||||||
|
model=self.model,
|
||||||
|
)
|
||||||
|
elif self.api_key:
|
||||||
|
# Use API key for authentication
|
||||||
|
return AzureOpenAIEmbedderClient(
|
||||||
|
azure_client=AsyncAzureOpenAI(
|
||||||
|
azure_endpoint=self.azure_openai_endpoint,
|
||||||
|
azure_deployment=self.azure_openai_deployment_name,
|
||||||
|
api_version=self.azure_openai_api_version,
|
||||||
|
api_key=self.api_key,
|
||||||
|
),
|
||||||
|
model=self.model,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.error('OPENAI_API_KEY must be set when using Azure OpenAI API')
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
# OpenAI API setup
|
||||||
|
if not self.api_key:
|
||||||
|
return None
|
||||||
|
|
||||||
|
embedder_config = OpenAIEmbedderConfig(api_key=self.api_key, embedding_model=self.model)
|
||||||
|
|
||||||
|
return OpenAIEmbedder(config=embedder_config)
|
||||||
83
mcp_server/entity_types.py
Normal file
83
mcp_server/entity_types.py
Normal file
|
|
@ -0,0 +1,83 @@
|
||||||
|
"""Entity type definitions for Graphiti MCP Server."""
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class Requirement(BaseModel):
|
||||||
|
"""A Requirement represents a specific need, feature, or functionality that a product or service must fulfill.
|
||||||
|
|
||||||
|
Always ensure an edge is created between the requirement and the project it belongs to, and clearly indicate on the
|
||||||
|
edge that the requirement is a requirement.
|
||||||
|
|
||||||
|
Instructions for identifying and extracting requirements:
|
||||||
|
1. Look for explicit statements of needs or necessities ("We need X", "X is required", "X must have Y")
|
||||||
|
2. Identify functional specifications that describe what the system should do
|
||||||
|
3. Pay attention to non-functional requirements like performance, security, or usability criteria
|
||||||
|
4. Extract constraints or limitations that must be adhered to
|
||||||
|
5. Focus on clear, specific, and measurable requirements rather than vague wishes
|
||||||
|
6. Capture the priority or importance if mentioned ("critical", "high priority", etc.)
|
||||||
|
7. Include any dependencies between requirements when explicitly stated
|
||||||
|
8. Preserve the original intent and scope of the requirement
|
||||||
|
9. Categorize requirements appropriately based on their domain or function
|
||||||
|
"""
|
||||||
|
|
||||||
|
project_name: str = Field(
|
||||||
|
...,
|
||||||
|
description='The name of the project to which the requirement belongs.',
|
||||||
|
)
|
||||||
|
description: str = Field(
|
||||||
|
...,
|
||||||
|
description='Description of the requirement. Only use information mentioned in the context to write this description.',
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Preference(BaseModel):
|
||||||
|
"""A Preference represents a user's expressed like, dislike, or preference for something.
|
||||||
|
|
||||||
|
Instructions for identifying and extracting preferences:
|
||||||
|
1. Look for explicit statements of preference such as "I like/love/enjoy/prefer X" or "I don't like/hate/dislike X"
|
||||||
|
2. Pay attention to comparative statements ("I prefer X over Y")
|
||||||
|
3. Consider the emotional tone when users mention certain topics
|
||||||
|
4. Extract only preferences that are clearly expressed, not assumptions
|
||||||
|
5. Categorize the preference appropriately based on its domain (food, music, brands, etc.)
|
||||||
|
6. Include relevant qualifiers (e.g., "likes spicy food" rather than just "likes food")
|
||||||
|
7. Only extract preferences directly stated by the user, not preferences of others they mention
|
||||||
|
8. Provide a concise but specific description that captures the nature of the preference
|
||||||
|
"""
|
||||||
|
|
||||||
|
category: str = Field(
|
||||||
|
...,
|
||||||
|
description="The category of the preference. (e.g., 'Brands', 'Food', 'Music')",
|
||||||
|
)
|
||||||
|
description: str = Field(
|
||||||
|
...,
|
||||||
|
description='Brief description of the preference. Only use information mentioned in the context to write this description.',
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Procedure(BaseModel):
|
||||||
|
"""A Procedure informing the agent what actions to take or how to perform in certain scenarios. Procedures are typically composed of several steps.
|
||||||
|
|
||||||
|
Instructions for identifying and extracting procedures:
|
||||||
|
1. Look for sequential instructions or steps ("First do X, then do Y")
|
||||||
|
2. Identify explicit directives or commands ("Always do X when Y happens")
|
||||||
|
3. Pay attention to conditional statements ("If X occurs, then do Y")
|
||||||
|
4. Extract procedures that have clear beginning and end points
|
||||||
|
5. Focus on actionable instructions rather than general information
|
||||||
|
6. Preserve the original sequence and dependencies between steps
|
||||||
|
7. Include any specified conditions or triggers for the procedure
|
||||||
|
8. Capture any stated purpose or goal of the procedure
|
||||||
|
9. Summarize complex procedures while maintaining critical details
|
||||||
|
"""
|
||||||
|
|
||||||
|
description: str = Field(
|
||||||
|
...,
|
||||||
|
description='Brief description of the procedure. Only use information mentioned in the context to write this description.',
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
ENTITY_TYPES: dict[str, BaseModel] = {
|
||||||
|
'Requirement': Requirement, # type: ignore
|
||||||
|
'Preference': Preference, # type: ignore
|
||||||
|
'Procedure': Procedure, # type: ignore
|
||||||
|
}
|
||||||
26
mcp_server/formatting.py
Normal file
26
mcp_server/formatting.py
Normal file
|
|
@ -0,0 +1,26 @@
|
||||||
|
"""Formatting utilities for Graphiti MCP Server."""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from graphiti_core.edges import EntityEdge
|
||||||
|
|
||||||
|
|
||||||
|
def format_fact_result(edge: EntityEdge) -> dict[str, Any]:
|
||||||
|
"""Format an entity edge into a readable result.
|
||||||
|
|
||||||
|
Since EntityEdge is a Pydantic BaseModel, we can use its built-in serialization capabilities.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
edge: The EntityEdge to format
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary representation of the edge with serialized dates and excluded embeddings
|
||||||
|
"""
|
||||||
|
result = edge.model_dump(
|
||||||
|
mode='json',
|
||||||
|
exclude={
|
||||||
|
'fact_embedding',
|
||||||
|
},
|
||||||
|
)
|
||||||
|
result.get('attributes', {}).pop('fact_embedding', None)
|
||||||
|
return result
|
||||||
|
|
@ -8,25 +8,30 @@ import asyncio
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from collections.abc import Callable
|
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Any, TypedDict, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
|
from config_manager import GraphitiConfig
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
from entity_types import ENTITY_TYPES
|
||||||
|
from formatting import format_fact_result
|
||||||
|
from graphiti_service import GraphitiService
|
||||||
|
from llm_config import DEFAULT_LLM_MODEL, SMALL_LLM_MODEL
|
||||||
from mcp.server.fastmcp import FastMCP
|
from mcp.server.fastmcp import FastMCP
|
||||||
from openai import AsyncAzureOpenAI
|
from queue_service import QueueService
|
||||||
from pydantic import BaseModel, Field
|
from response_types import (
|
||||||
|
EpisodeSearchResponse,
|
||||||
|
ErrorResponse,
|
||||||
|
FactSearchResponse,
|
||||||
|
NodeResult,
|
||||||
|
NodeSearchResponse,
|
||||||
|
StatusResponse,
|
||||||
|
SuccessResponse,
|
||||||
|
)
|
||||||
|
from server_config import MCPConfig
|
||||||
|
|
||||||
from graphiti_core import Graphiti
|
from graphiti_core import Graphiti
|
||||||
from graphiti_core.edges import EntityEdge
|
from graphiti_core.edges import EntityEdge
|
||||||
from graphiti_core.embedder.azure_openai import AzureOpenAIEmbedderClient
|
|
||||||
from graphiti_core.embedder.client import EmbedderClient
|
|
||||||
from graphiti_core.embedder.openai import OpenAIEmbedder, OpenAIEmbedderConfig
|
|
||||||
from graphiti_core.llm_client import LLMClient
|
|
||||||
from graphiti_core.llm_client.azure_openai_client import AzureOpenAILLMClient
|
|
||||||
from graphiti_core.llm_client.config import LLMConfig
|
|
||||||
from graphiti_core.llm_client.openai_client import OpenAIClient
|
|
||||||
from graphiti_core.nodes import EpisodeType, EpisodicNode
|
from graphiti_core.nodes import EpisodeType, EpisodicNode
|
||||||
from graphiti_core.search.search_config_recipes import (
|
from graphiti_core.search.search_config_recipes import (
|
||||||
NODE_HYBRID_SEARCH_NODE_DISTANCE,
|
NODE_HYBRID_SEARCH_NODE_DISTANCE,
|
||||||
|
|
@ -38,488 +43,12 @@ from graphiti_core.utils.maintenance.graph_data_operations import clear_data
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_LLM_MODEL = 'gpt-4.1-mini'
|
|
||||||
SMALL_LLM_MODEL = 'gpt-4.1-nano'
|
|
||||||
DEFAULT_EMBEDDER_MODEL = 'text-embedding-3-small'
|
|
||||||
|
|
||||||
# Semaphore limit for concurrent Graphiti operations.
|
# Semaphore limit for concurrent Graphiti operations.
|
||||||
# Decrease this if you're experiencing 429 rate limit errors from your LLM provider.
|
# Decrease this if you're experiencing 429 rate limit errors from your LLM provider.
|
||||||
# Increase if you have high rate limits.
|
# Increase if you have high rate limits.
|
||||||
SEMAPHORE_LIMIT = int(os.getenv('SEMAPHORE_LIMIT', 10))
|
SEMAPHORE_LIMIT = int(os.getenv('SEMAPHORE_LIMIT', 10))
|
||||||
|
|
||||||
|
|
||||||
class Requirement(BaseModel):
|
|
||||||
"""A Requirement represents a specific need, feature, or functionality that a product or service must fulfill.
|
|
||||||
|
|
||||||
Always ensure an edge is created between the requirement and the project it belongs to, and clearly indicate on the
|
|
||||||
edge that the requirement is a requirement.
|
|
||||||
|
|
||||||
Instructions for identifying and extracting requirements:
|
|
||||||
1. Look for explicit statements of needs or necessities ("We need X", "X is required", "X must have Y")
|
|
||||||
2. Identify functional specifications that describe what the system should do
|
|
||||||
3. Pay attention to non-functional requirements like performance, security, or usability criteria
|
|
||||||
4. Extract constraints or limitations that must be adhered to
|
|
||||||
5. Focus on clear, specific, and measurable requirements rather than vague wishes
|
|
||||||
6. Capture the priority or importance if mentioned ("critical", "high priority", etc.)
|
|
||||||
7. Include any dependencies between requirements when explicitly stated
|
|
||||||
8. Preserve the original intent and scope of the requirement
|
|
||||||
9. Categorize requirements appropriately based on their domain or function
|
|
||||||
"""
|
|
||||||
|
|
||||||
project_name: str = Field(
|
|
||||||
...,
|
|
||||||
description='The name of the project to which the requirement belongs.',
|
|
||||||
)
|
|
||||||
description: str = Field(
|
|
||||||
...,
|
|
||||||
description='Description of the requirement. Only use information mentioned in the context to write this description.',
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Preference(BaseModel):
|
|
||||||
"""A Preference represents a user's expressed like, dislike, or preference for something.
|
|
||||||
|
|
||||||
Instructions for identifying and extracting preferences:
|
|
||||||
1. Look for explicit statements of preference such as "I like/love/enjoy/prefer X" or "I don't like/hate/dislike X"
|
|
||||||
2. Pay attention to comparative statements ("I prefer X over Y")
|
|
||||||
3. Consider the emotional tone when users mention certain topics
|
|
||||||
4. Extract only preferences that are clearly expressed, not assumptions
|
|
||||||
5. Categorize the preference appropriately based on its domain (food, music, brands, etc.)
|
|
||||||
6. Include relevant qualifiers (e.g., "likes spicy food" rather than just "likes food")
|
|
||||||
7. Only extract preferences directly stated by the user, not preferences of others they mention
|
|
||||||
8. Provide a concise but specific description that captures the nature of the preference
|
|
||||||
"""
|
|
||||||
|
|
||||||
category: str = Field(
|
|
||||||
...,
|
|
||||||
description="The category of the preference. (e.g., 'Brands', 'Food', 'Music')",
|
|
||||||
)
|
|
||||||
description: str = Field(
|
|
||||||
...,
|
|
||||||
description='Brief description of the preference. Only use information mentioned in the context to write this description.',
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Procedure(BaseModel):
|
|
||||||
"""A Procedure informing the agent what actions to take or how to perform in certain scenarios. Procedures are typically composed of several steps.
|
|
||||||
|
|
||||||
Instructions for identifying and extracting procedures:
|
|
||||||
1. Look for sequential instructions or steps ("First do X, then do Y")
|
|
||||||
2. Identify explicit directives or commands ("Always do X when Y happens")
|
|
||||||
3. Pay attention to conditional statements ("If X occurs, then do Y")
|
|
||||||
4. Extract procedures that have clear beginning and end points
|
|
||||||
5. Focus on actionable instructions rather than general information
|
|
||||||
6. Preserve the original sequence and dependencies between steps
|
|
||||||
7. Include any specified conditions or triggers for the procedure
|
|
||||||
8. Capture any stated purpose or goal of the procedure
|
|
||||||
9. Summarize complex procedures while maintaining critical details
|
|
||||||
"""
|
|
||||||
|
|
||||||
description: str = Field(
|
|
||||||
...,
|
|
||||||
description='Brief description of the procedure. Only use information mentioned in the context to write this description.',
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
ENTITY_TYPES: dict[str, BaseModel] = {
|
|
||||||
'Requirement': Requirement, # type: ignore
|
|
||||||
'Preference': Preference, # type: ignore
|
|
||||||
'Procedure': Procedure, # type: ignore
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# Type definitions for API responses
|
|
||||||
class ErrorResponse(TypedDict):
|
|
||||||
error: str
|
|
||||||
|
|
||||||
|
|
||||||
class SuccessResponse(TypedDict):
|
|
||||||
message: str
|
|
||||||
|
|
||||||
|
|
||||||
class NodeResult(TypedDict):
|
|
||||||
uuid: str
|
|
||||||
name: str
|
|
||||||
summary: str
|
|
||||||
labels: list[str]
|
|
||||||
group_id: str
|
|
||||||
created_at: str
|
|
||||||
attributes: dict[str, Any]
|
|
||||||
|
|
||||||
|
|
||||||
class NodeSearchResponse(TypedDict):
|
|
||||||
message: str
|
|
||||||
nodes: list[NodeResult]
|
|
||||||
|
|
||||||
|
|
||||||
class FactSearchResponse(TypedDict):
|
|
||||||
message: str
|
|
||||||
facts: list[dict[str, Any]]
|
|
||||||
|
|
||||||
|
|
||||||
class EpisodeSearchResponse(TypedDict):
|
|
||||||
message: str
|
|
||||||
episodes: list[dict[str, Any]]
|
|
||||||
|
|
||||||
|
|
||||||
class StatusResponse(TypedDict):
|
|
||||||
status: str
|
|
||||||
message: str
|
|
||||||
|
|
||||||
|
|
||||||
def create_azure_credential_token_provider() -> Callable[[], str]:
|
|
||||||
credential = DefaultAzureCredential()
|
|
||||||
token_provider = get_bearer_token_provider(
|
|
||||||
credential, 'https://cognitiveservices.azure.com/.default'
|
|
||||||
)
|
|
||||||
return token_provider
|
|
||||||
|
|
||||||
|
|
||||||
# Server configuration classes
|
|
||||||
# The configuration system has a hierarchy:
|
|
||||||
# - GraphitiConfig is the top-level configuration
|
|
||||||
# - LLMConfig handles all OpenAI/LLM related settings
|
|
||||||
# - EmbedderConfig manages embedding settings
|
|
||||||
# - Neo4jConfig manages database connection details
|
|
||||||
# - Various other settings like group_id and feature flags
|
|
||||||
# Configuration values are loaded from:
|
|
||||||
# 1. Default values in the class definitions
|
|
||||||
# 2. Environment variables (loaded via load_dotenv())
|
|
||||||
# 3. Command line arguments (which override environment variables)
|
|
||||||
class GraphitiLLMConfig(BaseModel):
|
|
||||||
"""Configuration for the LLM client.
|
|
||||||
|
|
||||||
Centralizes all LLM-specific configuration parameters including API keys and model selection.
|
|
||||||
"""
|
|
||||||
|
|
||||||
api_key: str | None = None
|
|
||||||
model: str = DEFAULT_LLM_MODEL
|
|
||||||
small_model: str = SMALL_LLM_MODEL
|
|
||||||
temperature: float = 0.0
|
|
||||||
azure_openai_endpoint: str | None = None
|
|
||||||
azure_openai_deployment_name: str | None = None
|
|
||||||
azure_openai_api_version: str | None = None
|
|
||||||
azure_openai_use_managed_identity: bool = False
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_env(cls) -> 'GraphitiLLMConfig':
|
|
||||||
"""Create LLM configuration from environment variables."""
|
|
||||||
# Get model from environment, or use default if not set or empty
|
|
||||||
model_env = os.environ.get('MODEL_NAME', '')
|
|
||||||
model = model_env if model_env.strip() else DEFAULT_LLM_MODEL
|
|
||||||
|
|
||||||
# Get small_model from environment, or use default if not set or empty
|
|
||||||
small_model_env = os.environ.get('SMALL_MODEL_NAME', '')
|
|
||||||
small_model = small_model_env if small_model_env.strip() else SMALL_LLM_MODEL
|
|
||||||
|
|
||||||
azure_openai_endpoint = os.environ.get('AZURE_OPENAI_ENDPOINT', None)
|
|
||||||
azure_openai_api_version = os.environ.get('AZURE_OPENAI_API_VERSION', None)
|
|
||||||
azure_openai_deployment_name = os.environ.get('AZURE_OPENAI_DEPLOYMENT_NAME', None)
|
|
||||||
azure_openai_use_managed_identity = (
|
|
||||||
os.environ.get('AZURE_OPENAI_USE_MANAGED_IDENTITY', 'false').lower() == 'true'
|
|
||||||
)
|
|
||||||
|
|
||||||
if azure_openai_endpoint is None:
|
|
||||||
# Setup for OpenAI API
|
|
||||||
# Log if empty model was provided
|
|
||||||
if model_env == '':
|
|
||||||
logger.debug(
|
|
||||||
f'MODEL_NAME environment variable not set, using default: {DEFAULT_LLM_MODEL}'
|
|
||||||
)
|
|
||||||
elif not model_env.strip():
|
|
||||||
logger.warning(
|
|
||||||
f'Empty MODEL_NAME environment variable, using default: {DEFAULT_LLM_MODEL}'
|
|
||||||
)
|
|
||||||
|
|
||||||
return cls(
|
|
||||||
api_key=os.environ.get('OPENAI_API_KEY'),
|
|
||||||
model=model,
|
|
||||||
small_model=small_model,
|
|
||||||
temperature=float(os.environ.get('LLM_TEMPERATURE', '0.0')),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Setup for Azure OpenAI API
|
|
||||||
# Log if empty deployment name was provided
|
|
||||||
if azure_openai_deployment_name is None:
|
|
||||||
logger.error('AZURE_OPENAI_DEPLOYMENT_NAME environment variable not set')
|
|
||||||
|
|
||||||
raise ValueError('AZURE_OPENAI_DEPLOYMENT_NAME environment variable not set')
|
|
||||||
if not azure_openai_use_managed_identity:
|
|
||||||
# api key
|
|
||||||
api_key = os.environ.get('OPENAI_API_KEY', None)
|
|
||||||
else:
|
|
||||||
# Managed identity
|
|
||||||
api_key = None
|
|
||||||
|
|
||||||
return cls(
|
|
||||||
azure_openai_use_managed_identity=azure_openai_use_managed_identity,
|
|
||||||
azure_openai_endpoint=azure_openai_endpoint,
|
|
||||||
api_key=api_key,
|
|
||||||
azure_openai_api_version=azure_openai_api_version,
|
|
||||||
azure_openai_deployment_name=azure_openai_deployment_name,
|
|
||||||
model=model,
|
|
||||||
small_model=small_model,
|
|
||||||
temperature=float(os.environ.get('LLM_TEMPERATURE', '0.0')),
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_cli_and_env(cls, args: argparse.Namespace) -> 'GraphitiLLMConfig':
|
|
||||||
"""Create LLM configuration from CLI arguments, falling back to environment variables."""
|
|
||||||
# Start with environment-based config
|
|
||||||
config = cls.from_env()
|
|
||||||
|
|
||||||
# CLI arguments override environment variables when provided
|
|
||||||
if hasattr(args, 'model') and args.model:
|
|
||||||
# Only use CLI model if it's not empty
|
|
||||||
if args.model.strip():
|
|
||||||
config.model = args.model
|
|
||||||
else:
|
|
||||||
# Log that empty model was provided and default is used
|
|
||||||
logger.warning(f'Empty model name provided, using default: {DEFAULT_LLM_MODEL}')
|
|
||||||
|
|
||||||
if hasattr(args, 'small_model') and args.small_model:
|
|
||||||
if args.small_model.strip():
|
|
||||||
config.small_model = args.small_model
|
|
||||||
else:
|
|
||||||
logger.warning(f'Empty small_model name provided, using default: {SMALL_LLM_MODEL}')
|
|
||||||
|
|
||||||
if hasattr(args, 'temperature') and args.temperature is not None:
|
|
||||||
config.temperature = args.temperature
|
|
||||||
|
|
||||||
return config
|
|
||||||
|
|
||||||
def create_client(self) -> LLMClient:
|
|
||||||
"""Create an LLM client based on this configuration.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
LLMClient instance
|
|
||||||
"""
|
|
||||||
|
|
||||||
if self.azure_openai_endpoint is not None:
|
|
||||||
# Azure OpenAI API setup
|
|
||||||
if self.azure_openai_use_managed_identity:
|
|
||||||
# Use managed identity for authentication
|
|
||||||
token_provider = create_azure_credential_token_provider()
|
|
||||||
return AzureOpenAILLMClient(
|
|
||||||
azure_client=AsyncAzureOpenAI(
|
|
||||||
azure_endpoint=self.azure_openai_endpoint,
|
|
||||||
azure_deployment=self.azure_openai_deployment_name,
|
|
||||||
api_version=self.azure_openai_api_version,
|
|
||||||
azure_ad_token_provider=token_provider,
|
|
||||||
),
|
|
||||||
config=LLMConfig(
|
|
||||||
api_key=self.api_key,
|
|
||||||
model=self.model,
|
|
||||||
small_model=self.small_model,
|
|
||||||
temperature=self.temperature,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
elif self.api_key:
|
|
||||||
# Use API key for authentication
|
|
||||||
return AzureOpenAILLMClient(
|
|
||||||
azure_client=AsyncAzureOpenAI(
|
|
||||||
azure_endpoint=self.azure_openai_endpoint,
|
|
||||||
azure_deployment=self.azure_openai_deployment_name,
|
|
||||||
api_version=self.azure_openai_api_version,
|
|
||||||
api_key=self.api_key,
|
|
||||||
),
|
|
||||||
config=LLMConfig(
|
|
||||||
api_key=self.api_key,
|
|
||||||
model=self.model,
|
|
||||||
small_model=self.small_model,
|
|
||||||
temperature=self.temperature,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError('OPENAI_API_KEY must be set when using Azure OpenAI API')
|
|
||||||
|
|
||||||
if not self.api_key:
|
|
||||||
raise ValueError('OPENAI_API_KEY must be set when using OpenAI API')
|
|
||||||
|
|
||||||
llm_client_config = LLMConfig(
|
|
||||||
api_key=self.api_key, model=self.model, small_model=self.small_model
|
|
||||||
)
|
|
||||||
|
|
||||||
# Set temperature
|
|
||||||
llm_client_config.temperature = self.temperature
|
|
||||||
|
|
||||||
return OpenAIClient(config=llm_client_config)
|
|
||||||
|
|
||||||
|
|
||||||
class GraphitiEmbedderConfig(BaseModel):
|
|
||||||
"""Configuration for the embedder client.
|
|
||||||
|
|
||||||
Centralizes all embedding-related configuration parameters.
|
|
||||||
"""
|
|
||||||
|
|
||||||
model: str = DEFAULT_EMBEDDER_MODEL
|
|
||||||
api_key: str | None = None
|
|
||||||
azure_openai_endpoint: str | None = None
|
|
||||||
azure_openai_deployment_name: str | None = None
|
|
||||||
azure_openai_api_version: str | None = None
|
|
||||||
azure_openai_use_managed_identity: bool = False
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_env(cls) -> 'GraphitiEmbedderConfig':
|
|
||||||
"""Create embedder configuration from environment variables."""
|
|
||||||
|
|
||||||
# Get model from environment, or use default if not set or empty
|
|
||||||
model_env = os.environ.get('EMBEDDER_MODEL_NAME', '')
|
|
||||||
model = model_env if model_env.strip() else DEFAULT_EMBEDDER_MODEL
|
|
||||||
|
|
||||||
azure_openai_endpoint = os.environ.get('AZURE_OPENAI_EMBEDDING_ENDPOINT', None)
|
|
||||||
azure_openai_api_version = os.environ.get('AZURE_OPENAI_EMBEDDING_API_VERSION', None)
|
|
||||||
azure_openai_deployment_name = os.environ.get(
|
|
||||||
'AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME', None
|
|
||||||
)
|
|
||||||
azure_openai_use_managed_identity = (
|
|
||||||
os.environ.get('AZURE_OPENAI_USE_MANAGED_IDENTITY', 'false').lower() == 'true'
|
|
||||||
)
|
|
||||||
if azure_openai_endpoint is not None:
|
|
||||||
# Setup for Azure OpenAI API
|
|
||||||
# Log if empty deployment name was provided
|
|
||||||
azure_openai_deployment_name = os.environ.get(
|
|
||||||
'AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME', None
|
|
||||||
)
|
|
||||||
if azure_openai_deployment_name is None:
|
|
||||||
logger.error('AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME environment variable not set')
|
|
||||||
|
|
||||||
raise ValueError(
|
|
||||||
'AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME environment variable not set'
|
|
||||||
)
|
|
||||||
|
|
||||||
if not azure_openai_use_managed_identity:
|
|
||||||
# api key
|
|
||||||
api_key = os.environ.get('AZURE_OPENAI_EMBEDDING_API_KEY', None) or os.environ.get(
|
|
||||||
'OPENAI_API_KEY', None
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Managed identity
|
|
||||||
api_key = None
|
|
||||||
|
|
||||||
return cls(
|
|
||||||
azure_openai_use_managed_identity=azure_openai_use_managed_identity,
|
|
||||||
azure_openai_endpoint=azure_openai_endpoint,
|
|
||||||
api_key=api_key,
|
|
||||||
azure_openai_api_version=azure_openai_api_version,
|
|
||||||
azure_openai_deployment_name=azure_openai_deployment_name,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return cls(
|
|
||||||
model=model,
|
|
||||||
api_key=os.environ.get('OPENAI_API_KEY'),
|
|
||||||
)
|
|
||||||
|
|
||||||
def create_client(self) -> EmbedderClient | None:
|
|
||||||
if self.azure_openai_endpoint is not None:
|
|
||||||
# Azure OpenAI API setup
|
|
||||||
if self.azure_openai_use_managed_identity:
|
|
||||||
# Use managed identity for authentication
|
|
||||||
token_provider = create_azure_credential_token_provider()
|
|
||||||
return AzureOpenAIEmbedderClient(
|
|
||||||
azure_client=AsyncAzureOpenAI(
|
|
||||||
azure_endpoint=self.azure_openai_endpoint,
|
|
||||||
azure_deployment=self.azure_openai_deployment_name,
|
|
||||||
api_version=self.azure_openai_api_version,
|
|
||||||
azure_ad_token_provider=token_provider,
|
|
||||||
),
|
|
||||||
model=self.model,
|
|
||||||
)
|
|
||||||
elif self.api_key:
|
|
||||||
# Use API key for authentication
|
|
||||||
return AzureOpenAIEmbedderClient(
|
|
||||||
azure_client=AsyncAzureOpenAI(
|
|
||||||
azure_endpoint=self.azure_openai_endpoint,
|
|
||||||
azure_deployment=self.azure_openai_deployment_name,
|
|
||||||
api_version=self.azure_openai_api_version,
|
|
||||||
api_key=self.api_key,
|
|
||||||
),
|
|
||||||
model=self.model,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.error('OPENAI_API_KEY must be set when using Azure OpenAI API')
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
# OpenAI API setup
|
|
||||||
if not self.api_key:
|
|
||||||
return None
|
|
||||||
|
|
||||||
embedder_config = OpenAIEmbedderConfig(api_key=self.api_key, embedding_model=self.model)
|
|
||||||
|
|
||||||
return OpenAIEmbedder(config=embedder_config)
|
|
||||||
|
|
||||||
|
|
||||||
class Neo4jConfig(BaseModel):
|
|
||||||
"""Configuration for Neo4j database connection."""
|
|
||||||
|
|
||||||
uri: str = 'bolt://localhost:7687'
|
|
||||||
user: str = 'neo4j'
|
|
||||||
password: str = 'password'
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_env(cls) -> 'Neo4jConfig':
|
|
||||||
"""Create Neo4j configuration from environment variables."""
|
|
||||||
return cls(
|
|
||||||
uri=os.environ.get('NEO4J_URI', 'bolt://localhost:7687'),
|
|
||||||
user=os.environ.get('NEO4J_USER', 'neo4j'),
|
|
||||||
password=os.environ.get('NEO4J_PASSWORD', 'password'),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class GraphitiConfig(BaseModel):
|
|
||||||
"""Configuration for Graphiti client.
|
|
||||||
|
|
||||||
Centralizes all configuration parameters for the Graphiti client.
|
|
||||||
"""
|
|
||||||
|
|
||||||
llm: GraphitiLLMConfig = Field(default_factory=GraphitiLLMConfig)
|
|
||||||
embedder: GraphitiEmbedderConfig = Field(default_factory=GraphitiEmbedderConfig)
|
|
||||||
neo4j: Neo4jConfig = Field(default_factory=Neo4jConfig)
|
|
||||||
group_id: str | None = None
|
|
||||||
use_custom_entities: bool = False
|
|
||||||
destroy_graph: bool = False
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_env(cls) -> 'GraphitiConfig':
|
|
||||||
"""Create a configuration instance from environment variables."""
|
|
||||||
return cls(
|
|
||||||
llm=GraphitiLLMConfig.from_env(),
|
|
||||||
embedder=GraphitiEmbedderConfig.from_env(),
|
|
||||||
neo4j=Neo4jConfig.from_env(),
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_cli_and_env(cls, args: argparse.Namespace) -> 'GraphitiConfig':
|
|
||||||
"""Create configuration from CLI arguments, falling back to environment variables."""
|
|
||||||
# Start with environment configuration
|
|
||||||
config = cls.from_env()
|
|
||||||
|
|
||||||
# Apply CLI overrides
|
|
||||||
if args.group_id:
|
|
||||||
config.group_id = args.group_id
|
|
||||||
else:
|
|
||||||
config.group_id = 'default'
|
|
||||||
|
|
||||||
config.use_custom_entities = args.use_custom_entities
|
|
||||||
config.destroy_graph = args.destroy_graph
|
|
||||||
|
|
||||||
# Update LLM config using CLI args
|
|
||||||
config.llm = GraphitiLLMConfig.from_cli_and_env(args)
|
|
||||||
|
|
||||||
return config
|
|
||||||
|
|
||||||
|
|
||||||
class MCPConfig(BaseModel):
|
|
||||||
"""Configuration for MCP server."""
|
|
||||||
|
|
||||||
transport: str = 'sse' # Default to SSE transport
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_cli(cls, args: argparse.Namespace) -> 'MCPConfig':
|
|
||||||
"""Create MCP configuration from CLI arguments."""
|
|
||||||
return cls(transport=args.transport)
|
|
||||||
|
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.INFO,
|
level=logging.INFO,
|
||||||
|
|
@ -568,124 +97,9 @@ mcp = FastMCP(
|
||||||
instructions=GRAPHITI_MCP_INSTRUCTIONS,
|
instructions=GRAPHITI_MCP_INSTRUCTIONS,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize Graphiti client
|
# Global services
|
||||||
graphiti_client: Graphiti | None = None
|
graphiti_service: GraphitiService | None = None
|
||||||
|
queue_service: QueueService | None = None
|
||||||
|
|
||||||
async def initialize_graphiti():
|
|
||||||
"""Initialize the Graphiti client with the configured settings."""
|
|
||||||
global graphiti_client, config
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Create LLM client if possible
|
|
||||||
llm_client = config.llm.create_client()
|
|
||||||
if not llm_client and config.use_custom_entities:
|
|
||||||
# If custom entities are enabled, we must have an LLM client
|
|
||||||
raise ValueError('OPENAI_API_KEY must be set when custom entities are enabled')
|
|
||||||
|
|
||||||
# Validate Neo4j configuration
|
|
||||||
if not config.neo4j.uri or not config.neo4j.user or not config.neo4j.password:
|
|
||||||
raise ValueError('NEO4J_URI, NEO4J_USER, and NEO4J_PASSWORD must be set')
|
|
||||||
|
|
||||||
embedder_client = config.embedder.create_client()
|
|
||||||
|
|
||||||
# Initialize Graphiti client
|
|
||||||
graphiti_client = Graphiti(
|
|
||||||
uri=config.neo4j.uri,
|
|
||||||
user=config.neo4j.user,
|
|
||||||
password=config.neo4j.password,
|
|
||||||
llm_client=llm_client,
|
|
||||||
embedder=embedder_client,
|
|
||||||
max_coroutines=SEMAPHORE_LIMIT,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Destroy graph if requested
|
|
||||||
if config.destroy_graph:
|
|
||||||
logger.info('Destroying graph...')
|
|
||||||
await clear_data(graphiti_client.driver)
|
|
||||||
|
|
||||||
# Initialize the graph database with Graphiti's indices
|
|
||||||
await graphiti_client.build_indices_and_constraints()
|
|
||||||
logger.info('Graphiti client initialized successfully')
|
|
||||||
|
|
||||||
# Log configuration details for transparency
|
|
||||||
if llm_client:
|
|
||||||
logger.info(f'Using OpenAI model: {config.llm.model}')
|
|
||||||
logger.info(f'Using temperature: {config.llm.temperature}')
|
|
||||||
else:
|
|
||||||
logger.info('No LLM client configured - entity extraction will be limited')
|
|
||||||
|
|
||||||
logger.info(f'Using group_id: {config.group_id}')
|
|
||||||
logger.info(
|
|
||||||
f'Custom entity extraction: {"enabled" if config.use_custom_entities else "disabled"}'
|
|
||||||
)
|
|
||||||
logger.info(f'Using concurrency limit: {SEMAPHORE_LIMIT}')
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f'Failed to initialize Graphiti: {str(e)}')
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
def format_fact_result(edge: EntityEdge) -> dict[str, Any]:
|
|
||||||
"""Format an entity edge into a readable result.
|
|
||||||
|
|
||||||
Since EntityEdge is a Pydantic BaseModel, we can use its built-in serialization capabilities.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
edge: The EntityEdge to format
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A dictionary representation of the edge with serialized dates and excluded embeddings
|
|
||||||
"""
|
|
||||||
result = edge.model_dump(
|
|
||||||
mode='json',
|
|
||||||
exclude={
|
|
||||||
'fact_embedding',
|
|
||||||
},
|
|
||||||
)
|
|
||||||
result.get('attributes', {}).pop('fact_embedding', None)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
# Dictionary to store queues for each group_id
|
|
||||||
# Each queue is a list of tasks to be processed sequentially
|
|
||||||
episode_queues: dict[str, asyncio.Queue] = {}
|
|
||||||
# Dictionary to track if a worker is running for each group_id
|
|
||||||
queue_workers: dict[str, bool] = {}
|
|
||||||
|
|
||||||
|
|
||||||
async def process_episode_queue(group_id: str):
|
|
||||||
"""Process episodes for a specific group_id sequentially.
|
|
||||||
|
|
||||||
This function runs as a long-lived task that processes episodes
|
|
||||||
from the queue one at a time.
|
|
||||||
"""
|
|
||||||
global queue_workers
|
|
||||||
|
|
||||||
logger.info(f'Starting episode queue worker for group_id: {group_id}')
|
|
||||||
queue_workers[group_id] = True
|
|
||||||
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
# Get the next episode processing function from the queue
|
|
||||||
# This will wait if the queue is empty
|
|
||||||
process_func = await episode_queues[group_id].get()
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Process the episode
|
|
||||||
await process_func()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f'Error processing queued episode for group_id {group_id}: {str(e)}')
|
|
||||||
finally:
|
|
||||||
# Mark the task as done regardless of success/failure
|
|
||||||
episode_queues[group_id].task_done()
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
logger.info(f'Episode queue worker for group_id {group_id} was cancelled')
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f'Unexpected error in queue worker for group_id {group_id}: {str(e)}')
|
|
||||||
finally:
|
|
||||||
queue_workers[group_id] = False
|
|
||||||
logger.info(f'Stopped episode queue worker for group_id: {group_id}')
|
|
||||||
|
|
||||||
|
|
||||||
@mcp.tool()
|
@mcp.tool()
|
||||||
|
|
@ -752,10 +166,13 @@ async def add_memory(
|
||||||
- Entities will be created from appropriate JSON properties
|
- Entities will be created from appropriate JSON properties
|
||||||
- Relationships between entities will be established based on the JSON structure
|
- Relationships between entities will be established based on the JSON structure
|
||||||
"""
|
"""
|
||||||
global graphiti_client, episode_queues, queue_workers
|
global graphiti_service, queue_service, config
|
||||||
|
|
||||||
if graphiti_client is None:
|
if not graphiti_service or not graphiti_service.is_initialized():
|
||||||
return ErrorResponse(error='Graphiti client not initialized')
|
return ErrorResponse(error='Graphiti service not initialized')
|
||||||
|
|
||||||
|
if not queue_service:
|
||||||
|
return ErrorResponse(error='Queue service not initialized')
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Map string source to EpisodeType enum
|
# Map string source to EpisodeType enum
|
||||||
|
|
@ -772,13 +189,6 @@ async def add_memory(
|
||||||
# The Graphiti client expects a str for group_id, not Optional[str]
|
# The Graphiti client expects a str for group_id, not Optional[str]
|
||||||
group_id_str = str(effective_group_id) if effective_group_id is not None else ''
|
group_id_str = str(effective_group_id) if effective_group_id is not None else ''
|
||||||
|
|
||||||
# We've already checked that graphiti_client is not None above
|
|
||||||
# This assert statement helps type checkers understand that graphiti_client is defined
|
|
||||||
assert graphiti_client is not None, 'graphiti_client should not be None here'
|
|
||||||
|
|
||||||
# Use cast to help the type checker understand that graphiti_client is not None
|
|
||||||
client = cast(Graphiti, graphiti_client)
|
|
||||||
|
|
||||||
# Define the episode processing function
|
# Define the episode processing function
|
||||||
async def process_episode():
|
async def process_episode():
|
||||||
try:
|
try:
|
||||||
|
|
@ -786,7 +196,7 @@ async def add_memory(
|
||||||
# Use all entity types if use_custom_entities is enabled, otherwise use empty dict
|
# Use all entity types if use_custom_entities is enabled, otherwise use empty dict
|
||||||
entity_types = ENTITY_TYPES if config.use_custom_entities else {}
|
entity_types = ENTITY_TYPES if config.use_custom_entities else {}
|
||||||
|
|
||||||
await client.add_episode(
|
await graphiti_service.client.add_episode(
|
||||||
name=name,
|
name=name,
|
||||||
episode_body=episode_body,
|
episode_body=episode_body,
|
||||||
source=source_type,
|
source=source_type,
|
||||||
|
|
@ -796,8 +206,6 @@ async def add_memory(
|
||||||
reference_time=datetime.now(timezone.utc),
|
reference_time=datetime.now(timezone.utc),
|
||||||
entity_types=entity_types,
|
entity_types=entity_types,
|
||||||
)
|
)
|
||||||
logger.info(f"Episode '{name}' added successfully")
|
|
||||||
|
|
||||||
logger.info(f"Episode '{name}' processed successfully")
|
logger.info(f"Episode '{name}' processed successfully")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = str(e)
|
error_msg = str(e)
|
||||||
|
|
@ -805,20 +213,12 @@ async def add_memory(
|
||||||
f"Error processing episode '{name}' for group_id {group_id_str}: {error_msg}"
|
f"Error processing episode '{name}' for group_id {group_id_str}: {error_msg}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize queue for this group_id if it doesn't exist
|
|
||||||
if group_id_str not in episode_queues:
|
|
||||||
episode_queues[group_id_str] = asyncio.Queue()
|
|
||||||
|
|
||||||
# Add the episode processing function to the queue
|
# Add the episode processing function to the queue
|
||||||
await episode_queues[group_id_str].put(process_episode)
|
queue_position = await queue_service.add_episode_task(group_id_str, process_episode)
|
||||||
|
|
||||||
# Start a worker for this queue if one isn't already running
|
|
||||||
if not queue_workers.get(group_id_str, False):
|
|
||||||
asyncio.create_task(process_episode_queue(group_id_str))
|
|
||||||
|
|
||||||
# Return immediately with a success message
|
# Return immediately with a success message
|
||||||
return SuccessResponse(
|
return SuccessResponse(
|
||||||
message=f"Episode '{name}' queued for processing (position: {episode_queues[group_id_str].qsize()})"
|
message=f"Episode '{name}' queued for processing (position: {queue_position})"
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = str(e)
|
error_msg = str(e)
|
||||||
|
|
@ -846,10 +246,10 @@ async def search_memory_nodes(
|
||||||
center_node_uuid: Optional UUID of a node to center the search around
|
center_node_uuid: Optional UUID of a node to center the search around
|
||||||
entity: Optional single entity type to filter results (permitted: "Preference", "Procedure")
|
entity: Optional single entity type to filter results (permitted: "Preference", "Procedure")
|
||||||
"""
|
"""
|
||||||
global graphiti_client
|
global graphiti_service, config
|
||||||
|
|
||||||
if graphiti_client is None:
|
if not graphiti_service or not graphiti_service.is_initialized():
|
||||||
return ErrorResponse(error='Graphiti client not initialized')
|
return ErrorResponse(error='Graphiti service not initialized')
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Use the provided group_ids or fall back to the default from config if none provided
|
# Use the provided group_ids or fall back to the default from config if none provided
|
||||||
|
|
@ -868,11 +268,7 @@ async def search_memory_nodes(
|
||||||
if entity != '':
|
if entity != '':
|
||||||
filters.node_labels = [entity]
|
filters.node_labels = [entity]
|
||||||
|
|
||||||
# We've already checked that graphiti_client is not None above
|
client = graphiti_service.client
|
||||||
assert graphiti_client is not None
|
|
||||||
|
|
||||||
# Use cast to help the type checker understand that graphiti_client is not None
|
|
||||||
client = cast(Graphiti, graphiti_client)
|
|
||||||
|
|
||||||
# Perform the search using the _search method
|
# Perform the search using the _search method
|
||||||
search_results = await client._search(
|
search_results = await client._search(
|
||||||
|
|
@ -1218,8 +614,11 @@ async def initialize_server() -> MCPConfig:
|
||||||
else:
|
else:
|
||||||
logger.info('Entity extraction disabled (no custom entities will be used)')
|
logger.info('Entity extraction disabled (no custom entities will be used)')
|
||||||
|
|
||||||
# Initialize Graphiti
|
# Initialize services
|
||||||
await initialize_graphiti()
|
global graphiti_service, queue_service
|
||||||
|
graphiti_service = GraphitiService(config, SEMAPHORE_LIMIT)
|
||||||
|
queue_service = QueueService()
|
||||||
|
await graphiti_service.initialize()
|
||||||
|
|
||||||
if args.host:
|
if args.host:
|
||||||
logger.info(f'Setting MCP server host to: {args.host}')
|
logger.info(f'Setting MCP server host to: {args.host}')
|
||||||
|
|
|
||||||
110
mcp_server/graphiti_service.py
Normal file
110
mcp_server/graphiti_service.py
Normal file
|
|
@ -0,0 +1,110 @@
|
||||||
|
"""Graphiti service for managing client lifecycle and operations."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from config_manager import GraphitiConfig
|
||||||
|
|
||||||
|
from graphiti_core import Graphiti
|
||||||
|
from graphiti_core.utils.maintenance.graph_data_operations import clear_data
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class GraphitiService:
|
||||||
|
"""Service for managing Graphiti client operations."""
|
||||||
|
|
||||||
|
def __init__(self, config: GraphitiConfig, semaphore_limit: int = 10):
|
||||||
|
"""Initialize the Graphiti service with configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: The Graphiti configuration
|
||||||
|
semaphore_limit: Maximum concurrent operations
|
||||||
|
"""
|
||||||
|
self.config = config
|
||||||
|
self.semaphore_limit = semaphore_limit
|
||||||
|
self._client: Graphiti | None = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def client(self) -> Graphiti:
|
||||||
|
"""Get the Graphiti client instance.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If the client hasn't been initialized
|
||||||
|
"""
|
||||||
|
if self._client is None:
|
||||||
|
raise RuntimeError('Graphiti client not initialized. Call initialize() first.')
|
||||||
|
return self._client
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
"""Initialize the Graphiti client with the configured settings."""
|
||||||
|
try:
|
||||||
|
# Create LLM client if possible
|
||||||
|
llm_client = self.config.llm.create_client()
|
||||||
|
if not llm_client and self.config.use_custom_entities:
|
||||||
|
# If custom entities are enabled, we must have an LLM client
|
||||||
|
raise ValueError('OPENAI_API_KEY must be set when custom entities are enabled')
|
||||||
|
|
||||||
|
# Validate Neo4j configuration
|
||||||
|
if (
|
||||||
|
not self.config.neo4j.uri
|
||||||
|
or not self.config.neo4j.user
|
||||||
|
or not self.config.neo4j.password
|
||||||
|
):
|
||||||
|
raise ValueError('NEO4J_URI, NEO4J_USER, and NEO4J_PASSWORD must be set')
|
||||||
|
|
||||||
|
embedder_client = self.config.embedder.create_client()
|
||||||
|
|
||||||
|
# Initialize Graphiti client
|
||||||
|
self._client = Graphiti(
|
||||||
|
uri=self.config.neo4j.uri,
|
||||||
|
user=self.config.neo4j.user,
|
||||||
|
password=self.config.neo4j.password,
|
||||||
|
llm_client=llm_client,
|
||||||
|
embedder=embedder_client,
|
||||||
|
max_coroutines=self.semaphore_limit,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Destroy graph if requested
|
||||||
|
if self.config.destroy_graph:
|
||||||
|
logger.info('Destroying graph...')
|
||||||
|
await clear_data(self._client.driver)
|
||||||
|
|
||||||
|
# Initialize the graph database with Graphiti's indices
|
||||||
|
await self._client.build_indices_and_constraints()
|
||||||
|
logger.info('Graphiti client initialized successfully')
|
||||||
|
|
||||||
|
# Log configuration details for transparency
|
||||||
|
if llm_client:
|
||||||
|
logger.info(f'Using OpenAI model: {self.config.llm.model}')
|
||||||
|
logger.info(f'Using temperature: {self.config.llm.temperature}')
|
||||||
|
else:
|
||||||
|
logger.info('No LLM client configured - entity extraction will be limited')
|
||||||
|
|
||||||
|
logger.info(f'Using group_id: {self.config.group_id}')
|
||||||
|
logger.info(
|
||||||
|
f'Custom entity extraction: {"enabled" if self.config.use_custom_entities else "disabled"}'
|
||||||
|
)
|
||||||
|
logger.info(f'Using concurrency limit: {self.semaphore_limit}')
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f'Failed to initialize Graphiti: {str(e)}')
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def clear_graph(self) -> None:
|
||||||
|
"""Clear all data from the graph and rebuild indices."""
|
||||||
|
if self._client is None:
|
||||||
|
raise RuntimeError('Graphiti client not initialized')
|
||||||
|
|
||||||
|
await clear_data(self._client.driver)
|
||||||
|
await self._client.build_indices_and_constraints()
|
||||||
|
|
||||||
|
async def verify_connection(self) -> None:
|
||||||
|
"""Verify the database connection."""
|
||||||
|
if self._client is None:
|
||||||
|
raise RuntimeError('Graphiti client not initialized')
|
||||||
|
|
||||||
|
await self._client.driver.client.verify_connectivity() # type: ignore
|
||||||
|
|
||||||
|
def is_initialized(self) -> bool:
|
||||||
|
"""Check if the client is initialized."""
|
||||||
|
return self._client is not None
|
||||||
182
mcp_server/llm_config.py
Normal file
182
mcp_server/llm_config.py
Normal file
|
|
@ -0,0 +1,182 @@
|
||||||
|
"""LLM configuration for Graphiti MCP Server."""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from openai import AsyncAzureOpenAI
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from utils import create_azure_credential_token_provider
|
||||||
|
|
||||||
|
from graphiti_core.llm_client import LLMClient
|
||||||
|
from graphiti_core.llm_client.azure_openai_client import AzureOpenAILLMClient
|
||||||
|
from graphiti_core.llm_client.config import LLMConfig
|
||||||
|
from graphiti_core.llm_client.openai_client import OpenAIClient
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
pass
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
DEFAULT_LLM_MODEL = 'gpt-4.1-mini'
|
||||||
|
SMALL_LLM_MODEL = 'gpt-4.1-nano'
|
||||||
|
|
||||||
|
|
||||||
|
class GraphitiLLMConfig(BaseModel):
|
||||||
|
"""Configuration for the LLM client.
|
||||||
|
|
||||||
|
Centralizes all LLM-specific configuration parameters including API keys and model selection.
|
||||||
|
"""
|
||||||
|
|
||||||
|
api_key: str | None = None
|
||||||
|
model: str = DEFAULT_LLM_MODEL
|
||||||
|
small_model: str = SMALL_LLM_MODEL
|
||||||
|
temperature: float = 0.0
|
||||||
|
azure_openai_endpoint: str | None = None
|
||||||
|
azure_openai_deployment_name: str | None = None
|
||||||
|
azure_openai_api_version: str | None = None
|
||||||
|
azure_openai_use_managed_identity: bool = False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_env(cls) -> 'GraphitiLLMConfig':
|
||||||
|
"""Create LLM configuration from environment variables."""
|
||||||
|
# Get model from environment, or use default if not set or empty
|
||||||
|
model_env = os.environ.get('MODEL_NAME', '')
|
||||||
|
model = model_env if model_env.strip() else DEFAULT_LLM_MODEL
|
||||||
|
|
||||||
|
# Get small_model from environment, or use default if not set or empty
|
||||||
|
small_model_env = os.environ.get('SMALL_MODEL_NAME', '')
|
||||||
|
small_model = small_model_env if small_model_env.strip() else SMALL_LLM_MODEL
|
||||||
|
|
||||||
|
azure_openai_endpoint = os.environ.get('AZURE_OPENAI_ENDPOINT', None)
|
||||||
|
azure_openai_api_version = os.environ.get('AZURE_OPENAI_API_VERSION', None)
|
||||||
|
azure_openai_deployment_name = os.environ.get('AZURE_OPENAI_DEPLOYMENT_NAME', None)
|
||||||
|
azure_openai_use_managed_identity = (
|
||||||
|
os.environ.get('AZURE_OPENAI_USE_MANAGED_IDENTITY', 'false').lower() == 'true'
|
||||||
|
)
|
||||||
|
|
||||||
|
if azure_openai_endpoint is None:
|
||||||
|
# Setup for OpenAI API
|
||||||
|
# Log if empty model was provided
|
||||||
|
if model_env == '':
|
||||||
|
logger.debug(
|
||||||
|
f'MODEL_NAME environment variable not set, using default: {DEFAULT_LLM_MODEL}'
|
||||||
|
)
|
||||||
|
elif not model_env.strip():
|
||||||
|
logger.warning(
|
||||||
|
f'Empty MODEL_NAME environment variable, using default: {DEFAULT_LLM_MODEL}'
|
||||||
|
)
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
api_key=os.environ.get('OPENAI_API_KEY'),
|
||||||
|
model=model,
|
||||||
|
small_model=small_model,
|
||||||
|
temperature=float(os.environ.get('LLM_TEMPERATURE', '0.0')),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Setup for Azure OpenAI API
|
||||||
|
# Log if empty deployment name was provided
|
||||||
|
if azure_openai_deployment_name is None:
|
||||||
|
logger.error('AZURE_OPENAI_DEPLOYMENT_NAME environment variable not set')
|
||||||
|
raise ValueError('AZURE_OPENAI_DEPLOYMENT_NAME environment variable not set')
|
||||||
|
|
||||||
|
if not azure_openai_use_managed_identity:
|
||||||
|
# api key
|
||||||
|
api_key = os.environ.get('OPENAI_API_KEY', None)
|
||||||
|
else:
|
||||||
|
# Managed identity
|
||||||
|
api_key = None
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
azure_openai_use_managed_identity=azure_openai_use_managed_identity,
|
||||||
|
azure_openai_endpoint=azure_openai_endpoint,
|
||||||
|
api_key=api_key,
|
||||||
|
azure_openai_api_version=azure_openai_api_version,
|
||||||
|
azure_openai_deployment_name=azure_openai_deployment_name,
|
||||||
|
model=model,
|
||||||
|
small_model=small_model,
|
||||||
|
temperature=float(os.environ.get('LLM_TEMPERATURE', '0.0')),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_cli_and_env(cls, args: argparse.Namespace) -> 'GraphitiLLMConfig':
|
||||||
|
"""Create LLM configuration from CLI arguments, falling back to environment variables."""
|
||||||
|
# Start with environment-based config
|
||||||
|
config = cls.from_env()
|
||||||
|
|
||||||
|
# CLI arguments override environment variables when provided
|
||||||
|
if hasattr(args, 'model') and args.model:
|
||||||
|
# Only use CLI model if it's not empty
|
||||||
|
if args.model.strip():
|
||||||
|
config.model = args.model
|
||||||
|
else:
|
||||||
|
# Log that empty model was provided and default is used
|
||||||
|
logger.warning(f'Empty model name provided, using default: {DEFAULT_LLM_MODEL}')
|
||||||
|
|
||||||
|
if hasattr(args, 'small_model') and args.small_model:
|
||||||
|
if args.small_model.strip():
|
||||||
|
config.small_model = args.small_model
|
||||||
|
else:
|
||||||
|
logger.warning(f'Empty small_model name provided, using default: {SMALL_LLM_MODEL}')
|
||||||
|
|
||||||
|
if hasattr(args, 'temperature') and args.temperature is not None:
|
||||||
|
config.temperature = args.temperature
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
def create_client(self) -> LLMClient:
|
||||||
|
"""Create an LLM client based on this configuration.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LLMClient instance
|
||||||
|
"""
|
||||||
|
if self.azure_openai_endpoint is not None:
|
||||||
|
# Azure OpenAI API setup
|
||||||
|
if self.azure_openai_use_managed_identity:
|
||||||
|
# Use managed identity for authentication
|
||||||
|
token_provider = create_azure_credential_token_provider()
|
||||||
|
return AzureOpenAILLMClient(
|
||||||
|
azure_client=AsyncAzureOpenAI(
|
||||||
|
azure_endpoint=self.azure_openai_endpoint,
|
||||||
|
azure_deployment=self.azure_openai_deployment_name,
|
||||||
|
api_version=self.azure_openai_api_version,
|
||||||
|
azure_ad_token_provider=token_provider,
|
||||||
|
),
|
||||||
|
config=LLMConfig(
|
||||||
|
api_key=self.api_key,
|
||||||
|
model=self.model,
|
||||||
|
small_model=self.small_model,
|
||||||
|
temperature=self.temperature,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
elif self.api_key:
|
||||||
|
# Use API key for authentication
|
||||||
|
return AzureOpenAILLMClient(
|
||||||
|
azure_client=AsyncAzureOpenAI(
|
||||||
|
azure_endpoint=self.azure_openai_endpoint,
|
||||||
|
azure_deployment=self.azure_openai_deployment_name,
|
||||||
|
api_version=self.azure_openai_api_version,
|
||||||
|
api_key=self.api_key,
|
||||||
|
),
|
||||||
|
config=LLMConfig(
|
||||||
|
api_key=self.api_key,
|
||||||
|
model=self.model,
|
||||||
|
small_model=self.small_model,
|
||||||
|
temperature=self.temperature,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError('OPENAI_API_KEY must be set when using Azure OpenAI API')
|
||||||
|
|
||||||
|
if not self.api_key:
|
||||||
|
raise ValueError('OPENAI_API_KEY must be set when using OpenAI API')
|
||||||
|
|
||||||
|
llm_client_config = LLMConfig(
|
||||||
|
api_key=self.api_key, model=self.model, small_model=self.small_model
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set temperature
|
||||||
|
llm_client_config.temperature = self.temperature
|
||||||
|
|
||||||
|
return OpenAIClient(config=llm_client_config)
|
||||||
22
mcp_server/neo4j_config.py
Normal file
22
mcp_server/neo4j_config.py
Normal file
|
|
@ -0,0 +1,22 @@
|
||||||
|
"""Neo4j database configuration for Graphiti MCP Server."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class Neo4jConfig(BaseModel):
|
||||||
|
"""Configuration for Neo4j database connection."""
|
||||||
|
|
||||||
|
uri: str = 'bolt://localhost:7687'
|
||||||
|
user: str = 'neo4j'
|
||||||
|
password: str = 'password'
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_env(cls) -> 'Neo4jConfig':
|
||||||
|
"""Create Neo4j configuration from environment variables."""
|
||||||
|
return cls(
|
||||||
|
uri=os.environ.get('NEO4J_URI', 'bolt://localhost:7687'),
|
||||||
|
user=os.environ.get('NEO4J_USER', 'neo4j'),
|
||||||
|
password=os.environ.get('NEO4J_PASSWORD', 'password'),
|
||||||
|
)
|
||||||
|
|
@ -11,3 +11,9 @@ dependencies = [
|
||||||
"azure-identity>=1.21.0",
|
"azure-identity>=1.21.0",
|
||||||
"graphiti-core",
|
"graphiti-core",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[dependency-groups]
|
||||||
|
dev = [
|
||||||
|
"httpx>=0.28.1",
|
||||||
|
"mcp>=1.9.4",
|
||||||
|
]
|
||||||
|
|
|
||||||
86
mcp_server/queue_service.py
Normal file
86
mcp_server/queue_service.py
Normal file
|
|
@ -0,0 +1,86 @@
|
||||||
|
"""Queue service for managing episode processing."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class QueueService:
|
||||||
|
"""Service for managing sequential episode processing queues by group_id."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize the queue service."""
|
||||||
|
# Dictionary to store queues for each group_id
|
||||||
|
self._episode_queues: dict[str, asyncio.Queue] = {}
|
||||||
|
# Dictionary to track if a worker is running for each group_id
|
||||||
|
self._queue_workers: dict[str, bool] = {}
|
||||||
|
|
||||||
|
async def add_episode_task(
|
||||||
|
self, group_id: str, process_func: Callable[[], Awaitable[None]]
|
||||||
|
) -> int:
|
||||||
|
"""Add an episode processing task to the queue.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
group_id: The group ID for the episode
|
||||||
|
process_func: The async function to process the episode
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The position in the queue
|
||||||
|
"""
|
||||||
|
# Initialize queue for this group_id if it doesn't exist
|
||||||
|
if group_id not in self._episode_queues:
|
||||||
|
self._episode_queues[group_id] = asyncio.Queue()
|
||||||
|
|
||||||
|
# Add the episode processing function to the queue
|
||||||
|
await self._episode_queues[group_id].put(process_func)
|
||||||
|
|
||||||
|
# Start a worker for this queue if one isn't already running
|
||||||
|
if not self._queue_workers.get(group_id, False):
|
||||||
|
asyncio.create_task(self._process_episode_queue(group_id))
|
||||||
|
|
||||||
|
return self._episode_queues[group_id].qsize()
|
||||||
|
|
||||||
|
async def _process_episode_queue(self, group_id: str) -> None:
|
||||||
|
"""Process episodes for a specific group_id sequentially.
|
||||||
|
|
||||||
|
This function runs as a long-lived task that processes episodes
|
||||||
|
from the queue one at a time.
|
||||||
|
"""
|
||||||
|
logger.info(f'Starting episode queue worker for group_id: {group_id}')
|
||||||
|
self._queue_workers[group_id] = True
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
# Get the next episode processing function from the queue
|
||||||
|
# This will wait if the queue is empty
|
||||||
|
process_func = await self._episode_queues[group_id].get()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Process the episode
|
||||||
|
await process_func()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f'Error processing queued episode for group_id {group_id}: {str(e)}'
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
# Mark the task as done regardless of success/failure
|
||||||
|
self._episode_queues[group_id].task_done()
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.info(f'Episode queue worker for group_id {group_id} was cancelled')
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f'Unexpected error in queue worker for group_id {group_id}: {str(e)}')
|
||||||
|
finally:
|
||||||
|
self._queue_workers[group_id] = False
|
||||||
|
logger.info(f'Stopped episode queue worker for group_id: {group_id}')
|
||||||
|
|
||||||
|
def get_queue_size(self, group_id: str) -> int:
|
||||||
|
"""Get the current queue size for a group_id."""
|
||||||
|
if group_id not in self._episode_queues:
|
||||||
|
return 0
|
||||||
|
return self._episode_queues[group_id].qsize()
|
||||||
|
|
||||||
|
def is_worker_running(self, group_id: str) -> bool:
|
||||||
|
"""Check if a worker is running for a group_id."""
|
||||||
|
return self._queue_workers.get(group_id, False)
|
||||||
41
mcp_server/response_types.py
Normal file
41
mcp_server/response_types.py
Normal file
|
|
@ -0,0 +1,41 @@
|
||||||
|
"""Response type definitions for Graphiti MCP Server."""
|
||||||
|
|
||||||
|
from typing import Any, TypedDict
|
||||||
|
|
||||||
|
|
||||||
|
class ErrorResponse(TypedDict):
|
||||||
|
error: str
|
||||||
|
|
||||||
|
|
||||||
|
class SuccessResponse(TypedDict):
|
||||||
|
message: str
|
||||||
|
|
||||||
|
|
||||||
|
class NodeResult(TypedDict):
|
||||||
|
uuid: str
|
||||||
|
name: str
|
||||||
|
summary: str
|
||||||
|
labels: list[str]
|
||||||
|
group_id: str
|
||||||
|
created_at: str
|
||||||
|
attributes: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
class NodeSearchResponse(TypedDict):
|
||||||
|
message: str
|
||||||
|
nodes: list[NodeResult]
|
||||||
|
|
||||||
|
|
||||||
|
class FactSearchResponse(TypedDict):
|
||||||
|
message: str
|
||||||
|
facts: list[dict[str, Any]]
|
||||||
|
|
||||||
|
|
||||||
|
class EpisodeSearchResponse(TypedDict):
|
||||||
|
message: str
|
||||||
|
episodes: list[dict[str, Any]]
|
||||||
|
|
||||||
|
|
||||||
|
class StatusResponse(TypedDict):
|
||||||
|
status: str
|
||||||
|
message: str
|
||||||
16
mcp_server/server_config.py
Normal file
16
mcp_server/server_config.py
Normal file
|
|
@ -0,0 +1,16 @@
|
||||||
|
"""Server configuration for Graphiti MCP Server."""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class MCPConfig(BaseModel):
|
||||||
|
"""Configuration for MCP server."""
|
||||||
|
|
||||||
|
transport: str = 'sse' # Default to SSE transport
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_cli(cls, args: argparse.Namespace) -> 'MCPConfig':
|
||||||
|
"""Create MCP configuration from CLI arguments."""
|
||||||
|
return cls(transport=args.transport)
|
||||||
363
mcp_server/test_integration.py
Normal file
363
mcp_server/test_integration.py
Normal file
|
|
@ -0,0 +1,363 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Integration test for the refactored Graphiti MCP Server.
|
||||||
|
Tests all major MCP tools and handles episode processing latency.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
|
||||||
|
class MCPIntegrationTest:
|
||||||
|
"""Integration test client for Graphiti MCP Server."""
|
||||||
|
|
||||||
|
def __init__(self, base_url: str = 'http://localhost:8000'):
|
||||||
|
self.base_url = base_url
|
||||||
|
self.client = httpx.AsyncClient(timeout=30.0)
|
||||||
|
self.test_group_id = f'test_group_{int(time.time())}'
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
await self.client.aclose()
|
||||||
|
|
||||||
|
async def call_mcp_tool(self, tool_name: str, arguments: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Call an MCP tool via the SSE endpoint."""
|
||||||
|
# MCP protocol message structure
|
||||||
|
message = {
|
||||||
|
'jsonrpc': '2.0',
|
||||||
|
'id': int(time.time() * 1000),
|
||||||
|
'method': 'tools/call',
|
||||||
|
'params': {'name': tool_name, 'arguments': arguments},
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await self.client.post(
|
||||||
|
f'{self.base_url}/message',
|
||||||
|
json=message,
|
||||||
|
headers={'Content-Type': 'application/json'},
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
return {'error': f'HTTP {response.status_code}: {response.text}'}
|
||||||
|
|
||||||
|
result = response.json()
|
||||||
|
return result.get('result', result)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return {'error': str(e)}
|
||||||
|
|
||||||
|
async def test_server_status(self) -> bool:
|
||||||
|
"""Test the get_status resource."""
|
||||||
|
print('🔍 Testing server status...')
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await self.client.get(f'{self.base_url}/resources/http://graphiti/status')
|
||||||
|
if response.status_code == 200:
|
||||||
|
status = response.json()
|
||||||
|
print(f' ✅ Server status: {status.get("status", "unknown")}')
|
||||||
|
return status.get('status') == 'ok'
|
||||||
|
else:
|
||||||
|
print(f' ❌ Status check failed: HTTP {response.status_code}')
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
print(f' ❌ Status check failed: {e}')
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def test_add_memory(self) -> dict[str, str]:
|
||||||
|
"""Test adding various types of memory episodes."""
|
||||||
|
print('📝 Testing add_memory functionality...')
|
||||||
|
|
||||||
|
episode_results = {}
|
||||||
|
|
||||||
|
# Test 1: Add text episode
|
||||||
|
print(' Testing text episode...')
|
||||||
|
result = await self.call_mcp_tool(
|
||||||
|
'add_memory',
|
||||||
|
{
|
||||||
|
'name': 'Test Company News',
|
||||||
|
'episode_body': 'Acme Corp announced a revolutionary new AI product that will transform the industry. The CEO mentioned this is their biggest launch since 2020.',
|
||||||
|
'source': 'text',
|
||||||
|
'source_description': 'news article',
|
||||||
|
'group_id': self.test_group_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if 'error' in result:
|
||||||
|
print(f' ❌ Text episode failed: {result["error"]}')
|
||||||
|
else:
|
||||||
|
print(f' ✅ Text episode queued: {result.get("message", "Success")}')
|
||||||
|
episode_results['text'] = 'success'
|
||||||
|
|
||||||
|
# Test 2: Add JSON episode
|
||||||
|
print(' Testing JSON episode...')
|
||||||
|
json_data = {
|
||||||
|
'company': {'name': 'TechCorp', 'founded': 2010},
|
||||||
|
'products': [
|
||||||
|
{'id': 'P001', 'name': 'CloudSync', 'category': 'software'},
|
||||||
|
{'id': 'P002', 'name': 'DataMiner', 'category': 'analytics'},
|
||||||
|
],
|
||||||
|
'employees': 150,
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await self.call_mcp_tool(
|
||||||
|
'add_memory',
|
||||||
|
{
|
||||||
|
'name': 'Company Profile',
|
||||||
|
'episode_body': json.dumps(json_data),
|
||||||
|
'source': 'json',
|
||||||
|
'source_description': 'CRM data',
|
||||||
|
'group_id': self.test_group_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if 'error' in result:
|
||||||
|
print(f' ❌ JSON episode failed: {result["error"]}')
|
||||||
|
else:
|
||||||
|
print(f' ✅ JSON episode queued: {result.get("message", "Success")}')
|
||||||
|
episode_results['json'] = 'success'
|
||||||
|
|
||||||
|
# Test 3: Add message episode
|
||||||
|
print(' Testing message episode...')
|
||||||
|
result = await self.call_mcp_tool(
|
||||||
|
'add_memory',
|
||||||
|
{
|
||||||
|
'name': 'Customer Support Chat',
|
||||||
|
'episode_body': "user: What's your return policy?\nassistant: You can return items within 30 days of purchase with receipt.\nuser: Thanks!",
|
||||||
|
'source': 'message',
|
||||||
|
'source_description': 'support chat log',
|
||||||
|
'group_id': self.test_group_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if 'error' in result:
|
||||||
|
print(f' ❌ Message episode failed: {result["error"]}')
|
||||||
|
else:
|
||||||
|
print(f' ✅ Message episode queued: {result.get("message", "Success")}')
|
||||||
|
episode_results['message'] = 'success'
|
||||||
|
|
||||||
|
return episode_results
|
||||||
|
|
||||||
|
async def wait_for_processing(self, max_wait: int = 30) -> None:
|
||||||
|
"""Wait for episode processing to complete."""
|
||||||
|
print(f'⏳ Waiting up to {max_wait} seconds for episode processing...')
|
||||||
|
|
||||||
|
for i in range(max_wait):
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
# Check if we have any episodes
|
||||||
|
result = await self.call_mcp_tool(
|
||||||
|
'get_episodes', {'group_id': self.test_group_id, 'last_n': 10}
|
||||||
|
)
|
||||||
|
|
||||||
|
if not isinstance(result, dict) or 'error' in result:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if isinstance(result, list) and len(result) > 0:
|
||||||
|
print(f' ✅ Found {len(result)} processed episodes after {i + 1} seconds')
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f' ⚠️ Still waiting after {max_wait} seconds...')
|
||||||
|
|
||||||
|
async def test_search_functions(self) -> dict[str, bool]:
|
||||||
|
"""Test search functionality."""
|
||||||
|
print('🔍 Testing search functions...')
|
||||||
|
|
||||||
|
results = {}
|
||||||
|
|
||||||
|
# Test search_memory_nodes
|
||||||
|
print(' Testing search_memory_nodes...')
|
||||||
|
result = await self.call_mcp_tool(
|
||||||
|
'search_memory_nodes',
|
||||||
|
{
|
||||||
|
'query': 'Acme Corp product launch',
|
||||||
|
'group_ids': [self.test_group_id],
|
||||||
|
'max_nodes': 5,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if 'error' in result:
|
||||||
|
print(f' ❌ Node search failed: {result["error"]}')
|
||||||
|
results['nodes'] = False
|
||||||
|
else:
|
||||||
|
nodes = result.get('nodes', [])
|
||||||
|
print(f' ✅ Node search returned {len(nodes)} nodes')
|
||||||
|
results['nodes'] = True
|
||||||
|
|
||||||
|
# Test search_memory_facts
|
||||||
|
print(' Testing search_memory_facts...')
|
||||||
|
result = await self.call_mcp_tool(
|
||||||
|
'search_memory_facts',
|
||||||
|
{
|
||||||
|
'query': 'company products software',
|
||||||
|
'group_ids': [self.test_group_id],
|
||||||
|
'max_facts': 5,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if 'error' in result:
|
||||||
|
print(f' ❌ Fact search failed: {result["error"]}')
|
||||||
|
results['facts'] = False
|
||||||
|
else:
|
||||||
|
facts = result.get('facts', [])
|
||||||
|
print(f' ✅ Fact search returned {len(facts)} facts')
|
||||||
|
results['facts'] = True
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
async def test_episode_retrieval(self) -> bool:
|
||||||
|
"""Test episode retrieval."""
|
||||||
|
print('📚 Testing episode retrieval...')
|
||||||
|
|
||||||
|
result = await self.call_mcp_tool(
|
||||||
|
'get_episodes', {'group_id': self.test_group_id, 'last_n': 10}
|
||||||
|
)
|
||||||
|
|
||||||
|
if 'error' in result:
|
||||||
|
print(f' ❌ Episode retrieval failed: {result["error"]}')
|
||||||
|
return False
|
||||||
|
|
||||||
|
if isinstance(result, list):
|
||||||
|
print(f' ✅ Retrieved {len(result)} episodes')
|
||||||
|
|
||||||
|
# Print episode details
|
||||||
|
for i, episode in enumerate(result[:3]): # Show first 3
|
||||||
|
name = episode.get('name', 'Unknown')
|
||||||
|
source = episode.get('source', 'unknown')
|
||||||
|
print(f' Episode {i + 1}: {name} (source: {source})')
|
||||||
|
|
||||||
|
return len(result) > 0
|
||||||
|
else:
|
||||||
|
print(f' ❌ Unexpected result format: {type(result)}')
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def test_edge_cases(self) -> dict[str, bool]:
|
||||||
|
"""Test edge cases and error handling."""
|
||||||
|
print('🧪 Testing edge cases...')
|
||||||
|
|
||||||
|
results = {}
|
||||||
|
|
||||||
|
# Test with invalid group_id
|
||||||
|
print(' Testing invalid group_id...')
|
||||||
|
result = await self.call_mcp_tool(
|
||||||
|
'search_memory_nodes',
|
||||||
|
{'query': 'nonexistent data', 'group_ids': ['nonexistent_group'], 'max_nodes': 5},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should not error, just return empty results
|
||||||
|
if 'error' not in result:
|
||||||
|
nodes = result.get('nodes', [])
|
||||||
|
print(f' ✅ Invalid group_id handled gracefully (returned {len(nodes)} nodes)')
|
||||||
|
results['invalid_group'] = True
|
||||||
|
else:
|
||||||
|
print(f' ❌ Invalid group_id caused error: {result["error"]}')
|
||||||
|
results['invalid_group'] = False
|
||||||
|
|
||||||
|
# Test empty query
|
||||||
|
print(' Testing empty query...')
|
||||||
|
result = await self.call_mcp_tool(
|
||||||
|
'search_memory_nodes', {'query': '', 'group_ids': [self.test_group_id], 'max_nodes': 5}
|
||||||
|
)
|
||||||
|
|
||||||
|
if 'error' not in result:
|
||||||
|
print(' ✅ Empty query handled gracefully')
|
||||||
|
results['empty_query'] = True
|
||||||
|
else:
|
||||||
|
print(f' ❌ Empty query caused error: {result["error"]}')
|
||||||
|
results['empty_query'] = False
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
async def run_full_test_suite(self) -> dict[str, Any]:
|
||||||
|
"""Run the complete integration test suite."""
|
||||||
|
print('🚀 Starting Graphiti MCP Server Integration Test')
|
||||||
|
print(f' Test group ID: {self.test_group_id}')
|
||||||
|
print('=' * 60)
|
||||||
|
|
||||||
|
results = {
|
||||||
|
'server_status': False,
|
||||||
|
'add_memory': {},
|
||||||
|
'search': {},
|
||||||
|
'episodes': False,
|
||||||
|
'edge_cases': {},
|
||||||
|
'overall_success': False,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Test 1: Server Status
|
||||||
|
results['server_status'] = await self.test_server_status()
|
||||||
|
if not results['server_status']:
|
||||||
|
print('❌ Server not responding, aborting tests')
|
||||||
|
return results
|
||||||
|
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Test 2: Add Memory
|
||||||
|
results['add_memory'] = await self.test_add_memory()
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Test 3: Wait for processing
|
||||||
|
await self.wait_for_processing()
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Test 4: Search Functions
|
||||||
|
results['search'] = await self.test_search_functions()
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Test 5: Episode Retrieval
|
||||||
|
results['episodes'] = await self.test_episode_retrieval()
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Test 6: Edge Cases
|
||||||
|
results['edge_cases'] = await self.test_edge_cases()
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Calculate overall success
|
||||||
|
memory_success = len(results['add_memory']) > 0
|
||||||
|
search_success = any(results['search'].values())
|
||||||
|
edge_case_success = any(results['edge_cases'].values())
|
||||||
|
|
||||||
|
results['overall_success'] = (
|
||||||
|
results['server_status']
|
||||||
|
and memory_success
|
||||||
|
and results['episodes']
|
||||||
|
and (search_success or edge_case_success) # At least some functionality working
|
||||||
|
)
|
||||||
|
|
||||||
|
# Print summary
|
||||||
|
print('=' * 60)
|
||||||
|
print('📊 TEST SUMMARY')
|
||||||
|
print(f' Server Status: {"✅" if results["server_status"] else "❌"}')
|
||||||
|
print(
|
||||||
|
f' Memory Operations: {"✅" if memory_success else "❌"} ({len(results["add_memory"])} types)'
|
||||||
|
)
|
||||||
|
print(f' Search Functions: {"✅" if search_success else "❌"}')
|
||||||
|
print(f' Episode Retrieval: {"✅" if results["episodes"] else "❌"}')
|
||||||
|
print(f' Edge Cases: {"✅" if edge_case_success else "❌"}')
|
||||||
|
print()
|
||||||
|
print(f'🎯 OVERALL: {"✅ SUCCESS" if results["overall_success"] else "❌ FAILED"}')
|
||||||
|
|
||||||
|
if results['overall_success']:
|
||||||
|
print(' The refactored MCP server is working correctly!')
|
||||||
|
else:
|
||||||
|
print(' Some issues detected. Check individual test results above.')
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
"""Run the integration test."""
|
||||||
|
async with MCPIntegrationTest() as test:
|
||||||
|
results = await test.run_full_test_suite()
|
||||||
|
|
||||||
|
# Exit with appropriate code
|
||||||
|
exit_code = 0 if results['overall_success'] else 1
|
||||||
|
exit(exit_code)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
asyncio.run(main())
|
||||||
502
mcp_server/test_mcp_integration.py
Normal file
502
mcp_server/test_mcp_integration.py
Normal file
|
|
@ -0,0 +1,502 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Integration test for the refactored Graphiti MCP Server using the official MCP Python SDK.
|
||||||
|
Tests all major MCP tools and handles episode processing latency.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from mcp import ClientSession, StdioServerParameters
|
||||||
|
from mcp.client.stdio import stdio_client
|
||||||
|
|
||||||
|
|
||||||
|
class GraphitiMCPIntegrationTest:
|
||||||
|
"""Integration test client for Graphiti MCP Server using official MCP SDK."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.test_group_id = f'test_group_{int(time.time())}'
|
||||||
|
self.session = None
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
"""Start the MCP client session."""
|
||||||
|
# Configure server parameters to run our refactored server
|
||||||
|
server_params = StdioServerParameters(
|
||||||
|
command='uv',
|
||||||
|
args=['run', 'graphiti_mcp_server.py', '--transport', 'stdio'],
|
||||||
|
env={
|
||||||
|
'NEO4J_URI': 'bolt://localhost:7687',
|
||||||
|
'NEO4J_USER': 'neo4j',
|
||||||
|
'NEO4J_PASSWORD': 'demodemo',
|
||||||
|
'OPENAI_API_KEY': 'dummy_key_for_testing', # Will use existing .env
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f'🚀 Starting MCP client session with test group: {self.test_group_id}')
|
||||||
|
|
||||||
|
# Use the async context manager properly
|
||||||
|
self.client_context = stdio_client(server_params)
|
||||||
|
read, write = await self.client_context.__aenter__()
|
||||||
|
self.session = ClientSession(read, write)
|
||||||
|
await self.session.initialize()
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
"""Close the MCP client session."""
|
||||||
|
if self.session:
|
||||||
|
await self.session.close()
|
||||||
|
if hasattr(self, 'client_context'):
|
||||||
|
await self.client_context.__aexit__(exc_type, exc_val, exc_tb)
|
||||||
|
|
||||||
|
async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any:
|
||||||
|
"""Call an MCP tool and return the result."""
|
||||||
|
try:
|
||||||
|
result = await self.session.call_tool(tool_name, arguments)
|
||||||
|
return result.content[0].text if result.content else {'error': 'No content returned'}
|
||||||
|
except Exception as e:
|
||||||
|
return {'error': str(e)}
|
||||||
|
|
||||||
|
async def test_server_initialization(self) -> bool:
|
||||||
|
"""Test that the server initializes properly."""
|
||||||
|
print('🔍 Testing server initialization...')
|
||||||
|
|
||||||
|
try:
|
||||||
|
# List available tools to verify server is responding
|
||||||
|
tools_result = await self.session.list_tools()
|
||||||
|
tools = [tool.name for tool in tools_result.tools]
|
||||||
|
|
||||||
|
expected_tools = [
|
||||||
|
'add_memory',
|
||||||
|
'search_memory_nodes',
|
||||||
|
'search_memory_facts',
|
||||||
|
'get_episodes',
|
||||||
|
'delete_episode',
|
||||||
|
'delete_entity_edge',
|
||||||
|
'get_entity_edge',
|
||||||
|
'clear_graph',
|
||||||
|
]
|
||||||
|
|
||||||
|
available_tools = len([tool for tool in expected_tools if tool in tools])
|
||||||
|
print(
|
||||||
|
f' ✅ Server responding with {len(tools)} tools ({available_tools}/{len(expected_tools)} expected)'
|
||||||
|
)
|
||||||
|
print(f' Available tools: {", ".join(sorted(tools))}')
|
||||||
|
|
||||||
|
return available_tools >= len(expected_tools) * 0.8 # 80% of expected tools
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f' ❌ Server initialization failed: {e}')
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def test_add_memory_operations(self) -> dict[str, bool]:
|
||||||
|
"""Test adding various types of memory episodes."""
|
||||||
|
print('📝 Testing add_memory operations...')
|
||||||
|
|
||||||
|
results = {}
|
||||||
|
|
||||||
|
# Test 1: Add text episode
|
||||||
|
print(' Testing text episode...')
|
||||||
|
try:
|
||||||
|
result = await self.call_tool(
|
||||||
|
'add_memory',
|
||||||
|
{
|
||||||
|
'name': 'Test Company News',
|
||||||
|
'episode_body': 'Acme Corp announced a revolutionary new AI product that will transform the industry. The CEO mentioned this is their biggest launch since 2020.',
|
||||||
|
'source': 'text',
|
||||||
|
'source_description': 'news article',
|
||||||
|
'group_id': self.test_group_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(result, str) and 'queued' in result.lower():
|
||||||
|
print(f' ✅ Text episode: {result}')
|
||||||
|
results['text'] = True
|
||||||
|
else:
|
||||||
|
print(f' ❌ Text episode failed: {result}')
|
||||||
|
results['text'] = False
|
||||||
|
except Exception as e:
|
||||||
|
print(f' ❌ Text episode error: {e}')
|
||||||
|
results['text'] = False
|
||||||
|
|
||||||
|
# Test 2: Add JSON episode
|
||||||
|
print(' Testing JSON episode...')
|
||||||
|
try:
|
||||||
|
json_data = {
|
||||||
|
'company': {'name': 'TechCorp', 'founded': 2010},
|
||||||
|
'products': [
|
||||||
|
{'id': 'P001', 'name': 'CloudSync', 'category': 'software'},
|
||||||
|
{'id': 'P002', 'name': 'DataMiner', 'category': 'analytics'},
|
||||||
|
],
|
||||||
|
'employees': 150,
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await self.call_tool(
|
||||||
|
'add_memory',
|
||||||
|
{
|
||||||
|
'name': 'Company Profile',
|
||||||
|
'episode_body': json.dumps(json_data),
|
||||||
|
'source': 'json',
|
||||||
|
'source_description': 'CRM data',
|
||||||
|
'group_id': self.test_group_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(result, str) and 'queued' in result.lower():
|
||||||
|
print(f' ✅ JSON episode: {result}')
|
||||||
|
results['json'] = True
|
||||||
|
else:
|
||||||
|
print(f' ❌ JSON episode failed: {result}')
|
||||||
|
results['json'] = False
|
||||||
|
except Exception as e:
|
||||||
|
print(f' ❌ JSON episode error: {e}')
|
||||||
|
results['json'] = False
|
||||||
|
|
||||||
|
# Test 3: Add message episode
|
||||||
|
print(' Testing message episode...')
|
||||||
|
try:
|
||||||
|
result = await self.call_tool(
|
||||||
|
'add_memory',
|
||||||
|
{
|
||||||
|
'name': 'Customer Support Chat',
|
||||||
|
'episode_body': "user: What's your return policy?\nassistant: You can return items within 30 days of purchase with receipt.\nuser: Thanks!",
|
||||||
|
'source': 'message',
|
||||||
|
'source_description': 'support chat log',
|
||||||
|
'group_id': self.test_group_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(result, str) and 'queued' in result.lower():
|
||||||
|
print(f' ✅ Message episode: {result}')
|
||||||
|
results['message'] = True
|
||||||
|
else:
|
||||||
|
print(f' ❌ Message episode failed: {result}')
|
||||||
|
results['message'] = False
|
||||||
|
except Exception as e:
|
||||||
|
print(f' ❌ Message episode error: {e}')
|
||||||
|
results['message'] = False
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
async def wait_for_processing(self, max_wait: int = 45) -> bool:
|
||||||
|
"""Wait for episode processing to complete."""
|
||||||
|
print(f'⏳ Waiting up to {max_wait} seconds for episode processing...')
|
||||||
|
|
||||||
|
for i in range(max_wait):
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Check if we have any episodes
|
||||||
|
result = await self.call_tool(
|
||||||
|
'get_episodes', {'group_id': self.test_group_id, 'last_n': 10}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parse the JSON result if it's a string
|
||||||
|
if isinstance(result, str):
|
||||||
|
try:
|
||||||
|
parsed_result = json.loads(result)
|
||||||
|
if isinstance(parsed_result, list) and len(parsed_result) > 0:
|
||||||
|
print(
|
||||||
|
f' ✅ Found {len(parsed_result)} processed episodes after {i + 1} seconds'
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
if 'episodes' in result.lower():
|
||||||
|
print(f' ✅ Episodes detected after {i + 1} seconds')
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if i == 0: # Only log first error to avoid spam
|
||||||
|
print(f' ⚠️ Waiting for processing... ({e})')
|
||||||
|
continue
|
||||||
|
|
||||||
|
print(f' ⚠️ Still waiting after {max_wait} seconds...')
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def test_search_operations(self) -> dict[str, bool]:
|
||||||
|
"""Test search functionality."""
|
||||||
|
print('🔍 Testing search operations...')
|
||||||
|
|
||||||
|
results = {}
|
||||||
|
|
||||||
|
# Test search_memory_nodes
|
||||||
|
print(' Testing search_memory_nodes...')
|
||||||
|
try:
|
||||||
|
result = await self.call_tool(
|
||||||
|
'search_memory_nodes',
|
||||||
|
{
|
||||||
|
'query': 'Acme Corp product launch AI',
|
||||||
|
'group_ids': [self.test_group_id],
|
||||||
|
'max_nodes': 5,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
success = False
|
||||||
|
if isinstance(result, str):
|
||||||
|
try:
|
||||||
|
parsed = json.loads(result)
|
||||||
|
nodes = parsed.get('nodes', [])
|
||||||
|
success = isinstance(nodes, list)
|
||||||
|
print(f' ✅ Node search returned {len(nodes)} nodes')
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
success = 'nodes' in result.lower() and 'successfully' in result.lower()
|
||||||
|
if success:
|
||||||
|
print(' ✅ Node search completed successfully')
|
||||||
|
|
||||||
|
results['nodes'] = success
|
||||||
|
if not success:
|
||||||
|
print(f' ❌ Node search failed: {result}')
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f' ❌ Node search error: {e}')
|
||||||
|
results['nodes'] = False
|
||||||
|
|
||||||
|
# Test search_memory_facts
|
||||||
|
print(' Testing search_memory_facts...')
|
||||||
|
try:
|
||||||
|
result = await self.call_tool(
|
||||||
|
'search_memory_facts',
|
||||||
|
{
|
||||||
|
'query': 'company products software TechCorp',
|
||||||
|
'group_ids': [self.test_group_id],
|
||||||
|
'max_facts': 5,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
success = False
|
||||||
|
if isinstance(result, str):
|
||||||
|
try:
|
||||||
|
parsed = json.loads(result)
|
||||||
|
facts = parsed.get('facts', [])
|
||||||
|
success = isinstance(facts, list)
|
||||||
|
print(f' ✅ Fact search returned {len(facts)} facts')
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
success = 'facts' in result.lower() and 'successfully' in result.lower()
|
||||||
|
if success:
|
||||||
|
print(' ✅ Fact search completed successfully')
|
||||||
|
|
||||||
|
results['facts'] = success
|
||||||
|
if not success:
|
||||||
|
print(f' ❌ Fact search failed: {result}')
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f' ❌ Fact search error: {e}')
|
||||||
|
results['facts'] = False
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
async def test_episode_retrieval(self) -> bool:
|
||||||
|
"""Test episode retrieval."""
|
||||||
|
print('📚 Testing episode retrieval...')
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await self.call_tool(
|
||||||
|
'get_episodes', {'group_id': self.test_group_id, 'last_n': 10}
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(result, str):
|
||||||
|
try:
|
||||||
|
parsed = json.loads(result)
|
||||||
|
if isinstance(parsed, list):
|
||||||
|
print(f' ✅ Retrieved {len(parsed)} episodes')
|
||||||
|
|
||||||
|
# Show episode details
|
||||||
|
for i, episode in enumerate(parsed[:3]):
|
||||||
|
name = episode.get('name', 'Unknown')
|
||||||
|
source = episode.get('source', 'unknown')
|
||||||
|
print(f' Episode {i + 1}: {name} (source: {source})')
|
||||||
|
|
||||||
|
return len(parsed) > 0
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
# Check if response indicates success
|
||||||
|
if 'episode' in result.lower():
|
||||||
|
print(' ✅ Episode retrieval completed')
|
||||||
|
return True
|
||||||
|
|
||||||
|
print(f' ❌ Unexpected result format: {result}')
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f' ❌ Episode retrieval failed: {e}')
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def test_error_handling(self) -> dict[str, bool]:
|
||||||
|
"""Test error handling and edge cases."""
|
||||||
|
print('🧪 Testing error handling...')
|
||||||
|
|
||||||
|
results = {}
|
||||||
|
|
||||||
|
# Test with nonexistent group
|
||||||
|
print(' Testing nonexistent group handling...')
|
||||||
|
try:
|
||||||
|
result = await self.call_tool(
|
||||||
|
'search_memory_nodes',
|
||||||
|
{
|
||||||
|
'query': 'nonexistent data',
|
||||||
|
'group_ids': ['nonexistent_group_12345'],
|
||||||
|
'max_nodes': 5,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should handle gracefully, not crash
|
||||||
|
success = (
|
||||||
|
'error' not in str(result).lower() or 'not initialized' not in str(result).lower()
|
||||||
|
)
|
||||||
|
if success:
|
||||||
|
print(' ✅ Nonexistent group handled gracefully')
|
||||||
|
else:
|
||||||
|
print(f' ❌ Nonexistent group caused issues: {result}')
|
||||||
|
|
||||||
|
results['nonexistent_group'] = success
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f' ❌ Nonexistent group test failed: {e}')
|
||||||
|
results['nonexistent_group'] = False
|
||||||
|
|
||||||
|
# Test empty query
|
||||||
|
print(' Testing empty query handling...')
|
||||||
|
try:
|
||||||
|
result = await self.call_tool(
|
||||||
|
'search_memory_nodes',
|
||||||
|
{'query': '', 'group_ids': [self.test_group_id], 'max_nodes': 5},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should handle gracefully
|
||||||
|
success = (
|
||||||
|
'error' not in str(result).lower() or 'not initialized' not in str(result).lower()
|
||||||
|
)
|
||||||
|
if success:
|
||||||
|
print(' ✅ Empty query handled gracefully')
|
||||||
|
else:
|
||||||
|
print(f' ❌ Empty query caused issues: {result}')
|
||||||
|
|
||||||
|
results['empty_query'] = success
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f' ❌ Empty query test failed: {e}')
|
||||||
|
results['empty_query'] = False
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
async def run_comprehensive_test(self) -> dict[str, Any]:
|
||||||
|
"""Run the complete integration test suite."""
|
||||||
|
print('🚀 Starting Comprehensive Graphiti MCP Server Integration Test')
|
||||||
|
print(f' Test group ID: {self.test_group_id}')
|
||||||
|
print('=' * 70)
|
||||||
|
|
||||||
|
results = {
|
||||||
|
'server_init': False,
|
||||||
|
'add_memory': {},
|
||||||
|
'processing_wait': False,
|
||||||
|
'search': {},
|
||||||
|
'episodes': False,
|
||||||
|
'error_handling': {},
|
||||||
|
'overall_success': False,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Test 1: Server Initialization
|
||||||
|
results['server_init'] = await self.test_server_initialization()
|
||||||
|
if not results['server_init']:
|
||||||
|
print('❌ Server initialization failed, aborting remaining tests')
|
||||||
|
return results
|
||||||
|
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Test 2: Add Memory Operations
|
||||||
|
results['add_memory'] = await self.test_add_memory_operations()
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Test 3: Wait for Processing
|
||||||
|
results['processing_wait'] = await self.wait_for_processing()
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Test 4: Search Operations
|
||||||
|
results['search'] = await self.test_search_operations()
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Test 5: Episode Retrieval
|
||||||
|
results['episodes'] = await self.test_episode_retrieval()
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Test 6: Error Handling
|
||||||
|
results['error_handling'] = await self.test_error_handling()
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Calculate overall success
|
||||||
|
memory_success = any(results['add_memory'].values())
|
||||||
|
search_success = any(results['search'].values()) if results['search'] else False
|
||||||
|
error_success = (
|
||||||
|
any(results['error_handling'].values()) if results['error_handling'] else True
|
||||||
|
)
|
||||||
|
|
||||||
|
results['overall_success'] = (
|
||||||
|
results['server_init']
|
||||||
|
and memory_success
|
||||||
|
and (results['episodes'] or results['processing_wait'])
|
||||||
|
and error_success
|
||||||
|
)
|
||||||
|
|
||||||
|
# Print comprehensive summary
|
||||||
|
print('=' * 70)
|
||||||
|
print('📊 COMPREHENSIVE TEST SUMMARY')
|
||||||
|
print('-' * 35)
|
||||||
|
print(f'Server Initialization: {"✅ PASS" if results["server_init"] else "❌ FAIL"}')
|
||||||
|
|
||||||
|
memory_stats = f'({sum(results["add_memory"].values())}/{len(results["add_memory"])} types)'
|
||||||
|
print(
|
||||||
|
f'Memory Operations: {"✅ PASS" if memory_success else "❌ FAIL"} {memory_stats}'
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f'Processing Pipeline: {"✅ PASS" if results["processing_wait"] else "❌ FAIL"}')
|
||||||
|
|
||||||
|
search_stats = (
|
||||||
|
f'({sum(results["search"].values())}/{len(results["search"])} types)'
|
||||||
|
if results['search']
|
||||||
|
else '(0/0 types)'
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f'Search Operations: {"✅ PASS" if search_success else "❌ FAIL"} {search_stats}'
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f'Episode Retrieval: {"✅ PASS" if results["episodes"] else "❌ FAIL"}')
|
||||||
|
|
||||||
|
error_stats = (
|
||||||
|
f'({sum(results["error_handling"].values())}/{len(results["error_handling"])} cases)'
|
||||||
|
if results['error_handling']
|
||||||
|
else '(0/0 cases)'
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f'Error Handling: {"✅ PASS" if error_success else "❌ FAIL"} {error_stats}'
|
||||||
|
)
|
||||||
|
|
||||||
|
print('-' * 35)
|
||||||
|
print(f'🎯 OVERALL RESULT: {"✅ SUCCESS" if results["overall_success"] else "❌ FAILED"}')
|
||||||
|
|
||||||
|
if results['overall_success']:
|
||||||
|
print('\n🎉 The refactored Graphiti MCP server is working correctly!')
|
||||||
|
print(' All core functionality has been successfully tested.')
|
||||||
|
else:
|
||||||
|
print('\n⚠️ Some issues were detected. Review the test results above.')
|
||||||
|
print(' The refactoring may need additional attention.')
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
"""Run the integration test."""
|
||||||
|
try:
|
||||||
|
async with GraphitiMCPIntegrationTest() as test:
|
||||||
|
results = await test.run_comprehensive_test()
|
||||||
|
|
||||||
|
# Exit with appropriate code
|
||||||
|
exit_code = 0 if results['overall_success'] else 1
|
||||||
|
exit(exit_code)
|
||||||
|
except Exception as e:
|
||||||
|
print(f'❌ Test setup failed: {e}')
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
asyncio.run(main())
|
||||||
199
mcp_server/test_simple_validation.py
Normal file
199
mcp_server/test_simple_validation.py
Normal file
|
|
@ -0,0 +1,199 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Simple validation test for the refactored Graphiti MCP Server.
|
||||||
|
Tests basic functionality quickly without timeouts.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
|
def test_server_startup():
|
||||||
|
"""Test that the refactored server starts up successfully."""
|
||||||
|
print('🚀 Testing Graphiti MCP Server Startup...')
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Start the server and capture output
|
||||||
|
process = subprocess.Popen(
|
||||||
|
['uv', 'run', 'graphiti_mcp_server.py', '--transport', 'stdio'],
|
||||||
|
env={
|
||||||
|
'NEO4J_URI': 'bolt://localhost:7687',
|
||||||
|
'NEO4J_USER': 'neo4j',
|
||||||
|
'NEO4J_PASSWORD': 'demodemo',
|
||||||
|
},
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.PIPE,
|
||||||
|
text=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Wait for startup logs
|
||||||
|
startup_output = ''
|
||||||
|
for _ in range(50): # Wait up to 5 seconds
|
||||||
|
if process.poll() is not None:
|
||||||
|
break
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
# Check if we have output
|
||||||
|
try:
|
||||||
|
line = process.stderr.readline()
|
||||||
|
if line:
|
||||||
|
startup_output += line
|
||||||
|
print(f' 📋 {line.strip()}')
|
||||||
|
|
||||||
|
# Check for success indicators
|
||||||
|
if 'Graphiti client initialized successfully' in line:
|
||||||
|
print(' ✅ Graphiti service initialization: SUCCESS')
|
||||||
|
success = True
|
||||||
|
break
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
print(' ⚠️ Timeout waiting for initialization')
|
||||||
|
success = False
|
||||||
|
|
||||||
|
# Clean shutdown
|
||||||
|
process.terminate()
|
||||||
|
try:
|
||||||
|
process.wait(timeout=5)
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
process.kill()
|
||||||
|
|
||||||
|
return success
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f' ❌ Server startup failed: {e}')
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def test_import_validation():
|
||||||
|
"""Test that all refactored modules import correctly."""
|
||||||
|
print('\n🔍 Testing Module Import Validation...')
|
||||||
|
|
||||||
|
modules_to_test = [
|
||||||
|
'config_manager',
|
||||||
|
'llm_config',
|
||||||
|
'embedder_config',
|
||||||
|
'neo4j_config',
|
||||||
|
'server_config',
|
||||||
|
'graphiti_service',
|
||||||
|
'queue_service',
|
||||||
|
'entity_types',
|
||||||
|
'response_types',
|
||||||
|
'formatting',
|
||||||
|
'utils',
|
||||||
|
]
|
||||||
|
|
||||||
|
success_count = 0
|
||||||
|
|
||||||
|
for module in modules_to_test:
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
['python', '-c', f"import {module}; print('✅ {module}')"],
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.returncode == 0:
|
||||||
|
print(f' ✅ {module}: Import successful')
|
||||||
|
success_count += 1
|
||||||
|
else:
|
||||||
|
print(f' ❌ {module}: Import failed - {result.stderr.strip()}')
|
||||||
|
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
print(f' ❌ {module}: Import timeout')
|
||||||
|
except Exception as e:
|
||||||
|
print(f' ❌ {module}: Import error - {e}')
|
||||||
|
|
||||||
|
print(f' 📊 Import Results: {success_count}/{len(modules_to_test)} modules successful')
|
||||||
|
return success_count == len(modules_to_test)
|
||||||
|
|
||||||
|
|
||||||
|
def test_syntax_validation():
|
||||||
|
"""Test that all Python files have valid syntax."""
|
||||||
|
print('\n🔧 Testing Syntax Validation...')
|
||||||
|
|
||||||
|
files_to_test = [
|
||||||
|
'graphiti_mcp_server.py',
|
||||||
|
'config_manager.py',
|
||||||
|
'llm_config.py',
|
||||||
|
'embedder_config.py',
|
||||||
|
'neo4j_config.py',
|
||||||
|
'server_config.py',
|
||||||
|
'graphiti_service.py',
|
||||||
|
'queue_service.py',
|
||||||
|
'entity_types.py',
|
||||||
|
'response_types.py',
|
||||||
|
'formatting.py',
|
||||||
|
'utils.py',
|
||||||
|
]
|
||||||
|
|
||||||
|
success_count = 0
|
||||||
|
|
||||||
|
for file in files_to_test:
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
['python', '-m', 'py_compile', file], capture_output=True, text=True, timeout=10
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.returncode == 0:
|
||||||
|
print(f' ✅ {file}: Syntax valid')
|
||||||
|
success_count += 1
|
||||||
|
else:
|
||||||
|
print(f' ❌ {file}: Syntax error - {result.stderr.strip()}')
|
||||||
|
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
print(f' ❌ {file}: Syntax check timeout')
|
||||||
|
except Exception as e:
|
||||||
|
print(f' ❌ {file}: Syntax check error - {e}')
|
||||||
|
|
||||||
|
print(f' 📊 Syntax Results: {success_count}/{len(files_to_test)} files valid')
|
||||||
|
return success_count == len(files_to_test)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Run the validation tests."""
|
||||||
|
print('🧪 Graphiti MCP Server Refactoring Validation')
|
||||||
|
print('=' * 55)
|
||||||
|
|
||||||
|
results = {}
|
||||||
|
|
||||||
|
# Test 1: Syntax validation
|
||||||
|
results['syntax'] = test_syntax_validation()
|
||||||
|
|
||||||
|
# Test 2: Import validation
|
||||||
|
results['imports'] = test_import_validation()
|
||||||
|
|
||||||
|
# Test 3: Server startup
|
||||||
|
results['startup'] = test_server_startup()
|
||||||
|
|
||||||
|
# Summary
|
||||||
|
print('\n' + '=' * 55)
|
||||||
|
print('📊 VALIDATION SUMMARY')
|
||||||
|
print('-' * 25)
|
||||||
|
print(f'Syntax Validation: {"✅ PASS" if results["syntax"] else "❌ FAIL"}')
|
||||||
|
print(f'Import Validation: {"✅ PASS" if results["imports"] else "❌ FAIL"}')
|
||||||
|
print(f'Startup Validation: {"✅ PASS" if results["startup"] else "❌ FAIL"}')
|
||||||
|
|
||||||
|
overall_success = all(results.values())
|
||||||
|
print('-' * 25)
|
||||||
|
print(f'🎯 OVERALL: {"✅ SUCCESS" if overall_success else "❌ FAILED"}')
|
||||||
|
|
||||||
|
if overall_success:
|
||||||
|
print('\n🎉 Refactoring validation successful!')
|
||||||
|
print(' ✅ All modules have valid syntax')
|
||||||
|
print(' ✅ All imports work correctly')
|
||||||
|
print(' ✅ Server initializes successfully')
|
||||||
|
print(' ✅ The refactored MCP server is ready for use!')
|
||||||
|
else:
|
||||||
|
print('\n⚠️ Some validation issues detected.')
|
||||||
|
print(' Please review the failed tests above.')
|
||||||
|
|
||||||
|
return 0 if overall_success else 1
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
exit_code = main()
|
||||||
|
sys.exit(exit_code)
|
||||||
14
mcp_server/utils.py
Normal file
14
mcp_server/utils.py
Normal file
|
|
@ -0,0 +1,14 @@
|
||||||
|
"""Utility functions for Graphiti MCP Server."""
|
||||||
|
|
||||||
|
from collections.abc import Callable
|
||||||
|
|
||||||
|
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
|
||||||
|
|
||||||
|
|
||||||
|
def create_azure_credential_token_provider() -> Callable[[], str]:
|
||||||
|
"""Create Azure credential token provider for managed identity authentication."""
|
||||||
|
credential = DefaultAzureCredential()
|
||||||
|
token_provider = get_bearer_token_provider(
|
||||||
|
credential, 'https://cognitiveservices.azure.com/.default'
|
||||||
|
)
|
||||||
|
return token_provider
|
||||||
14
mcp_server/uv.lock
generated
14
mcp_server/uv.lock
generated
|
|
@ -457,7 +457,7 @@ wheels = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "mcp-server"
|
name = "mcp-server"
|
||||||
version = "0.1.0"
|
version = "0.2.1"
|
||||||
source = { virtual = "." }
|
source = { virtual = "." }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "azure-identity" },
|
{ name = "azure-identity" },
|
||||||
|
|
@ -466,6 +466,12 @@ dependencies = [
|
||||||
{ name = "openai" },
|
{ name = "openai" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[package.dev-dependencies]
|
||||||
|
dev = [
|
||||||
|
{ name = "httpx" },
|
||||||
|
{ name = "mcp" },
|
||||||
|
]
|
||||||
|
|
||||||
[package.metadata]
|
[package.metadata]
|
||||||
requires-dist = [
|
requires-dist = [
|
||||||
{ name = "azure-identity", specifier = ">=1.21.0" },
|
{ name = "azure-identity", specifier = ">=1.21.0" },
|
||||||
|
|
@ -475,6 +481,12 @@ requires-dist = [
|
||||||
{ name = "openai", specifier = ">=1.68.2" },
|
{ name = "openai", specifier = ">=1.68.2" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[package.metadata.requires-dev]
|
||||||
|
dev = [
|
||||||
|
{ name = "httpx", specifier = ">=0.28.1" },
|
||||||
|
{ name = "mcp", specifier = ">=1.9.4" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "msal"
|
name = "msal"
|
||||||
version = "1.32.3"
|
version = "1.32.3"
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue