feat: cognee pipeline layers (#1287)
<!-- .github/pull_request_template.md --> ## Description Add dataset authorization layer Add pipeline processing status layer ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin.
This commit is contained in:
commit
5771b36c4c
5 changed files with 127 additions and 130 deletions
|
|
@ -1,60 +0,0 @@
|
|||
from datetime import datetime, timezone
|
||||
from sqlalchemy.orm import Mapped, MappedColumn
|
||||
from sqlalchemy import Column, DateTime, ForeignKey, Enum, JSON
|
||||
from cognee.infrastructure.databases.relational import Base, UUID
|
||||
|
||||
|
||||
class OperationType(Enum):
|
||||
"""
|
||||
Define various types of operations for data handling.
|
||||
|
||||
Public methods:
|
||||
- __str__(): Returns a string representation of the operation type.
|
||||
|
||||
Instance variables:
|
||||
- MERGE_DATA: Represents the merge data operation type.
|
||||
- APPEND_DATA: Represents the append data operation type.
|
||||
"""
|
||||
|
||||
MERGE_DATA = "MERGE_DATA"
|
||||
APPEND_DATA = "APPEND_DATA"
|
||||
|
||||
|
||||
class OperationStatus(Enum):
|
||||
"""
|
||||
Represent the status of an operation with predefined states.
|
||||
"""
|
||||
|
||||
STARTED = "OPERATION_STARTED"
|
||||
IN_PROGRESS = "OPERATION_IN_PROGRESS"
|
||||
COMPLETE = "OPERATION_COMPLETE"
|
||||
ERROR = "OPERATION_ERROR"
|
||||
CANCELLED = "OPERATION_CANCELLED"
|
||||
|
||||
|
||||
class Operation(Base):
|
||||
"""
|
||||
Represents an operation in the system, extending the Base class.
|
||||
|
||||
This class defines the structure of the 'operation' table, including fields for the
|
||||
operation's ID, status, type, associated data, metadata, and creation timestamp. The
|
||||
public methods available in this class are inherited from the Base class. Instance
|
||||
variables include:
|
||||
- id: Unique identifier for the operation.
|
||||
- status: The current status of the operation.
|
||||
- operation_type: The type of operation being represented.
|
||||
- data_id: Foreign key referencing the associated data's ID.
|
||||
- meta_data: Additional metadata related to the operation.
|
||||
- created_at: Timestamp for when the operation was created.
|
||||
"""
|
||||
|
||||
__tablename__ = "operation"
|
||||
|
||||
id = Column(UUID, primary_key=True)
|
||||
status = Column(Enum(OperationStatus))
|
||||
operation_type = Column(Enum(OperationType))
|
||||
|
||||
data_id = Column(UUID, ForeignKey("data.id"))
|
||||
meta_data: Mapped[dict] = MappedColumn(type_=JSON)
|
||||
|
||||
created_at = Column(DateTime, default=datetime.now(timezone.utc))
|
||||
59
cognee/modules/pipelines/layers/process_pipeline_check.py
Normal file
59
cognee/modules/pipelines/layers/process_pipeline_check.py
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
from typing import Union, Optional
|
||||
from cognee.modules.data.models import Dataset
|
||||
from cognee.modules.data.models import Data
|
||||
from cognee.modules.pipelines.models import PipelineRunStatus
|
||||
from cognee.modules.pipelines.operations.get_pipeline_status import get_pipeline_status
|
||||
from cognee.modules.pipelines.methods import get_pipeline_run_by_dataset
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
from cognee.modules.pipelines.models.PipelineRunInfo import (
|
||||
PipelineRunCompleted,
|
||||
PipelineRunStarted,
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
async def process_pipeline_check(
|
||||
dataset: Dataset, data: list[Data], pipeline_name: str
|
||||
) -> Optional[Union[PipelineRunStarted, PipelineRunCompleted]]:
|
||||
"""
|
||||
Function used to determine if pipeline is currently being processed or was already processed.
|
||||
In case pipeline was or is being processed return value is returned and current pipline execution should be stopped.
|
||||
In case pipeline is not or was not processed there will be no return value and pipeline processing can start.
|
||||
|
||||
Args:
|
||||
dataset: Dataset object
|
||||
data: List of Data
|
||||
pipeline_name: pipeline name
|
||||
|
||||
Returns: Pipeline state if it is being processed or was already processed
|
||||
|
||||
"""
|
||||
|
||||
# async with update_status_lock: TODO: Add UI lock to prevent multiple backend requests
|
||||
if isinstance(dataset, Dataset):
|
||||
task_status = await get_pipeline_status([dataset.id], pipeline_name)
|
||||
else:
|
||||
task_status = {}
|
||||
|
||||
if str(dataset.id) in task_status:
|
||||
if task_status[str(dataset.id)] == PipelineRunStatus.DATASET_PROCESSING_STARTED:
|
||||
logger.info("Dataset %s is already being processed.", dataset.id)
|
||||
pipeline_run = await get_pipeline_run_by_dataset(dataset.id, pipeline_name)
|
||||
return PipelineRunStarted(
|
||||
pipeline_run_id=pipeline_run.pipeline_run_id,
|
||||
dataset_id=dataset.id,
|
||||
dataset_name=dataset.name,
|
||||
payload=data,
|
||||
)
|
||||
elif task_status[str(dataset.id)] == PipelineRunStatus.DATASET_PROCESSING_COMPLETED:
|
||||
logger.info("Dataset %s is already processed.", dataset.id)
|
||||
pipeline_run = await get_pipeline_run_by_dataset(dataset.id, pipeline_name)
|
||||
return PipelineRunCompleted(
|
||||
pipeline_run_id=pipeline_run.pipeline_run_id,
|
||||
dataset_id=dataset.id,
|
||||
dataset_name=dataset.name,
|
||||
)
|
||||
|
||||
return
|
||||
|
|
@ -0,0 +1,55 @@
|
|||
from uuid import UUID
|
||||
from typing import Union, Tuple, List
|
||||
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.data.models import Dataset
|
||||
from cognee.modules.data.exceptions import DatasetNotFoundError
|
||||
from cognee.modules.data.methods import (
|
||||
get_authorized_existing_datasets,
|
||||
load_or_create_datasets,
|
||||
check_dataset_name,
|
||||
)
|
||||
|
||||
|
||||
async def resolve_authorized_user_datasets(
|
||||
datasets: Union[str, UUID, list[str], list[UUID]], user: User = None
|
||||
) -> Tuple[User, List[Dataset]]:
|
||||
"""
|
||||
Function handles creation and dataset authorization if datasets already exist for Cognee.
|
||||
Verifies that provided user has necessary permission for provided Dataset.
|
||||
If Dataset does not exist creates the Dataset and gives permission for the user creating the dataset.
|
||||
|
||||
Args:
|
||||
user: Cognee User request is being processed for, if None default user will be used.
|
||||
datasets: Dataset names or Dataset UUID (in case Datasets already exist)
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
# If no user is provided use default user
|
||||
if user is None:
|
||||
user = await get_default_user()
|
||||
|
||||
# Convert datasets to list
|
||||
if isinstance(datasets, str) or isinstance(datasets, UUID):
|
||||
datasets = [datasets]
|
||||
|
||||
# Get datasets user wants write permissions for (verify user has permissions if datasets are provided as well)
|
||||
# NOTE: If a user wants to write to a dataset he does not own it must be provided through UUID
|
||||
existing_datasets = await get_authorized_existing_datasets(datasets, "write", user)
|
||||
|
||||
if not datasets:
|
||||
# Get datasets from database if none sent.
|
||||
authorized_datasets = existing_datasets
|
||||
else:
|
||||
# If dataset matches an existing Dataset (by name or id), reuse it. Otherwise, create a new Dataset.
|
||||
authorized_datasets = await load_or_create_datasets(datasets, existing_datasets, user)
|
||||
|
||||
if not authorized_datasets:
|
||||
raise DatasetNotFoundError("There are no datasets to work with.")
|
||||
|
||||
for dataset in authorized_datasets:
|
||||
check_dataset_name(dataset.name)
|
||||
|
||||
return user, authorized_datasets
|
||||
|
|
@ -9,27 +9,16 @@ from cognee.shared.logging_utils import get_logger
|
|||
from cognee.modules.data.methods.get_dataset_data import get_dataset_data
|
||||
from cognee.modules.data.models import Data, Dataset
|
||||
from cognee.modules.pipelines.operations.run_tasks import run_tasks
|
||||
from cognee.modules.pipelines.models import PipelineRunStatus
|
||||
from cognee.modules.pipelines.utils import generate_pipeline_id
|
||||
from cognee.modules.pipelines.operations.get_pipeline_status import get_pipeline_status
|
||||
from cognee.modules.pipelines.methods import get_pipeline_run_by_dataset
|
||||
|
||||
from cognee.modules.pipelines.tasks.task import Task
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.pipelines.operations import log_pipeline_run_initiated
|
||||
from cognee.context_global_variables import set_database_global_context_variables
|
||||
from cognee.modules.data.exceptions import DatasetNotFoundError
|
||||
from cognee.modules.data.methods import (
|
||||
get_authorized_existing_datasets,
|
||||
load_or_create_datasets,
|
||||
check_dataset_name,
|
||||
)
|
||||
|
||||
from cognee.modules.pipelines.models.PipelineRunInfo import (
|
||||
PipelineRunCompleted,
|
||||
PipelineRunStarted,
|
||||
from cognee.modules.pipelines.layers.resolve_authorized_user_datasets import (
|
||||
resolve_authorized_user_datasets,
|
||||
)
|
||||
from cognee.modules.pipelines.layers.process_pipeline_check import process_pipeline_check
|
||||
|
||||
logger = get_logger("cognee.pipeline")
|
||||
|
||||
|
|
@ -48,29 +37,9 @@ async def cognee_pipeline(
|
|||
):
|
||||
await environment_setup_and_checks(vector_db_config, graph_db_config)
|
||||
|
||||
# If no user is provided use default user
|
||||
if user is None:
|
||||
user = await get_default_user()
|
||||
user, authorized_datasets = await resolve_authorized_user_datasets(datasets, user)
|
||||
|
||||
# Convert datasets to list
|
||||
if isinstance(datasets, str) or isinstance(datasets, UUID):
|
||||
datasets = [datasets]
|
||||
|
||||
# Get datasets user wants write permissions for (verify user has permissions if datasets are provided as well)
|
||||
# NOTE: If a user wants to write to a dataset he does not own it must be provided through UUID
|
||||
existing_datasets = await get_authorized_existing_datasets(datasets, "write", user)
|
||||
|
||||
if not datasets:
|
||||
# Get datasets from database if none sent.
|
||||
datasets = existing_datasets
|
||||
else:
|
||||
# If dataset matches an existing Dataset (by name or id), reuse it. Otherwise, create a new Dataset.
|
||||
datasets = await load_or_create_datasets(datasets, existing_datasets, user)
|
||||
|
||||
if not datasets:
|
||||
raise DatasetNotFoundError("There are no datasets to work with.")
|
||||
|
||||
for dataset in datasets:
|
||||
for dataset in authorized_datasets:
|
||||
async for run_info in run_pipeline(
|
||||
dataset=dataset,
|
||||
user=user,
|
||||
|
|
@ -92,8 +61,6 @@ async def run_pipeline(
|
|||
context: dict = None,
|
||||
incremental_loading=False,
|
||||
):
|
||||
check_dataset_name(dataset.name)
|
||||
|
||||
# Will only be used if ENABLE_BACKEND_ACCESS_CONTROL is set to True
|
||||
await set_database_global_context_variables(dataset.id, dataset.owner_id)
|
||||
|
||||
|
|
@ -117,39 +84,15 @@ async def run_pipeline(
|
|||
dataset_id=dataset.id,
|
||||
)
|
||||
|
||||
dataset_id = dataset.id
|
||||
|
||||
if not data:
|
||||
data: list[Data] = await get_dataset_data(dataset_id=dataset_id)
|
||||
data: list[Data] = await get_dataset_data(dataset_id=dataset.id)
|
||||
|
||||
# async with update_status_lock: TODO: Add UI lock to prevent multiple backend requests
|
||||
if isinstance(dataset, Dataset):
|
||||
task_status = await get_pipeline_status([dataset_id], pipeline_name)
|
||||
else:
|
||||
task_status = [
|
||||
PipelineRunStatus.DATASET_PROCESSING_COMPLETED
|
||||
] # TODO: this is a random assignment, find permanent solution
|
||||
|
||||
if str(dataset_id) in task_status:
|
||||
if task_status[str(dataset_id)] == PipelineRunStatus.DATASET_PROCESSING_STARTED:
|
||||
logger.info("Dataset %s is already being processed.", dataset_id)
|
||||
pipeline_run = await get_pipeline_run_by_dataset(dataset_id, pipeline_name)
|
||||
yield PipelineRunStarted(
|
||||
pipeline_run_id=pipeline_run.pipeline_run_id,
|
||||
dataset_id=dataset.id,
|
||||
dataset_name=dataset.name,
|
||||
payload=data,
|
||||
)
|
||||
return
|
||||
elif task_status[str(dataset_id)] == PipelineRunStatus.DATASET_PROCESSING_COMPLETED:
|
||||
logger.info("Dataset %s is already processed.", dataset_id)
|
||||
pipeline_run = await get_pipeline_run_by_dataset(dataset_id, pipeline_name)
|
||||
yield PipelineRunCompleted(
|
||||
pipeline_run_id=pipeline_run.pipeline_run_id,
|
||||
dataset_id=dataset.id,
|
||||
dataset_name=dataset.name,
|
||||
)
|
||||
return
|
||||
process_pipeline_status = await process_pipeline_check(dataset, data, pipeline_name)
|
||||
if process_pipeline_status:
|
||||
# If pipeline was already processed or is currently being processed
|
||||
# return status information to async generator and finish execution
|
||||
yield process_pipeline_status
|
||||
return
|
||||
|
||||
if not isinstance(tasks, list):
|
||||
raise ValueError("Tasks must be a list")
|
||||
|
|
@ -159,7 +102,7 @@ async def run_pipeline(
|
|||
raise ValueError(f"Task {task} is not an instance of Task")
|
||||
|
||||
pipeline_run = run_tasks(
|
||||
tasks, dataset_id, data, user, pipeline_name, context, incremental_loading
|
||||
tasks, dataset.id, data, user, pipeline_name, context, incremental_loading
|
||||
)
|
||||
|
||||
async for pipeline_run_info in pipeline_run:
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue