route server tools calls through cognee_client

This commit is contained in:
Daulet Amirkhanov 2025-10-08 18:36:44 +01:00
parent 806298d508
commit 3e23c96595

View file

@ -2,28 +2,25 @@ import json
import os
import sys
import argparse
import cognee
import asyncio
import subprocess
from pathlib import Path
from typing import Optional
from cognee.shared.logging_utils import get_logger, setup_logging, get_log_file_location
import importlib.util
from contextlib import redirect_stdout
import mcp.types as types
from mcp.server import FastMCP
from cognee.modules.pipelines.operations.get_pipeline_status import get_pipeline_status
from cognee.modules.data.methods.get_unique_dataset_id import get_unique_dataset_id
from cognee.modules.users.methods import get_default_user
from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline
from cognee.modules.search.types import SearchType
from cognee.shared.data_models import KnowledgeGraph
from cognee.modules.storage.utils import JSONEncoder
from starlette.responses import JSONResponse
from starlette.middleware import Middleware
from starlette.middleware.cors import CORSMiddleware
import uvicorn
# Import the new CogneeClient abstraction
from cognee_client import CogneeClient
try:
from cognee.tasks.codingagents.coding_rule_associations import (
@ -41,6 +38,9 @@ mcp = FastMCP("Cognee")
logger = get_logger()
# Global CogneeClient instance (will be initialized in main())
cognee_client: Optional[CogneeClient] = None
async def run_sse_with_cors():
"""Custom SSE transport with CORS middleware."""
@ -141,11 +141,20 @@ async def cognee_add_developer_rules(
with redirect_stdout(sys.stderr):
logger.info(f"Starting cognify for: {file_path}")
try:
await cognee.add(file_path, node_set=["developer_rules"])
model = KnowledgeGraph
await cognee_client.add(file_path, node_set=["developer_rules"])
model = None
if graph_model_file and graph_model_name:
if cognee_client.use_api:
logger.warning(
"Custom graph models are not supported in API mode, ignoring."
)
else:
from cognee.shared.data_models import KnowledgeGraph
model = load_class(graph_model_file, graph_model_name)
await cognee.cognify(graph_model=model)
await cognee_client.cognify(graph_model=model)
logger.info(f"Cognify finished for: {file_path}")
except Exception as e:
logger.error(f"Cognify failed for {file_path}: {str(e)}")
@ -293,15 +302,20 @@ async def cognify(
# going to stdout ( like the print function ) to stderr.
with redirect_stdout(sys.stderr):
logger.info("Cognify process starting.")
if graph_model_file and graph_model_name:
graph_model = load_class(graph_model_file, graph_model_name)
else:
graph_model = KnowledgeGraph
await cognee.add(data)
graph_model = None
if graph_model_file and graph_model_name:
if cognee_client.use_api:
logger.warning("Custom graph models are not supported in API mode, ignoring.")
else:
from cognee.shared.data_models import KnowledgeGraph
graph_model = load_class(graph_model_file, graph_model_name)
await cognee_client.add(data)
try:
await cognee.cognify(graph_model=graph_model, custom_prompt=custom_prompt)
await cognee_client.cognify(custom_prompt=custom_prompt, graph_model=graph_model)
logger.info("Cognify process finished.")
except Exception as e:
logger.error("Cognify process failed.")
@ -354,16 +368,19 @@ async def save_interaction(data: str) -> list:
with redirect_stdout(sys.stderr):
logger.info("Save interaction process starting.")
await cognee.add(data, node_set=["user_agent_interaction"])
await cognee_client.add(data, node_set=["user_agent_interaction"])
try:
await cognee.cognify()
await cognee_client.cognify()
logger.info("Save interaction process finished.")
# Rule associations only work in direct mode
if not cognee_client.use_api:
logger.info("Generating associated rules from interaction data.")
await add_rule_associations(data=data, rules_nodeset_name="coding_agent_rules")
logger.info("Associated rules generated from interaction data.")
else:
logger.warning("Rule associations are not available in API mode, skipping.")
except Exception as e:
logger.error("Save interaction process failed.")
@ -420,11 +437,18 @@ async def codify(repo_path: str) -> list:
- All stdout is redirected to stderr to maintain MCP communication integrity
"""
if cognee_client.use_api:
error_msg = "❌ Codify operation is not available in API mode. Please use direct mode for code graph pipeline."
logger.error(error_msg)
return [types.TextContent(type="text", text=error_msg)]
async def codify_task(repo_path: str):
# NOTE: MCP uses stdout to communicate, we must redirect all output
# going to stdout ( like the print function ) to stderr.
with redirect_stdout(sys.stderr):
logger.info("Codify process starting.")
from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline
results = []
async for result in run_code_graph_pipeline(repo_path, False):
results.append(result)
@ -574,14 +598,31 @@ async def search(search_query: str, search_type: str) -> list:
# NOTE: MCP uses stdout to communicate, we must redirect all output
# going to stdout ( like the print function ) to stderr.
with redirect_stdout(sys.stderr):
search_results = await cognee.search(
query_type=SearchType[search_type.upper()], query_text=search_query
search_results = await cognee_client.search(
query_text=search_query, query_type=search_type
)
# Handle different result formats based on API vs direct mode
if cognee_client.use_api:
# API mode returns JSON-serialized results
if isinstance(search_results, str):
return search_results
elif isinstance(search_results, list):
if (
search_type.upper() in ["GRAPH_COMPLETION", "RAG_COMPLETION"]
and len(search_results) > 0
):
return str(search_results[0])
return str(search_results)
else:
return json.dumps(search_results, cls=JSONEncoder)
else:
# Direct mode processing
if search_type.upper() == "CODE":
return json.dumps(search_results, cls=JSONEncoder)
elif (
search_type.upper() == "GRAPH_COMPLETION" or search_type.upper() == "RAG_COMPLETION"
search_type.upper() == "GRAPH_COMPLETION"
or search_type.upper() == "RAG_COMPLETION"
):
return str(search_results[0])
elif search_type.upper() == "CHUNKS":
@ -623,6 +664,10 @@ async def get_developer_rules() -> list:
async def fetch_rules_from_cognee() -> str:
"""Collect all developer rules from Cognee"""
with redirect_stdout(sys.stderr):
if cognee_client.use_api:
logger.warning("Developer rules retrieval is not available in API mode")
return "Developer rules retrieval is not available in API mode"
developer_rules = await get_existing_rules(rules_nodeset_name="coding_agent_rules")
return developer_rules
@ -662,16 +707,24 @@ async def list_data(dataset_id: str = None) -> list:
with redirect_stdout(sys.stderr):
try:
user = await get_default_user()
output_lines = []
if dataset_id:
# List data for specific dataset
# Detailed data listing for specific dataset is only available in direct mode
if cognee_client.use_api:
return [
types.TextContent(
type="text",
text="❌ Detailed data listing for specific datasets is not available in API mode.\nPlease use the API directly or use direct mode.",
)
]
from cognee.modules.users.methods import get_default_user
from cognee.modules.data.methods import get_dataset, get_dataset_data
logger.info(f"Listing data for dataset: {dataset_id}")
dataset_uuid = UUID(dataset_id)
# Get the dataset information
from cognee.modules.data.methods import get_dataset, get_dataset_data
user = await get_default_user()
dataset = await get_dataset(user.id, dataset_uuid)
@ -700,11 +753,9 @@ async def list_data(dataset_id: str = None) -> list:
output_lines.append(" (No data items in this dataset)")
else:
# List all datasets
# List all datasets - works in both modes
logger.info("Listing all datasets")
from cognee.modules.data.methods import get_datasets
datasets = await get_datasets(user.id)
datasets = await cognee_client.list_datasets()
if not datasets:
return [
@ -719,17 +770,18 @@ async def list_data(dataset_id: str = None) -> list:
output_lines.append("")
for i, dataset in enumerate(datasets, 1):
# Get data count for each dataset
from cognee.modules.data.methods import get_dataset_data
data_items = await get_dataset_data(dataset.id)
# In API mode, dataset is a dict; in direct mode, it's formatted as dict
if isinstance(dataset, dict):
output_lines.append(f"{i}. 📁 {dataset.get('name', 'Unnamed')}")
output_lines.append(f" Dataset ID: {dataset.get('id')}")
output_lines.append(f" Created: {dataset.get('created_at', 'N/A')}")
else:
output_lines.append(f"{i}. 📁 {dataset.name}")
output_lines.append(f" Dataset ID: {dataset.id}")
output_lines.append(f" Created: {dataset.created_at}")
output_lines.append(f" Data items: {len(data_items)}")
output_lines.append("")
if not cognee_client.use_api:
output_lines.append("💡 To see data items in a specific dataset, use:")
output_lines.append(' list_data(dataset_id="your-dataset-id-here")')
output_lines.append("")
@ -801,12 +853,9 @@ async def delete(data_id: str, dataset_id: str, mode: str = "soft") -> list:
data_uuid = UUID(data_id)
dataset_uuid = UUID(dataset_id)
# Get default user for the operation
user = await get_default_user()
# Call the cognee delete function
result = await cognee.delete(
data_id=data_uuid, dataset_id=dataset_uuid, mode=mode, user=user
# Call the cognee delete function via client
result = await cognee_client.delete(
data_id=data_uuid, dataset_id=dataset_uuid, mode=mode
)
logger.info(f"Delete operation completed successfully: {result}")
@ -853,11 +902,21 @@ async def prune():
-----
- This operation cannot be undone. All memory data will be permanently deleted.
- The function prunes both data content (using prune_data) and system metadata (using prune_system)
- This operation is not available in API mode
"""
with redirect_stdout(sys.stderr):
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
try:
await cognee_client.prune_data()
await cognee_client.prune_system(metadata=True)
return [types.TextContent(type="text", text="Pruned")]
except NotImplementedError:
error_msg = "❌ Prune operation is not available in API mode"
logger.error(error_msg)
return [types.TextContent(type="text", text=error_msg)]
except Exception as e:
error_msg = f"❌ Prune operation failed: {str(e)}"
logger.error(error_msg)
return [types.TextContent(type="text", text=error_msg)]
@mcp.tool()
@ -880,13 +939,26 @@ async def cognify_status():
- The function retrieves pipeline status specifically for the "cognify_pipeline" on the "main_dataset"
- Status information includes job progress, execution time, and completion status
- The status is returned in string format for easy reading
- This operation is not available in API mode
"""
with redirect_stdout(sys.stderr):
try:
from cognee.modules.data.methods.get_unique_dataset_id import get_unique_dataset_id
from cognee.modules.users.methods import get_default_user
user = await get_default_user()
status = await get_pipeline_status(
status = await cognee_client.get_pipeline_status(
[await get_unique_dataset_id("main_dataset", user)], "cognify_pipeline"
)
return [types.TextContent(type="text", text=str(status))]
except NotImplementedError:
error_msg = "❌ Pipeline status is not available in API mode"
logger.error(error_msg)
return [types.TextContent(type="text", text=error_msg)]
except Exception as e:
error_msg = f"❌ Failed to get cognify status: {str(e)}"
logger.error(error_msg)
return [types.TextContent(type="text", text=error_msg)]
@mcp.tool()
@ -909,13 +981,26 @@ async def codify_status():
- The function retrieves pipeline status specifically for the "cognify_code_pipeline" on the "codebase" dataset
- Status information includes job progress, execution time, and completion status
- The status is returned in string format for easy reading
- This operation is not available in API mode
"""
with redirect_stdout(sys.stderr):
try:
from cognee.modules.data.methods.get_unique_dataset_id import get_unique_dataset_id
from cognee.modules.users.methods import get_default_user
user = await get_default_user()
status = await get_pipeline_status(
status = await cognee_client.get_pipeline_status(
[await get_unique_dataset_id("codebase", user)], "cognify_code_pipeline"
)
return [types.TextContent(type="text", text=str(status))]
except NotImplementedError:
error_msg = "❌ Pipeline status is not available in API mode"
logger.error(error_msg)
return [types.TextContent(type="text", text=error_msg)]
except Exception as e:
error_msg = f"❌ Failed to get codify status: {str(e)}"
logger.error(error_msg)
return [types.TextContent(type="text", text=error_msg)]
def node_to_string(node):
@ -949,6 +1034,8 @@ def load_class(model_file, model_name):
async def main():
global cognee_client
parser = argparse.ArgumentParser()
parser.add_argument(
@ -992,12 +1079,30 @@ async def main():
help="Argument stops database migration from being attempted",
)
# Cognee API connection options
parser.add_argument(
"--api-url",
default=None,
help="Base URL of a running Cognee FastAPI server (e.g., http://localhost:8000). "
"If provided, the MCP server will connect to the API instead of using cognee directly.",
)
parser.add_argument(
"--api-token",
default=None,
help="Authentication token for the Cognee API. Required if --api-url is provided.",
)
args = parser.parse_args()
# Initialize the global CogneeClient
cognee_client = CogneeClient(api_url=args.api_url, api_token=args.api_token)
mcp.settings.host = args.host
mcp.settings.port = args.port
if not args.no_migration:
# Skip migrations when in API mode (the API server handles its own database)
if not args.no_migration and not args.api_url:
# Run Alembic migrations from the main cognee directory where alembic.ini is located
logger.info("Running database migrations...")
migration_result = subprocess.run(
@ -1020,6 +1125,8 @@ async def main():
sys.exit(1)
logger.info("Database migrations done.")
elif args.api_url:
logger.info("Skipping database migrations (using API mode)")
logger.info(f"Starting MCP server with transport: {args.transport}")
if args.transport == "stdio":