diff --git a/cognee/api/client.py b/cognee/api/client.py index 196582b5d..2b2128d2d 100644 --- a/cognee/api/client.py +++ b/cognee/api/client.py @@ -43,6 +43,77 @@ app.add_middleware( allow_headers=["*"], ) +from contextlib import asynccontextmanager + +from fastapi import Depends, FastAPI + +from cognee.infrastructure.databases.relational.user_authentication.authentication_db import User, create_db_and_tables +from cognee.infrastructure.databases.relational.user_authentication.schemas import UserCreate, UserRead, UserUpdate +from cognee.infrastructure.databases.relational.user_authentication.users import auth_backend, current_active_user, fastapi_users + +app.include_router( + fastapi_users.get_auth_router(auth_backend), prefix="/auth/jwt", tags=["auth"] +) +app.include_router( + fastapi_users.get_register_router(UserRead, UserCreate), + prefix="/auth", + tags=["auth"], +) +app.include_router( + fastapi_users.get_reset_password_router(), + prefix="/auth", + tags=["auth"], +) +app.include_router( + fastapi_users.get_verify_router(UserRead), + prefix="/auth", + tags=["auth"], +) +app.include_router( + fastapi_users.get_users_router(UserRead, UserUpdate), + prefix="/users", + tags=["users"], +) + + +@app.get("/authenticated-route") +async def authenticated_route(user: User = Depends(current_active_user)): + return {"message": f"Hello {user.email}!"} + +@asynccontextmanager +async def lifespan(app: FastAPI): + # Not needed if you setup a migration system like Alembic + await create_db_and_tables() + yield +app.include_router( + fastapi_users.get_auth_router(auth_backend), prefix="/auth/jwt", tags=["auth"] +) +app.include_router( + fastapi_users.get_register_router(UserRead, UserCreate), + prefix="/auth", + tags=["auth"], +) +app.include_router( + fastapi_users.get_reset_password_router(), + prefix="/auth", + tags=["auth"], +) +app.include_router( + fastapi_users.get_verify_router(UserRead), + prefix="/auth", + tags=["auth"], +) +app.include_router( + fastapi_users.get_users_router(UserRead, UserUpdate), + prefix="/users", + tags=["users"], +) + + +@app.get("/authenticated-route") +async def authenticated_route(user: User = Depends(current_active_user)): + return {"message": f"Hello {user.email}!"} + @app.get("/") async def root(): """ diff --git a/cognee/infrastructure/databases/relational/user_authentication/authentication_db.py b/cognee/infrastructure/databases/relational/user_authentication/authentication_db.py new file mode 100644 index 000000000..63bc2e94a --- /dev/null +++ b/cognee/infrastructure/databases/relational/user_authentication/authentication_db.py @@ -0,0 +1,36 @@ +from typing import AsyncGenerator + +from fastapi import Depends +from fastapi_users.db import SQLAlchemyBaseUserTableUUID, SQLAlchemyUserDatabase +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine +from sqlalchemy.orm import DeclarativeBase + +DATABASE_URL = "sqlite+aiosqlite:///./test.db" + + +class Base(DeclarativeBase): + pass + + +class User(SQLAlchemyBaseUserTableUUID, Base): + pass + + + + +engine = create_async_engine(DATABASE_URL) +async_session_maker = async_sessionmaker(engine, expire_on_commit=False) + + +async def create_db_and_tables(): + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + +async def get_async_session() -> AsyncGenerator[AsyncSession, None]: + async with async_session_maker() as session: + yield session + + +async def get_user_db(session: AsyncSession = Depends(get_async_session)): + yield SQLAlchemyUserDatabase(session, User) \ No newline at end of file diff --git a/cognee/infrastructure/databases/relational/user_authentication/schemas.py b/cognee/infrastructure/databases/relational/user_authentication/schemas.py new file mode 100644 index 000000000..d7156223f --- /dev/null +++ b/cognee/infrastructure/databases/relational/user_authentication/schemas.py @@ -0,0 +1,15 @@ +import uuid + +from fastapi_users import schemas + + +class UserRead(schemas.BaseUser[uuid.UUID]): + pass + + +class UserCreate(schemas.BaseUserCreate): + pass + + +class UserUpdate(schemas.BaseUserUpdate): + pass \ No newline at end of file diff --git a/cognee/infrastructure/databases/relational/user_authentication/users.py b/cognee/infrastructure/databases/relational/user_authentication/users.py new file mode 100644 index 000000000..3df38c4fe --- /dev/null +++ b/cognee/infrastructure/databases/relational/user_authentication/users.py @@ -0,0 +1,55 @@ +import uuid +from typing import Optional + +from fastapi import Depends, Request +from fastapi_users import BaseUserManager, FastAPIUsers, UUIDIDMixin, models +from fastapi_users.authentication import ( + AuthenticationBackend, + BearerTransport, + JWTStrategy, +) +from fastapi_users.db import SQLAlchemyUserDatabase + +from app.db import User, get_user_db + +SECRET = "SECRET" + + +class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]): + reset_password_token_secret = SECRET + verification_token_secret = SECRET + + async def on_after_register(self, user: User, request: Optional[Request] = None): + print(f"User {user.id} has registered.") + + async def on_after_forgot_password( + self, user: User, token: str, request: Optional[Request] = None + ): + print(f"User {user.id} has forgot their password. Reset token: {token}") + + async def on_after_request_verify( + self, user: User, token: str, request: Optional[Request] = None + ): + print(f"Verification requested for user {user.id}. Verification token: {token}") + + +async def get_user_manager(user_db: SQLAlchemyUserDatabase = Depends(get_user_db)): + yield UserManager(user_db) + + +bearer_transport = BearerTransport(tokenUrl="auth/jwt/login") + + +def get_jwt_strategy() -> JWTStrategy[models.UP, models.ID]: + return JWTStrategy(secret=SECRET, lifetime_seconds=3600) + + +auth_backend = AuthenticationBackend( + name="jwt", + transport=bearer_transport, + get_strategy=get_jwt_strategy, +) + +fastapi_users = FastAPIUsers[User, uuid.UUID](get_user_manager, [auth_backend]) + +current_active_user = fastapi_users.current_user(active=True) \ No newline at end of file