diff --git a/cognee/modules/pipelines/layers/authorized_user_datasets.py b/cognee/modules/pipelines/layers/authorized_user_datasets.py index f0ba013ec..2b4117208 100644 --- a/cognee/modules/pipelines/layers/authorized_user_datasets.py +++ b/cognee/modules/pipelines/layers/authorized_user_datasets.py @@ -1,8 +1,9 @@ from uuid import UUID -from typing import Union +from typing import Union, Tuple, List from cognee.modules.users.methods import get_default_user from cognee.modules.users.models import User +from cognee.modules.data.models import Dataset from cognee.modules.data.exceptions import DatasetNotFoundError from cognee.modules.data.methods import ( get_authorized_existing_datasets, @@ -11,7 +12,21 @@ from cognee.modules.data.methods import ( ) -async def authorized_user_datasets(user: User, datasets: Union[str, list[str], list[UUID]]): +async def authorized_user_datasets( + datasets: Union[str, list[str], list[UUID]], user: User = None +) -> Tuple[User, List[Dataset]]: + """ + Function handles creation and dataset authorization if datasets already exist for Cognee. + Verifies that provided user has necessary permission for provided Dataset. + If Dataset does not exist creates the Dataset and gives permission for the user creating the dataset. + + Args: + user: Cognee User request is being processed for, if None default user will be used. + datasets: Dataset names or Dataset UUID (in case Datasets already exist) + + Returns: + + """ # If no user is provided use default user if user is None: user = await get_default_user() diff --git a/cognee/modules/pipelines/operations/pipeline.py b/cognee/modules/pipelines/operations/pipeline.py index 569df8501..7a520f4a4 100644 --- a/cognee/modules/pipelines/operations/pipeline.py +++ b/cognee/modules/pipelines/operations/pipeline.py @@ -67,7 +67,7 @@ async def cognee_pipeline( await test_embedding_connection() cognee_pipeline.first_run = False # Update flag after first run - user, authorized_datasets = await authorized_user_datasets(user, datasets) + user, authorized_datasets = await authorized_user_datasets(datasets, user) for dataset in authorized_datasets: async for run_info in run_pipeline(