feat: add CORS middleware support for SSE and HTTP transports in MCP server
This commit is contained in:
parent
143d9433b1
commit
dc1669a948
1 changed files with 57 additions and 3 deletions
|
|
@ -20,6 +20,9 @@ from cognee.modules.search.types import SearchType
|
||||||
from cognee.shared.data_models import KnowledgeGraph
|
from cognee.shared.data_models import KnowledgeGraph
|
||||||
from cognee.modules.storage.utils import JSONEncoder
|
from cognee.modules.storage.utils import JSONEncoder
|
||||||
from starlette.responses import JSONResponse
|
from starlette.responses import JSONResponse
|
||||||
|
from starlette.middleware import Middleware
|
||||||
|
from starlette.middleware.cors import CORSMiddleware
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -39,8 +42,59 @@ mcp = FastMCP("Cognee")
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
cors_middleware = Middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def run_sse_with_cors():
|
||||||
|
"""Custom SSE transport with CORS middleware."""
|
||||||
|
sse_app = mcp.sse_app()
|
||||||
|
sse_app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
config = uvicorn.Config(
|
||||||
|
sse_app,
|
||||||
|
host=mcp.settings.host,
|
||||||
|
port=mcp.settings.port,
|
||||||
|
log_level=mcp.settings.log_level.lower(),
|
||||||
|
)
|
||||||
|
server = uvicorn.Server(config)
|
||||||
|
await server.serve()
|
||||||
|
|
||||||
|
|
||||||
|
async def run_http_with_cors():
|
||||||
|
"""Custom HTTP transport with CORS middleware."""
|
||||||
|
http_app = mcp.streamable_http_app()
|
||||||
|
http_app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
config = uvicorn.Config(
|
||||||
|
http_app,
|
||||||
|
host=mcp.settings.host,
|
||||||
|
port=mcp.settings.port,
|
||||||
|
log_level=mcp.settings.log_level.lower(),
|
||||||
|
)
|
||||||
|
server = uvicorn.Server(config)
|
||||||
|
await server.serve()
|
||||||
|
|
||||||
|
|
||||||
@mcp.custom_route("/health", methods=["GET"])
|
@mcp.custom_route("/health", methods=["GET"])
|
||||||
async def health_check(request) -> dict:
|
async def health_check(request):
|
||||||
return JSONResponse({"status": "ok"})
|
return JSONResponse({"status": "ok"})
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -981,12 +1035,12 @@ async def main():
|
||||||
await mcp.run_stdio_async()
|
await mcp.run_stdio_async()
|
||||||
elif args.transport == "sse":
|
elif args.transport == "sse":
|
||||||
logger.info(f"Running MCP server with SSE transport on {args.host}:{args.port}")
|
logger.info(f"Running MCP server with SSE transport on {args.host}:{args.port}")
|
||||||
await mcp.run_sse_async()
|
await run_sse_with_cors()
|
||||||
elif args.transport == "http":
|
elif args.transport == "http":
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Running MCP server with Streamable HTTP transport on {args.host}:{args.port}{args.path}"
|
f"Running MCP server with Streamable HTTP transport on {args.host}:{args.port}{args.path}"
|
||||||
)
|
)
|
||||||
await mcp.run_streamable_http_async()
|
await run_http_with_cors()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue