diff --git a/cognee/api/v1/add/add.py b/cognee/api/v1/add/add.py index 4f51729a3..798fbed97 100644 --- a/cognee/api/v1/add/add.py +++ b/cognee/api/v1/add/add.py @@ -1,19 +1,26 @@ from uuid import UUID +from fastapi import UploadFile from typing import Union, BinaryIO, List, Optional from cognee.modules.pipelines import Task from cognee.modules.users.models import User -from cognee.modules.pipelines import cognee_pipeline +from cognee.modules.users.methods import get_default_user +from cognee.modules.engine.operations.setup import setup +from cognee.modules.data.exceptions.exceptions import DatasetNotFoundError +from cognee.modules.data.methods import ( + get_authorized_dataset, + get_authorized_dataset_by_name, + create_authorized_dataset, +) +from cognee.modules.pipelines.operations.run_add_pipeline import run_add_pipeline from cognee.tasks.ingestion import ingest_data, resolve_data_directories async def add( - data: Union[BinaryIO, list[BinaryIO], str, list[str]], - dataset_name: str = "main_dataset", - user: User = None, + data: Union[BinaryIO, List[BinaryIO], str, List[str], UploadFile, List[UploadFile]], + dataset_name: Optional[str] = "main_dataset", + user: Optional[User] = None, node_set: Optional[List[str]] = None, - vector_db_config: dict = None, - graph_db_config: dict = None, dataset_id: Optional[UUID] = None, ): """ @@ -67,8 +74,6 @@ async def add( Users can only access datasets they have permissions for. node_set: Optional list of node identifiers for graph organization and access control. Used for grouping related data points in the knowledge graph. - vector_db_config: Optional configuration for vector database (for custom setups). - graph_db_config: Optional configuration for graph database (for custom setups). dataset_id: Optional specific dataset UUID to use instead of dataset_name. Returns: @@ -138,21 +143,41 @@ async def add( UnsupportedFileTypeError: If file format cannot be processed InvalidValueError: If LLM_API_KEY is not set or invalid """ + # Create databases if not already created + await setup() + tasks = [ Task(resolve_data_directories, include_subdirectories=True), Task(ingest_data, dataset_name, user, node_set, dataset_id), ] + if not user: + user = await get_default_user() + + if dataset_id: + authorized_dataset = await get_authorized_dataset(dataset_id, user, "write") + elif dataset_name: + authorized_dataset = await get_authorized_dataset_by_name(dataset_name, user, "write") + if not authorized_dataset: + authorized_dataset = await create_authorized_dataset( + dataset_name=dataset_name, user=user + ) + else: + raise ValueError("Either dataset_id or dataset_name must be provided.") + + if not authorized_dataset: + raise DatasetNotFoundError( + message=f"Dataset ({str(dataset_id) or dataset_name}) not found." + ) + pipeline_run_info = None - async for run_info in cognee_pipeline( + async for run_info in run_add_pipeline( tasks=tasks, - datasets=dataset_id if dataset_id else dataset_name, data=data, + dataset=authorized_dataset, user=user, pipeline_name="add_pipeline", - vector_db_config=vector_db_config, - graph_db_config=graph_db_config, ): pipeline_run_info = run_info diff --git a/cognee/api/v1/add/routers/get_add_router.py b/cognee/api/v1/add/routers/get_add_router.py index 4519af728..add3163c5 100644 --- a/cognee/api/v1/add/routers/get_add_router.py +++ b/cognee/api/v1/add/routers/get_add_router.py @@ -2,11 +2,12 @@ import os import requests import subprocess from uuid import UUID +from io import BytesIO from fastapi import APIRouter from fastapi.responses import JSONResponse from fastapi import Form, File, UploadFile, Depends -from typing import List, Optional, Union, Literal +from typing import BinaryIO, List, Literal, Optional, Union from cognee.modules.users.models import User from cognee.modules.users.methods import get_authenticated_user @@ -69,10 +70,11 @@ def get_add_router() -> APIRouter: }, ) - from cognee.api.v1.add import add as cognee_add + # Swagger send empty string so we convert it to None for type consistency + if datasetId == "": + datasetId = None - if not datasetId and not datasetName: - raise ValueError("Either datasetId or datasetName must be provided.") + from cognee.api.v1.add import add as cognee_add try: if ( @@ -88,19 +90,24 @@ def get_add_router() -> APIRouter: await cognee_add( "data://.data/", f"{repo_name}", + user=user, ) else: # Fetch and store the data from other types of URL using curl response = requests.get(data) response.raise_for_status() - file_data = await response.content() - # TODO: Update add call with dataset info - return await cognee_add(file_data) + file_data = response.content + binary_io_data: BinaryIO = BytesIO(file_data) + return await cognee_add( + binary_io_data, dataset_name=datasetName, user=user, dataset_id=datasetId + ) else: - add_run = await cognee_add(data, datasetName, user=user, dataset_id=datasetId) + add_run = await cognee_add( + data, dataset_name=datasetName, user=user, dataset_id=datasetId + ) - return add_run.model_dump() + return add_run.model_dump() if add_run else None except Exception as error: return JSONResponse(status_code=409, content={"error": str(error)}) diff --git a/cognee/api/v1/cognify/cognify.py b/cognee/api/v1/cognify/cognify.py index 7c7821460..01b9b691d 100644 --- a/cognee/api/v1/cognify/cognify.py +++ b/cognee/api/v1/cognify/cognify.py @@ -14,6 +14,8 @@ from cognee.modules.ontology.rdf_xml.OntologyResolver import OntologyResolver from cognee.modules.pipelines.models.PipelineRunInfo import PipelineRunCompleted, PipelineRunErrored from cognee.modules.pipelines.queues.pipeline_run_info_queues import push_to_queue from cognee.modules.users.models import User +from cognee.modules.users.methods.get_default_user import get_default_user +from cognee.modules.data.methods import get_authorized_existing_datasets from cognee.tasks.documents import ( check_permissions_on_dataset, @@ -185,13 +187,21 @@ async def cognify( ValueError: If chunks exceed max token limits (reduce chunk_size) DatabaseNotCreatedError: If databases are not properly initialized """ + if not user: + user = await get_default_user() + + if isinstance(datasets, str): + datasets = [datasets] + + user_datasets = await get_authorized_existing_datasets(datasets, "write", user) + tasks = await get_default_tasks(user, graph_model, chunker, chunk_size, ontology_file_path) if run_in_background: return await run_cognify_as_background_process( tasks=tasks, user=user, - datasets=datasets, + datasets=user_datasets, vector_db_config=vector_db_config, graph_db_config=graph_db_config, ) @@ -199,7 +209,7 @@ async def cognify( return await run_cognify_blocking( tasks=tasks, user=user, - datasets=datasets, + datasets=user_datasets, vector_db_config=vector_db_config, graph_db_config=graph_db_config, ) diff --git a/cognee/base_config.py b/cognee/base_config.py index 5081acaff..27e65d006 100644 --- a/cognee/base_config.py +++ b/cognee/base_config.py @@ -13,8 +13,8 @@ class BaseConfig(BaseSettings): langfuse_public_key: Optional[str] = os.getenv("LANGFUSE_PUBLIC_KEY") langfuse_secret_key: Optional[str] = os.getenv("LANGFUSE_SECRET_KEY") langfuse_host: Optional[str] = os.getenv("LANGFUSE_HOST") - default_user_email: Optional[str] = os.getenv("DEFAULT_USER_EMAIL") - default_user_password: Optional[str] = os.getenv("DEFAULT_USER_PASSWORD") + default_user_email: str = os.getenv("DEFAULT_USER_EMAIL", "default_user@example.com") + default_user_password: str = os.getenv("DEFAULT_USER_PASSWORD", "default_password") model_config = SettingsConfigDict(env_file=".env", extra="allow") def to_dict(self) -> dict: diff --git a/cognee/modules/data/methods/__init__.py b/cognee/modules/data/methods/__init__.py index 3c62c536d..1aef34f62 100644 --- a/cognee/modules/data/methods/__init__.py +++ b/cognee/modules/data/methods/__init__.py @@ -1,5 +1,6 @@ # Create from .create_dataset import create_dataset +from .create_authorized_dataset import create_authorized_dataset # Get from .get_dataset import get_dataset @@ -11,6 +12,10 @@ from .get_unique_dataset_id import get_unique_dataset_id from .get_authorized_existing_datasets import get_authorized_existing_datasets from .get_dataset_ids import get_dataset_ids +# Get with Permissions +from .get_authorized_dataset import get_authorized_dataset +from .get_authorized_dataset_by_name import get_authorized_dataset_by_name + # Delete from .delete_dataset import delete_dataset from .delete_data import delete_data diff --git a/cognee/modules/data/methods/check_dataset_name.py b/cognee/modules/data/methods/check_dataset_name.py index c075ea353..e10201034 100644 --- a/cognee/modules/data/methods/check_dataset_name.py +++ b/cognee/modules/data/methods/check_dataset_name.py @@ -1,3 +1,3 @@ def check_dataset_name(dataset_name: str): if "." in dataset_name or " " in dataset_name: - raise ValueError("Dataset name cannot contain spaces or underscores") + raise ValueError(f"Dataset name cannot contain spaces or underscores, got {dataset_name}") diff --git a/cognee/modules/data/methods/create_authorized_dataset.py b/cognee/modules/data/methods/create_authorized_dataset.py new file mode 100644 index 000000000..d69995d6e --- /dev/null +++ b/cognee/modules/data/methods/create_authorized_dataset.py @@ -0,0 +1,26 @@ +from cognee.infrastructure.databases.relational import get_relational_engine +from cognee.modules.users.models import User +from cognee.modules.data.models import Dataset +from cognee.modules.users.permissions.methods import give_permission_on_dataset +from cognee.modules.data.methods.get_unique_dataset_id import get_unique_dataset_id + + +async def create_authorized_dataset(dataset_name: str, user: User) -> Dataset: + # Dataset id should be generated based on dataset_name and owner_id/user so multiple users can use the same dataset_name + dataset_id = await get_unique_dataset_id(dataset_name=dataset_name, user=user) + new_dataset = Dataset(id=dataset_id, name=dataset_name, data=[]) + new_dataset.owner_id = user.id + + db_engine = get_relational_engine() + + async with db_engine.get_async_session() as session: + session.add(new_dataset) + + await session.commit() + + await give_permission_on_dataset(user, new_dataset.id, "read") + await give_permission_on_dataset(user, new_dataset.id, "write") + await give_permission_on_dataset(user, new_dataset.id, "delete") + await give_permission_on_dataset(user, new_dataset.id, "share") + + return new_dataset diff --git a/cognee/modules/data/methods/get_authorized_dataset.py b/cognee/modules/data/methods/get_authorized_dataset.py new file mode 100644 index 000000000..79e6593d7 --- /dev/null +++ b/cognee/modules/data/methods/get_authorized_dataset.py @@ -0,0 +1,15 @@ +from uuid import UUID +from typing import Optional + +from cognee.modules.users.models import User +from cognee.modules.users.permissions.methods import get_principal_datasets + +from ..models import Dataset + + +async def get_authorized_dataset( + dataset_id: UUID, user: User, permission_type: str +) -> Optional[Dataset]: + user_datasets = await get_principal_datasets(user, permission_type) + + return next((dataset for dataset in user_datasets if dataset.id == dataset_id), None) diff --git a/cognee/modules/data/methods/get_authorized_dataset_by_name.py b/cognee/modules/data/methods/get_authorized_dataset_by_name.py new file mode 100644 index 000000000..d11d3a1ea --- /dev/null +++ b/cognee/modules/data/methods/get_authorized_dataset_by_name.py @@ -0,0 +1,14 @@ +from typing import Optional + +from cognee.modules.users.models import User +from cognee.modules.users.permissions.methods import get_principal_datasets + +from ..models import Dataset + + +async def get_authorized_dataset_by_name( + dataset_name: str, user: User, permission_type: str +) -> Optional[Dataset]: + user_datasets = await get_principal_datasets(user, permission_type) + + return next((dataset for dataset in user_datasets if dataset.name == dataset_name), None) diff --git a/cognee/modules/data/models/Dataset.py b/cognee/modules/data/models/Dataset.py index 797401d5a..ef43a66aa 100644 --- a/cognee/modules/data/models/Dataset.py +++ b/cognee/modules/data/models/Dataset.py @@ -1,7 +1,7 @@ -from uuid import uuid4 +from uuid import uuid4, UUID as UUID_t from typing import List from datetime import datetime, timezone -from sqlalchemy.orm import relationship, Mapped +from sqlalchemy.orm import relationship, Mapped, mapped_column from sqlalchemy import Column, Text, DateTime, UUID from cognee.infrastructure.databases.relational import Base from .DatasetData import DatasetData @@ -10,14 +10,14 @@ from .DatasetData import DatasetData class Dataset(Base): __tablename__ = "datasets" - id = Column(UUID, primary_key=True, default=uuid4) + id: Mapped[UUID_t] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid4) - name = Column(Text) + name: Mapped[str] = mapped_column(Text) created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)) updated_at = Column(DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc)) - owner_id = Column(UUID, index=True) + owner_id: Mapped[UUID_t] = mapped_column(UUID(as_uuid=True), index=True) acls = relationship("ACL", back_populates="dataset", cascade="all, delete-orphan") diff --git a/cognee/modules/pipelines/models/PipelineRun.py b/cognee/modules/pipelines/models/PipelineRun.py index d1031c172..c09b40798 100644 --- a/cognee/modules/pipelines/models/PipelineRun.py +++ b/cognee/modules/pipelines/models/PipelineRun.py @@ -1,6 +1,7 @@ import enum -from uuid import uuid4 +from uuid import uuid4, UUID as UUID_t from datetime import datetime, timezone +from sqlalchemy.orm import Mapped, mapped_column from sqlalchemy import Column, DateTime, JSON, Enum, UUID, String from cognee.infrastructure.databases.relational import Base @@ -19,9 +20,9 @@ class PipelineRun(Base): created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)) - status = Column(Enum(PipelineRunStatus)) - pipeline_run_id = Column(UUID, index=True) + status: Mapped[PipelineRunStatus] = mapped_column(Enum(PipelineRunStatus)) + pipeline_run_id: Mapped[UUID_t] = mapped_column(UUID, index=True) pipeline_name = Column(String) - pipeline_id = Column(UUID, index=True) - dataset_id = Column(UUID, index=True) + pipeline_id: Mapped[UUID_t] = mapped_column(UUID(as_uuid=True), index=True) + dataset_id: Mapped[UUID_t] = mapped_column(UUID(as_uuid=True), index=True) run_info = Column(JSON) diff --git a/cognee/modules/pipelines/operations/__init__.py b/cognee/modules/pipelines/operations/__init__.py index 21ee552f0..d55315274 100644 --- a/cognee/modules/pipelines/operations/__init__.py +++ b/cognee/modules/pipelines/operations/__init__.py @@ -2,4 +2,5 @@ from .log_pipeline_run_initiated import log_pipeline_run_initiated from .log_pipeline_run_start import log_pipeline_run_start from .log_pipeline_run_complete import log_pipeline_run_complete from .log_pipeline_run_error import log_pipeline_run_error +from .get_pipeline_status import get_pipeline_status from .pipeline import cognee_pipeline diff --git a/cognee/modules/pipelines/operations/get_pipeline_status.py b/cognee/modules/pipelines/operations/get_pipeline_status.py index 05cc7ab6e..73e7e47ef 100644 --- a/cognee/modules/pipelines/operations/get_pipeline_status.py +++ b/cognee/modules/pipelines/operations/get_pipeline_status.py @@ -1,11 +1,13 @@ from uuid import UUID from sqlalchemy import select, func from cognee.infrastructure.databases.relational import get_relational_engine -from ..models import PipelineRun +from ..models import PipelineRun, PipelineRunStatus from sqlalchemy.orm import aliased -async def get_pipeline_status(dataset_ids: list[UUID], pipeline_name: str): +async def get_pipeline_status( + dataset_ids: list[UUID], pipeline_name: str +) -> dict[str, PipelineRunStatus]: db_engine = get_relational_engine() async with db_engine.get_async_session() as session: diff --git a/cognee/modules/pipelines/operations/log_pipeline_run_initiated.py b/cognee/modules/pipelines/operations/log_pipeline_run_initiated.py index e68efe31e..0ac054c7b 100644 --- a/cognee/modules/pipelines/operations/log_pipeline_run_initiated.py +++ b/cognee/modules/pipelines/operations/log_pipeline_run_initiated.py @@ -4,7 +4,7 @@ from cognee.modules.pipelines.models import PipelineRun, PipelineRunStatus from cognee.modules.pipelines.utils import generate_pipeline_run_id -async def log_pipeline_run_initiated(pipeline_id: str, pipeline_name: str, dataset_id: UUID): +async def log_pipeline_run_initiated(pipeline_id: UUID, pipeline_name: str, dataset_id: UUID): pipeline_run = PipelineRun( pipeline_run_id=generate_pipeline_run_id(pipeline_id, dataset_id), pipeline_name=pipeline_name, diff --git a/cognee/modules/pipelines/operations/pipeline.py b/cognee/modules/pipelines/operations/pipeline.py index e58c15254..171b4e9c5 100644 --- a/cognee/modules/pipelines/operations/pipeline.py +++ b/cognee/modules/pipelines/operations/pipeline.py @@ -3,24 +3,23 @@ from uuid import UUID from typing import Union from cognee.shared.logging_utils import get_logger +from cognee.modules.engine.operations.setup import setup from cognee.modules.data.methods.get_dataset_data import get_dataset_data from cognee.modules.data.models import Data, Dataset from cognee.modules.pipelines.operations.run_tasks import run_tasks from cognee.modules.pipelines.models import PipelineRunStatus -from cognee.modules.pipelines.utils import generate_pipeline_id +from cognee.modules.pipelines.utils import validate_pipeline_inputs from cognee.modules.pipelines.operations.get_pipeline_status import get_pipeline_status from cognee.modules.pipelines.methods import get_pipeline_run_by_dataset from cognee.modules.pipelines.tasks.task import Task from cognee.modules.users.methods import get_default_user from cognee.modules.users.models import User -from cognee.modules.pipelines.operations import log_pipeline_run_initiated from cognee.context_global_variables import set_database_global_context_variables from cognee.modules.data.exceptions import DatasetNotFoundError from cognee.modules.data.methods import ( get_authorized_existing_datasets, load_or_create_datasets, - check_dataset_name, ) from cognee.modules.pipelines.models.PipelineRunInfo import ( @@ -28,12 +27,6 @@ from cognee.modules.pipelines.models.PipelineRunInfo import ( PipelineRunStarted, ) -from cognee.infrastructure.databases.relational import ( - create_db_and_tables as create_relational_db_and_tables, -) -from cognee.infrastructure.databases.vector.pgvector import ( - create_db_and_tables as create_pgvector_db_and_tables, -) from cognee.context_global_variables import ( graph_db_config as context_graph_db_config, vector_db_config as context_vector_db_config, @@ -44,6 +37,7 @@ logger = get_logger("cognee.pipeline") update_status_lock = asyncio.Lock() +@validate_pipeline_inputs async def cognee_pipeline( tasks: list[Task], data=None, @@ -60,9 +54,8 @@ async def cognee_pipeline( if graph_db_config: context_graph_db_config.set(graph_db_config) - # Create tables for databases - await create_relational_db_and_tables() - await create_pgvector_db_and_tables() + # Create databases if they don't exist + await setup() # Initialize first_run attribute if it doesn't exist if not hasattr(cognee_pipeline, "first_run"): @@ -84,16 +77,17 @@ async def cognee_pipeline( if isinstance(datasets, str) or isinstance(datasets, UUID): datasets = [datasets] - # Get datasets user wants write permissions for (verify user has permissions if datasets are provided as well) - # NOTE: If a user wants to write to a dataset he does not own it must be provided through UUID - existing_datasets = await get_authorized_existing_datasets(datasets, "write", user) + if not all([isinstance(dataset, Dataset) for dataset in datasets]): + # Get datasets user wants write permissions for (verify user has permissions if datasets are provided as well) + # NOTE: If a user wants to write to a dataset he does not own it must be provided through UUID + existing_datasets = await get_authorized_existing_datasets(datasets, "write", user) - if not datasets: - # Get datasets from database if none sent. - datasets = existing_datasets - else: - # If dataset matches an existing Dataset (by name or id), reuse it. Otherwise, create a new Dataset. - datasets = await load_or_create_datasets(datasets, existing_datasets, user) + if not datasets: + # Get datasets from database if none sent. + datasets = existing_datasets + else: + # If dataset matches an existing Dataset (by name or id), reuse it. Otherwise, create a new Dataset. + datasets = await load_or_create_datasets(datasets, existing_datasets, user) if not datasets: raise DatasetNotFoundError("There are no datasets to work with.") @@ -118,31 +112,9 @@ async def run_pipeline( pipeline_name: str = "custom_pipeline", context: dict = None, ): - check_dataset_name(dataset.name) - # Will only be used if ENABLE_BACKEND_ACCESS_CONTROL is set to True await set_database_global_context_variables(dataset.id, dataset.owner_id) - # Ugly hack, but no easier way to do this. - if pipeline_name == "add_pipeline": - pipeline_id = generate_pipeline_id(user.id, dataset.id, pipeline_name) - # Refresh the add pipeline status so data is added to a dataset. - # Without this the app_pipeline status will be DATASET_PROCESSING_COMPLETED and will skip the execution. - - await log_pipeline_run_initiated( - pipeline_id=pipeline_id, - pipeline_name="add_pipeline", - dataset_id=dataset.id, - ) - - # Refresh the cognify pipeline status after we add new files. - # Without this the cognify_pipeline status will be DATASET_PROCESSING_COMPLETED and will skip the execution. - await log_pipeline_run_initiated( - pipeline_id=pipeline_id, - pipeline_name="cognify_pipeline", - dataset_id=dataset.id, - ) - dataset_id = dataset.id if not data: @@ -177,13 +149,6 @@ async def run_pipeline( ) return - if not isinstance(tasks, list): - raise ValueError("Tasks must be a list") - - for task in tasks: - if not isinstance(task, Task): - raise ValueError(f"Task {task} is not an instance of Task") - pipeline_run = run_tasks(tasks, dataset_id, data, user, pipeline_name, context) async for pipeline_run_info in pipeline_run: diff --git a/cognee/modules/pipelines/operations/run_add_pipeline.py b/cognee/modules/pipelines/operations/run_add_pipeline.py new file mode 100644 index 000000000..8eaba6893 --- /dev/null +++ b/cognee/modules/pipelines/operations/run_add_pipeline.py @@ -0,0 +1,46 @@ +from cognee.shared.logging_utils import get_logger +from cognee.modules.users.models import User +from cognee.modules.data.models import Dataset +from cognee.modules.pipelines.tasks.task import Task +from cognee.modules.pipelines.operations.run_tasks import run_tasks +from cognee.modules.pipelines.operations import log_pipeline_run_initiated +from cognee.modules.pipelines.utils import generate_pipeline_id, validate_pipeline_inputs +from cognee.context_global_variables import set_database_global_context_variables + +logger = get_logger("add.pipeline") + + +@validate_pipeline_inputs +async def run_add_pipeline( + tasks: list[Task], + data, + dataset: Dataset, + user: User, + pipeline_name: str = "add_pipeline", +): + await set_database_global_context_variables(dataset.id, dataset.owner_id) + + pipeline_run = run_tasks( + tasks, + dataset.id, + data, + user, + pipeline_name, + { + "user": user, + "dataset": dataset, + }, + ) + + async for pipeline_run_info in pipeline_run: + yield pipeline_run_info + + pipeline_id = generate_pipeline_id(user.id, dataset.id, pipeline_name) + + # Refresh the cognify pipeline status after we add new files. + # Without this the cognify_pipeline status will be DATASET_PROCESSING_COMPLETED and will skip the execution. + await log_pipeline_run_initiated( + pipeline_id=pipeline_id, + pipeline_name="cognify_pipeline", + dataset_id=dataset.id, + ) diff --git a/cognee/modules/pipelines/utils/__init__.py b/cognee/modules/pipelines/utils/__init__.py index 5d7609741..2db440358 100644 --- a/cognee/modules/pipelines/utils/__init__.py +++ b/cognee/modules/pipelines/utils/__init__.py @@ -1,2 +1,3 @@ from .generate_pipeline_id import generate_pipeline_id from .generate_pipeline_run_id import generate_pipeline_run_id +from .validate_pipeline_inputs import validate_pipeline_inputs diff --git a/cognee/modules/pipelines/utils/validate_pipeline_inputs.py b/cognee/modules/pipelines/utils/validate_pipeline_inputs.py new file mode 100644 index 000000000..44deb114e --- /dev/null +++ b/cognee/modules/pipelines/utils/validate_pipeline_inputs.py @@ -0,0 +1,56 @@ +import inspect +from functools import wraps + +from cognee.modules.users.models.User import User +from cognee.modules.pipelines.tasks.task import Task +from cognee.modules.data.models.Dataset import Dataset +from cognee.modules.data.methods.check_dataset_name import check_dataset_name + + +def validate_pipeline_inputs(pipeline_generator): + @wraps(pipeline_generator) + async def wrapper(*args, **kwargs): + sig = inspect.signature(pipeline_generator) + bound_args = sig.bind(*args, **kwargs) + bound_args.apply_defaults() + + if "tasks" in bound_args.arguments: + tasks = bound_args.arguments["tasks"] + if not isinstance(tasks, list): + raise ValueError(f"tasks must be a list, got {type(tasks).__name__}") + + for task in tasks: + if not isinstance(task, Task): + raise ValueError( + f"tasks must be a list of Task instances, got {type(task).__name__} in the list" + ) + + if "user" in bound_args.arguments: + user = bound_args.arguments["user"] + if not isinstance(user, User): + raise ValueError(f"user must be an instance of User, got {type(user).__name__}") + + if "dataset" in bound_args.arguments: + dataset = bound_args.arguments["dataset"] + if not isinstance(dataset, Dataset): + raise ValueError( + f"dataset must be an instance of Dataset, got {type(dataset).__name__}" + ) + check_dataset_name(dataset.name) + + if "datasets" in bound_args.arguments: + datasets = bound_args.arguments["datasets"] + if not isinstance(datasets, list): + raise ValueError(f"datasets must be a list, got {type(datasets).__name__}") + + for dataset in datasets: + if not isinstance(dataset, Dataset): + raise ValueError( + f"datasets must be a list of Dataset instances, got {type(dataset).__name__} in the list" + ) + check_dataset_name(dataset.name) + + async for run_info in pipeline_generator(*args, **kwargs): + yield run_info + + return wrapper diff --git a/cognee/modules/users/methods/create_default_user.py b/cognee/modules/users/methods/create_default_user.py index c19092b3a..029939d06 100644 --- a/cognee/modules/users/methods/create_default_user.py +++ b/cognee/modules/users/methods/create_default_user.py @@ -4,8 +4,8 @@ from cognee.base_config import get_base_config async def create_default_user(): base_config = get_base_config() - default_user_email = base_config.default_user_email or "default_user@example.com" - default_user_password = base_config.default_user_password or "default_password" + default_user_email = base_config.default_user_email + default_user_password = base_config.default_user_password user = await create_user( email=default_user_email, diff --git a/cognee/modules/users/methods/get_default_user.py b/cognee/modules/users/methods/get_default_user.py index 10779e028..845d27f55 100644 --- a/cognee/modules/users/methods/get_default_user.py +++ b/cognee/modules/users/methods/get_default_user.py @@ -1,7 +1,7 @@ -from types import SimpleNamespace +from sqlalchemy import select from sqlalchemy.orm import selectinload from sqlalchemy.exc import NoResultFound -from sqlalchemy.future import select + from cognee.modules.users.models import User from cognee.base_config import get_base_config from cognee.modules.users.exceptions.exceptions import UserNotFoundError @@ -10,15 +10,15 @@ from cognee.infrastructure.databases.relational import get_relational_engine from cognee.modules.users.methods.create_default_user import create_default_user -async def get_default_user() -> SimpleNamespace: +async def get_default_user() -> User: db_engine = get_relational_engine() base_config = get_base_config() - default_email = base_config.default_user_email or "default_user@example.com" + default_email: str = str(base_config.default_user_email) try: async with db_engine.get_async_session() as session: query = ( - select(User).options(selectinload(User.roles)).where(User.email == default_email) + select(User).options(selectinload(User.roles)).where(User.email == default_email) # type: ignore ) result = await session.execute(query) @@ -27,10 +27,7 @@ async def get_default_user() -> SimpleNamespace: if user is None: return await create_default_user() - # We return a SimpleNamespace to have the same user type as our SaaS - # SimpleNamespace is just a dictionary which can be accessed through attributes - auth_data = SimpleNamespace(id=user.id, tenant_id=user.tenant_id, roles=[]) - return auth_data + return user except Exception as error: if "principals" in str(error.args): raise DatabaseNotCreatedError() from error diff --git a/cognee/modules/users/models/__init__.py b/cognee/modules/users/models/__init__.py index ba2f40e49..792370a74 100644 --- a/cognee/modules/users/models/__init__.py +++ b/cognee/modules/users/models/__init__.py @@ -8,3 +8,4 @@ from .TenantDefaultPermissions import TenantDefaultPermissions from .Permission import Permission from .Tenant import Tenant from .ACL import ACL +from .Principal import Principal