diff --git a/mcp_server/src/graphiti_mcp_server.py b/mcp_server/src/graphiti_mcp_server.py index c611a2ae..9568f96a 100644 --- a/mcp_server/src/graphiti_mcp_server.py +++ b/mcp_server/src/graphiti_mcp_server.py @@ -13,6 +13,7 @@ from typing import Any, Optional from dotenv import load_dotenv from graphiti_core import Graphiti +from graphiti_core.driver.neo4j_driver import Neo4jDriver from graphiti_core.edges import EntityEdge from graphiti_core.nodes import EpisodeType, EpisodicNode from graphiti_core.search.search_filters import SearchFilters @@ -229,11 +230,16 @@ class GraphitiService: max_coroutines=self.semaphore_limit, ) else: - # For Neo4j (default), use the original approach - self.client = Graphiti( + # For Neo4j (default), create a Neo4jDriver instance with database parameter + neo4j_driver = Neo4jDriver( uri=db_config['uri'], user=db_config['user'], password=db_config['password'], + database=db_config.get('database', 'neo4j'), + ) + + self.client = Graphiti( + graph_driver=neo4j_driver, llm_client=llm_client, embedder=embedder_client, max_coroutines=self.semaphore_limit, diff --git a/mcp_server/src/services/factories.py b/mcp_server/src/services/factories.py index 65f710b7..af1d3f17 100644 --- a/mcp_server/src/services/factories.py +++ b/mcp_server/src/services/factories.py @@ -389,13 +389,13 @@ class DatabaseDriverFactory: uri = os.environ.get('NEO4J_URI', neo4j_config.uri) username = os.environ.get('NEO4J_USER', neo4j_config.username) password = os.environ.get('NEO4J_PASSWORD', neo4j_config.password) + database = os.environ.get('NEO4J_DATABASE', neo4j_config.database) return { 'uri': uri, 'user': username, 'password': password, - # Note: database and use_parallel_runtime would need to be passed - # to the driver after initialization if supported + 'database': database, } case 'falkordb': diff --git a/mcp_server/tests/test_database_param.py b/mcp_server/tests/test_database_param.py new file mode 100644 index 00000000..971a5d5f --- /dev/null +++ b/mcp_server/tests/test_database_param.py @@ -0,0 +1,74 @@ +#!/usr/bin/env python3 +"""Test Neo4j database parameter configuration.""" + +import os +import sys +from pathlib import Path + +# Setup path for imports +sys.path.insert(0, str(Path(__file__).parent.parent)) + +def test_neo4j_database_parameter(): + """Test that Neo4j database parameter is included in configuration.""" + from src.config.schema import GraphitiConfig + from src.services.factories import DatabaseDriverFactory + + print('\n' + '='*70) + print('Testing Neo4j Database Parameter Configuration') + print('='*70 + '\n') + + # Test 1: Default database value + print('Test 1: Default database value') + config = GraphitiConfig() + db_config = DatabaseDriverFactory.create_config(config.database) + + assert 'database' in db_config, 'Database parameter missing from config!' + print(f' ✓ Database parameter present in config') + print(f' ✓ Default database value: {db_config["database"]}') + assert db_config['database'] == 'neo4j', f'Expected default "neo4j", got {db_config["database"]}' + print(f' ✓ Default value matches expected: neo4j\n') + + # Test 2: Environment variable override + print('Test 2: Environment variable override') + os.environ['NEO4J_DATABASE'] = 'graphiti' + config2 = GraphitiConfig() + db_config2 = DatabaseDriverFactory.create_config(config2.database) + + assert 'database' in db_config2, 'Database parameter missing from config!' + print(f' ✓ Database parameter present in config') + print(f' ✓ Overridden database value: {db_config2["database"]}') + assert db_config2['database'] == 'graphiti', f'Expected "graphiti", got {db_config2["database"]}' + print(f' ✓ Environment override works correctly\n') + + # Clean up + del os.environ['NEO4J_DATABASE'] + + # Test 3: Verify all required parameters are present + print('Test 3: Verify all required Neo4j parameters') + required_params = ['uri', 'user', 'password', 'database'] + for param in required_params: + assert param in db_config, f'Required parameter "{param}" missing!' + print(f' ✓ {param}: present') + + print('\n' + '='*70) + print('✅ All database parameter tests passed!') + print('='*70) + print('\nSummary:') + print(' - database parameter is included in Neo4j config') + print(' - Default value is "neo4j"') + print(' - Environment variable NEO4J_DATABASE override works') + print(' - All required parameters (uri, user, password, database) present') + print() + + +if __name__ == '__main__': + try: + test_neo4j_database_parameter() + except AssertionError as e: + print(f'\n❌ Test failed: {e}\n') + sys.exit(1) + except Exception as e: + print(f'\n❌ Unexpected error: {e}\n') + import traceback + traceback.print_exc() + sys.exit(1)