feat: add CORS middleware support for SSE and HTTP transports in MCP server
This commit is contained in:
parent
143d9433b1
commit
dc1669a948
1 changed files with 57 additions and 3 deletions
|
|
@ -20,6 +20,9 @@ from cognee.modules.search.types import SearchType
|
|||
from cognee.shared.data_models import KnowledgeGraph
|
||||
from cognee.modules.storage.utils import JSONEncoder
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.middleware import Middleware
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
import uvicorn
|
||||
|
||||
|
||||
try:
|
||||
|
|
@ -39,8 +42,59 @@ mcp = FastMCP("Cognee")
|
|||
logger = get_logger()
|
||||
|
||||
|
||||
cors_middleware = Middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
async def run_sse_with_cors():
|
||||
"""Custom SSE transport with CORS middleware."""
|
||||
sse_app = mcp.sse_app()
|
||||
sse_app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
config = uvicorn.Config(
|
||||
sse_app,
|
||||
host=mcp.settings.host,
|
||||
port=mcp.settings.port,
|
||||
log_level=mcp.settings.log_level.lower(),
|
||||
)
|
||||
server = uvicorn.Server(config)
|
||||
await server.serve()
|
||||
|
||||
|
||||
async def run_http_with_cors():
|
||||
"""Custom HTTP transport with CORS middleware."""
|
||||
http_app = mcp.streamable_http_app()
|
||||
http_app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
config = uvicorn.Config(
|
||||
http_app,
|
||||
host=mcp.settings.host,
|
||||
port=mcp.settings.port,
|
||||
log_level=mcp.settings.log_level.lower(),
|
||||
)
|
||||
server = uvicorn.Server(config)
|
||||
await server.serve()
|
||||
|
||||
|
||||
@mcp.custom_route("/health", methods=["GET"])
|
||||
async def health_check(request) -> dict:
|
||||
async def health_check(request):
|
||||
return JSONResponse({"status": "ok"})
|
||||
|
||||
|
||||
|
|
@ -981,12 +1035,12 @@ async def main():
|
|||
await mcp.run_stdio_async()
|
||||
elif args.transport == "sse":
|
||||
logger.info(f"Running MCP server with SSE transport on {args.host}:{args.port}")
|
||||
await mcp.run_sse_async()
|
||||
await run_sse_with_cors()
|
||||
elif args.transport == "http":
|
||||
logger.info(
|
||||
f"Running MCP server with Streamable HTTP transport on {args.host}:{args.port}{args.path}"
|
||||
)
|
||||
await mcp.run_streamable_http_async()
|
||||
await run_http_with_cors()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue