feat: Add multi-tenancy (#1560)

<!-- .github/pull_request_template.md -->

## Description
Add multi-tenancy to Cognee

## Type of Change
<!-- Please check the relevant option -->
- [ ] Bug fix (non-breaking change that fixes an issue)
- [ ] New feature (non-breaking change that adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to change)
- [ ] Documentation update
- [ ] Code refactoring
- [ ] Performance improvement
- [ ] Other (please specify):

## Screenshots/Videos (if applicable)
<!-- Add screenshots or videos to help explain your changes -->

## Pre-submission Checklist
<!-- Please check all boxes that apply before submitting your PR -->
- [ ] **I have tested my changes thoroughly before submitting this PR**
- [ ] **This PR contains minimal changes necessary to address the
issue/feature**
- [ ] My code follows the project's coding standards and style
guidelines
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have added necessary documentation (if applicable)
- [ ] All new and existing tests pass
- [ ] I have searched existing PRs to ensure this change hasn't been
submitted already
- [ ] I have linked any relevant issues in the description
- [ ] My commits have clear and descriptive messages

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.
This commit is contained in:
Vasilije 2025-11-11 12:55:27 +01:00 committed by GitHub
commit 78b825f338
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
37 changed files with 814 additions and 151 deletions

View file

@ -226,7 +226,7 @@ jobs:
- name: Dependencies already installed
run: echo "Dependencies already installed in setup"
- name: Run parallel databases test
- name: Run permissions test
env:
ENV: 'dev'
LLM_MODEL: ${{ secrets.LLM_MODEL }}
@ -239,6 +239,31 @@ jobs:
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
run: uv run python ./cognee/tests/test_permissions.py
test-multi-tenancy:
name: Test multi tenancy with different situations in Cognee
runs-on: ubuntu-22.04
steps:
- name: Check out repository
uses: actions/checkout@v4
- name: Cognee Setup
uses: ./.github/actions/cognee_setup
with:
python-version: '3.11.x'
- name: Run multi tenancy test
env:
ENV: 'dev'
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
run: uv run python ./cognee/tests/test_multi_tenancy.py
test-graph-edges:
name: Test graph edge ingestion
runs-on: ubuntu-22.04

View file

@ -87,11 +87,6 @@ db_engine = get_relational_engine()
print("Using database:", db_engine.db_uri)
if "sqlite" in db_engine.db_uri:
from cognee.infrastructure.utils.run_sync import run_sync
run_sync(db_engine.create_database())
config.set_section_option(
config.config_ini_section,
"SQLALCHEMY_DATABASE_URI",

View file

@ -10,6 +10,7 @@ from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
@ -26,7 +27,34 @@ def upgrade() -> None:
connection = op.get_bind()
inspector = sa.inspect(connection)
if op.get_context().dialect.name == "postgresql":
syncstatus_enum = postgresql.ENUM(
"STARTED", "IN_PROGRESS", "COMPLETED", "FAILED", "CANCELLED", name="syncstatus"
)
syncstatus_enum.create(op.get_bind(), checkfirst=True)
if "sync_operations" not in inspector.get_table_names():
if op.get_context().dialect.name == "postgresql":
syncstatus = postgresql.ENUM(
"STARTED",
"IN_PROGRESS",
"COMPLETED",
"FAILED",
"CANCELLED",
name="syncstatus",
create_type=False,
)
else:
syncstatus = sa.Enum(
"STARTED",
"IN_PROGRESS",
"COMPLETED",
"FAILED",
"CANCELLED",
name="syncstatus",
create_type=False,
)
# Table doesn't exist, create it normally
op.create_table(
"sync_operations",
@ -34,15 +62,7 @@ def upgrade() -> None:
sa.Column("run_id", sa.Text(), nullable=True),
sa.Column(
"status",
sa.Enum(
"STARTED",
"IN_PROGRESS",
"COMPLETED",
"FAILED",
"CANCELLED",
name="syncstatus",
create_type=False,
),
syncstatus,
nullable=True,
),
sa.Column("progress_percentage", sa.Integer(), nullable=True),

View file

@ -23,11 +23,8 @@ depends_on: Union[str, Sequence[str], None] = "8057ae7329c2"
def upgrade() -> None:
try:
await_only(create_default_user())
except UserAlreadyExists:
pass # It's fine if the default user already exists
pass
def downgrade() -> None:
await_only(delete_user("default_user@example.com"))
pass

View file

@ -18,11 +18,8 @@ depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
db_engine = get_relational_engine()
# we might want to delete this
await_only(db_engine.create_database())
pass
def downgrade() -> None:
db_engine = get_relational_engine()
await_only(db_engine.delete_database())
pass

View file

@ -144,44 +144,58 @@ def _create_data_permission(conn, user_id, data_id, permission_name):
)
def _get_column(inspector, table, name, schema=None):
for col in inspector.get_columns(table, schema=schema):
if col["name"] == name:
return col
return None
def upgrade() -> None:
conn = op.get_bind()
insp = sa.inspect(conn)
# Recreate ACLs table with default permissions set to datasets instead of documents
op.drop_table("acls")
dataset_id_column = _get_column(insp, "acls", "dataset_id")
if not dataset_id_column:
# Recreate ACLs table with default permissions set to datasets instead of documents
op.drop_table("acls")
acls_table = op.create_table(
"acls",
sa.Column("id", UUID, primary_key=True, default=uuid4),
sa.Column(
"created_at", sa.DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
),
sa.Column(
"updated_at", sa.DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc)
),
sa.Column("principal_id", UUID, sa.ForeignKey("principals.id")),
sa.Column("permission_id", UUID, sa.ForeignKey("permissions.id")),
sa.Column("dataset_id", UUID, sa.ForeignKey("datasets.id", ondelete="CASCADE")),
)
acls_table = op.create_table(
"acls",
sa.Column("id", UUID, primary_key=True, default=uuid4),
sa.Column(
"created_at", sa.DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
onupdate=lambda: datetime.now(timezone.utc),
),
sa.Column("principal_id", UUID, sa.ForeignKey("principals.id")),
sa.Column("permission_id", UUID, sa.ForeignKey("permissions.id")),
sa.Column("dataset_id", UUID, sa.ForeignKey("datasets.id", ondelete="CASCADE")),
)
# Note: We can't use any Cognee model info to gather data (as it can change) in database so we must use our own table
# definition or load what is in the database
dataset_table = _define_dataset_table()
datasets = conn.execute(sa.select(dataset_table)).fetchall()
# Note: We can't use any Cognee model info to gather data (as it can change) in database so we must use our own table
# definition or load what is in the database
dataset_table = _define_dataset_table()
datasets = conn.execute(sa.select(dataset_table)).fetchall()
if not datasets:
return
if not datasets:
return
acl_list = []
acl_list = []
for dataset in datasets:
acl_list.append(_create_dataset_permission(conn, dataset.owner_id, dataset.id, "read"))
acl_list.append(_create_dataset_permission(conn, dataset.owner_id, dataset.id, "write"))
acl_list.append(_create_dataset_permission(conn, dataset.owner_id, dataset.id, "share"))
acl_list.append(_create_dataset_permission(conn, dataset.owner_id, dataset.id, "delete"))
for dataset in datasets:
acl_list.append(_create_dataset_permission(conn, dataset.owner_id, dataset.id, "read"))
acl_list.append(_create_dataset_permission(conn, dataset.owner_id, dataset.id, "write"))
acl_list.append(_create_dataset_permission(conn, dataset.owner_id, dataset.id, "share"))
acl_list.append(
_create_dataset_permission(conn, dataset.owner_id, dataset.id, "delete")
)
if acl_list:
op.bulk_insert(acls_table, acl_list)
if acl_list:
op.bulk_insert(acls_table, acl_list)
def downgrade() -> None:

View file

@ -0,0 +1,137 @@
"""Multi Tenant Support
Revision ID: c946955da633
Revises: 211ab850ef3d
Create Date: 2025-11-04 18:11:09.325158
"""
from typing import Sequence, Union
from datetime import datetime, timezone
from uuid import uuid4
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "c946955da633"
down_revision: Union[str, None] = "211ab850ef3d"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def _now():
return datetime.now(timezone.utc)
def _define_user_table() -> sa.Table:
table = sa.Table(
"users",
sa.MetaData(),
sa.Column(
"id",
sa.UUID,
sa.ForeignKey("principals.id", ondelete="CASCADE"),
primary_key=True,
nullable=False,
),
sa.Column("tenant_id", sa.UUID, sa.ForeignKey("tenants.id"), index=True, nullable=True),
)
return table
def _define_dataset_table() -> sa.Table:
# Note: We can't use any Cognee model info to gather data (as it can change) in database so we must use our own table
# definition or load what is in the database
table = sa.Table(
"datasets",
sa.MetaData(),
sa.Column("id", sa.UUID, primary_key=True, default=uuid4),
sa.Column("name", sa.Text),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
default=lambda: datetime.now(timezone.utc),
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
onupdate=lambda: datetime.now(timezone.utc),
),
sa.Column("owner_id", sa.UUID(), sa.ForeignKey("principals.id"), index=True),
sa.Column("tenant_id", sa.UUID(), sa.ForeignKey("tenants.id"), index=True, nullable=True),
)
return table
def _get_column(inspector, table, name, schema=None):
for col in inspector.get_columns(table, schema=schema):
if col["name"] == name:
return col
return None
def upgrade() -> None:
conn = op.get_bind()
insp = sa.inspect(conn)
dataset = _define_dataset_table()
user = _define_user_table()
if "user_tenants" not in insp.get_table_names():
# Define table with all necessary columns including primary key
user_tenants = op.create_table(
"user_tenants",
sa.Column("user_id", sa.UUID, sa.ForeignKey("users.id"), primary_key=True),
sa.Column("tenant_id", sa.UUID, sa.ForeignKey("tenants.id"), primary_key=True),
sa.Column(
"created_at", sa.DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
),
)
# Get all users with their tenant_id
user_data = conn.execute(
sa.select(user.c.id, user.c.tenant_id).where(user.c.tenant_id.isnot(None))
).fetchall()
# Insert into user_tenants table
if user_data:
op.bulk_insert(
user_tenants,
[
{"user_id": user_id, "tenant_id": tenant_id, "created_at": _now()}
for user_id, tenant_id in user_data
],
)
tenant_id_column = _get_column(insp, "datasets", "tenant_id")
if not tenant_id_column:
op.add_column("datasets", sa.Column("tenant_id", sa.UUID(), nullable=True))
# Build subquery, select users.tenant_id for each dataset.owner_id
tenant_id_from_dataset_owner = (
sa.select(user.c.tenant_id).where(user.c.id == dataset.c.owner_id).scalar_subquery()
)
if op.get_context().dialect.name == "sqlite":
# If column doesn't exist create new original_extension column and update from values of extension column
with op.batch_alter_table("datasets") as batch_op:
batch_op.execute(
dataset.update().values(
tenant_id=tenant_id_from_dataset_owner,
)
)
else:
conn = op.get_bind()
conn.execute(dataset.update().values(tenant_id=tenant_id_from_dataset_owner))
op.create_index(op.f("ix_datasets_tenant_id"), "datasets", ["tenant_id"])
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("user_tenants")
op.drop_index(op.f("ix_datasets_tenant_id"), table_name="datasets")
op.drop_column("datasets", "tenant_id")
# ### end Alembic commands ###

View file

@ -1096,6 +1096,10 @@ async def main():
# Skip migrations when in API mode (the API server handles its own database)
if not args.no_migration and not args.api_url:
from cognee.modules.engine.operations.setup import setup
await setup()
# Run Alembic migrations from the main cognee directory where alembic.ini is located
logger.info("Running database migrations...")
migration_result = subprocess.run(

View file

@ -1,15 +1,20 @@
from uuid import UUID
from typing import List
from typing import List, Union
from fastapi import APIRouter, Depends
from fastapi.responses import JSONResponse
from cognee.modules.users.models import User
from cognee.api.DTO import InDTO
from cognee.modules.users.methods import get_authenticated_user
from cognee.shared.utils import send_telemetry
from cognee import __version__ as cognee_version
class SelectTenantDTO(InDTO):
tenant_id: UUID | None = None
def get_permissions_router() -> APIRouter:
permissions_router = APIRouter()
@ -226,4 +231,39 @@ def get_permissions_router() -> APIRouter:
status_code=200, content={"message": "Tenant created.", "tenant_id": str(tenant_id)}
)
@permissions_router.post("/tenants/select")
async def select_tenant(payload: SelectTenantDTO, user: User = Depends(get_authenticated_user)):
"""
Select current tenant.
This endpoint selects a tenant with the specified UUID. Tenants are used
to organize users and resources in multi-tenant environments, providing
isolation and access control between different groups or organizations.
Sending a null/None value as tenant_id selects his default single user tenant
## Request Parameters
- **tenant_id** (Union[UUID, None]): UUID of the tenant to select, If null/None is provided use the default single user tenant
## Response
Returns a success message along with selected tenant id.
"""
send_telemetry(
"Permissions API Endpoint Invoked",
user.id,
additional_properties={
"endpoint": f"POST /v1/permissions/tenants/{str(payload.tenant_id)}",
"tenant_id": str(payload.tenant_id),
},
)
from cognee.modules.users.tenants.methods import select_tenant as select_tenant_method
await select_tenant_method(user_id=user.id, tenant_id=payload.tenant_id)
return JSONResponse(
status_code=200,
content={"message": "Tenant selected.", "tenant_id": str(payload.tenant_id)},
)
return permissions_router

View file

@ -10,6 +10,7 @@ from .get_authorized_dataset import get_authorized_dataset
from .get_authorized_dataset_by_name import get_authorized_dataset_by_name
from .get_data import get_data
from .get_unique_dataset_id import get_unique_dataset_id
from .get_unique_data_id import get_unique_data_id
from .get_authorized_existing_datasets import get_authorized_existing_datasets
from .get_dataset_ids import get_dataset_ids

View file

@ -16,14 +16,16 @@ async def create_dataset(dataset_name: str, user: User, session: AsyncSession) -
.options(joinedload(Dataset.data))
.filter(Dataset.name == dataset_name)
.filter(Dataset.owner_id == owner_id)
.filter(Dataset.tenant_id == user.tenant_id)
)
).first()
if dataset is None:
# Dataset id should be generated based on dataset_name and owner_id/user so multiple users can use the same dataset_name
dataset_id = await get_unique_dataset_id(dataset_name=dataset_name, user=user)
dataset = Dataset(id=dataset_id, name=dataset_name, data=[])
dataset.owner_id = owner_id
dataset = Dataset(
id=dataset_id, name=dataset_name, data=[], owner_id=owner_id, tenant_id=user.tenant_id
)
session.add(dataset)

View file

@ -27,7 +27,11 @@ async def get_dataset_ids(datasets: Union[list[str], list[UUID]], user):
# Get all user owned dataset objects (If a user wants to write to a dataset he is not the owner of it must be provided through UUID.)
user_datasets = await get_datasets(user.id)
# Filter out non name mentioned datasets
dataset_ids = [dataset.id for dataset in user_datasets if dataset.name in datasets]
dataset_ids = [dataset for dataset in user_datasets if dataset.name in datasets]
# Filter out non current tenant datasets
dataset_ids = [
dataset.id for dataset in dataset_ids if dataset.tenant_id == user.tenant_id
]
else:
raise DatasetTypeError(
f"One or more of the provided dataset types is not handled: f{datasets}"

View file

@ -0,0 +1,68 @@
from uuid import uuid5, NAMESPACE_OID, UUID
from sqlalchemy import select
from cognee.modules.data.models.Data import Data
from cognee.infrastructure.databases.relational import get_relational_engine
from cognee.modules.users.models import User
async def get_unique_data_id(data_identifier: str, user: User) -> UUID:
"""
Function returns a unique UUID for data based on data identifier, user id and tenant id.
If data with legacy ID exists, return that ID to maintain compatibility.
Args:
data_identifier: A way to uniquely identify data (e.g. file hash, data name, etc.)
user: User object adding the data
tenant_id: UUID of the tenant for which data is being added
Returns:
UUID: Unique identifier for the data
"""
def _get_deprecated_unique_data_id(data_identifier: str, user: User) -> UUID:
"""
Deprecated function, returns a unique UUID for data based on data identifier and user id.
Needed to support legacy data without tenant information.
Args:
data_identifier: A way to uniquely identify data (e.g. file hash, data name, etc.)
user: User object adding the data
Returns:
UUID: Unique identifier for the data
"""
# return UUID hash of file contents + owner id + tenant_id
return uuid5(NAMESPACE_OID, f"{data_identifier}{str(user.id)}")
def _get_modern_unique_data_id(data_identifier: str, user: User) -> UUID:
"""
Function returns a unique UUID for data based on data identifier, user id and tenant id.
Args:
data_identifier: A way to uniquely identify data (e.g. file hash, data name, etc.)
user: User object adding the data
tenant_id: UUID of the tenant for which data is being added
Returns:
UUID: Unique identifier for the data
"""
# return UUID hash of file contents + owner id + tenant_id
return uuid5(NAMESPACE_OID, f"{data_identifier}{str(user.id)}{str(user.tenant_id)}")
# Get all possible data_id values
data_id = {
"modern_data_id": _get_modern_unique_data_id(data_identifier=data_identifier, user=user),
"legacy_data_id": _get_deprecated_unique_data_id(
data_identifier=data_identifier, user=user
),
}
# Check if data item with legacy_data_id exists, if so use that one, else use modern_data_id
db_engine = get_relational_engine()
async with db_engine.get_async_session() as session:
legacy_data_point = (
await session.execute(select(Data).filter(Data.id == data_id["legacy_data_id"]))
).scalar_one_or_none()
if not legacy_data_point:
return data_id["modern_data_id"]
return data_id["legacy_data_id"]

View file

@ -1,9 +1,71 @@
from uuid import UUID, uuid5, NAMESPACE_OID
from cognee.modules.users.models import User
from typing import Union
from sqlalchemy import select
from cognee.modules.data.models.Dataset import Dataset
from cognee.modules.users.models import User
from cognee.infrastructure.databases.relational import get_relational_engine
async def get_unique_dataset_id(dataset_name: Union[str, UUID], user: User) -> UUID:
if isinstance(dataset_name, UUID):
return dataset_name
return uuid5(NAMESPACE_OID, f"{dataset_name}{str(user.id)}")
"""
Function returns a unique UUID for dataset based on dataset name, user id and tenant id.
If dataset with legacy ID exists, return that ID to maintain compatibility.
Args:
dataset_name: string representing the dataset name
user: User object adding the dataset
tenant_id: UUID of the tenant for which dataset is being added
Returns:
UUID: Unique identifier for the dataset
"""
def _get_legacy_unique_dataset_id(dataset_name: Union[str, UUID], user: User) -> UUID:
"""
Legacy function, returns a unique UUID for dataset based on dataset name and user id.
Needed to support legacy datasets without tenant information.
Args:
dataset_name: string representing the dataset name
user: Current User object adding the dataset
Returns:
UUID: Unique identifier for the dataset
"""
if isinstance(dataset_name, UUID):
return dataset_name
return uuid5(NAMESPACE_OID, f"{dataset_name}{str(user.id)}")
def _get_modern_unique_dataset_id(dataset_name: Union[str, UUID], user: User) -> UUID:
"""
Returns a unique UUID for dataset based on dataset name, user id and tenant_id.
Args:
dataset_name: string representing the dataset name
user: Current User object adding the dataset
tenant_id: UUID of the tenant for which dataset is being added
Returns:
UUID: Unique identifier for the dataset
"""
if isinstance(dataset_name, UUID):
return dataset_name
return uuid5(NAMESPACE_OID, f"{dataset_name}{str(user.id)}{str(user.tenant_id)}")
# Get all possible dataset_id values
dataset_id = {
"modern_dataset_id": _get_modern_unique_dataset_id(dataset_name=dataset_name, user=user),
"legacy_dataset_id": _get_legacy_unique_dataset_id(dataset_name=dataset_name, user=user),
}
# Check if dataset with legacy_dataset_id exists, if so use that one, else use modern_dataset_id
db_engine = get_relational_engine()
async with db_engine.get_async_session() as session:
legacy_dataset = (
await session.execute(
select(Dataset).filter(Dataset.id == dataset_id["legacy_dataset_id"])
)
).scalar_one_or_none()
if not legacy_dataset:
return dataset_id["modern_dataset_id"]
return dataset_id["legacy_dataset_id"]

View file

@ -18,6 +18,7 @@ class Dataset(Base):
updated_at = Column(DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc))
owner_id = Column(UUID, index=True)
tenant_id = Column(UUID, index=True, nullable=True)
acls = relationship("ACL", back_populates="dataset", cascade="all, delete-orphan")
@ -36,5 +37,6 @@ class Dataset(Base):
"createdAt": self.created_at.isoformat(),
"updatedAt": self.updated_at.isoformat() if self.updated_at else None,
"ownerId": str(self.owner_id),
"tenantId": str(self.tenant_id),
"data": [data.to_json() for data in self.data],
}

View file

@ -1,11 +1,11 @@
from uuid import uuid5, NAMESPACE_OID
from uuid import UUID
from .data_types import IngestionData
from cognee.modules.users.models import User
from cognee.modules.data.methods import get_unique_data_id
def identify(data: IngestionData, user: User) -> str:
async def identify(data: IngestionData, user: User) -> UUID:
data_content_hash: str = data.get_identifier()
# return UUID hash of file contents + owner id
return uuid5(NAMESPACE_OID, f"{data_content_hash}{user.id}")
return await get_unique_data_id(data_identifier=data_content_hash, user=user)

View file

@ -69,7 +69,7 @@ async def run_tasks_data_item_incremental(
async with open_data_file(file_path) as file:
classified_data = ingestion.classify(file)
# data_id is the hash of file contents + owner id to avoid duplicate data
data_id = ingestion.identify(classified_data, user)
data_id = await ingestion.identify(classified_data, user)
else:
# If data was already processed by Cognee get data id
data_id = data_item.id

View file

@ -172,6 +172,7 @@ async def search(
"search_result": [context] if context else None,
"dataset_id": datasets[0].id,
"dataset_name": datasets[0].name,
"dataset_tenant_id": datasets[0].tenant_id,
"graphs": graphs,
}
)
@ -181,6 +182,7 @@ async def search(
"search_result": [result] if result else None,
"dataset_id": datasets[0].id,
"dataset_name": datasets[0].name,
"dataset_tenant_id": datasets[0].tenant_id,
"graphs": graphs,
}
)

View file

@ -18,7 +18,6 @@ from typing import Optional
async def create_user(
email: str,
password: str,
tenant_id: Optional[str] = None,
is_superuser: bool = False,
is_active: bool = True,
is_verified: bool = False,
@ -30,37 +29,23 @@ async def create_user(
async with relational_engine.get_async_session() as session:
async with get_user_db_context(session) as user_db:
async with get_user_manager_context(user_db) as user_manager:
if tenant_id:
# Check if the tenant already exists
result = await session.execute(select(Tenant).where(Tenant.id == tenant_id))
tenant = result.scalars().first()
if not tenant:
raise TenantNotFoundError
user = await user_manager.create(
UserCreate(
email=email,
password=password,
tenant_id=tenant.id,
is_superuser=is_superuser,
is_active=is_active,
is_verified=is_verified,
)
)
else:
user = await user_manager.create(
UserCreate(
email=email,
password=password,
is_superuser=is_superuser,
is_active=is_active,
is_verified=is_verified,
)
user = await user_manager.create(
UserCreate(
email=email,
password=password,
is_superuser=is_superuser,
is_active=is_active,
is_verified=is_verified,
)
)
if auto_login:
await session.refresh(user)
# Update tenants and roles information for User object
_ = await user.awaitable_attrs.tenants
_ = await user.awaitable_attrs.roles
return user
except UserAlreadyExists as error:
print(f"User {email} already exists")

View file

@ -18,7 +18,9 @@ async def get_default_user() -> User:
try:
async with db_engine.get_async_session() as session:
query = (
select(User).options(selectinload(User.roles)).where(User.email == default_email)
select(User)
.options(selectinload(User.roles), selectinload(User.tenants))
.where(User.email == default_email)
)
result = await session.execute(query)

View file

@ -14,7 +14,7 @@ async def get_user(user_id: UUID):
user = (
await session.execute(
select(User)
.options(selectinload(User.roles), selectinload(User.tenant))
.options(selectinload(User.roles), selectinload(User.tenants))
.where(User.id == user_id)
)
).scalar()

View file

@ -13,7 +13,7 @@ async def get_user_by_email(user_email: str):
user = (
await session.execute(
select(User)
.options(joinedload(User.roles), joinedload(User.tenant))
.options(joinedload(User.roles), joinedload(User.tenants))
.where(User.email == user_email)
)
).scalar()

View file

@ -1,7 +1,7 @@
from sqlalchemy.orm import relationship
from sqlalchemy.orm import relationship, Mapped
from sqlalchemy import Column, String, ForeignKey, UUID
from .Principal import Principal
from .User import User
from .UserTenant import UserTenant
from .Role import Role
@ -13,14 +13,13 @@ class Tenant(Principal):
owner_id = Column(UUID, index=True)
# One-to-Many relationship with User; specify the join via User.tenant_id
users = relationship(
users: Mapped[list["User"]] = relationship( # noqa: F821
"User",
back_populates="tenant",
foreign_keys=lambda: [User.tenant_id],
secondary=UserTenant.__tablename__,
back_populates="tenants",
)
# One-to-Many relationship with Role (if needed; similar fix)
# One-to-Many relationship with Role
roles = relationship(
"Role",
back_populates="tenant",

View file

@ -6,8 +6,10 @@ from sqlalchemy import ForeignKey, Column, UUID
from sqlalchemy.orm import relationship, Mapped
from .Principal import Principal
from .UserTenant import UserTenant
from .UserRole import UserRole
from .Role import Role
from .Tenant import Tenant
class User(SQLAlchemyBaseUserTableUUID, Principal):
@ -15,7 +17,7 @@ class User(SQLAlchemyBaseUserTableUUID, Principal):
id = Column(UUID, ForeignKey("principals.id", ondelete="CASCADE"), primary_key=True)
# Foreign key to Tenant (Many-to-One relationship)
# Foreign key to current Tenant (Many-to-One relationship)
tenant_id = Column(UUID, ForeignKey("tenants.id"))
# Many-to-Many Relationship with Roles
@ -25,11 +27,11 @@ class User(SQLAlchemyBaseUserTableUUID, Principal):
back_populates="users",
)
# Relationship to Tenant
tenant = relationship(
# Many-to-Many Relationship with Tenants user is a part of
tenants: Mapped[list["Tenant"]] = relationship(
"Tenant",
secondary=UserTenant.__tablename__,
back_populates="users",
foreign_keys=[tenant_id],
)
# ACL Relationship (One-to-Many)
@ -46,7 +48,6 @@ class UserRead(schemas.BaseUser[uuid_UUID]):
class UserCreate(schemas.BaseUserCreate):
tenant_id: Optional[uuid_UUID] = None
is_verified: bool = True

View file

@ -0,0 +1,12 @@
from datetime import datetime, timezone
from sqlalchemy import Column, ForeignKey, DateTime, UUID
from cognee.infrastructure.databases.relational import Base
class UserTenant(Base):
__tablename__ = "user_tenants"
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
user_id = Column(UUID, ForeignKey("users.id"), primary_key=True)
tenant_id = Column(UUID, ForeignKey("tenants.id"), primary_key=True)

View file

@ -1,6 +1,7 @@
from .User import User
from .Role import Role
from .UserRole import UserRole
from .UserTenant import UserTenant
from .DatasetDatabase import DatasetDatabase
from .RoleDefaultPermissions import RoleDefaultPermissions
from .UserDefaultPermissions import UserDefaultPermissions

View file

@ -1,11 +1,8 @@
from types import SimpleNamespace
from cognee.shared.logging_utils import get_logger
from ...models.User import User
from cognee.modules.data.models.Dataset import Dataset
from cognee.modules.users.permissions.methods import get_principal_datasets
from cognee.modules.users.permissions.methods import get_role, get_tenant
logger = get_logger()
@ -25,17 +22,14 @@ async def get_all_user_permission_datasets(user: User, permission_type: str) ->
# Get all datasets User has explicit access to
datasets.extend(await get_principal_datasets(user, permission_type))
if user.tenant_id:
# Get all datasets all tenants have access to
tenant = await get_tenant(user.tenant_id)
# Get all tenants user is a part of
tenants = await user.awaitable_attrs.tenants
for tenant in tenants:
# Get all datasets all tenant members have access to
datasets.extend(await get_principal_datasets(tenant, permission_type))
# Get all datasets Users roles have access to
if isinstance(user, SimpleNamespace):
# If simple namespace use roles defined in user
roles = user.roles
else:
roles = await user.awaitable_attrs.roles
# Get all datasets accessible by roles user is a part of
roles = await user.awaitable_attrs.roles
for role in roles:
datasets.extend(await get_principal_datasets(role, permission_type))
@ -45,4 +39,10 @@ async def get_all_user_permission_datasets(user: User, permission_type: str) ->
# If the dataset id key already exists, leave the dictionary unchanged.
unique.setdefault(dataset.id, dataset)
return list(unique.values())
# Filter out dataset that aren't part of the selected user's tenant
filtered_datasets = []
for dataset in list(unique.values()):
if dataset.tenant_id == user.tenant_id:
filtered_datasets.append(dataset)
return filtered_datasets

View file

@ -42,11 +42,13 @@ async def add_user_to_role(user_id: UUID, role_id: UUID, owner_id: UUID):
.first()
)
user_tenants = await user.awaitable_attrs.tenants
if not user:
raise UserNotFoundError
elif not role:
raise RoleNotFoundError
elif user.tenant_id != role.tenant_id:
elif role.tenant_id not in [tenant.id for tenant in user_tenants]:
raise TenantNotFoundError(
message="User tenant does not match role tenant. User cannot be added to role."
)

View file

@ -1,2 +1,3 @@
from .create_tenant import create_tenant
from .add_user_to_tenant import add_user_to_tenant
from .select_tenant import select_tenant

View file

@ -1,8 +1,11 @@
from typing import Optional
from uuid import UUID
from sqlalchemy.exc import IntegrityError
from sqlalchemy import insert
from cognee.infrastructure.databases.exceptions import EntityAlreadyExistsError
from cognee.infrastructure.databases.relational import get_relational_engine
from cognee.modules.users.models.UserTenant import UserTenant
from cognee.modules.users.methods import get_user
from cognee.modules.users.permissions.methods import get_tenant
from cognee.modules.users.exceptions import (
@ -12,14 +15,19 @@ from cognee.modules.users.exceptions import (
)
async def add_user_to_tenant(user_id: UUID, tenant_id: UUID, owner_id: UUID):
async def add_user_to_tenant(
user_id: UUID, tenant_id: UUID, owner_id: UUID, set_as_active_tenant: Optional[bool] = False
):
"""
Add a user with the given id to the tenant with the given id.
This can only be successful if the request owner with the given id is the tenant owner.
If set_as_active_tenant is true it will automatically set the users active tenant to provided tenant.
Args:
user_id: Id of the user.
tenant_id: Id of the tenant.
owner_id: Id of the request owner.
set_as_active_tenant: If set_as_active_tenant is true it will automatically set the users active tenant to provided tenant.
Returns:
None
@ -40,17 +48,18 @@ async def add_user_to_tenant(user_id: UUID, tenant_id: UUID, owner_id: UUID):
message="Only tenant owner can add other users to organization."
)
try:
if user.tenant_id is None:
user.tenant_id = tenant_id
elif user.tenant_id == tenant_id:
return
else:
raise IntegrityError
if set_as_active_tenant:
user.tenant_id = tenant_id
await session.merge(user)
await session.commit()
except IntegrityError:
raise EntityAlreadyExistsError(
message="User is already part of a tenant. Only one tenant can be assigned to user."
try:
# Add association directly to the association table
create_user_tenant_statement = insert(UserTenant).values(
user_id=user_id, tenant_id=tenant_id
)
await session.execute(create_user_tenant_statement)
await session.commit()
except IntegrityError:
raise EntityAlreadyExistsError(message="User is already part of group.")

View file

@ -1,19 +1,25 @@
from uuid import UUID
from sqlalchemy import insert
from sqlalchemy.exc import IntegrityError
from typing import Optional
from cognee.modules.users.models.UserTenant import UserTenant
from cognee.infrastructure.databases.exceptions import EntityAlreadyExistsError
from cognee.infrastructure.databases.relational import get_relational_engine
from cognee.modules.users.models import Tenant
from cognee.modules.users.methods import get_user
async def create_tenant(tenant_name: str, user_id: UUID) -> UUID:
async def create_tenant(
tenant_name: str, user_id: UUID, set_as_active_tenant: Optional[bool] = True
) -> UUID:
"""
Create a new tenant with the given name, for the user with the given id.
This user is the owner of the tenant.
Args:
tenant_name: Name of the new tenant.
user_id: Id of the user.
set_as_active_tenant: If true, set the newly created tenant as the active tenant for the user.
Returns:
None
@ -22,18 +28,26 @@ async def create_tenant(tenant_name: str, user_id: UUID) -> UUID:
async with db_engine.get_async_session() as session:
try:
user = await get_user(user_id)
if user.tenant_id:
raise EntityAlreadyExistsError(
message="User already has a tenant. New tenant cannot be created."
)
tenant = Tenant(name=tenant_name, owner_id=user_id)
session.add(tenant)
await session.flush()
user.tenant_id = tenant.id
await session.merge(user)
await session.commit()
if set_as_active_tenant:
user.tenant_id = tenant.id
await session.merge(user)
await session.commit()
try:
# Add association directly to the association table
create_user_tenant_statement = insert(UserTenant).values(
user_id=user_id, tenant_id=tenant.id
)
await session.execute(create_user_tenant_statement)
await session.commit()
except IntegrityError:
raise EntityAlreadyExistsError(message="User is already part of tenant.")
return tenant.id
except IntegrityError as e:
raise EntityAlreadyExistsError(message="Tenant already exists.") from e

View file

@ -0,0 +1,62 @@
from uuid import UUID
from typing import Union
import sqlalchemy.exc
from sqlalchemy import select
from cognee.infrastructure.databases.relational import get_relational_engine
from cognee.modules.users.methods.get_user import get_user
from cognee.modules.users.models.UserTenant import UserTenant
from cognee.modules.users.models.User import User
from cognee.modules.users.permissions.methods import get_tenant
from cognee.modules.users.exceptions import UserNotFoundError, TenantNotFoundError
async def select_tenant(user_id: UUID, tenant_id: Union[UUID, None]) -> User:
"""
Set the users active tenant to provided tenant.
If None tenant_id is provided set current Tenant to the default single user-tenant
Args:
user_id: UUID of the user.
tenant_id: Id of the tenant.
Returns:
None
"""
db_engine = get_relational_engine()
async with db_engine.get_async_session() as session:
user = await get_user(user_id)
if tenant_id is None:
# If no tenant_id is provided set current Tenant to the single user-tenant
user.tenant_id = None
await session.merge(user)
await session.commit()
return user
tenant = await get_tenant(tenant_id)
if not user:
raise UserNotFoundError
elif not tenant:
raise TenantNotFoundError
# Check if User is part of Tenant
result = await session.execute(
select(UserTenant)
.where(UserTenant.user_id == user.id)
.where(UserTenant.tenant_id == tenant_id)
)
try:
result = result.scalar_one()
except sqlalchemy.exc.NoResultFound as e:
raise TenantNotFoundError("User is not part of the tenant.") from e
if result:
# If user is part of tenant update current tenant of user
user.tenant_id = tenant_id
await session.merge(user)
await session.commit()
return user

View file

@ -99,7 +99,7 @@ async def ingest_data(
# data_id is the hash of original file contents + owner id to avoid duplicate data
data_id = ingestion.identify(classified_data, user)
data_id = await ingestion.identify(classified_data, user)
original_file_metadata = classified_data.get_metadata()
# Find metadata from Cognee data storage text file

View file

@ -55,7 +55,7 @@ async def main():
classified_data = ingestion.classify(file)
# data_id is the hash of original file contents + owner id to avoid duplicate data
data_id = ingestion.identify(classified_data, await get_default_user())
data_id = await ingestion.identify(classified_data, await get_default_user())
await cognee.add(file_path)

View file

@ -0,0 +1,165 @@
import cognee
import pytest
from cognee.modules.users.exceptions import PermissionDeniedError
from cognee.modules.users.tenants.methods import select_tenant
from cognee.modules.users.methods import get_user
from cognee.shared.logging_utils import get_logger
from cognee.modules.search.types import SearchType
from cognee.modules.users.methods import create_user
from cognee.modules.users.permissions.methods import authorized_give_permission_on_datasets
from cognee.modules.users.roles.methods import add_user_to_role
from cognee.modules.users.roles.methods import create_role
from cognee.modules.users.tenants.methods import create_tenant
from cognee.modules.users.tenants.methods import add_user_to_tenant
from cognee.modules.engine.operations.setup import setup
from cognee.shared.logging_utils import setup_logging, CRITICAL
logger = get_logger()
async def main():
# Create a clean slate for cognee -- reset data and system state
print("Resetting cognee data...")
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
print("Data reset complete.\n")
# Set up the necessary databases and tables for user management.
await setup()
# Add document for user_1, add it under dataset name AI
text = """A quantum computer is a computer that takes advantage of quantum mechanical phenomena.
At small scales, physical matter exhibits properties of both particles and waves, and quantum computing leverages
this behavior, specifically quantum superposition and entanglement, using specialized hardware that supports the
preparation and manipulation of quantum state"""
print("Creating user_1: user_1@example.com")
user_1 = await create_user("user_1@example.com", "example")
await cognee.add([text], dataset_name="AI", user=user_1)
print("\nCreating user_2: user_2@example.com")
user_2 = await create_user("user_2@example.com", "example")
# Run cognify for both datasets as the appropriate user/owner
print("\nCreating different datasets for user_1 (AI dataset) and user_2 (QUANTUM dataset)")
ai_cognify_result = await cognee.cognify(["AI"], user=user_1)
# Extract dataset_ids from cognify results
def extract_dataset_id_from_cognify(cognify_result):
"""Extract dataset_id from cognify output dictionary"""
for dataset_id, pipeline_result in cognify_result.items():
return dataset_id # Return the first dataset_id
return None
# Get dataset IDs from cognify results
# Note: When we want to work with datasets from other users (search, add, cognify and etc.) we must supply dataset
# information through dataset_id using dataset name only looks for datasets owned by current user
ai_dataset_id = extract_dataset_id_from_cognify(ai_cognify_result)
# We can see here that user_1 can read his own dataset (AI dataset)
search_results = await cognee.search(
query_type=SearchType.GRAPH_COMPLETION,
query_text="What is in the document?",
user=user_1,
datasets=[ai_dataset_id],
)
# Verify that user_2 cannot access user_1's dataset without permission
with pytest.raises(PermissionDeniedError):
search_results = await cognee.search(
query_type=SearchType.GRAPH_COMPLETION,
query_text="What is in the document?",
user=user_2,
datasets=[ai_dataset_id],
)
# Create new tenant and role, add user_2 to tenant and role
tenant_id = await create_tenant("CogneeLab", user_1.id)
await select_tenant(user_id=user_1.id, tenant_id=tenant_id)
role_id = await create_role(role_name="Researcher", owner_id=user_1.id)
await add_user_to_tenant(
user_id=user_2.id, tenant_id=tenant_id, owner_id=user_1.id, set_as_active_tenant=True
)
await add_user_to_role(user_id=user_2.id, role_id=role_id, owner_id=user_1.id)
# Assert that user_1 cannot give permissions on his dataset to role before switching to the correct tenant
# AI dataset was made with default tenant and not CogneeLab tenant
with pytest.raises(PermissionDeniedError):
await authorized_give_permission_on_datasets(
role_id,
[ai_dataset_id],
"read",
user_1.id,
)
# We need to refresh the user object with changes made when switching tenants
user_1 = await get_user(user_1.id)
await cognee.add([text], dataset_name="AI_COGNEE_LAB", user=user_1)
ai_cognee_lab_cognify_result = await cognee.cognify(["AI_COGNEE_LAB"], user=user_1)
ai_cognee_lab_dataset_id = extract_dataset_id_from_cognify(ai_cognee_lab_cognify_result)
await authorized_give_permission_on_datasets(
role_id,
[ai_cognee_lab_dataset_id],
"read",
user_1.id,
)
search_results = await cognee.search(
query_type=SearchType.GRAPH_COMPLETION,
query_text="What is in the document?",
user=user_2,
dataset_ids=[ai_cognee_lab_dataset_id],
)
for result in search_results:
print(f"{result}\n")
# Let's test changing tenants
tenant_id = await create_tenant("CogneeLab2", user_1.id)
await select_tenant(user_id=user_1.id, tenant_id=tenant_id)
user_1 = await get_user(user_1.id)
await cognee.add([text], dataset_name="AI_COGNEE_LAB", user=user_1)
await cognee.cognify(["AI_COGNEE_LAB"], user=user_1)
search_results = await cognee.search(
query_type=SearchType.GRAPH_COMPLETION,
query_text="What is in the document?",
user=user_1,
)
# Assert only AI_COGNEE_LAB dataset from CogneeLab2 tenant is visible as the currently selected tenant
assert len(search_results) == 1, (
f"Search results must only contain one dataset from current tenant: {search_results}"
)
assert search_results[0]["dataset_name"] == "AI_COGNEE_LAB", (
f"Dict must contain dataset name 'AI_COGNEE_LAB': {search_results[0]}"
)
assert search_results[0]["dataset_tenant_id"] == user_1.tenant_id, (
f"Dataset tenant_id must be same as user_1 tenant_id: {search_results[0]}"
)
# Switch back to no tenant (default tenant)
await select_tenant(user_id=user_1.id, tenant_id=None)
# Refresh user_1 object
user_1 = await get_user(user_1.id)
search_results = await cognee.search(
query_type=SearchType.GRAPH_COMPLETION,
query_text="What is in the document?",
user=user_1,
)
assert len(search_results) == 1, (
f"Search results must only contain one dataset from default tenant: {search_results}"
)
assert search_results[0]["dataset_name"] == "AI", (
f"Dict must contain dataset name 'AI': {search_results[0]}"
)
if __name__ == "__main__":
import asyncio
logger = setup_logging(log_level=CRITICAL)
asyncio.run(main())

View file

@ -3,6 +3,7 @@ import cognee
import pathlib
from cognee.modules.users.exceptions import PermissionDeniedError
from cognee.modules.users.tenants.methods import select_tenant
from cognee.shared.logging_utils import get_logger
from cognee.modules.search.types import SearchType
from cognee.modules.users.methods import create_user
@ -116,6 +117,7 @@ async def main():
print(
"\nOperation started as user_2 to give read permission to user_1 for the dataset owned by user_2"
)
await authorized_give_permission_on_datasets(
user_1.id,
[quantum_dataset_id],
@ -142,6 +144,9 @@ async def main():
print("User 2 is creating CogneeLab tenant/organization")
tenant_id = await create_tenant("CogneeLab", user_2.id)
print("User 2 is selecting CogneeLab tenant/organization as active tenant")
await select_tenant(user_id=user_2.id, tenant_id=tenant_id)
print("\nUser 2 is creating Researcher role")
role_id = await create_role(role_name="Researcher", owner_id=user_2.id)
@ -157,23 +162,59 @@ async def main():
)
await add_user_to_role(user_id=user_3.id, role_id=role_id, owner_id=user_2.id)
print("\nOperation as user_3 to select CogneeLab tenant/organization as active tenant")
await select_tenant(user_id=user_3.id, tenant_id=tenant_id)
print(
"\nOperation started as user_2 to give read permission to Researcher role for the dataset owned by user_2"
"\nOperation started as user_2, with CogneeLab as its active tenant, to give read permission to Researcher role for the dataset QUANTUM owned by user_2"
)
# Even though the dataset owner is user_2, the dataset doesn't belong to the tenant/organization CogneeLab.
# So we can't assign permissions to it when we're acting in the CogneeLab tenant.
try:
await authorized_give_permission_on_datasets(
role_id,
[quantum_dataset_id],
"read",
user_2.id,
)
except PermissionDeniedError:
print(
"User 2 could not give permission to the role as the QUANTUM dataset is not part of the CogneeLab tenant"
)
print(
"We will now create a new QUANTUM dataset with the QUANTUM_COGNEE_LAB name in the CogneeLab tenant so that permissions can be assigned to the Researcher role inside the tenant/organization"
)
# We can re-create the QUANTUM dataset in the CogneeLab tenant. The old QUANTUM dataset is still owned by user_2 personally
# and can still be accessed by selecting the personal tenant for user 2.
from cognee.modules.users.methods import get_user
# Note: We need to update user_2 from the database to refresh its tenant context changes
user_2 = await get_user(user_2.id)
await cognee.add([text], dataset_name="QUANTUM_COGNEE_LAB", user=user_2)
quantum_cognee_lab_cognify_result = await cognee.cognify(["QUANTUM_COGNEE_LAB"], user=user_2)
# The recreated Quantum dataset will now have a different dataset_id as it's a new dataset in a different organization
quantum_cognee_lab_dataset_id = extract_dataset_id_from_cognify(
quantum_cognee_lab_cognify_result
)
print(
"\nOperation started as user_2, with CogneeLab as its active tenant, to give read permission to Researcher role for the dataset QUANTUM owned by the CogneeLab tenant"
)
await authorized_give_permission_on_datasets(
role_id,
[quantum_dataset_id],
[quantum_cognee_lab_dataset_id],
"read",
user_2.id,
)
# Now user_3 can read from QUANTUM dataset as part of the Researcher role after proper permissions have been assigned by the QUANTUM dataset owner, user_2.
print("\nSearch result as user_3 on the dataset owned by user_2:")
print("\nSearch result as user_3 on the QUANTUM dataset owned by the CogneeLab organization:")
search_results = await cognee.search(
query_type=SearchType.GRAPH_COMPLETION,
query_text="What is in the document?",
user=user_1,
dataset_ids=[quantum_dataset_id],
user=user_3,
dataset_ids=[quantum_cognee_lab_dataset_id],
)
for result in search_results:
print(f"{result}\n")

View file

@ -167,7 +167,6 @@ exclude = [
"/dist",
"/.data",
"/.github",
"/alembic",
"/deployment",
"/cognee-mcp",
"/cognee-frontend",