diff --git a/cognee-mcp/src/server.py b/cognee-mcp/src/server.py index cc6eac09e..c390d9ac6 100755 --- a/cognee-mcp/src/server.py +++ b/cognee-mcp/src/server.py @@ -2,28 +2,25 @@ import json import os import sys import argparse -import cognee import asyncio import subprocess from pathlib import Path +from typing import Optional from cognee.shared.logging_utils import get_logger, setup_logging, get_log_file_location import importlib.util from contextlib import redirect_stdout import mcp.types as types from mcp.server import FastMCP -from cognee.modules.pipelines.operations.get_pipeline_status import get_pipeline_status -from cognee.modules.data.methods.get_unique_dataset_id import get_unique_dataset_id -from cognee.modules.users.methods import get_default_user -from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline -from cognee.modules.search.types import SearchType -from cognee.shared.data_models import KnowledgeGraph from cognee.modules.storage.utils import JSONEncoder from starlette.responses import JSONResponse from starlette.middleware import Middleware from starlette.middleware.cors import CORSMiddleware import uvicorn +# Import the new CogneeClient abstraction +from cognee_client import CogneeClient + try: from cognee.tasks.codingagents.coding_rule_associations import ( @@ -41,6 +38,9 @@ mcp = FastMCP("Cognee") logger = get_logger() +# Global CogneeClient instance (will be initialized in main()) +cognee_client: Optional[CogneeClient] = None + async def run_sse_with_cors(): """Custom SSE transport with CORS middleware.""" @@ -141,11 +141,20 @@ async def cognee_add_developer_rules( with redirect_stdout(sys.stderr): logger.info(f"Starting cognify for: {file_path}") try: - await cognee.add(file_path, node_set=["developer_rules"]) - model = KnowledgeGraph + await cognee_client.add(file_path, node_set=["developer_rules"]) + + model = None if graph_model_file and graph_model_name: - model = load_class(graph_model_file, graph_model_name) - await cognee.cognify(graph_model=model) + if cognee_client.use_api: + logger.warning( + "Custom graph models are not supported in API mode, ignoring." + ) + else: + from cognee.shared.data_models import KnowledgeGraph + + model = load_class(graph_model_file, graph_model_name) + + await cognee_client.cognify(graph_model=model) logger.info(f"Cognify finished for: {file_path}") except Exception as e: logger.error(f"Cognify failed for {file_path}: {str(e)}") @@ -293,15 +302,20 @@ async def cognify( # going to stdout ( like the print function ) to stderr. with redirect_stdout(sys.stderr): logger.info("Cognify process starting.") - if graph_model_file and graph_model_name: - graph_model = load_class(graph_model_file, graph_model_name) - else: - graph_model = KnowledgeGraph - await cognee.add(data) + graph_model = None + if graph_model_file and graph_model_name: + if cognee_client.use_api: + logger.warning("Custom graph models are not supported in API mode, ignoring.") + else: + from cognee.shared.data_models import KnowledgeGraph + + graph_model = load_class(graph_model_file, graph_model_name) + + await cognee_client.add(data) try: - await cognee.cognify(graph_model=graph_model, custom_prompt=custom_prompt) + await cognee_client.cognify(custom_prompt=custom_prompt, graph_model=graph_model) logger.info("Cognify process finished.") except Exception as e: logger.error("Cognify process failed.") @@ -354,16 +368,19 @@ async def save_interaction(data: str) -> list: with redirect_stdout(sys.stderr): logger.info("Save interaction process starting.") - await cognee.add(data, node_set=["user_agent_interaction"]) + await cognee_client.add(data, node_set=["user_agent_interaction"]) try: - await cognee.cognify() + await cognee_client.cognify() logger.info("Save interaction process finished.") - logger.info("Generating associated rules from interaction data.") - await add_rule_associations(data=data, rules_nodeset_name="coding_agent_rules") - - logger.info("Associated rules generated from interaction data.") + # Rule associations only work in direct mode + if not cognee_client.use_api: + logger.info("Generating associated rules from interaction data.") + await add_rule_associations(data=data, rules_nodeset_name="coding_agent_rules") + logger.info("Associated rules generated from interaction data.") + else: + logger.warning("Rule associations are not available in API mode, skipping.") except Exception as e: logger.error("Save interaction process failed.") @@ -420,11 +437,18 @@ async def codify(repo_path: str) -> list: - All stdout is redirected to stderr to maintain MCP communication integrity """ + if cognee_client.use_api: + error_msg = "❌ Codify operation is not available in API mode. Please use direct mode for code graph pipeline." + logger.error(error_msg) + return [types.TextContent(type="text", text=error_msg)] + async def codify_task(repo_path: str): # NOTE: MCP uses stdout to communicate, we must redirect all output # going to stdout ( like the print function ) to stderr. with redirect_stdout(sys.stderr): logger.info("Codify process starting.") + from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline + results = [] async for result in run_code_graph_pipeline(repo_path, False): results.append(result) @@ -574,23 +598,40 @@ async def search(search_query: str, search_type: str) -> list: # NOTE: MCP uses stdout to communicate, we must redirect all output # going to stdout ( like the print function ) to stderr. with redirect_stdout(sys.stderr): - search_results = await cognee.search( - query_type=SearchType[search_type.upper()], query_text=search_query + search_results = await cognee_client.search( + query_text=search_query, query_type=search_type ) - if search_type.upper() == "CODE": - return json.dumps(search_results, cls=JSONEncoder) - elif ( - search_type.upper() == "GRAPH_COMPLETION" or search_type.upper() == "RAG_COMPLETION" - ): - return str(search_results[0]) - elif search_type.upper() == "CHUNKS": - return str(search_results) - elif search_type.upper() == "INSIGHTS": - results = retrieved_edges_to_string(search_results) - return results + # Handle different result formats based on API vs direct mode + if cognee_client.use_api: + # API mode returns JSON-serialized results + if isinstance(search_results, str): + return search_results + elif isinstance(search_results, list): + if ( + search_type.upper() in ["GRAPH_COMPLETION", "RAG_COMPLETION"] + and len(search_results) > 0 + ): + return str(search_results[0]) + return str(search_results) + else: + return json.dumps(search_results, cls=JSONEncoder) else: - return str(search_results) + # Direct mode processing + if search_type.upper() == "CODE": + return json.dumps(search_results, cls=JSONEncoder) + elif ( + search_type.upper() == "GRAPH_COMPLETION" + or search_type.upper() == "RAG_COMPLETION" + ): + return str(search_results[0]) + elif search_type.upper() == "CHUNKS": + return str(search_results) + elif search_type.upper() == "INSIGHTS": + results = retrieved_edges_to_string(search_results) + return results + else: + return str(search_results) search_results = await search_task(search_query, search_type) return [types.TextContent(type="text", text=search_results)] @@ -623,6 +664,10 @@ async def get_developer_rules() -> list: async def fetch_rules_from_cognee() -> str: """Collect all developer rules from Cognee""" with redirect_stdout(sys.stderr): + if cognee_client.use_api: + logger.warning("Developer rules retrieval is not available in API mode") + return "Developer rules retrieval is not available in API mode" + developer_rules = await get_existing_rules(rules_nodeset_name="coding_agent_rules") return developer_rules @@ -662,16 +707,24 @@ async def list_data(dataset_id: str = None) -> list: with redirect_stdout(sys.stderr): try: - user = await get_default_user() output_lines = [] if dataset_id: - # List data for specific dataset + # Detailed data listing for specific dataset is only available in direct mode + if cognee_client.use_api: + return [ + types.TextContent( + type="text", + text="❌ Detailed data listing for specific datasets is not available in API mode.\nPlease use the API directly or use direct mode.", + ) + ] + + from cognee.modules.users.methods import get_default_user + from cognee.modules.data.methods import get_dataset, get_dataset_data + logger.info(f"Listing data for dataset: {dataset_id}") dataset_uuid = UUID(dataset_id) - - # Get the dataset information - from cognee.modules.data.methods import get_dataset, get_dataset_data + user = await get_default_user() dataset = await get_dataset(user.id, dataset_uuid) @@ -700,11 +753,9 @@ async def list_data(dataset_id: str = None) -> list: output_lines.append(" (No data items in this dataset)") else: - # List all datasets + # List all datasets - works in both modes logger.info("Listing all datasets") - from cognee.modules.data.methods import get_datasets - - datasets = await get_datasets(user.id) + datasets = await cognee_client.list_datasets() if not datasets: return [ @@ -719,20 +770,21 @@ async def list_data(dataset_id: str = None) -> list: output_lines.append("") for i, dataset in enumerate(datasets, 1): - # Get data count for each dataset - from cognee.modules.data.methods import get_dataset_data - - data_items = await get_dataset_data(dataset.id) - - output_lines.append(f"{i}. 📁 {dataset.name}") - output_lines.append(f" Dataset ID: {dataset.id}") - output_lines.append(f" Created: {dataset.created_at}") - output_lines.append(f" Data items: {len(data_items)}") + # In API mode, dataset is a dict; in direct mode, it's formatted as dict + if isinstance(dataset, dict): + output_lines.append(f"{i}. 📁 {dataset.get('name', 'Unnamed')}") + output_lines.append(f" Dataset ID: {dataset.get('id')}") + output_lines.append(f" Created: {dataset.get('created_at', 'N/A')}") + else: + output_lines.append(f"{i}. 📁 {dataset.name}") + output_lines.append(f" Dataset ID: {dataset.id}") + output_lines.append(f" Created: {dataset.created_at}") output_lines.append("") - output_lines.append("💡 To see data items in a specific dataset, use:") - output_lines.append(' list_data(dataset_id="your-dataset-id-here")') - output_lines.append("") + if not cognee_client.use_api: + output_lines.append("💡 To see data items in a specific dataset, use:") + output_lines.append(' list_data(dataset_id="your-dataset-id-here")') + output_lines.append("") output_lines.append("🗑️ To delete specific data, use:") output_lines.append(' delete(data_id="data-id", dataset_id="dataset-id")') @@ -801,12 +853,9 @@ async def delete(data_id: str, dataset_id: str, mode: str = "soft") -> list: data_uuid = UUID(data_id) dataset_uuid = UUID(dataset_id) - # Get default user for the operation - user = await get_default_user() - - # Call the cognee delete function - result = await cognee.delete( - data_id=data_uuid, dataset_id=dataset_uuid, mode=mode, user=user + # Call the cognee delete function via client + result = await cognee_client.delete( + data_id=data_uuid, dataset_id=dataset_uuid, mode=mode ) logger.info(f"Delete operation completed successfully: {result}") @@ -853,11 +902,21 @@ async def prune(): ----- - This operation cannot be undone. All memory data will be permanently deleted. - The function prunes both data content (using prune_data) and system metadata (using prune_system) + - This operation is not available in API mode """ with redirect_stdout(sys.stderr): - await cognee.prune.prune_data() - await cognee.prune.prune_system(metadata=True) - return [types.TextContent(type="text", text="Pruned")] + try: + await cognee_client.prune_data() + await cognee_client.prune_system(metadata=True) + return [types.TextContent(type="text", text="Pruned")] + except NotImplementedError: + error_msg = "❌ Prune operation is not available in API mode" + logger.error(error_msg) + return [types.TextContent(type="text", text=error_msg)] + except Exception as e: + error_msg = f"❌ Prune operation failed: {str(e)}" + logger.error(error_msg) + return [types.TextContent(type="text", text=error_msg)] @mcp.tool() @@ -880,13 +939,26 @@ async def cognify_status(): - The function retrieves pipeline status specifically for the "cognify_pipeline" on the "main_dataset" - Status information includes job progress, execution time, and completion status - The status is returned in string format for easy reading + - This operation is not available in API mode """ with redirect_stdout(sys.stderr): - user = await get_default_user() - status = await get_pipeline_status( - [await get_unique_dataset_id("main_dataset", user)], "cognify_pipeline" - ) - return [types.TextContent(type="text", text=str(status))] + try: + from cognee.modules.data.methods.get_unique_dataset_id import get_unique_dataset_id + from cognee.modules.users.methods import get_default_user + + user = await get_default_user() + status = await cognee_client.get_pipeline_status( + [await get_unique_dataset_id("main_dataset", user)], "cognify_pipeline" + ) + return [types.TextContent(type="text", text=str(status))] + except NotImplementedError: + error_msg = "❌ Pipeline status is not available in API mode" + logger.error(error_msg) + return [types.TextContent(type="text", text=error_msg)] + except Exception as e: + error_msg = f"❌ Failed to get cognify status: {str(e)}" + logger.error(error_msg) + return [types.TextContent(type="text", text=error_msg)] @mcp.tool() @@ -909,13 +981,26 @@ async def codify_status(): - The function retrieves pipeline status specifically for the "cognify_code_pipeline" on the "codebase" dataset - Status information includes job progress, execution time, and completion status - The status is returned in string format for easy reading + - This operation is not available in API mode """ with redirect_stdout(sys.stderr): - user = await get_default_user() - status = await get_pipeline_status( - [await get_unique_dataset_id("codebase", user)], "cognify_code_pipeline" - ) - return [types.TextContent(type="text", text=str(status))] + try: + from cognee.modules.data.methods.get_unique_dataset_id import get_unique_dataset_id + from cognee.modules.users.methods import get_default_user + + user = await get_default_user() + status = await cognee_client.get_pipeline_status( + [await get_unique_dataset_id("codebase", user)], "cognify_code_pipeline" + ) + return [types.TextContent(type="text", text=str(status))] + except NotImplementedError: + error_msg = "❌ Pipeline status is not available in API mode" + logger.error(error_msg) + return [types.TextContent(type="text", text=error_msg)] + except Exception as e: + error_msg = f"❌ Failed to get codify status: {str(e)}" + logger.error(error_msg) + return [types.TextContent(type="text", text=error_msg)] def node_to_string(node): @@ -949,6 +1034,8 @@ def load_class(model_file, model_name): async def main(): + global cognee_client + parser = argparse.ArgumentParser() parser.add_argument( @@ -992,12 +1079,30 @@ async def main(): help="Argument stops database migration from being attempted", ) + # Cognee API connection options + parser.add_argument( + "--api-url", + default=None, + help="Base URL of a running Cognee FastAPI server (e.g., http://localhost:8000). " + "If provided, the MCP server will connect to the API instead of using cognee directly.", + ) + + parser.add_argument( + "--api-token", + default=None, + help="Authentication token for the Cognee API. Required if --api-url is provided.", + ) + args = parser.parse_args() + # Initialize the global CogneeClient + cognee_client = CogneeClient(api_url=args.api_url, api_token=args.api_token) + mcp.settings.host = args.host mcp.settings.port = args.port - if not args.no_migration: + # Skip migrations when in API mode (the API server handles its own database) + if not args.no_migration and not args.api_url: # Run Alembic migrations from the main cognee directory where alembic.ini is located logger.info("Running database migrations...") migration_result = subprocess.run( @@ -1020,6 +1125,8 @@ async def main(): sys.exit(1) logger.info("Database migrations done.") + elif args.api_url: + logger.info("Skipping database migrations (using API mode)") logger.info(f"Starting MCP server with transport: {args.transport}") if args.transport == "stdio":