diff --git a/.github/workflows/e2e_tests.yml b/.github/workflows/e2e_tests.yml index 0596f22d3..715487372 100644 --- a/.github/workflows/e2e_tests.yml +++ b/.github/workflows/e2e_tests.yml @@ -226,7 +226,7 @@ jobs: - name: Dependencies already installed run: echo "Dependencies already installed in setup" - - name: Run parallel databases test + - name: Run permissions test env: ENV: 'dev' LLM_MODEL: ${{ secrets.LLM_MODEL }} @@ -239,6 +239,31 @@ jobs: EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} run: uv run python ./cognee/tests/test_permissions.py + test-multi-tenancy: + name: Test multi tenancy with different situations in Cognee + runs-on: ubuntu-22.04 + steps: + - name: Check out repository + uses: actions/checkout@v4 + + - name: Cognee Setup + uses: ./.github/actions/cognee_setup + with: + python-version: '3.11.x' + + - name: Run multi tenancy test + env: + ENV: 'dev' + LLM_MODEL: ${{ secrets.LLM_MODEL }} + LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} + LLM_API_KEY: ${{ secrets.LLM_API_KEY }} + LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }} + EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }} + EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }} + EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }} + EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} + run: uv run python ./cognee/tests/test_multi_tenancy.py + test-graph-edges: name: Test graph edge ingestion runs-on: ubuntu-22.04 @@ -487,4 +512,4 @@ jobs: AWS_ENDPOINT_URL: https://s3-eu-west-1.amazonaws.com AWS_ACCESS_KEY_ID: ${{ secrets.AWS_S3_DEV_USER_KEY_ID }} AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_S3_DEV_USER_SECRET_KEY }} - run: uv run python ./cognee/tests/test_load.py \ No newline at end of file + run: uv run python ./cognee/tests/test_load.py diff --git a/cognee/modules/data/methods/create_dataset.py b/cognee/modules/data/methods/create_dataset.py index 280c9e105..7e28a8255 100644 --- a/cognee/modules/data/methods/create_dataset.py +++ b/cognee/modules/data/methods/create_dataset.py @@ -16,6 +16,7 @@ async def create_dataset(dataset_name: str, user: User, session: AsyncSession) - .options(joinedload(Dataset.data)) .filter(Dataset.name == dataset_name) .filter(Dataset.owner_id == owner_id) + .filter(Dataset.tenant_id == user.tenant_id) ) ).first() diff --git a/cognee/modules/data/methods/get_dataset_ids.py b/cognee/modules/data/methods/get_dataset_ids.py index d4402ff36..a61e85310 100644 --- a/cognee/modules/data/methods/get_dataset_ids.py +++ b/cognee/modules/data/methods/get_dataset_ids.py @@ -27,7 +27,11 @@ async def get_dataset_ids(datasets: Union[list[str], list[UUID]], user): # Get all user owned dataset objects (If a user wants to write to a dataset he is not the owner of it must be provided through UUID.) user_datasets = await get_datasets(user.id) # Filter out non name mentioned datasets - dataset_ids = [dataset.id for dataset in user_datasets if dataset.name in datasets] + dataset_ids = [dataset for dataset in user_datasets if dataset.name in datasets] + # Filter out non current tenant datasets + dataset_ids = [ + dataset.id for dataset in dataset_ids if dataset.tenant_id == user.tenant_id + ] else: raise DatasetTypeError( f"One or more of the provided dataset types is not handled: f{datasets}" diff --git a/cognee/modules/search/methods/search.py b/cognee/modules/search/methods/search.py index 5e465b239..b4278424b 100644 --- a/cognee/modules/search/methods/search.py +++ b/cognee/modules/search/methods/search.py @@ -172,6 +172,7 @@ async def search( "search_result": [context] if context else None, "dataset_id": datasets[0].id, "dataset_name": datasets[0].name, + "dataset_tenant_id": datasets[0].tenant_id, "graphs": graphs, } ) @@ -181,6 +182,7 @@ async def search( "search_result": [result] if result else None, "dataset_id": datasets[0].id, "dataset_name": datasets[0].name, + "dataset_tenant_id": datasets[0].tenant_id, "graphs": graphs, } ) diff --git a/cognee/tests/test_multi_tenancy.py b/cognee/tests/test_multi_tenancy.py new file mode 100644 index 000000000..7cdcda8d8 --- /dev/null +++ b/cognee/tests/test_multi_tenancy.py @@ -0,0 +1,165 @@ +import cognee +import pytest + +from cognee.modules.users.exceptions import PermissionDeniedError +from cognee.modules.users.tenants.methods import select_tenant +from cognee.modules.users.methods import get_user +from cognee.shared.logging_utils import get_logger +from cognee.modules.search.types import SearchType +from cognee.modules.users.methods import create_user +from cognee.modules.users.permissions.methods import authorized_give_permission_on_datasets +from cognee.modules.users.roles.methods import add_user_to_role +from cognee.modules.users.roles.methods import create_role +from cognee.modules.users.tenants.methods import create_tenant +from cognee.modules.users.tenants.methods import add_user_to_tenant +from cognee.modules.engine.operations.setup import setup +from cognee.shared.logging_utils import setup_logging, CRITICAL + +logger = get_logger() + + +async def main(): + # Create a clean slate for cognee -- reset data and system state + print("Resetting cognee data...") + await cognee.prune.prune_data() + await cognee.prune.prune_system(metadata=True) + print("Data reset complete.\n") + + # Set up the necessary databases and tables for user management. + await setup() + + # Add document for user_1, add it under dataset name AI + text = """A quantum computer is a computer that takes advantage of quantum mechanical phenomena. + At small scales, physical matter exhibits properties of both particles and waves, and quantum computing leverages + this behavior, specifically quantum superposition and entanglement, using specialized hardware that supports the + preparation and manipulation of quantum state""" + + print("Creating user_1: user_1@example.com") + user_1 = await create_user("user_1@example.com", "example") + await cognee.add([text], dataset_name="AI", user=user_1) + + print("\nCreating user_2: user_2@example.com") + user_2 = await create_user("user_2@example.com", "example") + + # Run cognify for both datasets as the appropriate user/owner + print("\nCreating different datasets for user_1 (AI dataset) and user_2 (QUANTUM dataset)") + ai_cognify_result = await cognee.cognify(["AI"], user=user_1) + + # Extract dataset_ids from cognify results + def extract_dataset_id_from_cognify(cognify_result): + """Extract dataset_id from cognify output dictionary""" + for dataset_id, pipeline_result in cognify_result.items(): + return dataset_id # Return the first dataset_id + return None + + # Get dataset IDs from cognify results + # Note: When we want to work with datasets from other users (search, add, cognify and etc.) we must supply dataset + # information through dataset_id using dataset name only looks for datasets owned by current user + ai_dataset_id = extract_dataset_id_from_cognify(ai_cognify_result) + + # We can see here that user_1 can read his own dataset (AI dataset) + search_results = await cognee.search( + query_type=SearchType.GRAPH_COMPLETION, + query_text="What is in the document?", + user=user_1, + datasets=[ai_dataset_id], + ) + + # Verify that user_2 cannot access user_1's dataset without permission + with pytest.raises(PermissionDeniedError): + search_results = await cognee.search( + query_type=SearchType.GRAPH_COMPLETION, + query_text="What is in the document?", + user=user_2, + datasets=[ai_dataset_id], + ) + + # Create new tenant and role, add user_2 to tenant and role + tenant_id = await create_tenant("CogneeLab", user_1.id) + await select_tenant(user_id=user_1.id, tenant_id=tenant_id) + role_id = await create_role(role_name="Researcher", owner_id=user_1.id) + await add_user_to_tenant( + user_id=user_2.id, tenant_id=tenant_id, owner_id=user_1.id, set_as_active_tenant=True + ) + await add_user_to_role(user_id=user_2.id, role_id=role_id, owner_id=user_1.id) + + # Assert that user_1 cannot give permissions on his dataset to role before switching to the correct tenant + # AI dataset was made with default tenant and not CogneeLab tenant + with pytest.raises(PermissionDeniedError): + await authorized_give_permission_on_datasets( + role_id, + [ai_dataset_id], + "read", + user_1.id, + ) + + # We need to refresh the user object with changes made when switching tenants + user_1 = await get_user(user_1.id) + await cognee.add([text], dataset_name="AI_COGNEE_LAB", user=user_1) + ai_cognee_lab_cognify_result = await cognee.cognify(["AI_COGNEE_LAB"], user=user_1) + + ai_cognee_lab_dataset_id = extract_dataset_id_from_cognify(ai_cognee_lab_cognify_result) + + await authorized_give_permission_on_datasets( + role_id, + [ai_cognee_lab_dataset_id], + "read", + user_1.id, + ) + + search_results = await cognee.search( + query_type=SearchType.GRAPH_COMPLETION, + query_text="What is in the document?", + user=user_2, + dataset_ids=[ai_cognee_lab_dataset_id], + ) + for result in search_results: + print(f"{result}\n") + + # Let's test changing tenants + tenant_id = await create_tenant("CogneeLab2", user_1.id) + await select_tenant(user_id=user_1.id, tenant_id=tenant_id) + + user_1 = await get_user(user_1.id) + await cognee.add([text], dataset_name="AI_COGNEE_LAB", user=user_1) + await cognee.cognify(["AI_COGNEE_LAB"], user=user_1) + + search_results = await cognee.search( + query_type=SearchType.GRAPH_COMPLETION, + query_text="What is in the document?", + user=user_1, + ) + + # Assert only AI_COGNEE_LAB dataset from CogneeLab2 tenant is visible as the currently selected tenant + assert len(search_results) == 1, ( + f"Search results must only contain one dataset from current tenant: {search_results}" + ) + assert search_results[0]["dataset_name"] == "AI_COGNEE_LAB", ( + f"Dict must contain dataset name 'AI_COGNEE_LAB': {search_results[0]}" + ) + assert search_results[0]["dataset_tenant_id"] == user_1.tenant_id, ( + f"Dataset tenant_id must be same as user_1 tenant_id: {search_results[0]}" + ) + + # Switch back to no tenant (default tenant) + await select_tenant(user_id=user_1.id, tenant_id=None) + # Refresh user_1 object + user_1 = await get_user(user_1.id) + search_results = await cognee.search( + query_type=SearchType.GRAPH_COMPLETION, + query_text="What is in the document?", + user=user_1, + ) + assert len(search_results) == 1, ( + f"Search results must only contain one dataset from default tenant: {search_results}" + ) + assert search_results[0]["dataset_name"] == "AI", ( + f"Dict must contain dataset name 'AI': {search_results[0]}" + ) + + +if __name__ == "__main__": + import asyncio + + logger = setup_logging(log_level=CRITICAL) + asyncio.run(main())