diff --git a/.env.template b/.env.template index 28980de74..e9e9fb571 100644 --- a/.env.template +++ b/.env.template @@ -137,6 +137,14 @@ REQUIRE_AUTHENTICATION=False # It enforces LanceDB and KuzuDB use and uses them to create databases per Cognee user + dataset ENABLE_BACKEND_ACCESS_CONTROL=False +################################################################################ +# ☁️ Cloud Sync Settings +################################################################################ + +# Cognee Cloud API settings for syncing data to/from cloud infrastructure +COGNEE_CLOUD_API_URL="http://localhost:8001" +COGNEE_CLOUD_AUTH_TOKEN="your-auth-token" + ################################################################################ # 🛠️ DEV Settings ################################################################################ diff --git a/alembic/versions/211ab850ef3d_add_sync_operations_table.py b/alembic/versions/211ab850ef3d_add_sync_operations_table.py new file mode 100644 index 000000000..f22c7c6e2 --- /dev/null +++ b/alembic/versions/211ab850ef3d_add_sync_operations_table.py @@ -0,0 +1,98 @@ +"""Add sync_operations table + +Revision ID: 211ab850ef3d +Revises: 9e7a3cb85175 +Create Date: 2025-09-10 20:11:13.534829 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "211ab850ef3d" +down_revision: Union[str, None] = "9e7a3cb85175" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + + # Check if table already exists (it might be created by Base.metadata.create_all() in initial migration) + connection = op.get_bind() + inspector = sa.inspect(connection) + + if "sync_operations" not in inspector.get_table_names(): + # Table doesn't exist, create it normally + op.create_table( + "sync_operations", + sa.Column("id", sa.UUID(), nullable=False), + sa.Column("run_id", sa.Text(), nullable=True), + sa.Column( + "status", + sa.Enum( + "STARTED", + "IN_PROGRESS", + "COMPLETED", + "FAILED", + "CANCELLED", + name="syncstatus", + create_type=False, + ), + nullable=True, + ), + sa.Column("progress_percentage", sa.Integer(), nullable=True), + sa.Column("dataset_ids", sa.JSON(), nullable=True), + sa.Column("dataset_names", sa.JSON(), nullable=True), + sa.Column("user_id", sa.UUID(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("started_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("completed_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("total_records_to_sync", sa.Integer(), nullable=True), + sa.Column("total_records_to_download", sa.Integer(), nullable=True), + sa.Column("total_records_to_upload", sa.Integer(), nullable=True), + sa.Column("records_downloaded", sa.Integer(), nullable=True), + sa.Column("records_uploaded", sa.Integer(), nullable=True), + sa.Column("bytes_downloaded", sa.Integer(), nullable=True), + sa.Column("bytes_uploaded", sa.Integer(), nullable=True), + sa.Column("dataset_sync_hashes", sa.JSON(), nullable=True), + sa.Column("error_message", sa.Text(), nullable=True), + sa.Column("retry_count", sa.Integer(), nullable=True), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + op.f("ix_sync_operations_run_id"), "sync_operations", ["run_id"], unique=True + ) + op.create_index( + op.f("ix_sync_operations_user_id"), "sync_operations", ["user_id"], unique=False + ) + else: + # Table already exists, but we might need to add missing columns or indexes + # For now, just log that the table already exists + print("sync_operations table already exists, skipping creation") + + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + + # Only drop if table exists (might have been created by Base.metadata.create_all()) + connection = op.get_bind() + inspector = sa.inspect(connection) + + if "sync_operations" in inspector.get_table_names(): + op.drop_index(op.f("ix_sync_operations_user_id"), table_name="sync_operations") + op.drop_index(op.f("ix_sync_operations_run_id"), table_name="sync_operations") + op.drop_table("sync_operations") + + # Drop the enum type that was created (only if no other tables are using it) + sa.Enum(name="syncstatus").drop(op.get_bind(), checkfirst=True) + else: + print("sync_operations table doesn't exist, skipping downgrade") + + # ### end Alembic commands ### diff --git a/alembic/versions/8057ae7329c2_initial_migration.py b/alembic/versions/8057ae7329c2_initial_migration.py index 48e795327..aa0ecd4b8 100644 --- a/alembic/versions/8057ae7329c2_initial_migration.py +++ b/alembic/versions/8057ae7329c2_initial_migration.py @@ -19,6 +19,7 @@ depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: db_engine = get_relational_engine() + # we might want to delete this await_only(db_engine.create_database()) diff --git a/cognee/api/v1/sync/__init__.py b/cognee/api/v1/sync/__init__.py index 7f60bc948..64bd4f2a5 100644 --- a/cognee/api/v1/sync/__init__.py +++ b/cognee/api/v1/sync/__init__.py @@ -3,7 +3,7 @@ from .sync import ( SyncResponse, LocalFileInfo, CheckMissingHashesRequest, - CheckMissingHashesResponse, + CheckHashesDiffResponse, PruneDatasetRequest, ) @@ -12,6 +12,6 @@ __all__ = [ "SyncResponse", "LocalFileInfo", "CheckMissingHashesRequest", - "CheckMissingHashesResponse", + "CheckHashesDiffResponse", "PruneDatasetRequest", ] diff --git a/cognee/api/v1/sync/routers/get_sync_router.py b/cognee/api/v1/sync/routers/get_sync_router.py index 735fd13bd..d74ae4e7d 100644 --- a/cognee/api/v1/sync/routers/get_sync_router.py +++ b/cognee/api/v1/sync/routers/get_sync_router.py @@ -1,5 +1,5 @@ from uuid import UUID -from typing import Optional +from typing import Optional, List from fastapi import APIRouter, Depends from fastapi.responses import JSONResponse @@ -8,6 +8,7 @@ from cognee.api.DTO import InDTO from cognee.modules.users.models import User from cognee.modules.users.methods import get_authenticated_user from cognee.modules.users.permissions.methods import get_specific_user_permission_datasets +from cognee.modules.sync.methods import get_running_sync_operations_for_user, get_sync_operation from cognee.shared.utils import send_telemetry from cognee.shared.logging_utils import get_logger from cognee.api.v1.sync import SyncResponse @@ -19,7 +20,7 @@ logger = get_logger() class SyncRequest(InDTO): """Request model for sync operations.""" - dataset_id: Optional[UUID] = None + dataset_ids: Optional[List[UUID]] = None def get_sync_router() -> APIRouter: @@ -40,7 +41,7 @@ def get_sync_router() -> APIRouter: ## Request Body (JSON) ```json { - "dataset_id": "123e4567-e89b-12d3-a456-426614174000" + "dataset_ids": ["123e4567-e89b-12d3-a456-426614174000", "456e7890-e12b-34c5-d678-901234567000"] } ``` @@ -48,8 +49,8 @@ def get_sync_router() -> APIRouter: Returns immediate response for the sync operation: - **run_id**: Unique identifier for tracking the background sync operation - **status**: Always "started" (operation runs in background) - - **dataset_id**: ID of the dataset being synced - - **dataset_name**: Name of the dataset being synced + - **dataset_ids**: List of dataset IDs being synced + - **dataset_names**: List of dataset names being synced - **message**: Description of the background operation - **timestamp**: When the sync was initiated - **user_id**: User who initiated the sync @@ -64,15 +65,21 @@ def get_sync_router() -> APIRouter: ## Example Usage ```bash - # Sync dataset to cloud by ID (JSON request) + # Sync multiple datasets to cloud by IDs (JSON request) curl -X POST "http://localhost:8000/api/v1/sync" \\ -H "Content-Type: application/json" \\ -H "Cookie: auth_token=your-token" \\ - -d '{"dataset_id": "123e4567-e89b-12d3-a456-426614174000"}' + -d '{"dataset_ids": ["123e4567-e89b-12d3-a456-426614174000", "456e7890-e12b-34c5-d678-901234567000"]}' + + # Sync all user datasets (empty request body or null dataset_ids) + curl -X POST "http://localhost:8000/api/v1/sync" \\ + -H "Content-Type: application/json" \\ + -H "Cookie: auth_token=your-token" \\ + -d '{}' ``` ## Error Codes - - **400 Bad Request**: Invalid dataset_id format + - **400 Bad Request**: Invalid dataset_ids format - **401 Unauthorized**: Invalid or missing authentication - **403 Forbidden**: User doesn't have permission to access dataset - **404 Not Found**: Dataset not found @@ -92,32 +99,48 @@ def get_sync_router() -> APIRouter: user.id, additional_properties={ "endpoint": "POST /v1/sync", - "dataset_id": str(request.dataset_id) if request.dataset_id else "*", + "dataset_ids": [str(id) for id in request.dataset_ids] + if request.dataset_ids + else "*", }, ) from cognee.api.v1.sync import sync as cognee_sync try: - # Retrieve existing dataset and check permissions - datasets = await get_specific_user_permission_datasets( - user.id, "write", [request.dataset_id] if request.dataset_id else None - ) - - sync_results = {} - - for dataset in datasets: - await set_database_global_context_variables(dataset.id, dataset.owner_id) - - # Execute cloud sync operation - sync_result = await cognee_sync( - dataset=dataset, - user=user, + # Check if user has any running sync operations + running_syncs = await get_running_sync_operations_for_user(user.id) + if running_syncs: + # Return information about the existing sync operation + existing_sync = running_syncs[0] # Get the most recent running sync + return JSONResponse( + status_code=409, + content={ + "error": "Sync operation already in progress", + "details": { + "run_id": existing_sync.run_id, + "status": "already_running", + "dataset_ids": existing_sync.dataset_ids, + "dataset_names": existing_sync.dataset_names, + "message": f"You have a sync operation already in progress with run_id '{existing_sync.run_id}'. Use the status endpoint to monitor progress, or wait for it to complete before starting a new sync.", + "timestamp": existing_sync.created_at.isoformat(), + "progress_percentage": existing_sync.progress_percentage, + }, + }, ) - sync_results[str(dataset.id)] = sync_result + # Retrieve existing dataset and check permissions + datasets = await get_specific_user_permission_datasets( + user.id, "write", request.dataset_ids if request.dataset_ids else None + ) - return sync_results + # Execute new cloud sync operation for all datasets + sync_result = await cognee_sync( + datasets=datasets, + user=user, + ) + + return sync_result except ValueError as e: return JSONResponse(status_code=400, content={"error": str(e)}) @@ -131,4 +154,88 @@ def get_sync_router() -> APIRouter: logger.error(f"Cloud sync operation failed: {str(e)}") return JSONResponse(status_code=409, content={"error": "Cloud sync operation failed."}) + @router.get("/status") + async def get_sync_status_overview( + user: User = Depends(get_authenticated_user), + ): + """ + Check if there are any running sync operations for the current user. + + This endpoint provides a simple check to see if the user has any active sync operations + without needing to know specific run IDs. + + ## Response + Returns a simple status overview: + - **has_running_sync**: Boolean indicating if there are any running syncs + - **running_sync_count**: Number of currently running sync operations + - **latest_running_sync** (optional): Information about the most recent running sync if any exists + + ## Example Usage + ```bash + curl -X GET "http://localhost:8000/api/v1/sync/status" \\ + -H "Cookie: auth_token=your-token" + ``` + + ## Example Responses + + **No running syncs:** + ```json + { + "has_running_sync": false, + "running_sync_count": 0 + } + ``` + + **With running sync:** + ```json + { + "has_running_sync": true, + "running_sync_count": 1, + "latest_running_sync": { + "run_id": "12345678-1234-5678-9012-123456789012", + "dataset_name": "My Dataset", + "progress_percentage": 45, + "created_at": "2025-01-01T00:00:00Z" + } + } + ``` + """ + send_telemetry( + "Sync Status Overview API Endpoint Invoked", + user.id, + additional_properties={ + "endpoint": "GET /v1/sync/status", + }, + ) + + try: + # Get any running sync operations for the user + running_syncs = await get_running_sync_operations_for_user(user.id) + + response = { + "has_running_sync": len(running_syncs) > 0, + "running_sync_count": len(running_syncs), + } + + # If there are running syncs, include info about the latest one + if running_syncs: + latest_sync = running_syncs[0] # Already ordered by created_at desc + response["latest_running_sync"] = { + "run_id": latest_sync.run_id, + "dataset_ids": latest_sync.dataset_ids, + "dataset_names": latest_sync.dataset_names, + "progress_percentage": latest_sync.progress_percentage, + "created_at": latest_sync.created_at.isoformat() + if latest_sync.created_at + else None, + } + + return response + + except Exception as e: + logger.error(f"Failed to get sync status overview: {str(e)}") + return JSONResponse( + status_code=500, content={"error": "Failed to get sync status overview"} + ) + return router diff --git a/cognee/api/v1/sync/sync.py b/cognee/api/v1/sync/sync.py index 4c6b58eac..54339c1c4 100644 --- a/cognee/api/v1/sync/sync.py +++ b/cognee/api/v1/sync/sync.py @@ -1,3 +1,4 @@ +import io import os import uuid import asyncio @@ -5,8 +6,12 @@ import aiohttp from pydantic import BaseModel from typing import List, Optional from datetime import datetime, timezone +from dataclasses import dataclass + +from cognee.api.v1.cognify import cognify from cognee.infrastructure.files.storage import get_file_storage +from cognee.tasks.ingestion.ingest_data import ingest_data from cognee.shared.logging_utils import get_logger from cognee.modules.users.models import User from cognee.modules.data.models import Dataset @@ -19,7 +24,28 @@ from cognee.modules.sync.methods import ( mark_sync_failed, ) -logger = get_logger() +logger = get_logger("sync") + + +async def _safe_update_progress(run_id: str, stage: str, **kwargs): + """ + Safely update sync progress with better error handling and context. + + Args: + run_id: Sync operation run ID + progress_percentage: Progress percentage (0-100) + stage: Description of current stage for logging + **kwargs: Additional fields to update (records_downloaded, records_uploaded, etc.) + """ + try: + await update_sync_operation(run_id, **kwargs) + logger.info(f"Sync {run_id}: Progress updated during {stage}") + except Exception as e: + # Log error but don't fail the sync - progress updates are nice-to-have + logger.warning( + f"Sync {run_id}: Non-critical progress update failed during {stage}: {str(e)}" + ) + # Continue without raising - sync operation is more important than progress tracking class LocalFileInfo(BaseModel): @@ -38,13 +64,16 @@ class LocalFileInfo(BaseModel): class CheckMissingHashesRequest(BaseModel): """Request model for checking missing hashes in a dataset""" + dataset_id: str + dataset_name: str hashes: List[str] -class CheckMissingHashesResponse(BaseModel): +class CheckHashesDiffResponse(BaseModel): """Response model for missing hashes check""" - missing: List[str] + missing_on_remote: List[str] + missing_on_local: List[str] class PruneDatasetRequest(BaseModel): @@ -58,34 +87,34 @@ class SyncResponse(BaseModel): run_id: str status: str # "started" for immediate response - dataset_id: str - dataset_name: str + dataset_ids: List[str] + dataset_names: List[str] message: str timestamp: str user_id: str async def sync( - dataset: Dataset, + datasets: List[Dataset], user: User, ) -> SyncResponse: """ Sync local Cognee data to Cognee Cloud. - This function handles synchronization of local datasets, knowledge graphs, and + This function handles synchronization of multiple datasets, knowledge graphs, and processed data to the Cognee Cloud infrastructure. It uploads local data for cloud-based processing, backup, and sharing. Args: - dataset: Dataset object to sync (permissions already verified) + datasets: List of Dataset objects to sync (permissions already verified) user: User object for authentication and permissions Returns: SyncResponse model with immediate response: - run_id: Unique identifier for tracking this sync operation - status: Always "started" (sync runs in background) - - dataset_id: ID of the dataset being synced - - dataset_name: Name of the dataset being synced + - dataset_ids: List of dataset IDs being synced + - dataset_names: List of dataset names being synced - message: Description of what's happening - timestamp: When the sync was initiated - user_id: User who initiated the sync @@ -94,8 +123,8 @@ async def sync( ConnectionError: If Cognee Cloud service is unreachable Exception: For other sync-related errors """ - if not dataset: - raise ValueError("Dataset must be provided for sync operation") + if not datasets: + raise ValueError("At least one dataset must be provided for sync operation") # Generate a unique run ID run_id = str(uuid.uuid4()) @@ -103,12 +132,16 @@ async def sync( # Get current timestamp timestamp = datetime.now(timezone.utc).isoformat() - logger.info(f"Starting cloud sync operation {run_id}: dataset {dataset.name} ({dataset.id})") + dataset_info = ", ".join([f"{d.name} ({d.id})" for d in datasets]) + logger.info(f"Starting cloud sync operation {run_id}: datasets {dataset_info}") # Create sync operation record in database (total_records will be set during background sync) try: await create_sync_operation( - run_id=run_id, dataset_id=dataset.id, dataset_name=dataset.name, user_id=user.id + run_id=run_id, + dataset_ids=[d.id for d in datasets], + dataset_names=[d.name for d in datasets], + user_id=user.id, ) logger.info(f"Created sync operation record for {run_id}") except Exception as e: @@ -116,44 +149,74 @@ async def sync( # Continue without database tracking if record creation fails # Start the sync operation in the background - asyncio.create_task(_perform_background_sync(run_id, dataset, user)) + asyncio.create_task(_perform_background_sync(run_id, datasets, user)) # Return immediately with run_id return SyncResponse( run_id=run_id, status="started", - dataset_id=str(dataset.id), - dataset_name=dataset.name, - message=f"Sync operation started in background. Use run_id '{run_id}' to track progress.", + dataset_ids=[str(d.id) for d in datasets], + dataset_names=[d.name for d in datasets], + message=f"Sync operation started in background for {len(datasets)} datasets. Use run_id '{run_id}' to track progress.", timestamp=timestamp, user_id=str(user.id), ) -async def _perform_background_sync(run_id: str, dataset: Dataset, user: User) -> None: - """Perform the actual sync operation in the background.""" +async def _perform_background_sync(run_id: str, datasets: List[Dataset], user: User) -> None: + """Perform the actual sync operation in the background for multiple datasets.""" start_time = datetime.now(timezone.utc) try: - logger.info( - f"Background sync {run_id}: Starting sync for dataset {dataset.name} ({dataset.id})" - ) + dataset_info = ", ".join([f"{d.name} ({d.id})" for d in datasets]) + logger.info(f"Background sync {run_id}: Starting sync for datasets {dataset_info}") # Mark sync as in progress await mark_sync_started(run_id) # Perform the actual sync operation - records_processed, bytes_transferred = await _sync_to_cognee_cloud(dataset, user, run_id) + MAX_RETRY_COUNT = 3 + retry_count = 0 + while retry_count < MAX_RETRY_COUNT: + try: + ( + records_downloaded, + records_uploaded, + bytes_downloaded, + bytes_uploaded, + dataset_sync_hashes, + ) = await _sync_to_cognee_cloud(datasets, user, run_id) + break + except Exception as e: + retry_count += 1 + logger.error( + f"Background sync {run_id}: Failed after {retry_count} retries with error: {str(e)}" + ) + await update_sync_operation(run_id, retry_count=retry_count) + await asyncio.sleep(2**retry_count) + continue + + if retry_count == MAX_RETRY_COUNT: + logger.error(f"Background sync {run_id}: Failed after {MAX_RETRY_COUNT} retries") + await mark_sync_failed(run_id, "Failed after 3 retries") + return end_time = datetime.now(timezone.utc) duration = (end_time - start_time).total_seconds() logger.info( - f"Background sync {run_id}: Completed successfully. Records: {records_processed}, Bytes: {bytes_transferred}, Duration: {duration}s" + f"Background sync {run_id}: Completed successfully. Downloaded: {records_downloaded} records/{bytes_downloaded} bytes, Uploaded: {records_uploaded} records/{bytes_uploaded} bytes, Duration: {duration}s" ) - # Mark sync as completed with final stats - await mark_sync_completed(run_id, records_processed, bytes_transferred) + # Mark sync as completed with final stats and data lineage + await mark_sync_completed( + run_id, + records_downloaded, + records_uploaded, + bytes_downloaded, + bytes_uploaded, + dataset_sync_hashes, + ) except Exception as e: end_time = datetime.now(timezone.utc) @@ -165,89 +228,248 @@ async def _perform_background_sync(run_id: str, dataset: Dataset, user: User) -> await mark_sync_failed(run_id, str(e)) -async def _sync_to_cognee_cloud(dataset: Dataset, user: User, run_id: str) -> tuple[int, int]: +async def _sync_to_cognee_cloud( + datasets: List[Dataset], user: User, run_id: str +) -> tuple[int, int, int, int, dict]: """ Sync local data to Cognee Cloud using three-step idempotent process: 1. Extract local files with stored MD5 hashes and check what's missing on cloud 2. Upload missing files individually 3. Prune cloud dataset to match local state """ - logger.info(f"Starting sync to Cognee Cloud: dataset {dataset.name} ({dataset.id})") + dataset_info = ", ".join([f"{d.name} ({d.id})" for d in datasets]) + logger.info(f"Starting sync to Cognee Cloud: datasets {dataset_info}") + + total_records_downloaded = 0 + total_records_uploaded = 0 + total_bytes_downloaded = 0 + total_bytes_uploaded = 0 + dataset_sync_hashes = {} try: # Get cloud configuration cloud_base_url = await _get_cloud_base_url() cloud_auth_token = await _get_cloud_auth_token(user) - logger.info(f"Cloud API URL: {cloud_base_url}") + # Step 1: Sync files for all datasets concurrently + sync_files_tasks = [ + _sync_dataset_files(dataset, cloud_base_url, cloud_auth_token, user, run_id) + for dataset in datasets + ] - # Step 1: Extract local file info with stored hashes - local_files = await _extract_local_files_with_hashes(dataset, user, run_id) - logger.info(f"Found {len(local_files)} local files to sync") + logger.info(f"Starting concurrent file sync for {len(datasets)} datasets") - # Update sync operation with total file count - try: - await update_sync_operation(run_id, processed_records=0) - except Exception as e: - logger.warning(f"Failed to initialize sync progress: {str(e)}") + has_any_uploads = False + has_any_downloads = False + processed_datasets = [] + completed_datasets = 0 - if not local_files: - logger.info("No files to sync - dataset is empty") - return 0, 0 + # Process datasets concurrently and accumulate results + for completed_task in asyncio.as_completed(sync_files_tasks): + try: + dataset_result = await completed_task + completed_datasets += 1 - # Step 2: Check what files are missing on cloud - local_hashes = [f.content_hash for f in local_files] - missing_hashes = await _check_missing_hashes( - cloud_base_url, cloud_auth_token, dataset.id, local_hashes, run_id - ) - logger.info(f"Cloud is missing {len(missing_hashes)} out of {len(local_hashes)} files") + # Update progress based on completed datasets (0-80% for file sync) + file_sync_progress = int((completed_datasets / len(datasets)) * 80) + await _safe_update_progress( + run_id, "file_sync", progress_percentage=file_sync_progress + ) - # Update progress - try: - await update_sync_operation(run_id, progress_percentage=25) - except Exception as e: - logger.warning(f"Failed to update progress: {str(e)}") + if dataset_result is None: + logger.info( + f"Progress: {completed_datasets}/{len(datasets)} datasets processed ({file_sync_progress}%)" + ) + continue - # Step 3: Upload missing files - bytes_uploaded = await _upload_missing_files( - cloud_base_url, cloud_auth_token, dataset, local_files, missing_hashes, run_id - ) - logger.info(f"Upload complete: {len(missing_hashes)} files, {bytes_uploaded} bytes") + total_records_downloaded += dataset_result.records_downloaded + total_records_uploaded += dataset_result.records_uploaded + total_bytes_downloaded += dataset_result.bytes_downloaded + total_bytes_uploaded += dataset_result.bytes_uploaded - # Update progress - try: - await update_sync_operation(run_id, progress_percentage=75) - except Exception as e: - logger.warning(f"Failed to update progress: {str(e)}") + # Build per-dataset hash tracking for data lineage + dataset_sync_hashes[dataset_result.dataset_id] = { + "uploaded": dataset_result.uploaded_hashes, + "downloaded": dataset_result.downloaded_hashes, + } - # Step 4: Trigger cognify processing on cloud dataset (only if new files were uploaded) - if missing_hashes: - await _trigger_remote_cognify(cloud_base_url, cloud_auth_token, dataset.id, run_id) - logger.info(f"Cognify processing triggered for dataset {dataset.id}") + if dataset_result.has_uploads: + has_any_uploads = True + if dataset_result.has_downloads: + has_any_downloads = True + + processed_datasets.append(dataset_result.dataset_id) + + logger.info( + f"Progress: {completed_datasets}/{len(datasets)} datasets processed ({file_sync_progress}%) - " + f"Completed file sync for dataset {dataset_result.dataset_name}: " + f"↑{dataset_result.records_uploaded} files ({dataset_result.bytes_uploaded} bytes), " + f"↓{dataset_result.records_downloaded} files ({dataset_result.bytes_downloaded} bytes)" + ) + except Exception as e: + completed_datasets += 1 + logger.error(f"Dataset file sync failed: {str(e)}") + # Update progress even for failed datasets + file_sync_progress = int((completed_datasets / len(datasets)) * 80) + await _safe_update_progress( + run_id, "file_sync", progress_percentage=file_sync_progress + ) + # Continue with other datasets even if one fails + + # Step 2: Trigger cognify processing once for all datasets (only if any files were uploaded) + # Update progress to 90% before cognify + await _safe_update_progress(run_id, "cognify", progress_percentage=90) + + if has_any_uploads and processed_datasets: + logger.info( + f"Progress: 90% - Triggering cognify processing for {len(processed_datasets)} datasets with new files" + ) + try: + # Trigger cognify for all datasets at once - use first dataset as reference point + await _trigger_remote_cognify( + cloud_base_url, cloud_auth_token, datasets[0].id, run_id + ) + logger.info("Cognify processing triggered successfully for all datasets") + except Exception as e: + logger.warning(f"Failed to trigger cognify processing: {str(e)}") + # Don't fail the entire sync if cognify fails else: logger.info( - f"Skipping cognify processing - no new files were uploaded for dataset {dataset.id}" + "Progress: 90% - Skipping cognify processing - no new files were uploaded across any datasets" ) - # Final progress - try: - await update_sync_operation(run_id, progress_percentage=100) - except Exception as e: - logger.warning(f"Failed to update final progress: {str(e)}") + # Step 3: Trigger local cognify processing if any files were downloaded + if has_any_downloads and processed_datasets: + logger.info( + f"Progress: 95% - Triggering local cognify processing for {len(processed_datasets)} datasets with downloaded files" + ) + try: + await cognify() + logger.info("Local cognify processing completed successfully for all datasets") + except Exception as e: + logger.warning(f"Failed to run local cognify processing: {str(e)}") + # Don't fail the entire sync if local cognify fails + else: + logger.info( + "Progress: 95% - Skipping local cognify processing - no new files were downloaded across any datasets" + ) - records_processed = len(local_files) + # Update final progress + try: + await _safe_update_progress( + run_id, + "final", + progress_percentage=100, + total_records_to_sync=total_records_uploaded + total_records_downloaded, + total_records_to_download=total_records_downloaded, + total_records_to_upload=total_records_uploaded, + records_downloaded=total_records_downloaded, + records_uploaded=total_records_uploaded, + ) + except Exception as e: + logger.warning(f"Failed to update final sync progress: {str(e)}") logger.info( - f"Sync completed successfully: {records_processed} records, {bytes_uploaded} bytes uploaded" + f"Multi-dataset sync completed: {len(datasets)} datasets processed, downloaded {total_records_downloaded} records/{total_bytes_downloaded} bytes, uploaded {total_records_uploaded} records/{total_bytes_uploaded} bytes" ) - return records_processed, bytes_uploaded + return ( + total_records_downloaded, + total_records_uploaded, + total_bytes_downloaded, + total_bytes_uploaded, + dataset_sync_hashes, + ) except Exception as e: logger.error(f"Sync failed: {str(e)}") raise ConnectionError(f"Cloud sync failed: {str(e)}") +@dataclass +class DatasetSyncResult: + """Result of syncing files for a single dataset.""" + + dataset_name: str + dataset_id: str + records_downloaded: int + records_uploaded: int + bytes_downloaded: int + bytes_uploaded: int + has_uploads: bool # Whether any files were uploaded (for cognify decision) + has_downloads: bool # Whether any files were downloaded (for cognify decision) + uploaded_hashes: List[str] # Content hashes of files uploaded during sync + downloaded_hashes: List[str] # Content hashes of files downloaded during sync + + +async def _sync_dataset_files( + dataset: Dataset, cloud_base_url: str, cloud_auth_token: str, user: User, run_id: str +) -> Optional[DatasetSyncResult]: + """ + Sync files for a single dataset (2-way: upload to cloud, download from cloud). + Does NOT trigger cognify - that's done separately once for all datasets. + + Returns: + DatasetSyncResult with sync results or None if dataset was empty + """ + logger.info(f"Syncing files for dataset: {dataset.name} ({dataset.id})") + + try: + # Step 1: Extract local file info with stored hashes + local_files = await _extract_local_files_with_hashes(dataset, user, run_id) + logger.info(f"Found {len(local_files)} local files for dataset {dataset.name}") + + if not local_files: + logger.info(f"No files to sync for dataset {dataset.name} - skipping") + return None + + # Step 2: Check what files are missing on cloud + local_hashes = [f.content_hash for f in local_files] + hashes_diff_response = await _check_hashes_diff( + cloud_base_url, cloud_auth_token, dataset, local_hashes, run_id + ) + + hashes_missing_on_remote = hashes_diff_response.missing_on_remote + hashes_missing_on_local = hashes_diff_response.missing_on_local + + logger.info( + f"Dataset {dataset.name}: {len(hashes_missing_on_remote)} files to upload, {len(hashes_missing_on_local)} files to download" + ) + + # Step 3: Upload files that are missing on cloud + bytes_uploaded = await _upload_missing_files( + cloud_base_url, cloud_auth_token, dataset, local_files, hashes_missing_on_remote, run_id + ) + logger.info( + f"Dataset {dataset.name}: Upload complete - {len(hashes_missing_on_remote)} files, {bytes_uploaded} bytes" + ) + + # Step 4: Download files that are missing locally + bytes_downloaded = await _download_missing_files( + cloud_base_url, cloud_auth_token, dataset, hashes_missing_on_local, user + ) + logger.info( + f"Dataset {dataset.name}: Download complete - {len(hashes_missing_on_local)} files, {bytes_downloaded} bytes" + ) + + return DatasetSyncResult( + dataset_name=dataset.name, + dataset_id=str(dataset.id), + records_downloaded=len(hashes_missing_on_local), + records_uploaded=len(hashes_missing_on_remote), + bytes_downloaded=bytes_downloaded, + bytes_uploaded=bytes_uploaded, + has_uploads=len(hashes_missing_on_remote) > 0, + has_downloads=len(hashes_missing_on_local) > 0, + uploaded_hashes=hashes_missing_on_remote, + downloaded_hashes=hashes_missing_on_local, + ) + + except Exception as e: + logger.error(f"Failed to sync files for dataset {dataset.name} ({dataset.id}): {str(e)}") + raise # Re-raise to be handled by the caller + + async def _extract_local_files_with_hashes( dataset: Dataset, user: User, run_id: str ) -> List[LocalFileInfo]: @@ -334,43 +556,42 @@ async def _get_file_size(file_path: str) -> int: async def _get_cloud_base_url() -> str: """Get Cognee Cloud API base URL.""" - # TODO: Make this configurable via environment variable or config return os.getenv("COGNEE_CLOUD_API_URL", "http://localhost:8001") async def _get_cloud_auth_token(user: User) -> str: """Get authentication token for Cognee Cloud API.""" - # TODO: Implement proper authentication with Cognee Cloud - # This should get or refresh an API token for the user - return os.getenv("COGNEE_CLOUD_AUTH_TOKEN", "your-auth-token-here") + return os.getenv("COGNEE_CLOUD_AUTH_TOKEN", "your-auth-token") -async def _check_missing_hashes( - cloud_base_url: str, auth_token: str, dataset_id: str, local_hashes: List[str], run_id: str -) -> List[str]: +async def _check_hashes_diff( + cloud_base_url: str, auth_token: str, dataset: Dataset, local_hashes: List[str], run_id: str +) -> CheckHashesDiffResponse: """ - Step 1: Check which hashes are missing on cloud. + Check which hashes are missing on cloud. Returns: List[str]: MD5 hashes that need to be uploaded """ - url = f"{cloud_base_url}/api/sync/{dataset_id}/diff" + url = f"{cloud_base_url}/api/sync/{dataset.id}/diff" headers = {"X-Api-Key": auth_token, "Content-Type": "application/json"} - payload = CheckMissingHashesRequest(hashes=local_hashes) + payload = CheckMissingHashesRequest( + dataset_id=str(dataset.id), dataset_name=dataset.name, hashes=local_hashes + ) - logger.info(f"Checking missing hashes on cloud for dataset {dataset_id}") + logger.info(f"Checking missing hashes on cloud for dataset {dataset.id}") try: async with aiohttp.ClientSession() as session: async with session.post(url, json=payload.dict(), headers=headers) as response: if response.status == 200: data = await response.json() - missing_response = CheckMissingHashesResponse(**data) + missing_response = CheckHashesDiffResponse(**data) logger.info( - f"Cloud reports {len(missing_response.missing)} missing files out of {len(local_hashes)} total" + f"Cloud is missing {len(missing_response.missing_on_remote)} out of {len(local_hashes)} files, local is missing {len(missing_response.missing_on_local)} files" ) - return missing_response.missing + return missing_response else: error_text = await response.text() logger.error( @@ -385,22 +606,137 @@ async def _check_missing_hashes( raise ConnectionError(f"Failed to check missing hashes: {str(e)}") +async def _download_missing_files( + cloud_base_url: str, + auth_token: str, + dataset: Dataset, + hashes_missing_on_local: List[str], + user: User, +) -> int: + """ + Download files that are missing locally from the cloud. + + Returns: + int: Total bytes downloaded + """ + logger.info(f"Downloading {len(hashes_missing_on_local)} missing files from cloud") + + if not hashes_missing_on_local: + logger.info("No files need to be downloaded - all files already exist locally") + return 0 + + total_bytes_downloaded = 0 + downloaded_count = 0 + + headers = {"X-Api-Key": auth_token} + + async with aiohttp.ClientSession() as session: + for file_hash in hashes_missing_on_local: + try: + # Download file from cloud by hash + download_url = f"{cloud_base_url}/api/sync/{dataset.id}/data/{file_hash}" + + logger.debug(f"Downloading file with hash: {file_hash}") + + async with session.get(download_url, headers=headers) as response: + if response.status == 200: + file_content = await response.read() + file_size = len(file_content) + + # Get file metadata from response headers + file_name = response.headers.get("X-File-Name", f"file_{file_hash}") + + # Save file locally using ingestion pipeline + await _save_downloaded_file( + dataset, file_hash, file_name, file_content, user + ) + + total_bytes_downloaded += file_size + downloaded_count += 1 + + logger.debug(f"Successfully downloaded {file_name} ({file_size} bytes)") + + elif response.status == 404: + logger.warning(f"File with hash {file_hash} not found on cloud") + continue + else: + error_text = await response.text() + logger.error( + f"Failed to download file {file_hash}: Status {response.status} - {error_text}" + ) + continue + + except Exception as e: + logger.error(f"Error downloading file {file_hash}: {str(e)}") + continue + + logger.info( + f"Download summary: {downloaded_count}/{len(hashes_missing_on_local)} files downloaded, {total_bytes_downloaded} bytes total" + ) + return total_bytes_downloaded + + +class InMemoryDownload: + def __init__(self, data: bytes, filename: str): + self.file = io.BufferedReader(io.BytesIO(data)) + self.filename = filename + + +async def _save_downloaded_file( + dataset: Dataset, + file_hash: str, + file_name: str, + file_content: bytes, + user: User, +) -> None: + """ + Save a downloaded file to local storage and register it in the dataset. + Uses the existing ingest_data function for consistency with normal ingestion. + + Args: + dataset: The dataset to add the file to + file_hash: MD5 hash of the file content + file_name: Original file name + file_content: Raw file content bytes + """ + try: + # Create a temporary file-like object from the bytes + file_obj = InMemoryDownload(file_content, file_name) + + # User is injected as dependency + + # Use the existing ingest_data function to properly handle the file + # This ensures consistency with normal file ingestion + await ingest_data( + data=file_obj, + dataset_name=dataset.name, + user=user, + dataset_id=dataset.id, + ) + + logger.debug(f"Successfully saved downloaded file: {file_name} (hash: {file_hash})") + + except Exception as e: + logger.error(f"Failed to save downloaded file {file_name}: {str(e)}") + raise + + async def _upload_missing_files( cloud_base_url: str, auth_token: str, dataset: Dataset, local_files: List[LocalFileInfo], - missing_hashes: List[str], + hashes_missing_on_remote: List[str], run_id: str, ) -> int: """ - Step 2: Upload files that are missing on cloud. + Upload files that are missing on cloud. Returns: int: Total bytes uploaded """ # Filter local files to only those with missing hashes - files_to_upload = [f for f in local_files if f.content_hash in missing_hashes] + files_to_upload = [f for f in local_files if f.content_hash in hashes_missing_on_remote] logger.info(f"Uploading {len(files_to_upload)} missing files to cloud") @@ -442,13 +778,6 @@ async def _upload_missing_files( if response.status in [200, 201]: total_bytes_uploaded += len(file_content) uploaded_count += 1 - - # Update progress periodically - if uploaded_count % 10 == 0: - progress = ( - 25 + (uploaded_count / len(files_to_upload)) * 50 - ) # 25-75% range - await update_sync_operation(run_id, progress_percentage=int(progress)) else: error_text = await response.text() logger.error( @@ -470,7 +799,7 @@ async def _prune_cloud_dataset( cloud_base_url: str, auth_token: str, dataset_id: str, local_hashes: List[str], run_id: str ) -> None: """ - Step 3: Prune cloud dataset to match local state. + Prune cloud dataset to match local state. """ url = f"{cloud_base_url}/api/sync/{dataset_id}?prune=true" headers = {"X-Api-Key": auth_token, "Content-Type": "application/json"} @@ -506,7 +835,7 @@ async def _trigger_remote_cognify( cloud_base_url: str, auth_token: str, dataset_id: str, run_id: str ) -> None: """ - Step 4: Trigger cognify processing on the cloud dataset. + Trigger cognify processing on the cloud dataset. This initiates knowledge graph processing on the synchronized dataset using the cloud infrastructure. diff --git a/cognee/modules/sync/methods/__init__.py b/cognee/modules/sync/methods/__init__.py index 0fb1c48e1..07f91de00 100644 --- a/cognee/modules/sync/methods/__init__.py +++ b/cognee/modules/sync/methods/__init__.py @@ -1,5 +1,9 @@ from .create_sync_operation import create_sync_operation -from .get_sync_operation import get_sync_operation, get_user_sync_operations +from .get_sync_operation import ( + get_sync_operation, + get_user_sync_operations, + get_running_sync_operations_for_user, +) from .update_sync_operation import ( update_sync_operation, mark_sync_started, @@ -11,6 +15,7 @@ __all__ = [ "create_sync_operation", "get_sync_operation", "get_user_sync_operations", + "get_running_sync_operations_for_user", "update_sync_operation", "mark_sync_started", "mark_sync_completed", diff --git a/cognee/modules/sync/methods/create_sync_operation.py b/cognee/modules/sync/methods/create_sync_operation.py index 96f53faa3..42aa46a75 100644 --- a/cognee/modules/sync/methods/create_sync_operation.py +++ b/cognee/modules/sync/methods/create_sync_operation.py @@ -1,5 +1,5 @@ from uuid import UUID -from typing import Optional +from typing import Optional, List from datetime import datetime, timezone from cognee.modules.sync.models import SyncOperation, SyncStatus from cognee.infrastructure.databases.relational import get_relational_engine @@ -7,20 +7,24 @@ from cognee.infrastructure.databases.relational import get_relational_engine async def create_sync_operation( run_id: str, - dataset_id: UUID, - dataset_name: str, + dataset_ids: List[UUID], + dataset_names: List[str], user_id: UUID, - total_records: Optional[int] = None, + total_records_to_sync: Optional[int] = None, + total_records_to_download: Optional[int] = None, + total_records_to_upload: Optional[int] = None, ) -> SyncOperation: """ Create a new sync operation record in the database. Args: run_id: Unique public identifier for this sync operation - dataset_id: UUID of the dataset being synced - dataset_name: Name of the dataset being synced + dataset_ids: List of dataset UUIDs being synced + dataset_names: List of dataset names being synced user_id: UUID of the user who initiated the sync - total_records: Total number of records to sync (if known) + total_records_to_sync: Total number of records to sync (if known) + total_records_to_download: Total number of records to download (if known) + total_records_to_upload: Total number of records to upload (if known) Returns: SyncOperation: The created sync operation record @@ -29,11 +33,15 @@ async def create_sync_operation( sync_operation = SyncOperation( run_id=run_id, - dataset_id=dataset_id, - dataset_name=dataset_name, + dataset_ids=[ + str(uuid) for uuid in dataset_ids + ], # Convert UUIDs to strings for JSON storage + dataset_names=dataset_names, user_id=user_id, status=SyncStatus.STARTED, - total_records=total_records, + total_records_to_sync=total_records_to_sync, + total_records_to_download=total_records_to_download, + total_records_to_upload=total_records_to_upload, created_at=datetime.now(timezone.utc), ) diff --git a/cognee/modules/sync/methods/get_sync_operation.py b/cognee/modules/sync/methods/get_sync_operation.py index f3c466a3f..992dd91c7 100644 --- a/cognee/modules/sync/methods/get_sync_operation.py +++ b/cognee/modules/sync/methods/get_sync_operation.py @@ -1,7 +1,7 @@ from uuid import UUID from typing import List, Optional -from sqlalchemy import select, desc -from cognee.modules.sync.models import SyncOperation +from sqlalchemy import select, desc, and_ +from cognee.modules.sync.models import SyncOperation, SyncStatus from cognee.infrastructure.databases.relational import get_relational_engine @@ -77,3 +77,31 @@ async def get_sync_operations_by_dataset( ) result = await session.execute(query) return list(result.scalars().all()) + + +async def get_running_sync_operations_for_user(user_id: UUID) -> List[SyncOperation]: + """ + Get all currently running sync operations for a specific user. + Checks for operations with STARTED or IN_PROGRESS status. + + Args: + user_id: UUID of the user + + Returns: + List[SyncOperation]: List of running sync operations for the user + """ + db_engine = get_relational_engine() + + async with db_engine.get_async_session() as session: + query = ( + select(SyncOperation) + .where( + and_( + SyncOperation.user_id == user_id, + SyncOperation.status.in_([SyncStatus.STARTED, SyncStatus.IN_PROGRESS]), + ) + ) + .order_by(desc(SyncOperation.created_at)) + ) + result = await session.execute(query) + return list(result.scalars().all()) diff --git a/cognee/modules/sync/methods/update_sync_operation.py b/cognee/modules/sync/methods/update_sync_operation.py index 04ad0c786..5530dbca0 100644 --- a/cognee/modules/sync/methods/update_sync_operation.py +++ b/cognee/modules/sync/methods/update_sync_operation.py @@ -1,16 +1,74 @@ -from typing import Optional +import asyncio +from typing import Optional, List from datetime import datetime, timezone from sqlalchemy import select +from sqlalchemy.exc import SQLAlchemyError, DisconnectionError, OperationalError, TimeoutError from cognee.modules.sync.models import SyncOperation, SyncStatus from cognee.infrastructure.databases.relational import get_relational_engine +from cognee.shared.logging_utils import get_logger +from cognee.infrastructure.utils.calculate_backoff import calculate_backoff + +logger = get_logger("sync.db_operations") + + +async def _retry_db_operation(operation_func, run_id: str, max_retries: int = 3): + """ + Retry database operations with exponential backoff for transient failures. + + Args: + operation_func: Async function to retry + run_id: Run ID for logging context + max_retries: Maximum number of retry attempts + + Returns: + Result of the operation function + + Raises: + Exception: Re-raises the last exception if all retries fail + """ + attempt = 0 + last_exception = None + + while attempt < max_retries: + try: + return await operation_func() + except (DisconnectionError, OperationalError, TimeoutError) as e: + attempt += 1 + last_exception = e + + if attempt >= max_retries: + logger.error( + f"Database operation failed after {max_retries} attempts for run_id {run_id}: {str(e)}" + ) + break + + backoff_time = calculate_backoff(attempt - 1) # calculate_backoff is 0-indexed + logger.warning( + f"Database operation failed for run_id {run_id}, retrying in {backoff_time:.2f}s (attempt {attempt}/{max_retries}): {str(e)}" + ) + await asyncio.sleep(backoff_time) + + except Exception as e: + # Non-transient errors should not be retried + logger.error(f"Non-retryable database error for run_id {run_id}: {str(e)}") + raise + + # If we get here, all retries failed + raise last_exception async def update_sync_operation( run_id: str, status: Optional[SyncStatus] = None, progress_percentage: Optional[int] = None, - processed_records: Optional[int] = None, - bytes_transferred: Optional[int] = None, + records_downloaded: Optional[int] = None, + total_records_to_sync: Optional[int] = None, + total_records_to_download: Optional[int] = None, + total_records_to_upload: Optional[int] = None, + records_uploaded: Optional[int] = None, + bytes_downloaded: Optional[int] = None, + bytes_uploaded: Optional[int] = None, + dataset_sync_hashes: Optional[dict] = None, error_message: Optional[str] = None, retry_count: Optional[int] = None, started_at: Optional[datetime] = None, @@ -23,8 +81,14 @@ async def update_sync_operation( run_id: The public run_id of the sync operation to update status: New status for the operation progress_percentage: Progress percentage (0-100) - processed_records: Number of records processed so far - bytes_transferred: Total bytes transferred + records_downloaded: Number of records downloaded so far + total_records_to_sync: Total number of records that need to be synced + total_records_to_download: Total number of records to download from cloud + total_records_to_upload: Total number of records to upload to cloud + records_uploaded: Number of records uploaded so far + bytes_downloaded: Total bytes downloaded from cloud + bytes_uploaded: Total bytes uploaded to cloud + dataset_sync_hashes: Dict mapping dataset_id -> {uploaded: [hashes], downloaded: [hashes]} error_message: Error message if operation failed retry_count: Number of retry attempts started_at: When the actual processing started @@ -33,57 +97,116 @@ async def update_sync_operation( Returns: SyncOperation: The updated sync operation record, or None if not found """ - db_engine = get_relational_engine() - async with db_engine.get_async_session() as session: - # Find the sync operation - query = select(SyncOperation).where(SyncOperation.run_id == run_id) - result = await session.execute(query) - sync_operation = result.scalars().first() + async def _perform_update(): + db_engine = get_relational_engine() - if not sync_operation: - return None + async with db_engine.get_async_session() as session: + try: + # Find the sync operation + query = select(SyncOperation).where(SyncOperation.run_id == run_id) + result = await session.execute(query) + sync_operation = result.scalars().first() - # Update fields that were provided - if status is not None: - sync_operation.status = status + if not sync_operation: + logger.warning(f"Sync operation not found for run_id: {run_id}") + return None - if progress_percentage is not None: - sync_operation.progress_percentage = max(0, min(100, progress_percentage)) + # Log what we're updating for debugging + updates = [] + if status is not None: + updates.append(f"status={status.value}") + if progress_percentage is not None: + updates.append(f"progress={progress_percentage}%") + if records_downloaded is not None: + updates.append(f"downloaded={records_downloaded}") + if records_uploaded is not None: + updates.append(f"uploaded={records_uploaded}") + if total_records_to_sync is not None: + updates.append(f"total_sync={total_records_to_sync}") + if total_records_to_download is not None: + updates.append(f"total_download={total_records_to_download}") + if total_records_to_upload is not None: + updates.append(f"total_upload={total_records_to_upload}") - if processed_records is not None: - sync_operation.processed_records = processed_records + if updates: + logger.debug(f"Updating sync operation {run_id}: {', '.join(updates)}") - if bytes_transferred is not None: - sync_operation.bytes_transferred = bytes_transferred + # Update fields that were provided + if status is not None: + sync_operation.status = status - if error_message is not None: - sync_operation.error_message = error_message + if progress_percentage is not None: + sync_operation.progress_percentage = max(0, min(100, progress_percentage)) - if retry_count is not None: - sync_operation.retry_count = retry_count + if records_downloaded is not None: + sync_operation.records_downloaded = records_downloaded - if started_at is not None: - sync_operation.started_at = started_at + if records_uploaded is not None: + sync_operation.records_uploaded = records_uploaded - if completed_at is not None: - sync_operation.completed_at = completed_at + if total_records_to_sync is not None: + sync_operation.total_records_to_sync = total_records_to_sync - # Auto-set completion timestamp for terminal statuses - if ( - status in [SyncStatus.COMPLETED, SyncStatus.FAILED, SyncStatus.CANCELLED] - and completed_at is None - ): - sync_operation.completed_at = datetime.now(timezone.utc) + if total_records_to_download is not None: + sync_operation.total_records_to_download = total_records_to_download - # Auto-set started timestamp when moving to IN_PROGRESS - if status == SyncStatus.IN_PROGRESS and sync_operation.started_at is None: - sync_operation.started_at = datetime.now(timezone.utc) + if total_records_to_upload is not None: + sync_operation.total_records_to_upload = total_records_to_upload - await session.commit() - await session.refresh(sync_operation) + if bytes_downloaded is not None: + sync_operation.bytes_downloaded = bytes_downloaded - return sync_operation + if bytes_uploaded is not None: + sync_operation.bytes_uploaded = bytes_uploaded + + if dataset_sync_hashes is not None: + sync_operation.dataset_sync_hashes = dataset_sync_hashes + + if error_message is not None: + sync_operation.error_message = error_message + + if retry_count is not None: + sync_operation.retry_count = retry_count + + if started_at is not None: + sync_operation.started_at = started_at + + if completed_at is not None: + sync_operation.completed_at = completed_at + + # Auto-set completion timestamp for terminal statuses + if ( + status in [SyncStatus.COMPLETED, SyncStatus.FAILED, SyncStatus.CANCELLED] + and completed_at is None + ): + sync_operation.completed_at = datetime.now(timezone.utc) + + # Auto-set started timestamp when moving to IN_PROGRESS + if status == SyncStatus.IN_PROGRESS and sync_operation.started_at is None: + sync_operation.started_at = datetime.now(timezone.utc) + + await session.commit() + await session.refresh(sync_operation) + + logger.debug(f"Successfully updated sync operation {run_id}") + return sync_operation + + except SQLAlchemyError as e: + logger.error( + f"Database error updating sync operation {run_id}: {str(e)}", exc_info=True + ) + await session.rollback() + raise + except Exception as e: + logger.error( + f"Unexpected error updating sync operation {run_id}: {str(e)}", exc_info=True + ) + await session.rollback() + raise + + # Use retry logic for the database operation + return await _retry_db_operation(_perform_update, run_id) async def mark_sync_started(run_id: str) -> Optional[SyncOperation]: @@ -94,15 +217,23 @@ async def mark_sync_started(run_id: str) -> Optional[SyncOperation]: async def mark_sync_completed( - run_id: str, processed_records: int, bytes_transferred: int + run_id: str, + records_downloaded: int = 0, + records_uploaded: int = 0, + bytes_downloaded: int = 0, + bytes_uploaded: int = 0, + dataset_sync_hashes: Optional[dict] = None, ) -> Optional[SyncOperation]: """Convenience method to mark a sync operation as completed successfully.""" return await update_sync_operation( run_id=run_id, status=SyncStatus.COMPLETED, progress_percentage=100, - processed_records=processed_records, - bytes_transferred=bytes_transferred, + records_downloaded=records_downloaded, + records_uploaded=records_uploaded, + bytes_downloaded=bytes_downloaded, + bytes_uploaded=bytes_uploaded, + dataset_sync_hashes=dataset_sync_hashes, completed_at=datetime.now(timezone.utc), ) diff --git a/cognee/modules/sync/models/SyncOperation.py b/cognee/modules/sync/models/SyncOperation.py index ea63bb327..ddf82a077 100644 --- a/cognee/modules/sync/models/SyncOperation.py +++ b/cognee/modules/sync/models/SyncOperation.py @@ -1,8 +1,16 @@ from uuid import uuid4 from enum import Enum -from typing import Optional +from typing import Optional, List from datetime import datetime, timezone -from sqlalchemy import Column, Text, DateTime, UUID as SQLAlchemy_UUID, Integer, Enum as SQLEnum +from sqlalchemy import ( + Column, + Text, + DateTime, + UUID as SQLAlchemy_UUID, + Integer, + Enum as SQLEnum, + JSON, +) from cognee.infrastructure.databases.relational import Base @@ -38,8 +46,8 @@ class SyncOperation(Base): progress_percentage = Column(Integer, default=0, doc="Progress percentage (0-100)") # Operation metadata - dataset_id = Column(SQLAlchemy_UUID, index=True, doc="ID of the dataset being synced") - dataset_name = Column(Text, doc="Name of the dataset being synced") + dataset_ids = Column(JSON, doc="Array of dataset IDs being synced") + dataset_names = Column(JSON, doc="Array of dataset names being synced") user_id = Column(SQLAlchemy_UUID, index=True, doc="ID of the user who initiated the sync") # Timing information @@ -54,18 +62,24 @@ class SyncOperation(Base): ) # Operation details - total_records = Column(Integer, doc="Total number of records to sync") - processed_records = Column(Integer, default=0, doc="Number of records successfully processed") - bytes_transferred = Column(Integer, default=0, doc="Total bytes transferred to cloud") + total_records_to_sync = Column(Integer, doc="Total number of records to sync") + total_records_to_download = Column(Integer, doc="Total number of records to download") + total_records_to_upload = Column(Integer, doc="Total number of records to upload") + + records_downloaded = Column(Integer, default=0, doc="Number of records successfully downloaded") + records_uploaded = Column(Integer, default=0, doc="Number of records successfully uploaded") + bytes_downloaded = Column(Integer, default=0, doc="Total bytes downloaded from cloud") + bytes_uploaded = Column(Integer, default=0, doc="Total bytes uploaded to cloud") + + # Data lineage tracking per dataset + dataset_sync_hashes = Column( + JSON, doc="Mapping of dataset_id -> {uploaded: [hashes], downloaded: [hashes]}" + ) # Error handling error_message = Column(Text, doc="Error message if sync failed") retry_count = Column(Integer, default=0, doc="Number of retry attempts") - # Additional metadata (can be added later when needed) - # cloud_endpoint = Column(Text, doc="Cloud endpoint used for sync") - # compression_enabled = Column(Text, doc="Whether compression was used") - def get_duration_seconds(self) -> Optional[float]: """Get the duration of the sync operation in seconds.""" if not self.created_at: @@ -76,11 +90,53 @@ class SyncOperation(Base): def get_progress_info(self) -> dict: """Get comprehensive progress information.""" + total_records_processed = (self.records_downloaded or 0) + (self.records_uploaded or 0) + total_bytes_transferred = (self.bytes_downloaded or 0) + (self.bytes_uploaded or 0) + return { "status": self.status.value, "progress_percentage": self.progress_percentage, - "records_processed": f"{self.processed_records or 0}/{self.total_records or 'unknown'}", - "bytes_transferred": self.bytes_transferred or 0, + "records_processed": f"{total_records_processed}/{self.total_records_to_sync or 'unknown'}", + "records_downloaded": self.records_downloaded or 0, + "records_uploaded": self.records_uploaded or 0, + "bytes_transferred": total_bytes_transferred, + "bytes_downloaded": self.bytes_downloaded or 0, + "bytes_uploaded": self.bytes_uploaded or 0, "duration_seconds": self.get_duration_seconds(), "error_message": self.error_message, + "dataset_sync_hashes": self.dataset_sync_hashes or {}, } + + def _get_all_sync_hashes(self) -> List[str]: + """Get all content hashes for data created/modified during this sync operation.""" + all_hashes = set() + dataset_hashes = self.dataset_sync_hashes or {} + + for dataset_id, operations in dataset_hashes.items(): + if isinstance(operations, dict): + all_hashes.update(operations.get("uploaded", [])) + all_hashes.update(operations.get("downloaded", [])) + + return list(all_hashes) + + def _get_dataset_sync_hashes(self, dataset_id: str) -> dict: + """Get uploaded/downloaded hashes for a specific dataset.""" + dataset_hashes = self.dataset_sync_hashes or {} + return dataset_hashes.get(dataset_id, {"uploaded": [], "downloaded": []}) + + def was_data_synced(self, content_hash: str, dataset_id: str = None) -> bool: + """ + Check if a specific piece of data was part of this sync operation. + + Args: + content_hash: The content hash to check for + dataset_id: Optional - check only within this dataset + """ + if dataset_id: + dataset_hashes = self.get_dataset_sync_hashes(dataset_id) + return content_hash in dataset_hashes.get( + "uploaded", [] + ) or content_hash in dataset_hashes.get("downloaded", []) + + all_hashes = self.get_all_sync_hashes() + return content_hash in all_hashes