diff --git a/cognee/api/client.py b/cognee/api/client.py index 4bdc89f10..1ffa85ba2 100644 --- a/cognee/api/client.py +++ b/cognee/api/client.py @@ -12,6 +12,7 @@ from fastapi.encoders import jsonable_encoder from fastapi.responses import JSONResponse, Response from fastapi.middleware.cors import CORSMiddleware from fastapi.exceptions import RequestValidationError +from fastapi.openapi.utils import get_openapi from cognee.exceptions import CogneeApiError from cognee.shared.logging_utils import get_logger, setup_logging @@ -75,6 +76,40 @@ app.add_middleware( ) +def custom_openapi(): + if app.openapi_schema: + return app.openapi_schema + + openapi_schema = get_openapi( + title="Cognee API", + version="1.0.0", + description="Cognee API with Bearer token and Cookie auth", + routes=app.routes, + ) + + openapi_schema["components"]["securitySchemes"] = { + "BearerAuth": { + "type": "http", + "scheme": "bearer" + }, + "CookieAuth": { + "type": "apiKey", + "in": "cookie", + "name": os.getenv("AUTH_TOKEN_COOKIE_NAME", "auth_token") + } + } + + openapi_schema["security"] = [ + {"BearerAuth": []}, + {"CookieAuth": []} + ] + + app.openapi_schema = openapi_schema + + return app.openapi_schema + +app.openapi = custom_openapi + @app.exception_handler(RequestValidationError) async def request_validation_exception_handler(request: Request, exc: RequestValidationError): if request.url.path == "/api/v1/auth/login": diff --git a/cognee/modules/users/methods/get_authenticated_user.py b/cognee/modules/users/methods/get_authenticated_user.py index b60ddfe28..486f92c1b 100644 --- a/cognee/modules/users/methods/get_authenticated_user.py +++ b/cognee/modules/users/methods/get_authenticated_user.py @@ -1,9 +1,17 @@ +from cognee.modules.users.authentication.get_api_auth_backend import get_api_auth_backend +from cognee.modules.users.authentication.get_client_auth_backend import get_client_auth_backend from ..get_fastapi_users import get_fastapi_users fastapi_users = get_fastapi_users() -get_authenticated_user = fastapi_users.current_user(active=True) +def get_enabled_backends(): + api_auth_backend = get_api_auth_backend() + client_auth_backend = get_client_auth_backend() + + return [api_auth_backend, client_auth_backend] + +get_authenticated_user = fastapi_users.current_user(active=True, get_enabled_backends=get_enabled_backends) # from types import SimpleNamespace