feat: user authorization [COG-1189] (#593)
<!-- .github/pull_request_template.md --> ## Description Added user authorization through JWT header, reworked user and relevant RBAC models to accompany future User Permission system. ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Introduced an automated workflow to validate server startup. - Added secure JWT token generation for improved session handling. - Enabled a new structure for permission management with role and tenant-based controls, including endpoints for creating roles, tenants, and assigning permissions. - Added methods for assigning default permissions to roles, tenants, and users. - Introduced new classes for managing default permissions for roles, tenants, and users. - **Refactor** - Streamlined authentication and user management flows with enhanced error handling. - **Tests** - Upgraded integration tests with improved database initialization and data pruning for a more stable environment. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Vasilije <8619304+Vasilije1990@users.noreply.github.com>
This commit is contained in:
parent
38d527ceac
commit
88ed411f03
41 changed files with 702 additions and 193 deletions
78
.github/workflows/test_cognee_server_start.yml
vendored
Normal file
78
.github/workflows/test_cognee_server_start.yml
vendored
Normal file
|
|
@ -0,0 +1,78 @@
|
|||
name: test | test server start
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
pull_request:
|
||||
types: [labeled, synchronize]
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
env:
|
||||
RUNTIME__LOG_LEVEL: ERROR
|
||||
ENV: 'dev'
|
||||
|
||||
jobs:
|
||||
|
||||
run_server:
|
||||
name: Test cognee server start
|
||||
runs-on: ubuntu-latest
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
|
||||
steps:
|
||||
- name: Check out
|
||||
uses: actions/checkout@master
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.12.x'
|
||||
|
||||
- name: Install Poetry
|
||||
# https://github.com/snok/install-poetry#running-on-windows
|
||||
uses: snok/install-poetry@v1.4.1
|
||||
with:
|
||||
virtualenvs-create: true
|
||||
virtualenvs-in-project: true
|
||||
installer-parallel: true
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
poetry install --no-interaction
|
||||
|
||||
- name: Run cognee server
|
||||
env:
|
||||
ENV: 'dev'
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
run: |
|
||||
poetry run uvicorn cognee.api.client:app --host 0.0.0.0 --port 8000 &
|
||||
echo $! > server.pid
|
||||
sleep 10
|
||||
|
||||
- name: Check server process
|
||||
run: |
|
||||
if ! ps -p $(cat server.pid) > /dev/null; then
|
||||
echo "::error::Server failed to start"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
- name: Stop server
|
||||
run: |
|
||||
kill $(cat server.pid) || true
|
||||
|
||||
- name: Clean up disk space
|
||||
run: |
|
||||
sudo rm -rf ~/.cache
|
||||
sudo rm -rf /tmp/*
|
||||
sudo rm server.pid
|
||||
df -h
|
||||
|
|
@ -1,80 +1,67 @@
|
|||
from fastapi import APIRouter, Depends
|
||||
from fastapi.responses import JSONResponse
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy import insert
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from uuid import UUID
|
||||
|
||||
from cognee.infrastructure.databases.exceptions import EntityAlreadyExistsError
|
||||
from cognee.modules.users.exceptions import UserNotFoundError, GroupNotFoundError
|
||||
from cognee.modules.users import get_user_db
|
||||
from cognee.modules.users.models import User, Group, Permission, UserGroup, GroupPermission
|
||||
from fastapi import APIRouter
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
|
||||
def get_permissions_router() -> APIRouter:
|
||||
permissions_router = APIRouter()
|
||||
|
||||
@permissions_router.post("/groups/{group_id}/permissions")
|
||||
async def give_permission_to_group(
|
||||
group_id: str, permission: str, db: Session = Depends(get_user_db)
|
||||
@permissions_router.post("/roles/{role_id}/permissions")
|
||||
async def give_default_permission_to_role(role_id: UUID, permission_name: str):
|
||||
from cognee.modules.users.permissions.methods import (
|
||||
give_default_permission_to_role as set_default_permission_to_role,
|
||||
)
|
||||
|
||||
await set_default_permission_to_role(role_id, permission_name)
|
||||
|
||||
return JSONResponse(status_code=200, content={"message": "Permission assigned to role"})
|
||||
|
||||
@permissions_router.post("/tenants/{tenant_id}/permissions")
|
||||
async def give_default_permission_to_tenant(tenant_id: UUID, permission_name: str):
|
||||
from cognee.modules.users.permissions.methods import (
|
||||
give_default_permission_to_tenant as set_tenant_default_permissions,
|
||||
)
|
||||
|
||||
await set_tenant_default_permissions(tenant_id, permission_name)
|
||||
|
||||
return JSONResponse(status_code=200, content={"message": "Permission assigned to tenant"})
|
||||
|
||||
@permissions_router.post("/users/{user_id}/permissions")
|
||||
async def give_default_permission_to_user(user_id: UUID, permission_name: str):
|
||||
from cognee.modules.users.permissions.methods import (
|
||||
give_default_permission_to_user as set_default_permission_to_user,
|
||||
)
|
||||
|
||||
await set_default_permission_to_user(user_id, permission_name)
|
||||
|
||||
return JSONResponse(status_code=200, content={"message": "Permission assigned to user"})
|
||||
|
||||
@permissions_router.post("/roles")
|
||||
async def create_role(
|
||||
role_name: str,
|
||||
tenant_id: UUID,
|
||||
):
|
||||
group = (
|
||||
(await db.session.execute(select(Group).where(Group.id == group_id))).scalars().first()
|
||||
)
|
||||
from cognee.modules.users.roles.methods import create_role as create_role_method
|
||||
|
||||
if not group:
|
||||
raise GroupNotFoundError
|
||||
await create_role_method(role_name=role_name, tenant_id=tenant_id)
|
||||
|
||||
permission_entity = (
|
||||
(await db.session.execute(select(Permission).where(Permission.name == permission)))
|
||||
.scalars()
|
||||
.first()
|
||||
)
|
||||
return JSONResponse(status_code=200, content={"message": "Role created for tenant"})
|
||||
|
||||
if not permission_entity:
|
||||
stmt = insert(Permission).values(name=permission)
|
||||
await db.session.execute(stmt)
|
||||
permission_entity = (
|
||||
(await db.session.execute(select(Permission).where(Permission.name == permission)))
|
||||
.scalars()
|
||||
.first()
|
||||
)
|
||||
@permissions_router.post("/users/{user_id}/roles")
|
||||
async def add_user_to_role(user_id: UUID, role_id: UUID):
|
||||
from cognee.modules.users.roles.methods import add_user_to_role as add_user_to_role_method
|
||||
|
||||
try:
|
||||
# add permission to group
|
||||
await db.session.execute(
|
||||
insert(GroupPermission).values(
|
||||
group_id=group.id, permission_id=permission_entity.id
|
||||
)
|
||||
)
|
||||
except IntegrityError:
|
||||
raise EntityAlreadyExistsError(message="Group permission already exists.")
|
||||
await add_user_to_role_method(user_id=user_id, role_id=role_id)
|
||||
|
||||
await db.session.commit()
|
||||
return JSONResponse(status_code=200, content={"message": "User added to role"})
|
||||
|
||||
return JSONResponse(status_code=200, content={"message": "Permission assigned to group"})
|
||||
@permissions_router.post("/tenants")
|
||||
async def create_tenant(tenant_name: str):
|
||||
from cognee.modules.users.tenants.methods import create_tenant as create_tenant_method
|
||||
|
||||
@permissions_router.post("/users/{user_id}/groups")
|
||||
async def add_user_to_group(user_id: str, group_id: str, db: Session = Depends(get_user_db)):
|
||||
user = (await db.session.execute(select(User).where(User.id == user_id))).scalars().first()
|
||||
group = (
|
||||
(await db.session.execute(select(Group).where(Group.id == group_id))).scalars().first()
|
||||
)
|
||||
await create_tenant_method(tenant_name=tenant_name)
|
||||
|
||||
if not user:
|
||||
raise UserNotFoundError
|
||||
elif not group:
|
||||
raise GroupNotFoundError
|
||||
|
||||
try:
|
||||
# Add association directly to the association table
|
||||
stmt = insert(UserGroup).values(user_id=user_id, group_id=group_id)
|
||||
await db.session.execute(stmt)
|
||||
except IntegrityError:
|
||||
raise EntityAlreadyExistsError(message="User is already part of group.")
|
||||
|
||||
await db.session.commit()
|
||||
|
||||
return JSONResponse(status_code=200, content={"message": "User added to group"})
|
||||
return JSONResponse(status_code=200, content={"message": "Tenant created."})
|
||||
|
||||
return permissions_router
|
||||
|
|
|
|||
21
cognee/get_token.py
Normal file
21
cognee/get_token.py
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
import jwt
|
||||
import os
|
||||
import datetime
|
||||
|
||||
SECRET_KEY = os.getenv("FASTAPI_USERS_JWT_SECRET", "super_secret")
|
||||
|
||||
|
||||
def create_jwt(user_id: str, tenant: str, roles: list[str]):
|
||||
payload = {
|
||||
"user_id": user_id,
|
||||
"tenant_id": tenant,
|
||||
"roles": roles,
|
||||
"exp": datetime.datetime.utcnow() + datetime.timedelta(hours=1), # 1 hour expiry
|
||||
}
|
||||
return jwt.encode(payload, SECRET_KEY, algorithm="HS256")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Example token generation
|
||||
token = create_jwt("6763554c-91bd-432c-aba8-d42cd72ed659", "tenant_456", ["admin"])
|
||||
print(token)
|
||||
|
|
@ -32,6 +32,9 @@ class Data(Base):
|
|||
cascade="all, delete",
|
||||
)
|
||||
|
||||
# New relationship for ACLs with cascade deletion
|
||||
acls = relationship("ACL", back_populates="data", cascade="all, delete-orphan")
|
||||
|
||||
def to_json(self) -> dict:
|
||||
return {
|
||||
"id": str(self.id),
|
||||
|
|
|
|||
|
|
@ -1 +1,2 @@
|
|||
from .get_user_db import get_user_db
|
||||
from .get_user_db import get_async_session
|
||||
|
|
|
|||
|
|
@ -1,12 +1,31 @@
|
|||
import os
|
||||
from functools import lru_cache
|
||||
from fastapi_users import models
|
||||
from fastapi_users.jwt import generate_jwt
|
||||
from fastapi_users.authentication import (
|
||||
AuthenticationBackend,
|
||||
BearerTransport,
|
||||
JWTStrategy,
|
||||
)
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.users.methods import get_user
|
||||
|
||||
|
||||
class CustomJWTStrategy(JWTStrategy):
|
||||
async def write_token(self, user: User, lifetime_seconds: Optional[int] = None) -> str:
|
||||
# JoinLoad tenant and role information to user object
|
||||
user = await get_user(user.id)
|
||||
|
||||
if user.tenant:
|
||||
data = {"user_id": str(user.id), "tenant_id": str(user.tenant.id), "roles": user.roles}
|
||||
else:
|
||||
# The default tenant is None
|
||||
data = {"user_id": str(user.id), "tenant_id": None, "roles": user.roles}
|
||||
return generate_jwt(data, self.encode_key, self.lifetime_seconds, algorithm=self.algorithm)
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_auth_backend():
|
||||
|
|
@ -14,7 +33,7 @@ def get_auth_backend():
|
|||
|
||||
def get_jwt_strategy() -> JWTStrategy[models.UP, models.ID]:
|
||||
secret = os.getenv("FASTAPI_USERS_JWT_SECRET", "super_secret")
|
||||
return JWTStrategy(secret, lifetime_seconds=3600)
|
||||
return CustomJWTStrategy(secret, lifetime_seconds=3600)
|
||||
|
||||
auth_backend = AuthenticationBackend(
|
||||
name="jwt",
|
||||
|
|
|
|||
|
|
@ -5,7 +5,8 @@ This module defines a set of exceptions for handling various user errors
|
|||
"""
|
||||
|
||||
from .exceptions import (
|
||||
GroupNotFoundError,
|
||||
RoleNotFoundError,
|
||||
UserNotFoundError,
|
||||
PermissionDeniedError,
|
||||
TenantNotFoundError,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -2,13 +2,25 @@ from cognee.exceptions import CogneeApiError
|
|||
from fastapi import status
|
||||
|
||||
|
||||
class GroupNotFoundError(CogneeApiError):
|
||||
class RoleNotFoundError(CogneeApiError):
|
||||
"""User group not found"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "User group not found.",
|
||||
name: str = "GroupNotFoundError",
|
||||
message: str = "User role not found.",
|
||||
name: str = "RoleNotFoundError",
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
):
|
||||
super().__init__(message, name, status_code)
|
||||
|
||||
|
||||
class TenantNotFoundError(CogneeApiError):
|
||||
"""User group not found"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Tenant not found.",
|
||||
name: str = "TenantNotFoundError",
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
):
|
||||
super().__init__(message, name, status_code)
|
||||
|
|
|
|||
|
|
@ -1,13 +1,19 @@
|
|||
from fastapi_users.exceptions import UserAlreadyExists
|
||||
from cognee.modules.users.exceptions import TenantNotFoundError
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from ..get_user_manager import get_user_manager_context
|
||||
from ..get_user_db import get_user_db_context
|
||||
from ..models.User import UserCreate
|
||||
from cognee.modules.users.get_user_manager import get_user_manager_context
|
||||
from cognee.modules.users.get_user_db import get_user_db_context
|
||||
from cognee.modules.users.models.User import UserCreate
|
||||
from cognee.modules.users.models.Tenant import Tenant
|
||||
|
||||
from sqlalchemy import select
|
||||
from typing import Optional
|
||||
|
||||
|
||||
async def create_user(
|
||||
email: str,
|
||||
password: str,
|
||||
tenant_id: Optional[str] = None,
|
||||
is_superuser: bool = False,
|
||||
is_active: bool = True,
|
||||
is_verified: bool = False,
|
||||
|
|
@ -19,15 +25,33 @@ async def create_user(
|
|||
async with relational_engine.get_async_session() as session:
|
||||
async with get_user_db_context(session) as user_db:
|
||||
async with get_user_manager_context(user_db) as user_manager:
|
||||
user = await user_manager.create(
|
||||
UserCreate(
|
||||
email=email,
|
||||
password=password,
|
||||
is_superuser=is_superuser,
|
||||
is_active=is_active,
|
||||
is_verified=is_verified,
|
||||
if tenant_id:
|
||||
# Check if the tenant already exists
|
||||
result = await session.execute(select(Tenant).where(Tenant.id == tenant_id))
|
||||
tenant = result.scalars().first()
|
||||
if not tenant:
|
||||
raise TenantNotFoundError
|
||||
|
||||
user = await user_manager.create(
|
||||
UserCreate(
|
||||
email=email,
|
||||
password=password,
|
||||
tenant_id=tenant.id,
|
||||
is_superuser=is_superuser,
|
||||
is_active=is_active,
|
||||
is_verified=is_verified,
|
||||
)
|
||||
)
|
||||
else:
|
||||
user = await user_manager.create(
|
||||
UserCreate(
|
||||
email=email,
|
||||
password=password,
|
||||
is_superuser=is_superuser,
|
||||
is_active=is_active,
|
||||
is_verified=is_verified,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if auto_login:
|
||||
await session.refresh(user)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,31 @@
|
|||
from types import SimpleNamespace
|
||||
|
||||
from ..get_fastapi_users import get_fastapi_users
|
||||
from fastapi import HTTPException, Header
|
||||
import os
|
||||
import jwt
|
||||
|
||||
fastapi_users = get_fastapi_users()
|
||||
|
||||
get_authenticated_user = fastapi_users.current_user(active=True, verified=True)
|
||||
|
||||
async def get_authenticated_user(authorization: str = Header(...)) -> SimpleNamespace:
|
||||
"""Extract and validate JWT from Authorization header."""
|
||||
try:
|
||||
scheme, token = authorization.split()
|
||||
if scheme.lower() != "bearer":
|
||||
raise HTTPException(status_code=401, detail="Invalid authentication scheme")
|
||||
|
||||
payload = jwt.decode(
|
||||
token, os.getenv("FASTAPI_USERS_JWT_SECRET", "super_secret"), algorithms=["HS256"]
|
||||
)
|
||||
|
||||
# SimpleNamespace lets us access dictionary elements like attributes
|
||||
auth_data = SimpleNamespace(
|
||||
id=payload["user_id"], tenant_id=payload["tenant_id"], roles=payload["roles"]
|
||||
)
|
||||
return auth_data
|
||||
|
||||
except jwt.ExpiredSignatureError:
|
||||
raise HTTPException(status_code=401, detail="Token has expired")
|
||||
except jwt.InvalidTokenError:
|
||||
raise HTTPException(status_code=401, detail="Invalid token")
|
||||
|
|
|
|||
|
|
@ -1,17 +1,19 @@
|
|||
from types import SimpleNamespace
|
||||
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.future import select
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.users.models import User, Tenant
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from .create_default_user import create_default_user
|
||||
from cognee.modules.users.methods.create_default_user import create_default_user
|
||||
|
||||
|
||||
async def get_default_user():
|
||||
async def get_default_user() -> SimpleNamespace:
|
||||
db_engine = get_relational_engine()
|
||||
|
||||
async with db_engine.get_async_session() as session:
|
||||
query = (
|
||||
select(User)
|
||||
.options(selectinload(User.groups))
|
||||
.options(selectinload(User.roles))
|
||||
.where(User.email == "default_user@example.com")
|
||||
)
|
||||
|
||||
|
|
@ -21,4 +23,7 @@ async def get_default_user():
|
|||
if user is None:
|
||||
return await create_default_user()
|
||||
|
||||
return user
|
||||
# We return a SimpleNamespace to have the same user type as our SaaS
|
||||
# SimpleNamespace is just a dictionary which can be accessed through attributes
|
||||
auth_data = SimpleNamespace(id=user.id, tenant_id=user.tenant_id, roles=[])
|
||||
return auth_data
|
||||
|
|
|
|||
|
|
@ -11,7 +11,9 @@ async def get_user(user_id: UUID):
|
|||
async with db_engine.get_async_session() as session:
|
||||
user = (
|
||||
await session.execute(
|
||||
select(User).options(joinedload(User.groups)).where(User.id == user_id)
|
||||
select(User)
|
||||
.options(joinedload(User.roles), joinedload(User.tenant))
|
||||
.where(User.id == user_id)
|
||||
)
|
||||
).scalar()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,8 @@
|
|||
from uuid import uuid4
|
||||
from datetime import datetime, timezone
|
||||
from sqlalchemy.orm import relationship, Mapped
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy import Column, ForeignKey, DateTime, UUID
|
||||
from cognee.infrastructure.databases.relational import Base
|
||||
from .ACLResources import ACLResources
|
||||
|
||||
|
||||
class ACL(Base):
|
||||
|
|
@ -16,11 +15,8 @@ class ACL(Base):
|
|||
|
||||
principal_id = Column(UUID, ForeignKey("principals.id"))
|
||||
permission_id = Column(UUID, ForeignKey("permissions.id"))
|
||||
data_id = Column(UUID, ForeignKey("data.id", ondelete="CASCADE"))
|
||||
|
||||
principal = relationship("Principal")
|
||||
permission = relationship("Permission")
|
||||
resources: Mapped[list["Resource"]] = relationship(
|
||||
"Resource",
|
||||
secondary=ACLResources.__tablename__,
|
||||
back_populates="acls",
|
||||
)
|
||||
data = relationship("Data", back_populates="acls")
|
||||
|
|
|
|||
|
|
@ -1,12 +0,0 @@
|
|||
from datetime import datetime, timezone
|
||||
from sqlalchemy import Column, ForeignKey, DateTime, UUID
|
||||
from cognee.infrastructure.databases.relational import Base
|
||||
|
||||
|
||||
class ACLResources(Base):
|
||||
__tablename__ = "acl_resources"
|
||||
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
|
||||
acl_id = Column(UUID, ForeignKey("acls.id"), primary_key=True)
|
||||
resource_id = Column(UUID, ForeignKey("resources.id"), primary_key=True)
|
||||
|
|
@ -1,22 +0,0 @@
|
|||
from sqlalchemy.orm import relationship, Mapped
|
||||
from sqlalchemy import Column, String, ForeignKey, UUID
|
||||
from .Principal import Principal
|
||||
from .UserGroup import UserGroup
|
||||
|
||||
|
||||
class Group(Principal):
|
||||
__tablename__ = "groups"
|
||||
|
||||
id = Column(UUID, ForeignKey("principals.id"), primary_key=True)
|
||||
|
||||
name = Column(String, unique=True, nullable=False, index=True)
|
||||
|
||||
users: Mapped[list["User"]] = relationship(
|
||||
"User",
|
||||
secondary=UserGroup.__tablename__,
|
||||
back_populates="groups",
|
||||
)
|
||||
|
||||
__mapper_args__ = {
|
||||
"polymorphic_identity": "group",
|
||||
}
|
||||
|
|
@ -1,12 +0,0 @@
|
|||
from datetime import datetime, timezone
|
||||
from sqlalchemy import Column, ForeignKey, DateTime, UUID
|
||||
from cognee.infrastructure.databases.relational import Base
|
||||
|
||||
|
||||
class GroupPermission(Base):
|
||||
__tablename__ = "group_permissions"
|
||||
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
|
||||
group_id = Column(UUID, ForeignKey("groups.id"), primary_key=True)
|
||||
permission_id = Column(UUID, ForeignKey("permissions.id"), primary_key=True)
|
||||
|
|
@ -1,7 +1,6 @@
|
|||
from uuid import uuid4
|
||||
from datetime import datetime, timezone
|
||||
|
||||
# from sqlalchemy.orm import relationship
|
||||
from sqlalchemy import Column, DateTime, String, UUID
|
||||
from cognee.infrastructure.databases.relational import Base
|
||||
|
||||
|
|
@ -15,5 +14,3 @@ class Permission(Base):
|
|||
updated_at = Column(DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc))
|
||||
|
||||
name = Column(String, unique=True, nullable=False, index=True)
|
||||
|
||||
# acls = relationship("ACL", back_populates = "permission")
|
||||
|
|
|
|||
|
|
@ -1,19 +0,0 @@
|
|||
from uuid import uuid4
|
||||
from datetime import datetime, timezone
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy import Column, DateTime, UUID
|
||||
from cognee.infrastructure.databases.relational import Base
|
||||
from .ACLResources import ACLResources
|
||||
|
||||
|
||||
class Resource(Base):
|
||||
__tablename__ = "resources"
|
||||
|
||||
id = Column(UUID, primary_key=True, default=uuid4)
|
||||
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
updated_at = Column(DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc))
|
||||
|
||||
resource_id = Column(UUID, nullable=False)
|
||||
|
||||
acls = relationship("ACL", secondary=ACLResources.__tablename__, back_populates="resources")
|
||||
31
cognee/modules/users/models/Role.py
Normal file
31
cognee/modules/users/models/Role.py
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
from sqlalchemy.orm import relationship, Mapped
|
||||
from sqlalchemy import Column, String, ForeignKey, UUID, UniqueConstraint
|
||||
from .Principal import Principal
|
||||
from .UserRole import UserRole
|
||||
|
||||
|
||||
class Role(Principal):
|
||||
__tablename__ = "roles"
|
||||
|
||||
id = Column(UUID, ForeignKey("principals.id", ondelete="CASCADE"), primary_key=True)
|
||||
|
||||
name = Column(String, nullable=False, index=True)
|
||||
|
||||
users: Mapped[list["User"]] = relationship( # noqa: F821
|
||||
"User",
|
||||
secondary=UserRole.__tablename__,
|
||||
back_populates="roles",
|
||||
)
|
||||
|
||||
# Foreign key to Tenant (Many-to-One relationship)
|
||||
tenant_id = Column(UUID, ForeignKey("tenants.id"), nullable=False)
|
||||
|
||||
# Relationship to Tenant
|
||||
tenant = relationship("Tenant", back_populates="roles", foreign_keys=[tenant_id])
|
||||
|
||||
# Unique constraint on tenant_id and name
|
||||
__table_args__ = (UniqueConstraint("tenant_id", "name", name="uq_roles_tenant_id_name"),)
|
||||
|
||||
__mapper_args__ = {
|
||||
"polymorphic_identity": "role",
|
||||
}
|
||||
22
cognee/modules/users/models/RoleDefaultPermissions.py
Normal file
22
cognee/modules/users/models/RoleDefaultPermissions.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
from datetime import datetime, timezone
|
||||
from sqlalchemy import Column, ForeignKey, DateTime, UUID
|
||||
from cognee.infrastructure.databases.relational import Base
|
||||
|
||||
|
||||
class RoleDefaultPermissions(Base):
|
||||
__tablename__ = "role_default_permissions"
|
||||
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
|
||||
role_id = Column(
|
||||
UUID,
|
||||
ForeignKey("roles.id", ondelete="CASCADE"), # cascade deletion when Role is deleted
|
||||
primary_key=True,
|
||||
)
|
||||
permission_id = Column(
|
||||
UUID,
|
||||
ForeignKey(
|
||||
"permissions.id", ondelete="CASCADE"
|
||||
), # cascade deletion when Permission is deleted
|
||||
primary_key=True,
|
||||
)
|
||||
30
cognee/modules/users/models/Tenant.py
Normal file
30
cognee/modules/users/models/Tenant.py
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy import Column, String, ForeignKey, UUID
|
||||
from .Principal import Principal
|
||||
from .User import User
|
||||
from .Role import Role
|
||||
|
||||
|
||||
class Tenant(Principal):
|
||||
__tablename__ = "tenants"
|
||||
|
||||
id = Column(UUID, ForeignKey("principals.id"), primary_key=True)
|
||||
name = Column(String, unique=True, nullable=False, index=True)
|
||||
|
||||
# One-to-Many relationship with User; specify the join via User.tenant_id
|
||||
users = relationship(
|
||||
"User",
|
||||
back_populates="tenant",
|
||||
foreign_keys=lambda: [User.tenant_id],
|
||||
)
|
||||
|
||||
# One-to-Many relationship with Role (if needed; similar fix)
|
||||
roles = relationship(
|
||||
"Role",
|
||||
back_populates="tenant",
|
||||
foreign_keys=lambda: [Role.tenant_id],
|
||||
)
|
||||
|
||||
__mapper_args__ = {
|
||||
"polymorphic_identity": "tenant",
|
||||
}
|
||||
19
cognee/modules/users/models/TenantDefaultPermissions.py
Normal file
19
cognee/modules/users/models/TenantDefaultPermissions.py
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
from datetime import datetime, timezone
|
||||
from sqlalchemy import Column, ForeignKey, DateTime, UUID
|
||||
from cognee.infrastructure.databases.relational import Base
|
||||
|
||||
|
||||
class TenantDefaultPermissions(Base):
|
||||
__tablename__ = "tenant_default_permissions"
|
||||
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
|
||||
tenant_id = Column(UUID, ForeignKey("tenants.id", ondelete="CASCADE"), primary_key=True)
|
||||
|
||||
permission_id = Column(
|
||||
UUID,
|
||||
ForeignKey(
|
||||
"permissions.id", ondelete="CASCADE"
|
||||
), # cascade deletion when Permission is deleted
|
||||
primary_key=True,
|
||||
)
|
||||
|
|
@ -1,37 +1,51 @@
|
|||
from uuid import UUID as uuid_UUID
|
||||
from typing import Optional
|
||||
from sqlalchemy import ForeignKey, Column, UUID
|
||||
from sqlalchemy.orm import relationship, Mapped
|
||||
from fastapi_users.db import SQLAlchemyBaseUserTableUUID
|
||||
from .Principal import Principal
|
||||
from .UserGroup import UserGroup
|
||||
from .Group import Group
|
||||
from .UserRole import UserRole
|
||||
from .Role import Role
|
||||
from fastapi_users import schemas
|
||||
|
||||
|
||||
class User(SQLAlchemyBaseUserTableUUID, Principal):
|
||||
__tablename__ = "users"
|
||||
|
||||
id = Column(UUID, ForeignKey("principals.id"), primary_key=True)
|
||||
id = Column(UUID, ForeignKey("principals.id", ondelete="CASCADE"), primary_key=True)
|
||||
|
||||
groups: Mapped[list["Group"]] = relationship(
|
||||
secondary=UserGroup.__tablename__,
|
||||
# Foreign key to Tenant (Many-to-One relationship)
|
||||
tenant_id = Column(UUID, ForeignKey("tenants.id"))
|
||||
|
||||
# Many-to-Many Relationship with Roles
|
||||
roles: Mapped[list["Role"]] = relationship(
|
||||
"Role",
|
||||
secondary=UserRole.__tablename__,
|
||||
back_populates="users",
|
||||
)
|
||||
|
||||
# Relationship to Tenant
|
||||
tenant = relationship(
|
||||
"Tenant",
|
||||
back_populates="users",
|
||||
foreign_keys=[tenant_id],
|
||||
)
|
||||
|
||||
# ACL Relationship (One-to-Many)
|
||||
acls = relationship("ACL", back_populates="principal", cascade="all, delete")
|
||||
|
||||
__mapper_args__ = {
|
||||
"polymorphic_identity": "user",
|
||||
}
|
||||
|
||||
|
||||
# Keep these schemas in sync with User model
|
||||
|
||||
|
||||
class UserRead(schemas.BaseUser[uuid_UUID]):
|
||||
pass
|
||||
tenant_id: Optional[uuid_UUID] = None
|
||||
|
||||
|
||||
class UserCreate(schemas.BaseUserCreate):
|
||||
pass
|
||||
tenant_id: Optional[uuid_UUID] = None
|
||||
|
||||
|
||||
class UserUpdate(schemas.BaseUserUpdate):
|
||||
|
|
|
|||
18
cognee/modules/users/models/UserDefaultPermissions.py
Normal file
18
cognee/modules/users/models/UserDefaultPermissions.py
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
from datetime import datetime, timezone
|
||||
from sqlalchemy import Column, ForeignKey, DateTime, UUID
|
||||
from cognee.infrastructure.databases.relational import Base
|
||||
|
||||
|
||||
class UserDefaultPermissions(Base):
|
||||
__tablename__ = "user_default_permissions"
|
||||
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
|
||||
user_id = Column(UUID, ForeignKey("users.id", ondelete="CASCADE"), primary_key=True)
|
||||
permission_id = Column(
|
||||
UUID,
|
||||
ForeignKey(
|
||||
"permissions.id", ondelete="CASCADE"
|
||||
), # cascade deletion when Permission is deleted
|
||||
primary_key=True,
|
||||
)
|
||||
|
|
@ -3,10 +3,10 @@ from sqlalchemy import Column, ForeignKey, DateTime, UUID
|
|||
from cognee.infrastructure.databases.relational import Base
|
||||
|
||||
|
||||
class UserGroup(Base):
|
||||
__tablename__ = "user_groups"
|
||||
class UserRole(Base):
|
||||
__tablename__ = "user_roles"
|
||||
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
|
||||
user_id = Column(UUID, ForeignKey("users.id"), primary_key=True)
|
||||
group_id = Column(UUID, ForeignKey("groups.id"), primary_key=True)
|
||||
role_id = Column(UUID, ForeignKey("roles.id"), primary_key=True)
|
||||
|
|
@ -1,7 +1,9 @@
|
|||
from .User import User
|
||||
from .Group import Group
|
||||
from .UserGroup import UserGroup
|
||||
from .GroupPermission import GroupPermission
|
||||
from .Resource import Resource
|
||||
from .Role import Role
|
||||
from .UserRole import UserRole
|
||||
from .RoleDefaultPermissions import RoleDefaultPermissions
|
||||
from .UserDefaultPermissions import UserDefaultPermissions
|
||||
from .TenantDefaultPermissions import TenantDefaultPermissions
|
||||
from .Permission import Permission
|
||||
from .Tenant import Tenant
|
||||
from .ACL import ACL
|
||||
|
|
|
|||
|
|
@ -1,3 +1,6 @@
|
|||
from .check_permission_on_documents import check_permission_on_documents
|
||||
from .give_permission_on_document import give_permission_on_document
|
||||
from .get_document_ids_for_user import get_document_ids_for_user
|
||||
from .give_default_permission_to_tenant import give_default_permission_to_tenant
|
||||
from .give_default_permission_to_role import give_default_permission_to_role
|
||||
from .give_default_permission_to_user import give_default_permission_to_user
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
async def check_permission_on_documents(user: User, permission_type: str, document_ids: list[UUID]):
|
||||
user_group_ids = [group.id for group in user.groups]
|
||||
user_roles_ids = [role.id for role in user.roles]
|
||||
|
||||
db_engine = get_relational_engine()
|
||||
|
||||
|
|
@ -21,13 +21,13 @@ async def check_permission_on_documents(user: User, permission_type: str, docume
|
|||
result = await session.execute(
|
||||
select(ACL)
|
||||
.join(ACL.permission)
|
||||
.options(joinedload(ACL.resources))
|
||||
.where(ACL.principal_id.in_([user.id, *user_group_ids]))
|
||||
.options(joinedload(ACL.data))
|
||||
.where(ACL.principal_id.in_([user.id, *user_roles_ids]))
|
||||
.where(ACL.permission.has(name=permission_type))
|
||||
)
|
||||
acls = result.unique().scalars().all()
|
||||
resource_ids = [resource.resource_id for acl in acls for resource in acl.resources]
|
||||
has_permissions = all(document_id in resource_ids for document_id in document_ids)
|
||||
data_ids = [acl.data.id for acl in acls]
|
||||
has_permissions = all(document_id in data_ids for document_id in document_ids)
|
||||
|
||||
if not has_permissions:
|
||||
raise PermissionDeniedError(
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
from uuid import UUID
|
||||
from sqlalchemy import select
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from cognee.modules.data.models import Dataset, DatasetData
|
||||
from ...models import ACL, Resource, Permission
|
||||
from cognee.modules.data.models import Dataset, DatasetData, Data
|
||||
from ...models import ACL, Permission
|
||||
|
||||
|
||||
async def get_document_ids_for_user(user_id: UUID, datasets: list[str] = None) -> list[str]:
|
||||
|
|
@ -12,8 +12,8 @@ async def get_document_ids_for_user(user_id: UUID, datasets: list[str] = None) -
|
|||
async with session.begin():
|
||||
document_ids = (
|
||||
await session.scalars(
|
||||
select(Resource.resource_id)
|
||||
.join(ACL.resources)
|
||||
select(Data.id)
|
||||
.join(ACL.data)
|
||||
.join(ACL.permission)
|
||||
.where(
|
||||
ACL.principal_id == user_id,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,56 @@
|
|||
from uuid import UUID
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy import insert
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from cognee.infrastructure.databases.exceptions import EntityAlreadyExistsError
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from cognee.modules.users.exceptions import (
|
||||
RoleNotFoundError,
|
||||
)
|
||||
from cognee.modules.users.models import (
|
||||
Permission,
|
||||
Role,
|
||||
RoleDefaultPermissions,
|
||||
)
|
||||
|
||||
|
||||
async def give_default_permission_to_role(role_id: UUID, permission_name: str):
|
||||
db_engine = get_relational_engine()
|
||||
|
||||
async with db_engine.get_async_session() as session:
|
||||
role = (await session.execute(select(Role).where(Role.id == role_id))).scalars().first()
|
||||
|
||||
if not role:
|
||||
raise RoleNotFoundError
|
||||
|
||||
permission_entity = (
|
||||
(await session.execute(select(Permission).where(Permission.name == permission_name)))
|
||||
.scalars()
|
||||
.first()
|
||||
)
|
||||
|
||||
if not permission_entity:
|
||||
stmt = insert(Permission).values(name=permission_name)
|
||||
await session.execute(stmt)
|
||||
permission_entity = (
|
||||
(
|
||||
await session.execute(
|
||||
select(Permission).where(Permission.name == permission_name)
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.first()
|
||||
)
|
||||
|
||||
try:
|
||||
# add default permission to role
|
||||
await session.execute(
|
||||
insert(RoleDefaultPermissions).values(
|
||||
role_id=role.id, permission_id=permission_entity.id
|
||||
)
|
||||
)
|
||||
except IntegrityError:
|
||||
raise EntityAlreadyExistsError(message="Role permission already exists.")
|
||||
|
||||
await session.commit()
|
||||
|
|
@ -0,0 +1,57 @@
|
|||
from uuid import UUID
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy import insert
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from cognee.infrastructure.databases.exceptions import EntityAlreadyExistsError
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from cognee.modules.users.exceptions import (
|
||||
TenantNotFoundError,
|
||||
)
|
||||
from cognee.modules.users.models import (
|
||||
Permission,
|
||||
Tenant,
|
||||
TenantDefaultPermissions,
|
||||
)
|
||||
|
||||
|
||||
async def give_default_permission_to_tenant(tenant_id: UUID, permission_name: str):
|
||||
db_engine = get_relational_engine()
|
||||
async with db_engine.get_async_session() as session:
|
||||
tenant = (
|
||||
(await session.execute(select(Tenant).where(Tenant.id == tenant_id))).scalars().first()
|
||||
)
|
||||
|
||||
if not tenant:
|
||||
raise TenantNotFoundError
|
||||
|
||||
permission_entity = (
|
||||
(await session.execute(select(Permission).where(Permission.name == permission_name)))
|
||||
.scalars()
|
||||
.first()
|
||||
)
|
||||
|
||||
if not permission_entity:
|
||||
create_permission_statement = insert(Permission).values(name=permission_name)
|
||||
await session.execute(create_permission_statement)
|
||||
permission_entity = (
|
||||
(
|
||||
await session.execute(
|
||||
select(Permission).where(Permission.name == permission_name)
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.first()
|
||||
)
|
||||
|
||||
try:
|
||||
# add default permission to tenant
|
||||
await session.execute(
|
||||
insert(TenantDefaultPermissions).values(
|
||||
tenant_id=tenant.id, permission_id=permission_entity.id
|
||||
)
|
||||
)
|
||||
except IntegrityError:
|
||||
raise EntityAlreadyExistsError(message="Tenant permission already exists.")
|
||||
|
||||
await session.commit()
|
||||
|
|
@ -0,0 +1,55 @@
|
|||
from uuid import UUID
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy import insert
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from cognee.infrastructure.databases.exceptions import EntityAlreadyExistsError
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from cognee.modules.users.exceptions import (
|
||||
UserNotFoundError,
|
||||
)
|
||||
from cognee.modules.users.models import (
|
||||
Permission,
|
||||
User,
|
||||
UserDefaultPermissions,
|
||||
)
|
||||
|
||||
|
||||
async def give_default_permission_to_user(user_id: UUID, permission_name: str):
|
||||
db_engine = get_relational_engine()
|
||||
async with db_engine.get_async_session() as session:
|
||||
user = (await session.execute(select(User).where(User.id == user_id))).scalars().first()
|
||||
|
||||
if not user:
|
||||
raise UserNotFoundError
|
||||
|
||||
permission_entity = (
|
||||
(await session.execute(select(Permission).where(Permission.name == permission_name)))
|
||||
.scalars()
|
||||
.first()
|
||||
)
|
||||
|
||||
if not permission_entity:
|
||||
create_permission_statement = insert(Permission).values(name=permission_name)
|
||||
await session.execute(create_permission_statement)
|
||||
permission_entity = (
|
||||
(
|
||||
await session.execute(
|
||||
select(Permission).where(Permission.name == permission_name)
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.first()
|
||||
)
|
||||
|
||||
try:
|
||||
# add default permission to user
|
||||
await session.execute(
|
||||
insert(UserDefaultPermissions).values(
|
||||
user_id=user.id, permission_id=permission_entity.id
|
||||
)
|
||||
)
|
||||
except IntegrityError:
|
||||
raise EntityAlreadyExistsError(message="User permission already exists.")
|
||||
|
||||
await session.commit()
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
from sqlalchemy.future import select
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from ...models import User, ACL, Resource, Permission
|
||||
from ...models import User, ACL, Permission
|
||||
|
||||
|
||||
async def give_permission_on_document(
|
||||
|
|
@ -10,8 +10,6 @@ async def give_permission_on_document(
|
|||
):
|
||||
db_engine = get_relational_engine()
|
||||
|
||||
document_resource = Resource(resource_id=document_id)
|
||||
|
||||
async with db_engine.get_async_session() as session:
|
||||
permission = (
|
||||
(await session.execute(select(Permission).filter(Permission.name == permission_name)))
|
||||
|
|
@ -22,9 +20,7 @@ async def give_permission_on_document(
|
|||
if permission is None:
|
||||
permission = Permission(name=permission_name)
|
||||
|
||||
acl = ACL(principal_id=user.id)
|
||||
acl.permission = permission
|
||||
acl.resources.append(document_resource)
|
||||
acl = ACL(principal_id=user.id, data_id=document_id, permission=permission)
|
||||
|
||||
session.add(acl)
|
||||
|
||||
|
|
|
|||
2
cognee/modules/users/roles/methods/__init__.py
Normal file
2
cognee/modules/users/roles/methods/__init__.py
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
from .create_role import create_role
|
||||
from .add_user_to_role import add_user_to_role
|
||||
38
cognee/modules/users/roles/methods/add_user_to_role.py
Normal file
38
cognee/modules/users/roles/methods/add_user_to_role.py
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy import insert
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from cognee.infrastructure.databases.exceptions import EntityAlreadyExistsError
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from cognee.modules.users.exceptions import (
|
||||
UserNotFoundError,
|
||||
RoleNotFoundError,
|
||||
)
|
||||
from cognee.modules.users.models import (
|
||||
User,
|
||||
Role,
|
||||
UserRole,
|
||||
)
|
||||
|
||||
|
||||
async def add_user_to_role(user_id: UUID, role_id: UUID):
|
||||
db_engine = get_relational_engine()
|
||||
async with db_engine.get_async_session() as session:
|
||||
user = (await session.execute(select(User).where(User.id == user_id))).scalars().first()
|
||||
role = (await session.execute(select(Role).where(Role.id == role_id))).scalars().first()
|
||||
|
||||
if not user:
|
||||
raise UserNotFoundError
|
||||
elif not role:
|
||||
raise RoleNotFoundError
|
||||
|
||||
try:
|
||||
# Add association directly to the association table
|
||||
create_user_role_statement = insert(UserRole).values(user_id=user_id, role_id=role_id)
|
||||
await session.execute(create_user_role_statement)
|
||||
except IntegrityError:
|
||||
raise EntityAlreadyExistsError(message="User is already part of group.")
|
||||
|
||||
await session.commit()
|
||||
26
cognee/modules/users/roles/methods/create_role.py
Normal file
26
cognee/modules/users/roles/methods/create_role.py
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from cognee.infrastructure.databases.exceptions import EntityAlreadyExistsError
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from cognee.modules.users.models import (
|
||||
Role,
|
||||
)
|
||||
|
||||
|
||||
async def create_role(
|
||||
role_name: str,
|
||||
tenant_id: UUID,
|
||||
):
|
||||
db_engine = get_relational_engine()
|
||||
async with db_engine.get_async_session() as session:
|
||||
try:
|
||||
# Add association directly to the association table
|
||||
role = Role(name=role_name, tenant_id=tenant_id)
|
||||
session.add(role)
|
||||
except IntegrityError:
|
||||
raise EntityAlreadyExistsError(message="Role already exists for tenant.")
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(role)
|
||||
1
cognee/modules/users/tenants/methods/__init__.py
Normal file
1
cognee/modules/users/tenants/methods/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
from .create_tenant import create_tenant
|
||||
19
cognee/modules/users/tenants/methods/create_tenant.py
Normal file
19
cognee/modules/users/tenants/methods/create_tenant.py
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from cognee.infrastructure.databases.exceptions import EntityAlreadyExistsError
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from cognee.modules.users.models import Tenant
|
||||
|
||||
|
||||
async def create_tenant(tenant_name: str):
|
||||
db_engine = get_relational_engine()
|
||||
async with db_engine.get_async_session() as session:
|
||||
try:
|
||||
# Add association directly to the association table
|
||||
tenant = Tenant(name=tenant_name)
|
||||
session.add(tenant)
|
||||
except IntegrityError:
|
||||
raise EntityAlreadyExistsError(message="Tenant already exists.")
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(tenant)
|
||||
|
|
@ -15,7 +15,6 @@ from .save_data_item_to_storage import (
|
|||
|
||||
from typing import Union, BinaryIO
|
||||
import inspect
|
||||
import warnings
|
||||
|
||||
|
||||
async def ingest_data(data: Any, dataset_name: str, user: User):
|
||||
|
|
|
|||
|
|
@ -1,12 +1,17 @@
|
|||
import asyncio
|
||||
from queue import Queue
|
||||
|
||||
import cognee
|
||||
from cognee.modules.pipelines.operations.run_tasks import run_tasks_base
|
||||
from cognee.modules.pipelines.tasks.Task import Task
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.infrastructure.databases.relational import create_db_and_tables
|
||||
|
||||
|
||||
async def pipeline(data_queue):
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
async def queue_consumer():
|
||||
while not data_queue.is_closed:
|
||||
if not data_queue.empty():
|
||||
|
|
@ -20,7 +25,9 @@ async def pipeline(data_queue):
|
|||
async def multiply_by_two(num):
|
||||
yield num * 2
|
||||
|
||||
await create_db_and_tables()
|
||||
user = await get_default_user()
|
||||
|
||||
tasks_run = run_tasks_base(
|
||||
[
|
||||
Task(queue_consumer),
|
||||
|
|
|
|||
|
|
@ -1,11 +1,16 @@
|
|||
import asyncio
|
||||
|
||||
import cognee
|
||||
from cognee.modules.pipelines.operations.run_tasks import run_tasks_base
|
||||
from cognee.modules.pipelines.tasks.Task import Task
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.infrastructure.databases.relational import create_db_and_tables
|
||||
|
||||
|
||||
async def run_and_check_tasks():
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
def number_generator(num):
|
||||
for i in range(num):
|
||||
yield i + 1
|
||||
|
|
@ -20,7 +25,9 @@ async def run_and_check_tasks():
|
|||
async def add_one_single(num):
|
||||
yield num + 1
|
||||
|
||||
await create_db_and_tables()
|
||||
user = await get_default_user()
|
||||
|
||||
pipeline = run_tasks_base(
|
||||
[
|
||||
Task(number_generator),
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue