route server tools calls through cognee_client
This commit is contained in:
parent
806298d508
commit
3e23c96595
1 changed files with 185 additions and 78 deletions
|
|
@ -2,28 +2,25 @@ import json
|
|||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import cognee
|
||||
import asyncio
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from cognee.shared.logging_utils import get_logger, setup_logging, get_log_file_location
|
||||
import importlib.util
|
||||
from contextlib import redirect_stdout
|
||||
import mcp.types as types
|
||||
from mcp.server import FastMCP
|
||||
from cognee.modules.pipelines.operations.get_pipeline_status import get_pipeline_status
|
||||
from cognee.modules.data.methods.get_unique_dataset_id import get_unique_dataset_id
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline
|
||||
from cognee.modules.search.types import SearchType
|
||||
from cognee.shared.data_models import KnowledgeGraph
|
||||
from cognee.modules.storage.utils import JSONEncoder
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.middleware import Middleware
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
import uvicorn
|
||||
|
||||
# Import the new CogneeClient abstraction
|
||||
from cognee_client import CogneeClient
|
||||
|
||||
|
||||
try:
|
||||
from cognee.tasks.codingagents.coding_rule_associations import (
|
||||
|
|
@ -41,6 +38,9 @@ mcp = FastMCP("Cognee")
|
|||
|
||||
logger = get_logger()
|
||||
|
||||
# Global CogneeClient instance (will be initialized in main())
|
||||
cognee_client: Optional[CogneeClient] = None
|
||||
|
||||
|
||||
async def run_sse_with_cors():
|
||||
"""Custom SSE transport with CORS middleware."""
|
||||
|
|
@ -141,11 +141,20 @@ async def cognee_add_developer_rules(
|
|||
with redirect_stdout(sys.stderr):
|
||||
logger.info(f"Starting cognify for: {file_path}")
|
||||
try:
|
||||
await cognee.add(file_path, node_set=["developer_rules"])
|
||||
model = KnowledgeGraph
|
||||
await cognee_client.add(file_path, node_set=["developer_rules"])
|
||||
|
||||
model = None
|
||||
if graph_model_file and graph_model_name:
|
||||
if cognee_client.use_api:
|
||||
logger.warning(
|
||||
"Custom graph models are not supported in API mode, ignoring."
|
||||
)
|
||||
else:
|
||||
from cognee.shared.data_models import KnowledgeGraph
|
||||
|
||||
model = load_class(graph_model_file, graph_model_name)
|
||||
await cognee.cognify(graph_model=model)
|
||||
|
||||
await cognee_client.cognify(graph_model=model)
|
||||
logger.info(f"Cognify finished for: {file_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Cognify failed for {file_path}: {str(e)}")
|
||||
|
|
@ -293,15 +302,20 @@ async def cognify(
|
|||
# going to stdout ( like the print function ) to stderr.
|
||||
with redirect_stdout(sys.stderr):
|
||||
logger.info("Cognify process starting.")
|
||||
if graph_model_file and graph_model_name:
|
||||
graph_model = load_class(graph_model_file, graph_model_name)
|
||||
else:
|
||||
graph_model = KnowledgeGraph
|
||||
|
||||
await cognee.add(data)
|
||||
graph_model = None
|
||||
if graph_model_file and graph_model_name:
|
||||
if cognee_client.use_api:
|
||||
logger.warning("Custom graph models are not supported in API mode, ignoring.")
|
||||
else:
|
||||
from cognee.shared.data_models import KnowledgeGraph
|
||||
|
||||
graph_model = load_class(graph_model_file, graph_model_name)
|
||||
|
||||
await cognee_client.add(data)
|
||||
|
||||
try:
|
||||
await cognee.cognify(graph_model=graph_model, custom_prompt=custom_prompt)
|
||||
await cognee_client.cognify(custom_prompt=custom_prompt, graph_model=graph_model)
|
||||
logger.info("Cognify process finished.")
|
||||
except Exception as e:
|
||||
logger.error("Cognify process failed.")
|
||||
|
|
@ -354,16 +368,19 @@ async def save_interaction(data: str) -> list:
|
|||
with redirect_stdout(sys.stderr):
|
||||
logger.info("Save interaction process starting.")
|
||||
|
||||
await cognee.add(data, node_set=["user_agent_interaction"])
|
||||
await cognee_client.add(data, node_set=["user_agent_interaction"])
|
||||
|
||||
try:
|
||||
await cognee.cognify()
|
||||
await cognee_client.cognify()
|
||||
logger.info("Save interaction process finished.")
|
||||
|
||||
# Rule associations only work in direct mode
|
||||
if not cognee_client.use_api:
|
||||
logger.info("Generating associated rules from interaction data.")
|
||||
|
||||
await add_rule_associations(data=data, rules_nodeset_name="coding_agent_rules")
|
||||
|
||||
logger.info("Associated rules generated from interaction data.")
|
||||
else:
|
||||
logger.warning("Rule associations are not available in API mode, skipping.")
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Save interaction process failed.")
|
||||
|
|
@ -420,11 +437,18 @@ async def codify(repo_path: str) -> list:
|
|||
- All stdout is redirected to stderr to maintain MCP communication integrity
|
||||
"""
|
||||
|
||||
if cognee_client.use_api:
|
||||
error_msg = "❌ Codify operation is not available in API mode. Please use direct mode for code graph pipeline."
|
||||
logger.error(error_msg)
|
||||
return [types.TextContent(type="text", text=error_msg)]
|
||||
|
||||
async def codify_task(repo_path: str):
|
||||
# NOTE: MCP uses stdout to communicate, we must redirect all output
|
||||
# going to stdout ( like the print function ) to stderr.
|
||||
with redirect_stdout(sys.stderr):
|
||||
logger.info("Codify process starting.")
|
||||
from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline
|
||||
|
||||
results = []
|
||||
async for result in run_code_graph_pipeline(repo_path, False):
|
||||
results.append(result)
|
||||
|
|
@ -574,14 +598,31 @@ async def search(search_query: str, search_type: str) -> list:
|
|||
# NOTE: MCP uses stdout to communicate, we must redirect all output
|
||||
# going to stdout ( like the print function ) to stderr.
|
||||
with redirect_stdout(sys.stderr):
|
||||
search_results = await cognee.search(
|
||||
query_type=SearchType[search_type.upper()], query_text=search_query
|
||||
search_results = await cognee_client.search(
|
||||
query_text=search_query, query_type=search_type
|
||||
)
|
||||
|
||||
# Handle different result formats based on API vs direct mode
|
||||
if cognee_client.use_api:
|
||||
# API mode returns JSON-serialized results
|
||||
if isinstance(search_results, str):
|
||||
return search_results
|
||||
elif isinstance(search_results, list):
|
||||
if (
|
||||
search_type.upper() in ["GRAPH_COMPLETION", "RAG_COMPLETION"]
|
||||
and len(search_results) > 0
|
||||
):
|
||||
return str(search_results[0])
|
||||
return str(search_results)
|
||||
else:
|
||||
return json.dumps(search_results, cls=JSONEncoder)
|
||||
else:
|
||||
# Direct mode processing
|
||||
if search_type.upper() == "CODE":
|
||||
return json.dumps(search_results, cls=JSONEncoder)
|
||||
elif (
|
||||
search_type.upper() == "GRAPH_COMPLETION" or search_type.upper() == "RAG_COMPLETION"
|
||||
search_type.upper() == "GRAPH_COMPLETION"
|
||||
or search_type.upper() == "RAG_COMPLETION"
|
||||
):
|
||||
return str(search_results[0])
|
||||
elif search_type.upper() == "CHUNKS":
|
||||
|
|
@ -623,6 +664,10 @@ async def get_developer_rules() -> list:
|
|||
async def fetch_rules_from_cognee() -> str:
|
||||
"""Collect all developer rules from Cognee"""
|
||||
with redirect_stdout(sys.stderr):
|
||||
if cognee_client.use_api:
|
||||
logger.warning("Developer rules retrieval is not available in API mode")
|
||||
return "Developer rules retrieval is not available in API mode"
|
||||
|
||||
developer_rules = await get_existing_rules(rules_nodeset_name="coding_agent_rules")
|
||||
return developer_rules
|
||||
|
||||
|
|
@ -662,16 +707,24 @@ async def list_data(dataset_id: str = None) -> list:
|
|||
|
||||
with redirect_stdout(sys.stderr):
|
||||
try:
|
||||
user = await get_default_user()
|
||||
output_lines = []
|
||||
|
||||
if dataset_id:
|
||||
# List data for specific dataset
|
||||
# Detailed data listing for specific dataset is only available in direct mode
|
||||
if cognee_client.use_api:
|
||||
return [
|
||||
types.TextContent(
|
||||
type="text",
|
||||
text="❌ Detailed data listing for specific datasets is not available in API mode.\nPlease use the API directly or use direct mode.",
|
||||
)
|
||||
]
|
||||
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.modules.data.methods import get_dataset, get_dataset_data
|
||||
|
||||
logger.info(f"Listing data for dataset: {dataset_id}")
|
||||
dataset_uuid = UUID(dataset_id)
|
||||
|
||||
# Get the dataset information
|
||||
from cognee.modules.data.methods import get_dataset, get_dataset_data
|
||||
user = await get_default_user()
|
||||
|
||||
dataset = await get_dataset(user.id, dataset_uuid)
|
||||
|
||||
|
|
@ -700,11 +753,9 @@ async def list_data(dataset_id: str = None) -> list:
|
|||
output_lines.append(" (No data items in this dataset)")
|
||||
|
||||
else:
|
||||
# List all datasets
|
||||
# List all datasets - works in both modes
|
||||
logger.info("Listing all datasets")
|
||||
from cognee.modules.data.methods import get_datasets
|
||||
|
||||
datasets = await get_datasets(user.id)
|
||||
datasets = await cognee_client.list_datasets()
|
||||
|
||||
if not datasets:
|
||||
return [
|
||||
|
|
@ -719,17 +770,18 @@ async def list_data(dataset_id: str = None) -> list:
|
|||
output_lines.append("")
|
||||
|
||||
for i, dataset in enumerate(datasets, 1):
|
||||
# Get data count for each dataset
|
||||
from cognee.modules.data.methods import get_dataset_data
|
||||
|
||||
data_items = await get_dataset_data(dataset.id)
|
||||
|
||||
# In API mode, dataset is a dict; in direct mode, it's formatted as dict
|
||||
if isinstance(dataset, dict):
|
||||
output_lines.append(f"{i}. 📁 {dataset.get('name', 'Unnamed')}")
|
||||
output_lines.append(f" Dataset ID: {dataset.get('id')}")
|
||||
output_lines.append(f" Created: {dataset.get('created_at', 'N/A')}")
|
||||
else:
|
||||
output_lines.append(f"{i}. 📁 {dataset.name}")
|
||||
output_lines.append(f" Dataset ID: {dataset.id}")
|
||||
output_lines.append(f" Created: {dataset.created_at}")
|
||||
output_lines.append(f" Data items: {len(data_items)}")
|
||||
output_lines.append("")
|
||||
|
||||
if not cognee_client.use_api:
|
||||
output_lines.append("💡 To see data items in a specific dataset, use:")
|
||||
output_lines.append(' list_data(dataset_id="your-dataset-id-here")')
|
||||
output_lines.append("")
|
||||
|
|
@ -801,12 +853,9 @@ async def delete(data_id: str, dataset_id: str, mode: str = "soft") -> list:
|
|||
data_uuid = UUID(data_id)
|
||||
dataset_uuid = UUID(dataset_id)
|
||||
|
||||
# Get default user for the operation
|
||||
user = await get_default_user()
|
||||
|
||||
# Call the cognee delete function
|
||||
result = await cognee.delete(
|
||||
data_id=data_uuid, dataset_id=dataset_uuid, mode=mode, user=user
|
||||
# Call the cognee delete function via client
|
||||
result = await cognee_client.delete(
|
||||
data_id=data_uuid, dataset_id=dataset_uuid, mode=mode
|
||||
)
|
||||
|
||||
logger.info(f"Delete operation completed successfully: {result}")
|
||||
|
|
@ -853,11 +902,21 @@ async def prune():
|
|||
-----
|
||||
- This operation cannot be undone. All memory data will be permanently deleted.
|
||||
- The function prunes both data content (using prune_data) and system metadata (using prune_system)
|
||||
- This operation is not available in API mode
|
||||
"""
|
||||
with redirect_stdout(sys.stderr):
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
try:
|
||||
await cognee_client.prune_data()
|
||||
await cognee_client.prune_system(metadata=True)
|
||||
return [types.TextContent(type="text", text="Pruned")]
|
||||
except NotImplementedError:
|
||||
error_msg = "❌ Prune operation is not available in API mode"
|
||||
logger.error(error_msg)
|
||||
return [types.TextContent(type="text", text=error_msg)]
|
||||
except Exception as e:
|
||||
error_msg = f"❌ Prune operation failed: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
return [types.TextContent(type="text", text=error_msg)]
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
|
|
@ -880,13 +939,26 @@ async def cognify_status():
|
|||
- The function retrieves pipeline status specifically for the "cognify_pipeline" on the "main_dataset"
|
||||
- Status information includes job progress, execution time, and completion status
|
||||
- The status is returned in string format for easy reading
|
||||
- This operation is not available in API mode
|
||||
"""
|
||||
with redirect_stdout(sys.stderr):
|
||||
try:
|
||||
from cognee.modules.data.methods.get_unique_dataset_id import get_unique_dataset_id
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
|
||||
user = await get_default_user()
|
||||
status = await get_pipeline_status(
|
||||
status = await cognee_client.get_pipeline_status(
|
||||
[await get_unique_dataset_id("main_dataset", user)], "cognify_pipeline"
|
||||
)
|
||||
return [types.TextContent(type="text", text=str(status))]
|
||||
except NotImplementedError:
|
||||
error_msg = "❌ Pipeline status is not available in API mode"
|
||||
logger.error(error_msg)
|
||||
return [types.TextContent(type="text", text=error_msg)]
|
||||
except Exception as e:
|
||||
error_msg = f"❌ Failed to get cognify status: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
return [types.TextContent(type="text", text=error_msg)]
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
|
|
@ -909,13 +981,26 @@ async def codify_status():
|
|||
- The function retrieves pipeline status specifically for the "cognify_code_pipeline" on the "codebase" dataset
|
||||
- Status information includes job progress, execution time, and completion status
|
||||
- The status is returned in string format for easy reading
|
||||
- This operation is not available in API mode
|
||||
"""
|
||||
with redirect_stdout(sys.stderr):
|
||||
try:
|
||||
from cognee.modules.data.methods.get_unique_dataset_id import get_unique_dataset_id
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
|
||||
user = await get_default_user()
|
||||
status = await get_pipeline_status(
|
||||
status = await cognee_client.get_pipeline_status(
|
||||
[await get_unique_dataset_id("codebase", user)], "cognify_code_pipeline"
|
||||
)
|
||||
return [types.TextContent(type="text", text=str(status))]
|
||||
except NotImplementedError:
|
||||
error_msg = "❌ Pipeline status is not available in API mode"
|
||||
logger.error(error_msg)
|
||||
return [types.TextContent(type="text", text=error_msg)]
|
||||
except Exception as e:
|
||||
error_msg = f"❌ Failed to get codify status: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
return [types.TextContent(type="text", text=error_msg)]
|
||||
|
||||
|
||||
def node_to_string(node):
|
||||
|
|
@ -949,6 +1034,8 @@ def load_class(model_file, model_name):
|
|||
|
||||
|
||||
async def main():
|
||||
global cognee_client
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
|
|
@ -992,12 +1079,30 @@ async def main():
|
|||
help="Argument stops database migration from being attempted",
|
||||
)
|
||||
|
||||
# Cognee API connection options
|
||||
parser.add_argument(
|
||||
"--api-url",
|
||||
default=None,
|
||||
help="Base URL of a running Cognee FastAPI server (e.g., http://localhost:8000). "
|
||||
"If provided, the MCP server will connect to the API instead of using cognee directly.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--api-token",
|
||||
default=None,
|
||||
help="Authentication token for the Cognee API. Required if --api-url is provided.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Initialize the global CogneeClient
|
||||
cognee_client = CogneeClient(api_url=args.api_url, api_token=args.api_token)
|
||||
|
||||
mcp.settings.host = args.host
|
||||
mcp.settings.port = args.port
|
||||
|
||||
if not args.no_migration:
|
||||
# Skip migrations when in API mode (the API server handles its own database)
|
||||
if not args.no_migration and not args.api_url:
|
||||
# Run Alembic migrations from the main cognee directory where alembic.ini is located
|
||||
logger.info("Running database migrations...")
|
||||
migration_result = subprocess.run(
|
||||
|
|
@ -1020,6 +1125,8 @@ async def main():
|
|||
sys.exit(1)
|
||||
|
||||
logger.info("Database migrations done.")
|
||||
elif args.api_url:
|
||||
logger.info("Skipping database migrations (using API mode)")
|
||||
|
||||
logger.info(f"Starting MCP server with transport: {args.transport}")
|
||||
if args.transport == "stdio":
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue