route server tools calls through cognee_client

This commit is contained in:
Daulet Amirkhanov 2025-10-08 18:36:44 +01:00
parent 806298d508
commit 3e23c96595

View file

@ -2,28 +2,25 @@ import json
import os import os
import sys import sys
import argparse import argparse
import cognee
import asyncio import asyncio
import subprocess import subprocess
from pathlib import Path from pathlib import Path
from typing import Optional
from cognee.shared.logging_utils import get_logger, setup_logging, get_log_file_location from cognee.shared.logging_utils import get_logger, setup_logging, get_log_file_location
import importlib.util import importlib.util
from contextlib import redirect_stdout from contextlib import redirect_stdout
import mcp.types as types import mcp.types as types
from mcp.server import FastMCP from mcp.server import FastMCP
from cognee.modules.pipelines.operations.get_pipeline_status import get_pipeline_status
from cognee.modules.data.methods.get_unique_dataset_id import get_unique_dataset_id
from cognee.modules.users.methods import get_default_user
from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline
from cognee.modules.search.types import SearchType
from cognee.shared.data_models import KnowledgeGraph
from cognee.modules.storage.utils import JSONEncoder from cognee.modules.storage.utils import JSONEncoder
from starlette.responses import JSONResponse from starlette.responses import JSONResponse
from starlette.middleware import Middleware from starlette.middleware import Middleware
from starlette.middleware.cors import CORSMiddleware from starlette.middleware.cors import CORSMiddleware
import uvicorn import uvicorn
# Import the new CogneeClient abstraction
from cognee_client import CogneeClient
try: try:
from cognee.tasks.codingagents.coding_rule_associations import ( from cognee.tasks.codingagents.coding_rule_associations import (
@ -41,6 +38,9 @@ mcp = FastMCP("Cognee")
logger = get_logger() logger = get_logger()
# Global CogneeClient instance (will be initialized in main())
cognee_client: Optional[CogneeClient] = None
async def run_sse_with_cors(): async def run_sse_with_cors():
"""Custom SSE transport with CORS middleware.""" """Custom SSE transport with CORS middleware."""
@ -141,11 +141,20 @@ async def cognee_add_developer_rules(
with redirect_stdout(sys.stderr): with redirect_stdout(sys.stderr):
logger.info(f"Starting cognify for: {file_path}") logger.info(f"Starting cognify for: {file_path}")
try: try:
await cognee.add(file_path, node_set=["developer_rules"]) await cognee_client.add(file_path, node_set=["developer_rules"])
model = KnowledgeGraph
model = None
if graph_model_file and graph_model_name: if graph_model_file and graph_model_name:
model = load_class(graph_model_file, graph_model_name) if cognee_client.use_api:
await cognee.cognify(graph_model=model) logger.warning(
"Custom graph models are not supported in API mode, ignoring."
)
else:
from cognee.shared.data_models import KnowledgeGraph
model = load_class(graph_model_file, graph_model_name)
await cognee_client.cognify(graph_model=model)
logger.info(f"Cognify finished for: {file_path}") logger.info(f"Cognify finished for: {file_path}")
except Exception as e: except Exception as e:
logger.error(f"Cognify failed for {file_path}: {str(e)}") logger.error(f"Cognify failed for {file_path}: {str(e)}")
@ -293,15 +302,20 @@ async def cognify(
# going to stdout ( like the print function ) to stderr. # going to stdout ( like the print function ) to stderr.
with redirect_stdout(sys.stderr): with redirect_stdout(sys.stderr):
logger.info("Cognify process starting.") logger.info("Cognify process starting.")
if graph_model_file and graph_model_name:
graph_model = load_class(graph_model_file, graph_model_name)
else:
graph_model = KnowledgeGraph
await cognee.add(data) graph_model = None
if graph_model_file and graph_model_name:
if cognee_client.use_api:
logger.warning("Custom graph models are not supported in API mode, ignoring.")
else:
from cognee.shared.data_models import KnowledgeGraph
graph_model = load_class(graph_model_file, graph_model_name)
await cognee_client.add(data)
try: try:
await cognee.cognify(graph_model=graph_model, custom_prompt=custom_prompt) await cognee_client.cognify(custom_prompt=custom_prompt, graph_model=graph_model)
logger.info("Cognify process finished.") logger.info("Cognify process finished.")
except Exception as e: except Exception as e:
logger.error("Cognify process failed.") logger.error("Cognify process failed.")
@ -354,16 +368,19 @@ async def save_interaction(data: str) -> list:
with redirect_stdout(sys.stderr): with redirect_stdout(sys.stderr):
logger.info("Save interaction process starting.") logger.info("Save interaction process starting.")
await cognee.add(data, node_set=["user_agent_interaction"]) await cognee_client.add(data, node_set=["user_agent_interaction"])
try: try:
await cognee.cognify() await cognee_client.cognify()
logger.info("Save interaction process finished.") logger.info("Save interaction process finished.")
logger.info("Generating associated rules from interaction data.")
await add_rule_associations(data=data, rules_nodeset_name="coding_agent_rules") # Rule associations only work in direct mode
if not cognee_client.use_api:
logger.info("Associated rules generated from interaction data.") logger.info("Generating associated rules from interaction data.")
await add_rule_associations(data=data, rules_nodeset_name="coding_agent_rules")
logger.info("Associated rules generated from interaction data.")
else:
logger.warning("Rule associations are not available in API mode, skipping.")
except Exception as e: except Exception as e:
logger.error("Save interaction process failed.") logger.error("Save interaction process failed.")
@ -420,11 +437,18 @@ async def codify(repo_path: str) -> list:
- All stdout is redirected to stderr to maintain MCP communication integrity - All stdout is redirected to stderr to maintain MCP communication integrity
""" """
if cognee_client.use_api:
error_msg = "❌ Codify operation is not available in API mode. Please use direct mode for code graph pipeline."
logger.error(error_msg)
return [types.TextContent(type="text", text=error_msg)]
async def codify_task(repo_path: str): async def codify_task(repo_path: str):
# NOTE: MCP uses stdout to communicate, we must redirect all output # NOTE: MCP uses stdout to communicate, we must redirect all output
# going to stdout ( like the print function ) to stderr. # going to stdout ( like the print function ) to stderr.
with redirect_stdout(sys.stderr): with redirect_stdout(sys.stderr):
logger.info("Codify process starting.") logger.info("Codify process starting.")
from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline
results = [] results = []
async for result in run_code_graph_pipeline(repo_path, False): async for result in run_code_graph_pipeline(repo_path, False):
results.append(result) results.append(result)
@ -574,23 +598,40 @@ async def search(search_query: str, search_type: str) -> list:
# NOTE: MCP uses stdout to communicate, we must redirect all output # NOTE: MCP uses stdout to communicate, we must redirect all output
# going to stdout ( like the print function ) to stderr. # going to stdout ( like the print function ) to stderr.
with redirect_stdout(sys.stderr): with redirect_stdout(sys.stderr):
search_results = await cognee.search( search_results = await cognee_client.search(
query_type=SearchType[search_type.upper()], query_text=search_query query_text=search_query, query_type=search_type
) )
if search_type.upper() == "CODE": # Handle different result formats based on API vs direct mode
return json.dumps(search_results, cls=JSONEncoder) if cognee_client.use_api:
elif ( # API mode returns JSON-serialized results
search_type.upper() == "GRAPH_COMPLETION" or search_type.upper() == "RAG_COMPLETION" if isinstance(search_results, str):
): return search_results
return str(search_results[0]) elif isinstance(search_results, list):
elif search_type.upper() == "CHUNKS": if (
return str(search_results) search_type.upper() in ["GRAPH_COMPLETION", "RAG_COMPLETION"]
elif search_type.upper() == "INSIGHTS": and len(search_results) > 0
results = retrieved_edges_to_string(search_results) ):
return results return str(search_results[0])
return str(search_results)
else:
return json.dumps(search_results, cls=JSONEncoder)
else: else:
return str(search_results) # Direct mode processing
if search_type.upper() == "CODE":
return json.dumps(search_results, cls=JSONEncoder)
elif (
search_type.upper() == "GRAPH_COMPLETION"
or search_type.upper() == "RAG_COMPLETION"
):
return str(search_results[0])
elif search_type.upper() == "CHUNKS":
return str(search_results)
elif search_type.upper() == "INSIGHTS":
results = retrieved_edges_to_string(search_results)
return results
else:
return str(search_results)
search_results = await search_task(search_query, search_type) search_results = await search_task(search_query, search_type)
return [types.TextContent(type="text", text=search_results)] return [types.TextContent(type="text", text=search_results)]
@ -623,6 +664,10 @@ async def get_developer_rules() -> list:
async def fetch_rules_from_cognee() -> str: async def fetch_rules_from_cognee() -> str:
"""Collect all developer rules from Cognee""" """Collect all developer rules from Cognee"""
with redirect_stdout(sys.stderr): with redirect_stdout(sys.stderr):
if cognee_client.use_api:
logger.warning("Developer rules retrieval is not available in API mode")
return "Developer rules retrieval is not available in API mode"
developer_rules = await get_existing_rules(rules_nodeset_name="coding_agent_rules") developer_rules = await get_existing_rules(rules_nodeset_name="coding_agent_rules")
return developer_rules return developer_rules
@ -662,16 +707,24 @@ async def list_data(dataset_id: str = None) -> list:
with redirect_stdout(sys.stderr): with redirect_stdout(sys.stderr):
try: try:
user = await get_default_user()
output_lines = [] output_lines = []
if dataset_id: if dataset_id:
# List data for specific dataset # Detailed data listing for specific dataset is only available in direct mode
if cognee_client.use_api:
return [
types.TextContent(
type="text",
text="❌ Detailed data listing for specific datasets is not available in API mode.\nPlease use the API directly or use direct mode.",
)
]
from cognee.modules.users.methods import get_default_user
from cognee.modules.data.methods import get_dataset, get_dataset_data
logger.info(f"Listing data for dataset: {dataset_id}") logger.info(f"Listing data for dataset: {dataset_id}")
dataset_uuid = UUID(dataset_id) dataset_uuid = UUID(dataset_id)
user = await get_default_user()
# Get the dataset information
from cognee.modules.data.methods import get_dataset, get_dataset_data
dataset = await get_dataset(user.id, dataset_uuid) dataset = await get_dataset(user.id, dataset_uuid)
@ -700,11 +753,9 @@ async def list_data(dataset_id: str = None) -> list:
output_lines.append(" (No data items in this dataset)") output_lines.append(" (No data items in this dataset)")
else: else:
# List all datasets # List all datasets - works in both modes
logger.info("Listing all datasets") logger.info("Listing all datasets")
from cognee.modules.data.methods import get_datasets datasets = await cognee_client.list_datasets()
datasets = await get_datasets(user.id)
if not datasets: if not datasets:
return [ return [
@ -719,20 +770,21 @@ async def list_data(dataset_id: str = None) -> list:
output_lines.append("") output_lines.append("")
for i, dataset in enumerate(datasets, 1): for i, dataset in enumerate(datasets, 1):
# Get data count for each dataset # In API mode, dataset is a dict; in direct mode, it's formatted as dict
from cognee.modules.data.methods import get_dataset_data if isinstance(dataset, dict):
output_lines.append(f"{i}. 📁 {dataset.get('name', 'Unnamed')}")
data_items = await get_dataset_data(dataset.id) output_lines.append(f" Dataset ID: {dataset.get('id')}")
output_lines.append(f" Created: {dataset.get('created_at', 'N/A')}")
output_lines.append(f"{i}. 📁 {dataset.name}") else:
output_lines.append(f" Dataset ID: {dataset.id}") output_lines.append(f"{i}. 📁 {dataset.name}")
output_lines.append(f" Created: {dataset.created_at}") output_lines.append(f" Dataset ID: {dataset.id}")
output_lines.append(f" Data items: {len(data_items)}") output_lines.append(f" Created: {dataset.created_at}")
output_lines.append("") output_lines.append("")
output_lines.append("💡 To see data items in a specific dataset, use:") if not cognee_client.use_api:
output_lines.append(' list_data(dataset_id="your-dataset-id-here")') output_lines.append("💡 To see data items in a specific dataset, use:")
output_lines.append("") output_lines.append(' list_data(dataset_id="your-dataset-id-here")')
output_lines.append("")
output_lines.append("🗑️ To delete specific data, use:") output_lines.append("🗑️ To delete specific data, use:")
output_lines.append(' delete(data_id="data-id", dataset_id="dataset-id")') output_lines.append(' delete(data_id="data-id", dataset_id="dataset-id")')
@ -801,12 +853,9 @@ async def delete(data_id: str, dataset_id: str, mode: str = "soft") -> list:
data_uuid = UUID(data_id) data_uuid = UUID(data_id)
dataset_uuid = UUID(dataset_id) dataset_uuid = UUID(dataset_id)
# Get default user for the operation # Call the cognee delete function via client
user = await get_default_user() result = await cognee_client.delete(
data_id=data_uuid, dataset_id=dataset_uuid, mode=mode
# Call the cognee delete function
result = await cognee.delete(
data_id=data_uuid, dataset_id=dataset_uuid, mode=mode, user=user
) )
logger.info(f"Delete operation completed successfully: {result}") logger.info(f"Delete operation completed successfully: {result}")
@ -853,11 +902,21 @@ async def prune():
----- -----
- This operation cannot be undone. All memory data will be permanently deleted. - This operation cannot be undone. All memory data will be permanently deleted.
- The function prunes both data content (using prune_data) and system metadata (using prune_system) - The function prunes both data content (using prune_data) and system metadata (using prune_system)
- This operation is not available in API mode
""" """
with redirect_stdout(sys.stderr): with redirect_stdout(sys.stderr):
await cognee.prune.prune_data() try:
await cognee.prune.prune_system(metadata=True) await cognee_client.prune_data()
return [types.TextContent(type="text", text="Pruned")] await cognee_client.prune_system(metadata=True)
return [types.TextContent(type="text", text="Pruned")]
except NotImplementedError:
error_msg = "❌ Prune operation is not available in API mode"
logger.error(error_msg)
return [types.TextContent(type="text", text=error_msg)]
except Exception as e:
error_msg = f"❌ Prune operation failed: {str(e)}"
logger.error(error_msg)
return [types.TextContent(type="text", text=error_msg)]
@mcp.tool() @mcp.tool()
@ -880,13 +939,26 @@ async def cognify_status():
- The function retrieves pipeline status specifically for the "cognify_pipeline" on the "main_dataset" - The function retrieves pipeline status specifically for the "cognify_pipeline" on the "main_dataset"
- Status information includes job progress, execution time, and completion status - Status information includes job progress, execution time, and completion status
- The status is returned in string format for easy reading - The status is returned in string format for easy reading
- This operation is not available in API mode
""" """
with redirect_stdout(sys.stderr): with redirect_stdout(sys.stderr):
user = await get_default_user() try:
status = await get_pipeline_status( from cognee.modules.data.methods.get_unique_dataset_id import get_unique_dataset_id
[await get_unique_dataset_id("main_dataset", user)], "cognify_pipeline" from cognee.modules.users.methods import get_default_user
)
return [types.TextContent(type="text", text=str(status))] user = await get_default_user()
status = await cognee_client.get_pipeline_status(
[await get_unique_dataset_id("main_dataset", user)], "cognify_pipeline"
)
return [types.TextContent(type="text", text=str(status))]
except NotImplementedError:
error_msg = "❌ Pipeline status is not available in API mode"
logger.error(error_msg)
return [types.TextContent(type="text", text=error_msg)]
except Exception as e:
error_msg = f"❌ Failed to get cognify status: {str(e)}"
logger.error(error_msg)
return [types.TextContent(type="text", text=error_msg)]
@mcp.tool() @mcp.tool()
@ -909,13 +981,26 @@ async def codify_status():
- The function retrieves pipeline status specifically for the "cognify_code_pipeline" on the "codebase" dataset - The function retrieves pipeline status specifically for the "cognify_code_pipeline" on the "codebase" dataset
- Status information includes job progress, execution time, and completion status - Status information includes job progress, execution time, and completion status
- The status is returned in string format for easy reading - The status is returned in string format for easy reading
- This operation is not available in API mode
""" """
with redirect_stdout(sys.stderr): with redirect_stdout(sys.stderr):
user = await get_default_user() try:
status = await get_pipeline_status( from cognee.modules.data.methods.get_unique_dataset_id import get_unique_dataset_id
[await get_unique_dataset_id("codebase", user)], "cognify_code_pipeline" from cognee.modules.users.methods import get_default_user
)
return [types.TextContent(type="text", text=str(status))] user = await get_default_user()
status = await cognee_client.get_pipeline_status(
[await get_unique_dataset_id("codebase", user)], "cognify_code_pipeline"
)
return [types.TextContent(type="text", text=str(status))]
except NotImplementedError:
error_msg = "❌ Pipeline status is not available in API mode"
logger.error(error_msg)
return [types.TextContent(type="text", text=error_msg)]
except Exception as e:
error_msg = f"❌ Failed to get codify status: {str(e)}"
logger.error(error_msg)
return [types.TextContent(type="text", text=error_msg)]
def node_to_string(node): def node_to_string(node):
@ -949,6 +1034,8 @@ def load_class(model_file, model_name):
async def main(): async def main():
global cognee_client
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
@ -992,12 +1079,30 @@ async def main():
help="Argument stops database migration from being attempted", help="Argument stops database migration from being attempted",
) )
# Cognee API connection options
parser.add_argument(
"--api-url",
default=None,
help="Base URL of a running Cognee FastAPI server (e.g., http://localhost:8000). "
"If provided, the MCP server will connect to the API instead of using cognee directly.",
)
parser.add_argument(
"--api-token",
default=None,
help="Authentication token for the Cognee API. Required if --api-url is provided.",
)
args = parser.parse_args() args = parser.parse_args()
# Initialize the global CogneeClient
cognee_client = CogneeClient(api_url=args.api_url, api_token=args.api_token)
mcp.settings.host = args.host mcp.settings.host = args.host
mcp.settings.port = args.port mcp.settings.port = args.port
if not args.no_migration: # Skip migrations when in API mode (the API server handles its own database)
if not args.no_migration and not args.api_url:
# Run Alembic migrations from the main cognee directory where alembic.ini is located # Run Alembic migrations from the main cognee directory where alembic.ini is located
logger.info("Running database migrations...") logger.info("Running database migrations...")
migration_result = subprocess.run( migration_result = subprocess.run(
@ -1020,6 +1125,8 @@ async def main():
sys.exit(1) sys.exit(1)
logger.info("Database migrations done.") logger.info("Database migrations done.")
elif args.api_url:
logger.info("Skipping database migrations (using API mode)")
logger.info(f"Starting MCP server with transport: {args.transport}") logger.info(f"Starting MCP server with transport: {args.transport}")
if args.transport == "stdio": if args.transport == "stdio":