diff --git a/cognee-mcp/src/server.py b/cognee-mcp/src/server.py index f249f1d08..33cd26cb1 100755 --- a/cognee-mcp/src/server.py +++ b/cognee-mcp/src/server.py @@ -20,6 +20,9 @@ from cognee.modules.search.types import SearchType from cognee.shared.data_models import KnowledgeGraph from cognee.modules.storage.utils import JSONEncoder from starlette.responses import JSONResponse +from starlette.middleware import Middleware +from starlette.middleware.cors import CORSMiddleware +import uvicorn try: @@ -39,8 +42,59 @@ mcp = FastMCP("Cognee") logger = get_logger() +cors_middleware = Middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], + allow_headers=["*"], +) + + +async def run_sse_with_cors(): + """Custom SSE transport with CORS middleware.""" + sse_app = mcp.sse_app() + sse_app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], + allow_headers=["*"], + ) + + config = uvicorn.Config( + sse_app, + host=mcp.settings.host, + port=mcp.settings.port, + log_level=mcp.settings.log_level.lower(), + ) + server = uvicorn.Server(config) + await server.serve() + + +async def run_http_with_cors(): + """Custom HTTP transport with CORS middleware.""" + http_app = mcp.streamable_http_app() + http_app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], + allow_headers=["*"], + ) + + config = uvicorn.Config( + http_app, + host=mcp.settings.host, + port=mcp.settings.port, + log_level=mcp.settings.log_level.lower(), + ) + server = uvicorn.Server(config) + await server.serve() + + @mcp.custom_route("/health", methods=["GET"]) -async def health_check(request) -> dict: +async def health_check(request): return JSONResponse({"status": "ok"}) @@ -981,12 +1035,12 @@ async def main(): await mcp.run_stdio_async() elif args.transport == "sse": logger.info(f"Running MCP server with SSE transport on {args.host}:{args.port}") - await mcp.run_sse_async() + await run_sse_with_cors() elif args.transport == "http": logger.info( f"Running MCP server with Streamable HTTP transport on {args.host}:{args.port}{args.path}" ) - await mcp.run_streamable_http_async() + await run_http_with_cors() if __name__ == "__main__":