feat: Add multi-tenancy (#1560)
<!-- .github/pull_request_template.md --> ## Description Add multi-tenancy to Cognee ## Type of Change <!-- Please check the relevant option --> - [ ] Bug fix (non-breaking change that fixes an issue) - [ ] New feature (non-breaking change that adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Documentation update - [ ] Code refactoring - [ ] Performance improvement - [ ] Other (please specify): ## Screenshots/Videos (if applicable) <!-- Add screenshots or videos to help explain your changes --> ## Pre-submission Checklist <!-- Please check all boxes that apply before submitting your PR --> - [ ] **I have tested my changes thoroughly before submitting this PR** - [ ] **This PR contains minimal changes necessary to address the issue/feature** - [ ] My code follows the project's coding standards and style guidelines - [ ] I have added tests that prove my fix is effective or that my feature works - [ ] I have added necessary documentation (if applicable) - [ ] All new and existing tests pass - [ ] I have searched existing PRs to ensure this change hasn't been submitted already - [ ] I have linked any relevant issues in the description - [ ] My commits have clear and descriptive messages ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin.
This commit is contained in:
commit
78b825f338
37 changed files with 814 additions and 151 deletions
27
.github/workflows/e2e_tests.yml
vendored
27
.github/workflows/e2e_tests.yml
vendored
|
|
@ -226,7 +226,7 @@ jobs:
|
|||
- name: Dependencies already installed
|
||||
run: echo "Dependencies already installed in setup"
|
||||
|
||||
- name: Run parallel databases test
|
||||
- name: Run permissions test
|
||||
env:
|
||||
ENV: 'dev'
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
|
|
@ -239,6 +239,31 @@ jobs:
|
|||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
run: uv run python ./cognee/tests/test_permissions.py
|
||||
|
||||
test-multi-tenancy:
|
||||
name: Test multi tenancy with different situations in Cognee
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Check out repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Cognee Setup
|
||||
uses: ./.github/actions/cognee_setup
|
||||
with:
|
||||
python-version: '3.11.x'
|
||||
|
||||
- name: Run multi tenancy test
|
||||
env:
|
||||
ENV: 'dev'
|
||||
LLM_MODEL: ${{ secrets.LLM_MODEL }}
|
||||
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
|
||||
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
|
||||
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
|
||||
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
|
||||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
run: uv run python ./cognee/tests/test_multi_tenancy.py
|
||||
|
||||
test-graph-edges:
|
||||
name: Test graph edge ingestion
|
||||
runs-on: ubuntu-22.04
|
||||
|
|
|
|||
|
|
@ -87,11 +87,6 @@ db_engine = get_relational_engine()
|
|||
|
||||
print("Using database:", db_engine.db_uri)
|
||||
|
||||
if "sqlite" in db_engine.db_uri:
|
||||
from cognee.infrastructure.utils.run_sync import run_sync
|
||||
|
||||
run_sync(db_engine.create_database())
|
||||
|
||||
config.set_section_option(
|
||||
config.config_ini_section,
|
||||
"SQLALCHEMY_DATABASE_URI",
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from typing import Sequence, Union
|
|||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
|
|
@ -26,7 +27,34 @@ def upgrade() -> None:
|
|||
connection = op.get_bind()
|
||||
inspector = sa.inspect(connection)
|
||||
|
||||
if op.get_context().dialect.name == "postgresql":
|
||||
syncstatus_enum = postgresql.ENUM(
|
||||
"STARTED", "IN_PROGRESS", "COMPLETED", "FAILED", "CANCELLED", name="syncstatus"
|
||||
)
|
||||
syncstatus_enum.create(op.get_bind(), checkfirst=True)
|
||||
|
||||
if "sync_operations" not in inspector.get_table_names():
|
||||
if op.get_context().dialect.name == "postgresql":
|
||||
syncstatus = postgresql.ENUM(
|
||||
"STARTED",
|
||||
"IN_PROGRESS",
|
||||
"COMPLETED",
|
||||
"FAILED",
|
||||
"CANCELLED",
|
||||
name="syncstatus",
|
||||
create_type=False,
|
||||
)
|
||||
else:
|
||||
syncstatus = sa.Enum(
|
||||
"STARTED",
|
||||
"IN_PROGRESS",
|
||||
"COMPLETED",
|
||||
"FAILED",
|
||||
"CANCELLED",
|
||||
name="syncstatus",
|
||||
create_type=False,
|
||||
)
|
||||
|
||||
# Table doesn't exist, create it normally
|
||||
op.create_table(
|
||||
"sync_operations",
|
||||
|
|
@ -34,15 +62,7 @@ def upgrade() -> None:
|
|||
sa.Column("run_id", sa.Text(), nullable=True),
|
||||
sa.Column(
|
||||
"status",
|
||||
sa.Enum(
|
||||
"STARTED",
|
||||
"IN_PROGRESS",
|
||||
"COMPLETED",
|
||||
"FAILED",
|
||||
"CANCELLED",
|
||||
name="syncstatus",
|
||||
create_type=False,
|
||||
),
|
||||
syncstatus,
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("progress_percentage", sa.Integer(), nullable=True),
|
||||
|
|
|
|||
|
|
@ -23,11 +23,8 @@ depends_on: Union[str, Sequence[str], None] = "8057ae7329c2"
|
|||
|
||||
|
||||
def upgrade() -> None:
|
||||
try:
|
||||
await_only(create_default_user())
|
||||
except UserAlreadyExists:
|
||||
pass # It's fine if the default user already exists
|
||||
pass
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
await_only(delete_user("default_user@example.com"))
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -18,11 +18,8 @@ depends_on: Union[str, Sequence[str], None] = None
|
|||
|
||||
|
||||
def upgrade() -> None:
|
||||
db_engine = get_relational_engine()
|
||||
# we might want to delete this
|
||||
await_only(db_engine.create_database())
|
||||
pass
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
db_engine = get_relational_engine()
|
||||
await_only(db_engine.delete_database())
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -144,44 +144,58 @@ def _create_data_permission(conn, user_id, data_id, permission_name):
|
|||
)
|
||||
|
||||
|
||||
def _get_column(inspector, table, name, schema=None):
|
||||
for col in inspector.get_columns(table, schema=schema):
|
||||
if col["name"] == name:
|
||||
return col
|
||||
return None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
insp = sa.inspect(conn)
|
||||
|
||||
# Recreate ACLs table with default permissions set to datasets instead of documents
|
||||
op.drop_table("acls")
|
||||
dataset_id_column = _get_column(insp, "acls", "dataset_id")
|
||||
if not dataset_id_column:
|
||||
# Recreate ACLs table with default permissions set to datasets instead of documents
|
||||
op.drop_table("acls")
|
||||
|
||||
acls_table = op.create_table(
|
||||
"acls",
|
||||
sa.Column("id", UUID, primary_key=True, default=uuid4),
|
||||
sa.Column(
|
||||
"created_at", sa.DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at", sa.DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc)
|
||||
),
|
||||
sa.Column("principal_id", UUID, sa.ForeignKey("principals.id")),
|
||||
sa.Column("permission_id", UUID, sa.ForeignKey("permissions.id")),
|
||||
sa.Column("dataset_id", UUID, sa.ForeignKey("datasets.id", ondelete="CASCADE")),
|
||||
)
|
||||
acls_table = op.create_table(
|
||||
"acls",
|
||||
sa.Column("id", UUID, primary_key=True, default=uuid4),
|
||||
sa.Column(
|
||||
"created_at", sa.DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
onupdate=lambda: datetime.now(timezone.utc),
|
||||
),
|
||||
sa.Column("principal_id", UUID, sa.ForeignKey("principals.id")),
|
||||
sa.Column("permission_id", UUID, sa.ForeignKey("permissions.id")),
|
||||
sa.Column("dataset_id", UUID, sa.ForeignKey("datasets.id", ondelete="CASCADE")),
|
||||
)
|
||||
|
||||
# Note: We can't use any Cognee model info to gather data (as it can change) in database so we must use our own table
|
||||
# definition or load what is in the database
|
||||
dataset_table = _define_dataset_table()
|
||||
datasets = conn.execute(sa.select(dataset_table)).fetchall()
|
||||
# Note: We can't use any Cognee model info to gather data (as it can change) in database so we must use our own table
|
||||
# definition or load what is in the database
|
||||
dataset_table = _define_dataset_table()
|
||||
datasets = conn.execute(sa.select(dataset_table)).fetchall()
|
||||
|
||||
if not datasets:
|
||||
return
|
||||
if not datasets:
|
||||
return
|
||||
|
||||
acl_list = []
|
||||
acl_list = []
|
||||
|
||||
for dataset in datasets:
|
||||
acl_list.append(_create_dataset_permission(conn, dataset.owner_id, dataset.id, "read"))
|
||||
acl_list.append(_create_dataset_permission(conn, dataset.owner_id, dataset.id, "write"))
|
||||
acl_list.append(_create_dataset_permission(conn, dataset.owner_id, dataset.id, "share"))
|
||||
acl_list.append(_create_dataset_permission(conn, dataset.owner_id, dataset.id, "delete"))
|
||||
for dataset in datasets:
|
||||
acl_list.append(_create_dataset_permission(conn, dataset.owner_id, dataset.id, "read"))
|
||||
acl_list.append(_create_dataset_permission(conn, dataset.owner_id, dataset.id, "write"))
|
||||
acl_list.append(_create_dataset_permission(conn, dataset.owner_id, dataset.id, "share"))
|
||||
acl_list.append(
|
||||
_create_dataset_permission(conn, dataset.owner_id, dataset.id, "delete")
|
||||
)
|
||||
|
||||
if acl_list:
|
||||
op.bulk_insert(acls_table, acl_list)
|
||||
if acl_list:
|
||||
op.bulk_insert(acls_table, acl_list)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
|
|
|
|||
137
alembic/versions/c946955da633_multi_tenant_support.py
Normal file
137
alembic/versions/c946955da633_multi_tenant_support.py
Normal file
|
|
@ -0,0 +1,137 @@
|
|||
"""Multi Tenant Support
|
||||
|
||||
Revision ID: c946955da633
|
||||
Revises: 211ab850ef3d
|
||||
Create Date: 2025-11-04 18:11:09.325158
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
from datetime import datetime, timezone
|
||||
from uuid import uuid4
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "c946955da633"
|
||||
down_revision: Union[str, None] = "211ab850ef3d"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def _now():
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
def _define_user_table() -> sa.Table:
|
||||
table = sa.Table(
|
||||
"users",
|
||||
sa.MetaData(),
|
||||
sa.Column(
|
||||
"id",
|
||||
sa.UUID,
|
||||
sa.ForeignKey("principals.id", ondelete="CASCADE"),
|
||||
primary_key=True,
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("tenant_id", sa.UUID, sa.ForeignKey("tenants.id"), index=True, nullable=True),
|
||||
)
|
||||
return table
|
||||
|
||||
|
||||
def _define_dataset_table() -> sa.Table:
|
||||
# Note: We can't use any Cognee model info to gather data (as it can change) in database so we must use our own table
|
||||
# definition or load what is in the database
|
||||
table = sa.Table(
|
||||
"datasets",
|
||||
sa.MetaData(),
|
||||
sa.Column("id", sa.UUID, primary_key=True, default=uuid4),
|
||||
sa.Column("name", sa.Text),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
onupdate=lambda: datetime.now(timezone.utc),
|
||||
),
|
||||
sa.Column("owner_id", sa.UUID(), sa.ForeignKey("principals.id"), index=True),
|
||||
sa.Column("tenant_id", sa.UUID(), sa.ForeignKey("tenants.id"), index=True, nullable=True),
|
||||
)
|
||||
|
||||
return table
|
||||
|
||||
|
||||
def _get_column(inspector, table, name, schema=None):
|
||||
for col in inspector.get_columns(table, schema=schema):
|
||||
if col["name"] == name:
|
||||
return col
|
||||
return None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
insp = sa.inspect(conn)
|
||||
|
||||
dataset = _define_dataset_table()
|
||||
user = _define_user_table()
|
||||
|
||||
if "user_tenants" not in insp.get_table_names():
|
||||
# Define table with all necessary columns including primary key
|
||||
user_tenants = op.create_table(
|
||||
"user_tenants",
|
||||
sa.Column("user_id", sa.UUID, sa.ForeignKey("users.id"), primary_key=True),
|
||||
sa.Column("tenant_id", sa.UUID, sa.ForeignKey("tenants.id"), primary_key=True),
|
||||
sa.Column(
|
||||
"created_at", sa.DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
|
||||
),
|
||||
)
|
||||
|
||||
# Get all users with their tenant_id
|
||||
user_data = conn.execute(
|
||||
sa.select(user.c.id, user.c.tenant_id).where(user.c.tenant_id.isnot(None))
|
||||
).fetchall()
|
||||
|
||||
# Insert into user_tenants table
|
||||
if user_data:
|
||||
op.bulk_insert(
|
||||
user_tenants,
|
||||
[
|
||||
{"user_id": user_id, "tenant_id": tenant_id, "created_at": _now()}
|
||||
for user_id, tenant_id in user_data
|
||||
],
|
||||
)
|
||||
|
||||
tenant_id_column = _get_column(insp, "datasets", "tenant_id")
|
||||
if not tenant_id_column:
|
||||
op.add_column("datasets", sa.Column("tenant_id", sa.UUID(), nullable=True))
|
||||
|
||||
# Build subquery, select users.tenant_id for each dataset.owner_id
|
||||
tenant_id_from_dataset_owner = (
|
||||
sa.select(user.c.tenant_id).where(user.c.id == dataset.c.owner_id).scalar_subquery()
|
||||
)
|
||||
|
||||
if op.get_context().dialect.name == "sqlite":
|
||||
# If column doesn't exist create new original_extension column and update from values of extension column
|
||||
with op.batch_alter_table("datasets") as batch_op:
|
||||
batch_op.execute(
|
||||
dataset.update().values(
|
||||
tenant_id=tenant_id_from_dataset_owner,
|
||||
)
|
||||
)
|
||||
else:
|
||||
conn = op.get_bind()
|
||||
conn.execute(dataset.update().values(tenant_id=tenant_id_from_dataset_owner))
|
||||
|
||||
op.create_index(op.f("ix_datasets_tenant_id"), "datasets", ["tenant_id"])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table("user_tenants")
|
||||
op.drop_index(op.f("ix_datasets_tenant_id"), table_name="datasets")
|
||||
op.drop_column("datasets", "tenant_id")
|
||||
# ### end Alembic commands ###
|
||||
|
|
@ -1096,6 +1096,10 @@ async def main():
|
|||
|
||||
# Skip migrations when in API mode (the API server handles its own database)
|
||||
if not args.no_migration and not args.api_url:
|
||||
from cognee.modules.engine.operations.setup import setup
|
||||
|
||||
await setup()
|
||||
|
||||
# Run Alembic migrations from the main cognee directory where alembic.ini is located
|
||||
logger.info("Running database migrations...")
|
||||
migration_result = subprocess.run(
|
||||
|
|
|
|||
|
|
@ -1,15 +1,20 @@
|
|||
from uuid import UUID
|
||||
from typing import List
|
||||
from typing import List, Union
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.api.DTO import InDTO
|
||||
from cognee.modules.users.methods import get_authenticated_user
|
||||
from cognee.shared.utils import send_telemetry
|
||||
from cognee import __version__ as cognee_version
|
||||
|
||||
|
||||
class SelectTenantDTO(InDTO):
|
||||
tenant_id: UUID | None = None
|
||||
|
||||
|
||||
def get_permissions_router() -> APIRouter:
|
||||
permissions_router = APIRouter()
|
||||
|
||||
|
|
@ -226,4 +231,39 @@ def get_permissions_router() -> APIRouter:
|
|||
status_code=200, content={"message": "Tenant created.", "tenant_id": str(tenant_id)}
|
||||
)
|
||||
|
||||
@permissions_router.post("/tenants/select")
|
||||
async def select_tenant(payload: SelectTenantDTO, user: User = Depends(get_authenticated_user)):
|
||||
"""
|
||||
Select current tenant.
|
||||
|
||||
This endpoint selects a tenant with the specified UUID. Tenants are used
|
||||
to organize users and resources in multi-tenant environments, providing
|
||||
isolation and access control between different groups or organizations.
|
||||
|
||||
Sending a null/None value as tenant_id selects his default single user tenant
|
||||
|
||||
## Request Parameters
|
||||
- **tenant_id** (Union[UUID, None]): UUID of the tenant to select, If null/None is provided use the default single user tenant
|
||||
|
||||
## Response
|
||||
Returns a success message along with selected tenant id.
|
||||
"""
|
||||
send_telemetry(
|
||||
"Permissions API Endpoint Invoked",
|
||||
user.id,
|
||||
additional_properties={
|
||||
"endpoint": f"POST /v1/permissions/tenants/{str(payload.tenant_id)}",
|
||||
"tenant_id": str(payload.tenant_id),
|
||||
},
|
||||
)
|
||||
|
||||
from cognee.modules.users.tenants.methods import select_tenant as select_tenant_method
|
||||
|
||||
await select_tenant_method(user_id=user.id, tenant_id=payload.tenant_id)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content={"message": "Tenant selected.", "tenant_id": str(payload.tenant_id)},
|
||||
)
|
||||
|
||||
return permissions_router
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from .get_authorized_dataset import get_authorized_dataset
|
|||
from .get_authorized_dataset_by_name import get_authorized_dataset_by_name
|
||||
from .get_data import get_data
|
||||
from .get_unique_dataset_id import get_unique_dataset_id
|
||||
from .get_unique_data_id import get_unique_data_id
|
||||
from .get_authorized_existing_datasets import get_authorized_existing_datasets
|
||||
from .get_dataset_ids import get_dataset_ids
|
||||
|
||||
|
|
|
|||
|
|
@ -16,14 +16,16 @@ async def create_dataset(dataset_name: str, user: User, session: AsyncSession) -
|
|||
.options(joinedload(Dataset.data))
|
||||
.filter(Dataset.name == dataset_name)
|
||||
.filter(Dataset.owner_id == owner_id)
|
||||
.filter(Dataset.tenant_id == user.tenant_id)
|
||||
)
|
||||
).first()
|
||||
|
||||
if dataset is None:
|
||||
# Dataset id should be generated based on dataset_name and owner_id/user so multiple users can use the same dataset_name
|
||||
dataset_id = await get_unique_dataset_id(dataset_name=dataset_name, user=user)
|
||||
dataset = Dataset(id=dataset_id, name=dataset_name, data=[])
|
||||
dataset.owner_id = owner_id
|
||||
dataset = Dataset(
|
||||
id=dataset_id, name=dataset_name, data=[], owner_id=owner_id, tenant_id=user.tenant_id
|
||||
)
|
||||
|
||||
session.add(dataset)
|
||||
|
||||
|
|
|
|||
|
|
@ -27,7 +27,11 @@ async def get_dataset_ids(datasets: Union[list[str], list[UUID]], user):
|
|||
# Get all user owned dataset objects (If a user wants to write to a dataset he is not the owner of it must be provided through UUID.)
|
||||
user_datasets = await get_datasets(user.id)
|
||||
# Filter out non name mentioned datasets
|
||||
dataset_ids = [dataset.id for dataset in user_datasets if dataset.name in datasets]
|
||||
dataset_ids = [dataset for dataset in user_datasets if dataset.name in datasets]
|
||||
# Filter out non current tenant datasets
|
||||
dataset_ids = [
|
||||
dataset.id for dataset in dataset_ids if dataset.tenant_id == user.tenant_id
|
||||
]
|
||||
else:
|
||||
raise DatasetTypeError(
|
||||
f"One or more of the provided dataset types is not handled: f{datasets}"
|
||||
|
|
|
|||
68
cognee/modules/data/methods/get_unique_data_id.py
Normal file
68
cognee/modules/data/methods/get_unique_data_id.py
Normal file
|
|
@ -0,0 +1,68 @@
|
|||
from uuid import uuid5, NAMESPACE_OID, UUID
|
||||
from sqlalchemy import select
|
||||
|
||||
from cognee.modules.data.models.Data import Data
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from cognee.modules.users.models import User
|
||||
|
||||
|
||||
async def get_unique_data_id(data_identifier: str, user: User) -> UUID:
|
||||
"""
|
||||
Function returns a unique UUID for data based on data identifier, user id and tenant id.
|
||||
If data with legacy ID exists, return that ID to maintain compatibility.
|
||||
|
||||
Args:
|
||||
data_identifier: A way to uniquely identify data (e.g. file hash, data name, etc.)
|
||||
user: User object adding the data
|
||||
tenant_id: UUID of the tenant for which data is being added
|
||||
|
||||
Returns:
|
||||
UUID: Unique identifier for the data
|
||||
"""
|
||||
|
||||
def _get_deprecated_unique_data_id(data_identifier: str, user: User) -> UUID:
|
||||
"""
|
||||
Deprecated function, returns a unique UUID for data based on data identifier and user id.
|
||||
Needed to support legacy data without tenant information.
|
||||
Args:
|
||||
data_identifier: A way to uniquely identify data (e.g. file hash, data name, etc.)
|
||||
user: User object adding the data
|
||||
|
||||
Returns:
|
||||
UUID: Unique identifier for the data
|
||||
"""
|
||||
# return UUID hash of file contents + owner id + tenant_id
|
||||
return uuid5(NAMESPACE_OID, f"{data_identifier}{str(user.id)}")
|
||||
|
||||
def _get_modern_unique_data_id(data_identifier: str, user: User) -> UUID:
|
||||
"""
|
||||
Function returns a unique UUID for data based on data identifier, user id and tenant id.
|
||||
Args:
|
||||
data_identifier: A way to uniquely identify data (e.g. file hash, data name, etc.)
|
||||
user: User object adding the data
|
||||
tenant_id: UUID of the tenant for which data is being added
|
||||
|
||||
Returns:
|
||||
UUID: Unique identifier for the data
|
||||
"""
|
||||
# return UUID hash of file contents + owner id + tenant_id
|
||||
return uuid5(NAMESPACE_OID, f"{data_identifier}{str(user.id)}{str(user.tenant_id)}")
|
||||
|
||||
# Get all possible data_id values
|
||||
data_id = {
|
||||
"modern_data_id": _get_modern_unique_data_id(data_identifier=data_identifier, user=user),
|
||||
"legacy_data_id": _get_deprecated_unique_data_id(
|
||||
data_identifier=data_identifier, user=user
|
||||
),
|
||||
}
|
||||
|
||||
# Check if data item with legacy_data_id exists, if so use that one, else use modern_data_id
|
||||
db_engine = get_relational_engine()
|
||||
async with db_engine.get_async_session() as session:
|
||||
legacy_data_point = (
|
||||
await session.execute(select(Data).filter(Data.id == data_id["legacy_data_id"]))
|
||||
).scalar_one_or_none()
|
||||
|
||||
if not legacy_data_point:
|
||||
return data_id["modern_data_id"]
|
||||
return data_id["legacy_data_id"]
|
||||
|
|
@ -1,9 +1,71 @@
|
|||
from uuid import UUID, uuid5, NAMESPACE_OID
|
||||
from cognee.modules.users.models import User
|
||||
from typing import Union
|
||||
from sqlalchemy import select
|
||||
|
||||
from cognee.modules.data.models.Dataset import Dataset
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
|
||||
|
||||
async def get_unique_dataset_id(dataset_name: Union[str, UUID], user: User) -> UUID:
|
||||
if isinstance(dataset_name, UUID):
|
||||
return dataset_name
|
||||
return uuid5(NAMESPACE_OID, f"{dataset_name}{str(user.id)}")
|
||||
"""
|
||||
Function returns a unique UUID for dataset based on dataset name, user id and tenant id.
|
||||
If dataset with legacy ID exists, return that ID to maintain compatibility.
|
||||
|
||||
Args:
|
||||
dataset_name: string representing the dataset name
|
||||
user: User object adding the dataset
|
||||
tenant_id: UUID of the tenant for which dataset is being added
|
||||
|
||||
Returns:
|
||||
UUID: Unique identifier for the dataset
|
||||
"""
|
||||
|
||||
def _get_legacy_unique_dataset_id(dataset_name: Union[str, UUID], user: User) -> UUID:
|
||||
"""
|
||||
Legacy function, returns a unique UUID for dataset based on dataset name and user id.
|
||||
Needed to support legacy datasets without tenant information.
|
||||
Args:
|
||||
dataset_name: string representing the dataset name
|
||||
user: Current User object adding the dataset
|
||||
|
||||
Returns:
|
||||
UUID: Unique identifier for the dataset
|
||||
"""
|
||||
if isinstance(dataset_name, UUID):
|
||||
return dataset_name
|
||||
return uuid5(NAMESPACE_OID, f"{dataset_name}{str(user.id)}")
|
||||
|
||||
def _get_modern_unique_dataset_id(dataset_name: Union[str, UUID], user: User) -> UUID:
|
||||
"""
|
||||
Returns a unique UUID for dataset based on dataset name, user id and tenant_id.
|
||||
Args:
|
||||
dataset_name: string representing the dataset name
|
||||
user: Current User object adding the dataset
|
||||
tenant_id: UUID of the tenant for which dataset is being added
|
||||
|
||||
Returns:
|
||||
UUID: Unique identifier for the dataset
|
||||
"""
|
||||
if isinstance(dataset_name, UUID):
|
||||
return dataset_name
|
||||
return uuid5(NAMESPACE_OID, f"{dataset_name}{str(user.id)}{str(user.tenant_id)}")
|
||||
|
||||
# Get all possible dataset_id values
|
||||
dataset_id = {
|
||||
"modern_dataset_id": _get_modern_unique_dataset_id(dataset_name=dataset_name, user=user),
|
||||
"legacy_dataset_id": _get_legacy_unique_dataset_id(dataset_name=dataset_name, user=user),
|
||||
}
|
||||
|
||||
# Check if dataset with legacy_dataset_id exists, if so use that one, else use modern_dataset_id
|
||||
db_engine = get_relational_engine()
|
||||
async with db_engine.get_async_session() as session:
|
||||
legacy_dataset = (
|
||||
await session.execute(
|
||||
select(Dataset).filter(Dataset.id == dataset_id["legacy_dataset_id"])
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
|
||||
if not legacy_dataset:
|
||||
return dataset_id["modern_dataset_id"]
|
||||
return dataset_id["legacy_dataset_id"]
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ class Dataset(Base):
|
|||
updated_at = Column(DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc))
|
||||
|
||||
owner_id = Column(UUID, index=True)
|
||||
tenant_id = Column(UUID, index=True, nullable=True)
|
||||
|
||||
acls = relationship("ACL", back_populates="dataset", cascade="all, delete-orphan")
|
||||
|
||||
|
|
@ -36,5 +37,6 @@ class Dataset(Base):
|
|||
"createdAt": self.created_at.isoformat(),
|
||||
"updatedAt": self.updated_at.isoformat() if self.updated_at else None,
|
||||
"ownerId": str(self.owner_id),
|
||||
"tenantId": str(self.tenant_id),
|
||||
"data": [data.to_json() for data in self.data],
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
from uuid import uuid5, NAMESPACE_OID
|
||||
from uuid import UUID
|
||||
from .data_types import IngestionData
|
||||
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.data.methods import get_unique_data_id
|
||||
|
||||
|
||||
def identify(data: IngestionData, user: User) -> str:
|
||||
async def identify(data: IngestionData, user: User) -> UUID:
|
||||
data_content_hash: str = data.get_identifier()
|
||||
|
||||
# return UUID hash of file contents + owner id
|
||||
return uuid5(NAMESPACE_OID, f"{data_content_hash}{user.id}")
|
||||
return await get_unique_data_id(data_identifier=data_content_hash, user=user)
|
||||
|
|
|
|||
|
|
@ -69,7 +69,7 @@ async def run_tasks_data_item_incremental(
|
|||
async with open_data_file(file_path) as file:
|
||||
classified_data = ingestion.classify(file)
|
||||
# data_id is the hash of file contents + owner id to avoid duplicate data
|
||||
data_id = ingestion.identify(classified_data, user)
|
||||
data_id = await ingestion.identify(classified_data, user)
|
||||
else:
|
||||
# If data was already processed by Cognee get data id
|
||||
data_id = data_item.id
|
||||
|
|
|
|||
|
|
@ -172,6 +172,7 @@ async def search(
|
|||
"search_result": [context] if context else None,
|
||||
"dataset_id": datasets[0].id,
|
||||
"dataset_name": datasets[0].name,
|
||||
"dataset_tenant_id": datasets[0].tenant_id,
|
||||
"graphs": graphs,
|
||||
}
|
||||
)
|
||||
|
|
@ -181,6 +182,7 @@ async def search(
|
|||
"search_result": [result] if result else None,
|
||||
"dataset_id": datasets[0].id,
|
||||
"dataset_name": datasets[0].name,
|
||||
"dataset_tenant_id": datasets[0].tenant_id,
|
||||
"graphs": graphs,
|
||||
}
|
||||
)
|
||||
|
|
|
|||
|
|
@ -18,7 +18,6 @@ from typing import Optional
|
|||
async def create_user(
|
||||
email: str,
|
||||
password: str,
|
||||
tenant_id: Optional[str] = None,
|
||||
is_superuser: bool = False,
|
||||
is_active: bool = True,
|
||||
is_verified: bool = False,
|
||||
|
|
@ -30,37 +29,23 @@ async def create_user(
|
|||
async with relational_engine.get_async_session() as session:
|
||||
async with get_user_db_context(session) as user_db:
|
||||
async with get_user_manager_context(user_db) as user_manager:
|
||||
if tenant_id:
|
||||
# Check if the tenant already exists
|
||||
result = await session.execute(select(Tenant).where(Tenant.id == tenant_id))
|
||||
tenant = result.scalars().first()
|
||||
if not tenant:
|
||||
raise TenantNotFoundError
|
||||
|
||||
user = await user_manager.create(
|
||||
UserCreate(
|
||||
email=email,
|
||||
password=password,
|
||||
tenant_id=tenant.id,
|
||||
is_superuser=is_superuser,
|
||||
is_active=is_active,
|
||||
is_verified=is_verified,
|
||||
)
|
||||
)
|
||||
else:
|
||||
user = await user_manager.create(
|
||||
UserCreate(
|
||||
email=email,
|
||||
password=password,
|
||||
is_superuser=is_superuser,
|
||||
is_active=is_active,
|
||||
is_verified=is_verified,
|
||||
)
|
||||
user = await user_manager.create(
|
||||
UserCreate(
|
||||
email=email,
|
||||
password=password,
|
||||
is_superuser=is_superuser,
|
||||
is_active=is_active,
|
||||
is_verified=is_verified,
|
||||
)
|
||||
)
|
||||
|
||||
if auto_login:
|
||||
await session.refresh(user)
|
||||
|
||||
# Update tenants and roles information for User object
|
||||
_ = await user.awaitable_attrs.tenants
|
||||
_ = await user.awaitable_attrs.roles
|
||||
|
||||
return user
|
||||
except UserAlreadyExists as error:
|
||||
print(f"User {email} already exists")
|
||||
|
|
|
|||
|
|
@ -18,7 +18,9 @@ async def get_default_user() -> User:
|
|||
try:
|
||||
async with db_engine.get_async_session() as session:
|
||||
query = (
|
||||
select(User).options(selectinload(User.roles)).where(User.email == default_email)
|
||||
select(User)
|
||||
.options(selectinload(User.roles), selectinload(User.tenants))
|
||||
.where(User.email == default_email)
|
||||
)
|
||||
|
||||
result = await session.execute(query)
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ async def get_user(user_id: UUID):
|
|||
user = (
|
||||
await session.execute(
|
||||
select(User)
|
||||
.options(selectinload(User.roles), selectinload(User.tenant))
|
||||
.options(selectinload(User.roles), selectinload(User.tenants))
|
||||
.where(User.id == user_id)
|
||||
)
|
||||
).scalar()
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ async def get_user_by_email(user_email: str):
|
|||
user = (
|
||||
await session.execute(
|
||||
select(User)
|
||||
.options(joinedload(User.roles), joinedload(User.tenant))
|
||||
.options(joinedload(User.roles), joinedload(User.tenants))
|
||||
.where(User.email == user_email)
|
||||
)
|
||||
).scalar()
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.orm import relationship, Mapped
|
||||
from sqlalchemy import Column, String, ForeignKey, UUID
|
||||
from .Principal import Principal
|
||||
from .User import User
|
||||
from .UserTenant import UserTenant
|
||||
from .Role import Role
|
||||
|
||||
|
||||
|
|
@ -13,14 +13,13 @@ class Tenant(Principal):
|
|||
|
||||
owner_id = Column(UUID, index=True)
|
||||
|
||||
# One-to-Many relationship with User; specify the join via User.tenant_id
|
||||
users = relationship(
|
||||
users: Mapped[list["User"]] = relationship( # noqa: F821
|
||||
"User",
|
||||
back_populates="tenant",
|
||||
foreign_keys=lambda: [User.tenant_id],
|
||||
secondary=UserTenant.__tablename__,
|
||||
back_populates="tenants",
|
||||
)
|
||||
|
||||
# One-to-Many relationship with Role (if needed; similar fix)
|
||||
# One-to-Many relationship with Role
|
||||
roles = relationship(
|
||||
"Role",
|
||||
back_populates="tenant",
|
||||
|
|
|
|||
|
|
@ -6,8 +6,10 @@ from sqlalchemy import ForeignKey, Column, UUID
|
|||
from sqlalchemy.orm import relationship, Mapped
|
||||
|
||||
from .Principal import Principal
|
||||
from .UserTenant import UserTenant
|
||||
from .UserRole import UserRole
|
||||
from .Role import Role
|
||||
from .Tenant import Tenant
|
||||
|
||||
|
||||
class User(SQLAlchemyBaseUserTableUUID, Principal):
|
||||
|
|
@ -15,7 +17,7 @@ class User(SQLAlchemyBaseUserTableUUID, Principal):
|
|||
|
||||
id = Column(UUID, ForeignKey("principals.id", ondelete="CASCADE"), primary_key=True)
|
||||
|
||||
# Foreign key to Tenant (Many-to-One relationship)
|
||||
# Foreign key to current Tenant (Many-to-One relationship)
|
||||
tenant_id = Column(UUID, ForeignKey("tenants.id"))
|
||||
|
||||
# Many-to-Many Relationship with Roles
|
||||
|
|
@ -25,11 +27,11 @@ class User(SQLAlchemyBaseUserTableUUID, Principal):
|
|||
back_populates="users",
|
||||
)
|
||||
|
||||
# Relationship to Tenant
|
||||
tenant = relationship(
|
||||
# Many-to-Many Relationship with Tenants user is a part of
|
||||
tenants: Mapped[list["Tenant"]] = relationship(
|
||||
"Tenant",
|
||||
secondary=UserTenant.__tablename__,
|
||||
back_populates="users",
|
||||
foreign_keys=[tenant_id],
|
||||
)
|
||||
|
||||
# ACL Relationship (One-to-Many)
|
||||
|
|
@ -46,7 +48,6 @@ class UserRead(schemas.BaseUser[uuid_UUID]):
|
|||
|
||||
|
||||
class UserCreate(schemas.BaseUserCreate):
|
||||
tenant_id: Optional[uuid_UUID] = None
|
||||
is_verified: bool = True
|
||||
|
||||
|
||||
|
|
|
|||
12
cognee/modules/users/models/UserTenant.py
Normal file
12
cognee/modules/users/models/UserTenant.py
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
from datetime import datetime, timezone
|
||||
from sqlalchemy import Column, ForeignKey, DateTime, UUID
|
||||
from cognee.infrastructure.databases.relational import Base
|
||||
|
||||
|
||||
class UserTenant(Base):
|
||||
__tablename__ = "user_tenants"
|
||||
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
|
||||
user_id = Column(UUID, ForeignKey("users.id"), primary_key=True)
|
||||
tenant_id = Column(UUID, ForeignKey("tenants.id"), primary_key=True)
|
||||
|
|
@ -1,6 +1,7 @@
|
|||
from .User import User
|
||||
from .Role import Role
|
||||
from .UserRole import UserRole
|
||||
from .UserTenant import UserTenant
|
||||
from .DatasetDatabase import DatasetDatabase
|
||||
from .RoleDefaultPermissions import RoleDefaultPermissions
|
||||
from .UserDefaultPermissions import UserDefaultPermissions
|
||||
|
|
|
|||
|
|
@ -1,11 +1,8 @@
|
|||
from types import SimpleNamespace
|
||||
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
from ...models.User import User
|
||||
from cognee.modules.data.models.Dataset import Dataset
|
||||
from cognee.modules.users.permissions.methods import get_principal_datasets
|
||||
from cognee.modules.users.permissions.methods import get_role, get_tenant
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
|
@ -25,17 +22,14 @@ async def get_all_user_permission_datasets(user: User, permission_type: str) ->
|
|||
# Get all datasets User has explicit access to
|
||||
datasets.extend(await get_principal_datasets(user, permission_type))
|
||||
|
||||
if user.tenant_id:
|
||||
# Get all datasets all tenants have access to
|
||||
tenant = await get_tenant(user.tenant_id)
|
||||
# Get all tenants user is a part of
|
||||
tenants = await user.awaitable_attrs.tenants
|
||||
for tenant in tenants:
|
||||
# Get all datasets all tenant members have access to
|
||||
datasets.extend(await get_principal_datasets(tenant, permission_type))
|
||||
|
||||
# Get all datasets Users roles have access to
|
||||
if isinstance(user, SimpleNamespace):
|
||||
# If simple namespace use roles defined in user
|
||||
roles = user.roles
|
||||
else:
|
||||
roles = await user.awaitable_attrs.roles
|
||||
# Get all datasets accessible by roles user is a part of
|
||||
roles = await user.awaitable_attrs.roles
|
||||
for role in roles:
|
||||
datasets.extend(await get_principal_datasets(role, permission_type))
|
||||
|
||||
|
|
@ -45,4 +39,10 @@ async def get_all_user_permission_datasets(user: User, permission_type: str) ->
|
|||
# If the dataset id key already exists, leave the dictionary unchanged.
|
||||
unique.setdefault(dataset.id, dataset)
|
||||
|
||||
return list(unique.values())
|
||||
# Filter out dataset that aren't part of the selected user's tenant
|
||||
filtered_datasets = []
|
||||
for dataset in list(unique.values()):
|
||||
if dataset.tenant_id == user.tenant_id:
|
||||
filtered_datasets.append(dataset)
|
||||
|
||||
return filtered_datasets
|
||||
|
|
|
|||
|
|
@ -42,11 +42,13 @@ async def add_user_to_role(user_id: UUID, role_id: UUID, owner_id: UUID):
|
|||
.first()
|
||||
)
|
||||
|
||||
user_tenants = await user.awaitable_attrs.tenants
|
||||
|
||||
if not user:
|
||||
raise UserNotFoundError
|
||||
elif not role:
|
||||
raise RoleNotFoundError
|
||||
elif user.tenant_id != role.tenant_id:
|
||||
elif role.tenant_id not in [tenant.id for tenant in user_tenants]:
|
||||
raise TenantNotFoundError(
|
||||
message="User tenant does not match role tenant. User cannot be added to role."
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,2 +1,3 @@
|
|||
from .create_tenant import create_tenant
|
||||
from .add_user_to_tenant import add_user_to_tenant
|
||||
from .select_tenant import select_tenant
|
||||
|
|
|
|||
|
|
@ -1,8 +1,11 @@
|
|||
from typing import Optional
|
||||
from uuid import UUID
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy import insert
|
||||
|
||||
from cognee.infrastructure.databases.exceptions import EntityAlreadyExistsError
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from cognee.modules.users.models.UserTenant import UserTenant
|
||||
from cognee.modules.users.methods import get_user
|
||||
from cognee.modules.users.permissions.methods import get_tenant
|
||||
from cognee.modules.users.exceptions import (
|
||||
|
|
@ -12,14 +15,19 @@ from cognee.modules.users.exceptions import (
|
|||
)
|
||||
|
||||
|
||||
async def add_user_to_tenant(user_id: UUID, tenant_id: UUID, owner_id: UUID):
|
||||
async def add_user_to_tenant(
|
||||
user_id: UUID, tenant_id: UUID, owner_id: UUID, set_as_active_tenant: Optional[bool] = False
|
||||
):
|
||||
"""
|
||||
Add a user with the given id to the tenant with the given id.
|
||||
This can only be successful if the request owner with the given id is the tenant owner.
|
||||
|
||||
If set_as_active_tenant is true it will automatically set the users active tenant to provided tenant.
|
||||
Args:
|
||||
user_id: Id of the user.
|
||||
tenant_id: Id of the tenant.
|
||||
owner_id: Id of the request owner.
|
||||
set_as_active_tenant: If set_as_active_tenant is true it will automatically set the users active tenant to provided tenant.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
|
@ -40,17 +48,18 @@ async def add_user_to_tenant(user_id: UUID, tenant_id: UUID, owner_id: UUID):
|
|||
message="Only tenant owner can add other users to organization."
|
||||
)
|
||||
|
||||
try:
|
||||
if user.tenant_id is None:
|
||||
user.tenant_id = tenant_id
|
||||
elif user.tenant_id == tenant_id:
|
||||
return
|
||||
else:
|
||||
raise IntegrityError
|
||||
|
||||
if set_as_active_tenant:
|
||||
user.tenant_id = tenant_id
|
||||
await session.merge(user)
|
||||
await session.commit()
|
||||
except IntegrityError:
|
||||
raise EntityAlreadyExistsError(
|
||||
message="User is already part of a tenant. Only one tenant can be assigned to user."
|
||||
|
||||
try:
|
||||
# Add association directly to the association table
|
||||
create_user_tenant_statement = insert(UserTenant).values(
|
||||
user_id=user_id, tenant_id=tenant_id
|
||||
)
|
||||
await session.execute(create_user_tenant_statement)
|
||||
await session.commit()
|
||||
|
||||
except IntegrityError:
|
||||
raise EntityAlreadyExistsError(message="User is already part of group.")
|
||||
|
|
|
|||
|
|
@ -1,19 +1,25 @@
|
|||
from uuid import UUID
|
||||
from sqlalchemy import insert
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from typing import Optional
|
||||
|
||||
from cognee.modules.users.models.UserTenant import UserTenant
|
||||
from cognee.infrastructure.databases.exceptions import EntityAlreadyExistsError
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from cognee.modules.users.models import Tenant
|
||||
from cognee.modules.users.methods import get_user
|
||||
|
||||
|
||||
async def create_tenant(tenant_name: str, user_id: UUID) -> UUID:
|
||||
async def create_tenant(
|
||||
tenant_name: str, user_id: UUID, set_as_active_tenant: Optional[bool] = True
|
||||
) -> UUID:
|
||||
"""
|
||||
Create a new tenant with the given name, for the user with the given id.
|
||||
This user is the owner of the tenant.
|
||||
Args:
|
||||
tenant_name: Name of the new tenant.
|
||||
user_id: Id of the user.
|
||||
set_as_active_tenant: If true, set the newly created tenant as the active tenant for the user.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
|
@ -22,18 +28,26 @@ async def create_tenant(tenant_name: str, user_id: UUID) -> UUID:
|
|||
async with db_engine.get_async_session() as session:
|
||||
try:
|
||||
user = await get_user(user_id)
|
||||
if user.tenant_id:
|
||||
raise EntityAlreadyExistsError(
|
||||
message="User already has a tenant. New tenant cannot be created."
|
||||
)
|
||||
|
||||
tenant = Tenant(name=tenant_name, owner_id=user_id)
|
||||
session.add(tenant)
|
||||
await session.flush()
|
||||
|
||||
user.tenant_id = tenant.id
|
||||
await session.merge(user)
|
||||
await session.commit()
|
||||
if set_as_active_tenant:
|
||||
user.tenant_id = tenant.id
|
||||
await session.merge(user)
|
||||
await session.commit()
|
||||
|
||||
try:
|
||||
# Add association directly to the association table
|
||||
create_user_tenant_statement = insert(UserTenant).values(
|
||||
user_id=user_id, tenant_id=tenant.id
|
||||
)
|
||||
await session.execute(create_user_tenant_statement)
|
||||
await session.commit()
|
||||
except IntegrityError:
|
||||
raise EntityAlreadyExistsError(message="User is already part of tenant.")
|
||||
|
||||
return tenant.id
|
||||
except IntegrityError as e:
|
||||
raise EntityAlreadyExistsError(message="Tenant already exists.") from e
|
||||
|
|
|
|||
62
cognee/modules/users/tenants/methods/select_tenant.py
Normal file
62
cognee/modules/users/tenants/methods/select_tenant.py
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
from uuid import UUID
|
||||
from typing import Union
|
||||
|
||||
import sqlalchemy.exc
|
||||
from sqlalchemy import select
|
||||
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from cognee.modules.users.methods.get_user import get_user
|
||||
from cognee.modules.users.models.UserTenant import UserTenant
|
||||
from cognee.modules.users.models.User import User
|
||||
from cognee.modules.users.permissions.methods import get_tenant
|
||||
from cognee.modules.users.exceptions import UserNotFoundError, TenantNotFoundError
|
||||
|
||||
|
||||
async def select_tenant(user_id: UUID, tenant_id: Union[UUID, None]) -> User:
|
||||
"""
|
||||
Set the users active tenant to provided tenant.
|
||||
|
||||
If None tenant_id is provided set current Tenant to the default single user-tenant
|
||||
Args:
|
||||
user_id: UUID of the user.
|
||||
tenant_id: Id of the tenant.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
"""
|
||||
db_engine = get_relational_engine()
|
||||
async with db_engine.get_async_session() as session:
|
||||
user = await get_user(user_id)
|
||||
if tenant_id is None:
|
||||
# If no tenant_id is provided set current Tenant to the single user-tenant
|
||||
user.tenant_id = None
|
||||
await session.merge(user)
|
||||
await session.commit()
|
||||
return user
|
||||
|
||||
tenant = await get_tenant(tenant_id)
|
||||
|
||||
if not user:
|
||||
raise UserNotFoundError
|
||||
elif not tenant:
|
||||
raise TenantNotFoundError
|
||||
|
||||
# Check if User is part of Tenant
|
||||
result = await session.execute(
|
||||
select(UserTenant)
|
||||
.where(UserTenant.user_id == user.id)
|
||||
.where(UserTenant.tenant_id == tenant_id)
|
||||
)
|
||||
|
||||
try:
|
||||
result = result.scalar_one()
|
||||
except sqlalchemy.exc.NoResultFound as e:
|
||||
raise TenantNotFoundError("User is not part of the tenant.") from e
|
||||
|
||||
if result:
|
||||
# If user is part of tenant update current tenant of user
|
||||
user.tenant_id = tenant_id
|
||||
await session.merge(user)
|
||||
await session.commit()
|
||||
return user
|
||||
|
|
@ -99,7 +99,7 @@ async def ingest_data(
|
|||
|
||||
# data_id is the hash of original file contents + owner id to avoid duplicate data
|
||||
|
||||
data_id = ingestion.identify(classified_data, user)
|
||||
data_id = await ingestion.identify(classified_data, user)
|
||||
original_file_metadata = classified_data.get_metadata()
|
||||
|
||||
# Find metadata from Cognee data storage text file
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@ async def main():
|
|||
classified_data = ingestion.classify(file)
|
||||
|
||||
# data_id is the hash of original file contents + owner id to avoid duplicate data
|
||||
data_id = ingestion.identify(classified_data, await get_default_user())
|
||||
data_id = await ingestion.identify(classified_data, await get_default_user())
|
||||
|
||||
await cognee.add(file_path)
|
||||
|
||||
|
|
|
|||
165
cognee/tests/test_multi_tenancy.py
Normal file
165
cognee/tests/test_multi_tenancy.py
Normal file
|
|
@ -0,0 +1,165 @@
|
|||
import cognee
|
||||
import pytest
|
||||
|
||||
from cognee.modules.users.exceptions import PermissionDeniedError
|
||||
from cognee.modules.users.tenants.methods import select_tenant
|
||||
from cognee.modules.users.methods import get_user
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.search.types import SearchType
|
||||
from cognee.modules.users.methods import create_user
|
||||
from cognee.modules.users.permissions.methods import authorized_give_permission_on_datasets
|
||||
from cognee.modules.users.roles.methods import add_user_to_role
|
||||
from cognee.modules.users.roles.methods import create_role
|
||||
from cognee.modules.users.tenants.methods import create_tenant
|
||||
from cognee.modules.users.tenants.methods import add_user_to_tenant
|
||||
from cognee.modules.engine.operations.setup import setup
|
||||
from cognee.shared.logging_utils import setup_logging, CRITICAL
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
async def main():
|
||||
# Create a clean slate for cognee -- reset data and system state
|
||||
print("Resetting cognee data...")
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
print("Data reset complete.\n")
|
||||
|
||||
# Set up the necessary databases and tables for user management.
|
||||
await setup()
|
||||
|
||||
# Add document for user_1, add it under dataset name AI
|
||||
text = """A quantum computer is a computer that takes advantage of quantum mechanical phenomena.
|
||||
At small scales, physical matter exhibits properties of both particles and waves, and quantum computing leverages
|
||||
this behavior, specifically quantum superposition and entanglement, using specialized hardware that supports the
|
||||
preparation and manipulation of quantum state"""
|
||||
|
||||
print("Creating user_1: user_1@example.com")
|
||||
user_1 = await create_user("user_1@example.com", "example")
|
||||
await cognee.add([text], dataset_name="AI", user=user_1)
|
||||
|
||||
print("\nCreating user_2: user_2@example.com")
|
||||
user_2 = await create_user("user_2@example.com", "example")
|
||||
|
||||
# Run cognify for both datasets as the appropriate user/owner
|
||||
print("\nCreating different datasets for user_1 (AI dataset) and user_2 (QUANTUM dataset)")
|
||||
ai_cognify_result = await cognee.cognify(["AI"], user=user_1)
|
||||
|
||||
# Extract dataset_ids from cognify results
|
||||
def extract_dataset_id_from_cognify(cognify_result):
|
||||
"""Extract dataset_id from cognify output dictionary"""
|
||||
for dataset_id, pipeline_result in cognify_result.items():
|
||||
return dataset_id # Return the first dataset_id
|
||||
return None
|
||||
|
||||
# Get dataset IDs from cognify results
|
||||
# Note: When we want to work with datasets from other users (search, add, cognify and etc.) we must supply dataset
|
||||
# information through dataset_id using dataset name only looks for datasets owned by current user
|
||||
ai_dataset_id = extract_dataset_id_from_cognify(ai_cognify_result)
|
||||
|
||||
# We can see here that user_1 can read his own dataset (AI dataset)
|
||||
search_results = await cognee.search(
|
||||
query_type=SearchType.GRAPH_COMPLETION,
|
||||
query_text="What is in the document?",
|
||||
user=user_1,
|
||||
datasets=[ai_dataset_id],
|
||||
)
|
||||
|
||||
# Verify that user_2 cannot access user_1's dataset without permission
|
||||
with pytest.raises(PermissionDeniedError):
|
||||
search_results = await cognee.search(
|
||||
query_type=SearchType.GRAPH_COMPLETION,
|
||||
query_text="What is in the document?",
|
||||
user=user_2,
|
||||
datasets=[ai_dataset_id],
|
||||
)
|
||||
|
||||
# Create new tenant and role, add user_2 to tenant and role
|
||||
tenant_id = await create_tenant("CogneeLab", user_1.id)
|
||||
await select_tenant(user_id=user_1.id, tenant_id=tenant_id)
|
||||
role_id = await create_role(role_name="Researcher", owner_id=user_1.id)
|
||||
await add_user_to_tenant(
|
||||
user_id=user_2.id, tenant_id=tenant_id, owner_id=user_1.id, set_as_active_tenant=True
|
||||
)
|
||||
await add_user_to_role(user_id=user_2.id, role_id=role_id, owner_id=user_1.id)
|
||||
|
||||
# Assert that user_1 cannot give permissions on his dataset to role before switching to the correct tenant
|
||||
# AI dataset was made with default tenant and not CogneeLab tenant
|
||||
with pytest.raises(PermissionDeniedError):
|
||||
await authorized_give_permission_on_datasets(
|
||||
role_id,
|
||||
[ai_dataset_id],
|
||||
"read",
|
||||
user_1.id,
|
||||
)
|
||||
|
||||
# We need to refresh the user object with changes made when switching tenants
|
||||
user_1 = await get_user(user_1.id)
|
||||
await cognee.add([text], dataset_name="AI_COGNEE_LAB", user=user_1)
|
||||
ai_cognee_lab_cognify_result = await cognee.cognify(["AI_COGNEE_LAB"], user=user_1)
|
||||
|
||||
ai_cognee_lab_dataset_id = extract_dataset_id_from_cognify(ai_cognee_lab_cognify_result)
|
||||
|
||||
await authorized_give_permission_on_datasets(
|
||||
role_id,
|
||||
[ai_cognee_lab_dataset_id],
|
||||
"read",
|
||||
user_1.id,
|
||||
)
|
||||
|
||||
search_results = await cognee.search(
|
||||
query_type=SearchType.GRAPH_COMPLETION,
|
||||
query_text="What is in the document?",
|
||||
user=user_2,
|
||||
dataset_ids=[ai_cognee_lab_dataset_id],
|
||||
)
|
||||
for result in search_results:
|
||||
print(f"{result}\n")
|
||||
|
||||
# Let's test changing tenants
|
||||
tenant_id = await create_tenant("CogneeLab2", user_1.id)
|
||||
await select_tenant(user_id=user_1.id, tenant_id=tenant_id)
|
||||
|
||||
user_1 = await get_user(user_1.id)
|
||||
await cognee.add([text], dataset_name="AI_COGNEE_LAB", user=user_1)
|
||||
await cognee.cognify(["AI_COGNEE_LAB"], user=user_1)
|
||||
|
||||
search_results = await cognee.search(
|
||||
query_type=SearchType.GRAPH_COMPLETION,
|
||||
query_text="What is in the document?",
|
||||
user=user_1,
|
||||
)
|
||||
|
||||
# Assert only AI_COGNEE_LAB dataset from CogneeLab2 tenant is visible as the currently selected tenant
|
||||
assert len(search_results) == 1, (
|
||||
f"Search results must only contain one dataset from current tenant: {search_results}"
|
||||
)
|
||||
assert search_results[0]["dataset_name"] == "AI_COGNEE_LAB", (
|
||||
f"Dict must contain dataset name 'AI_COGNEE_LAB': {search_results[0]}"
|
||||
)
|
||||
assert search_results[0]["dataset_tenant_id"] == user_1.tenant_id, (
|
||||
f"Dataset tenant_id must be same as user_1 tenant_id: {search_results[0]}"
|
||||
)
|
||||
|
||||
# Switch back to no tenant (default tenant)
|
||||
await select_tenant(user_id=user_1.id, tenant_id=None)
|
||||
# Refresh user_1 object
|
||||
user_1 = await get_user(user_1.id)
|
||||
search_results = await cognee.search(
|
||||
query_type=SearchType.GRAPH_COMPLETION,
|
||||
query_text="What is in the document?",
|
||||
user=user_1,
|
||||
)
|
||||
assert len(search_results) == 1, (
|
||||
f"Search results must only contain one dataset from default tenant: {search_results}"
|
||||
)
|
||||
assert search_results[0]["dataset_name"] == "AI", (
|
||||
f"Dict must contain dataset name 'AI': {search_results[0]}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
logger = setup_logging(log_level=CRITICAL)
|
||||
asyncio.run(main())
|
||||
|
|
@ -3,6 +3,7 @@ import cognee
|
|||
import pathlib
|
||||
|
||||
from cognee.modules.users.exceptions import PermissionDeniedError
|
||||
from cognee.modules.users.tenants.methods import select_tenant
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.search.types import SearchType
|
||||
from cognee.modules.users.methods import create_user
|
||||
|
|
@ -116,6 +117,7 @@ async def main():
|
|||
print(
|
||||
"\nOperation started as user_2 to give read permission to user_1 for the dataset owned by user_2"
|
||||
)
|
||||
|
||||
await authorized_give_permission_on_datasets(
|
||||
user_1.id,
|
||||
[quantum_dataset_id],
|
||||
|
|
@ -142,6 +144,9 @@ async def main():
|
|||
print("User 2 is creating CogneeLab tenant/organization")
|
||||
tenant_id = await create_tenant("CogneeLab", user_2.id)
|
||||
|
||||
print("User 2 is selecting CogneeLab tenant/organization as active tenant")
|
||||
await select_tenant(user_id=user_2.id, tenant_id=tenant_id)
|
||||
|
||||
print("\nUser 2 is creating Researcher role")
|
||||
role_id = await create_role(role_name="Researcher", owner_id=user_2.id)
|
||||
|
||||
|
|
@ -157,23 +162,59 @@ async def main():
|
|||
)
|
||||
await add_user_to_role(user_id=user_3.id, role_id=role_id, owner_id=user_2.id)
|
||||
|
||||
print("\nOperation as user_3 to select CogneeLab tenant/organization as active tenant")
|
||||
await select_tenant(user_id=user_3.id, tenant_id=tenant_id)
|
||||
|
||||
print(
|
||||
"\nOperation started as user_2 to give read permission to Researcher role for the dataset owned by user_2"
|
||||
"\nOperation started as user_2, with CogneeLab as its active tenant, to give read permission to Researcher role for the dataset QUANTUM owned by user_2"
|
||||
)
|
||||
# Even though the dataset owner is user_2, the dataset doesn't belong to the tenant/organization CogneeLab.
|
||||
# So we can't assign permissions to it when we're acting in the CogneeLab tenant.
|
||||
try:
|
||||
await authorized_give_permission_on_datasets(
|
||||
role_id,
|
||||
[quantum_dataset_id],
|
||||
"read",
|
||||
user_2.id,
|
||||
)
|
||||
except PermissionDeniedError:
|
||||
print(
|
||||
"User 2 could not give permission to the role as the QUANTUM dataset is not part of the CogneeLab tenant"
|
||||
)
|
||||
|
||||
print(
|
||||
"We will now create a new QUANTUM dataset with the QUANTUM_COGNEE_LAB name in the CogneeLab tenant so that permissions can be assigned to the Researcher role inside the tenant/organization"
|
||||
)
|
||||
# We can re-create the QUANTUM dataset in the CogneeLab tenant. The old QUANTUM dataset is still owned by user_2 personally
|
||||
# and can still be accessed by selecting the personal tenant for user 2.
|
||||
from cognee.modules.users.methods import get_user
|
||||
|
||||
# Note: We need to update user_2 from the database to refresh its tenant context changes
|
||||
user_2 = await get_user(user_2.id)
|
||||
await cognee.add([text], dataset_name="QUANTUM_COGNEE_LAB", user=user_2)
|
||||
quantum_cognee_lab_cognify_result = await cognee.cognify(["QUANTUM_COGNEE_LAB"], user=user_2)
|
||||
|
||||
# The recreated Quantum dataset will now have a different dataset_id as it's a new dataset in a different organization
|
||||
quantum_cognee_lab_dataset_id = extract_dataset_id_from_cognify(
|
||||
quantum_cognee_lab_cognify_result
|
||||
)
|
||||
print(
|
||||
"\nOperation started as user_2, with CogneeLab as its active tenant, to give read permission to Researcher role for the dataset QUANTUM owned by the CogneeLab tenant"
|
||||
)
|
||||
await authorized_give_permission_on_datasets(
|
||||
role_id,
|
||||
[quantum_dataset_id],
|
||||
[quantum_cognee_lab_dataset_id],
|
||||
"read",
|
||||
user_2.id,
|
||||
)
|
||||
|
||||
# Now user_3 can read from QUANTUM dataset as part of the Researcher role after proper permissions have been assigned by the QUANTUM dataset owner, user_2.
|
||||
print("\nSearch result as user_3 on the dataset owned by user_2:")
|
||||
print("\nSearch result as user_3 on the QUANTUM dataset owned by the CogneeLab organization:")
|
||||
search_results = await cognee.search(
|
||||
query_type=SearchType.GRAPH_COMPLETION,
|
||||
query_text="What is in the document?",
|
||||
user=user_1,
|
||||
dataset_ids=[quantum_dataset_id],
|
||||
user=user_3,
|
||||
dataset_ids=[quantum_cognee_lab_dataset_id],
|
||||
)
|
||||
for result in search_results:
|
||||
print(f"{result}\n")
|
||||
|
|
|
|||
|
|
@ -167,7 +167,6 @@ exclude = [
|
|||
"/dist",
|
||||
"/.data",
|
||||
"/.github",
|
||||
"/alembic",
|
||||
"/deployment",
|
||||
"/cognee-mcp",
|
||||
"/cognee-frontend",
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue