feat: update sync to be two way (#1359)

<!-- .github/pull_request_template.md -->

## Description
<!-- Provide a clear description of the changes in this PR -->

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.
This commit is contained in:
Daulet Amirkhanov 2025-09-11 14:34:43 +01:00 committed by GitHub
parent a29ef8b2d3
commit 47cb34e89c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 970 additions and 199 deletions

View file

@ -137,6 +137,14 @@ REQUIRE_AUTHENTICATION=False
# It enforces LanceDB and KuzuDB use and uses them to create databases per Cognee user + dataset
ENABLE_BACKEND_ACCESS_CONTROL=False
################################################################################
# ☁️ Cloud Sync Settings
################################################################################
# Cognee Cloud API settings for syncing data to/from cloud infrastructure
COGNEE_CLOUD_API_URL="http://localhost:8001"
COGNEE_CLOUD_AUTH_TOKEN="your-auth-token"
################################################################################
# 🛠️ DEV Settings
################################################################################

View file

@ -0,0 +1,98 @@
"""Add sync_operations table
Revision ID: 211ab850ef3d
Revises: 9e7a3cb85175
Create Date: 2025-09-10 20:11:13.534829
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "211ab850ef3d"
down_revision: Union[str, None] = "9e7a3cb85175"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
# Check if table already exists (it might be created by Base.metadata.create_all() in initial migration)
connection = op.get_bind()
inspector = sa.inspect(connection)
if "sync_operations" not in inspector.get_table_names():
# Table doesn't exist, create it normally
op.create_table(
"sync_operations",
sa.Column("id", sa.UUID(), nullable=False),
sa.Column("run_id", sa.Text(), nullable=True),
sa.Column(
"status",
sa.Enum(
"STARTED",
"IN_PROGRESS",
"COMPLETED",
"FAILED",
"CANCELLED",
name="syncstatus",
create_type=False,
),
nullable=True,
),
sa.Column("progress_percentage", sa.Integer(), nullable=True),
sa.Column("dataset_ids", sa.JSON(), nullable=True),
sa.Column("dataset_names", sa.JSON(), nullable=True),
sa.Column("user_id", sa.UUID(), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("started_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("completed_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("total_records_to_sync", sa.Integer(), nullable=True),
sa.Column("total_records_to_download", sa.Integer(), nullable=True),
sa.Column("total_records_to_upload", sa.Integer(), nullable=True),
sa.Column("records_downloaded", sa.Integer(), nullable=True),
sa.Column("records_uploaded", sa.Integer(), nullable=True),
sa.Column("bytes_downloaded", sa.Integer(), nullable=True),
sa.Column("bytes_uploaded", sa.Integer(), nullable=True),
sa.Column("dataset_sync_hashes", sa.JSON(), nullable=True),
sa.Column("error_message", sa.Text(), nullable=True),
sa.Column("retry_count", sa.Integer(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
op.f("ix_sync_operations_run_id"), "sync_operations", ["run_id"], unique=True
)
op.create_index(
op.f("ix_sync_operations_user_id"), "sync_operations", ["user_id"], unique=False
)
else:
# Table already exists, but we might need to add missing columns or indexes
# For now, just log that the table already exists
print("sync_operations table already exists, skipping creation")
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
# Only drop if table exists (might have been created by Base.metadata.create_all())
connection = op.get_bind()
inspector = sa.inspect(connection)
if "sync_operations" in inspector.get_table_names():
op.drop_index(op.f("ix_sync_operations_user_id"), table_name="sync_operations")
op.drop_index(op.f("ix_sync_operations_run_id"), table_name="sync_operations")
op.drop_table("sync_operations")
# Drop the enum type that was created (only if no other tables are using it)
sa.Enum(name="syncstatus").drop(op.get_bind(), checkfirst=True)
else:
print("sync_operations table doesn't exist, skipping downgrade")
# ### end Alembic commands ###

View file

@ -19,6 +19,7 @@ depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
db_engine = get_relational_engine()
# we might want to delete this
await_only(db_engine.create_database())

View file

@ -3,7 +3,7 @@ from .sync import (
SyncResponse,
LocalFileInfo,
CheckMissingHashesRequest,
CheckMissingHashesResponse,
CheckHashesDiffResponse,
PruneDatasetRequest,
)
@ -12,6 +12,6 @@ __all__ = [
"SyncResponse",
"LocalFileInfo",
"CheckMissingHashesRequest",
"CheckMissingHashesResponse",
"CheckHashesDiffResponse",
"PruneDatasetRequest",
]

View file

@ -1,5 +1,5 @@
from uuid import UUID
from typing import Optional
from typing import Optional, List
from fastapi import APIRouter, Depends
from fastapi.responses import JSONResponse
@ -8,6 +8,7 @@ from cognee.api.DTO import InDTO
from cognee.modules.users.models import User
from cognee.modules.users.methods import get_authenticated_user
from cognee.modules.users.permissions.methods import get_specific_user_permission_datasets
from cognee.modules.sync.methods import get_running_sync_operations_for_user, get_sync_operation
from cognee.shared.utils import send_telemetry
from cognee.shared.logging_utils import get_logger
from cognee.api.v1.sync import SyncResponse
@ -19,7 +20,7 @@ logger = get_logger()
class SyncRequest(InDTO):
"""Request model for sync operations."""
dataset_id: Optional[UUID] = None
dataset_ids: Optional[List[UUID]] = None
def get_sync_router() -> APIRouter:
@ -40,7 +41,7 @@ def get_sync_router() -> APIRouter:
## Request Body (JSON)
```json
{
"dataset_id": "123e4567-e89b-12d3-a456-426614174000"
"dataset_ids": ["123e4567-e89b-12d3-a456-426614174000", "456e7890-e12b-34c5-d678-901234567000"]
}
```
@ -48,8 +49,8 @@ def get_sync_router() -> APIRouter:
Returns immediate response for the sync operation:
- **run_id**: Unique identifier for tracking the background sync operation
- **status**: Always "started" (operation runs in background)
- **dataset_id**: ID of the dataset being synced
- **dataset_name**: Name of the dataset being synced
- **dataset_ids**: List of dataset IDs being synced
- **dataset_names**: List of dataset names being synced
- **message**: Description of the background operation
- **timestamp**: When the sync was initiated
- **user_id**: User who initiated the sync
@ -64,15 +65,21 @@ def get_sync_router() -> APIRouter:
## Example Usage
```bash
# Sync dataset to cloud by ID (JSON request)
# Sync multiple datasets to cloud by IDs (JSON request)
curl -X POST "http://localhost:8000/api/v1/sync" \\
-H "Content-Type: application/json" \\
-H "Cookie: auth_token=your-token" \\
-d '{"dataset_id": "123e4567-e89b-12d3-a456-426614174000"}'
-d '{"dataset_ids": ["123e4567-e89b-12d3-a456-426614174000", "456e7890-e12b-34c5-d678-901234567000"]}'
# Sync all user datasets (empty request body or null dataset_ids)
curl -X POST "http://localhost:8000/api/v1/sync" \\
-H "Content-Type: application/json" \\
-H "Cookie: auth_token=your-token" \\
-d '{}'
```
## Error Codes
- **400 Bad Request**: Invalid dataset_id format
- **400 Bad Request**: Invalid dataset_ids format
- **401 Unauthorized**: Invalid or missing authentication
- **403 Forbidden**: User doesn't have permission to access dataset
- **404 Not Found**: Dataset not found
@ -92,32 +99,48 @@ def get_sync_router() -> APIRouter:
user.id,
additional_properties={
"endpoint": "POST /v1/sync",
"dataset_id": str(request.dataset_id) if request.dataset_id else "*",
"dataset_ids": [str(id) for id in request.dataset_ids]
if request.dataset_ids
else "*",
},
)
from cognee.api.v1.sync import sync as cognee_sync
try:
# Retrieve existing dataset and check permissions
datasets = await get_specific_user_permission_datasets(
user.id, "write", [request.dataset_id] if request.dataset_id else None
)
sync_results = {}
for dataset in datasets:
await set_database_global_context_variables(dataset.id, dataset.owner_id)
# Execute cloud sync operation
sync_result = await cognee_sync(
dataset=dataset,
user=user,
# Check if user has any running sync operations
running_syncs = await get_running_sync_operations_for_user(user.id)
if running_syncs:
# Return information about the existing sync operation
existing_sync = running_syncs[0] # Get the most recent running sync
return JSONResponse(
status_code=409,
content={
"error": "Sync operation already in progress",
"details": {
"run_id": existing_sync.run_id,
"status": "already_running",
"dataset_ids": existing_sync.dataset_ids,
"dataset_names": existing_sync.dataset_names,
"message": f"You have a sync operation already in progress with run_id '{existing_sync.run_id}'. Use the status endpoint to monitor progress, or wait for it to complete before starting a new sync.",
"timestamp": existing_sync.created_at.isoformat(),
"progress_percentage": existing_sync.progress_percentage,
},
},
)
sync_results[str(dataset.id)] = sync_result
# Retrieve existing dataset and check permissions
datasets = await get_specific_user_permission_datasets(
user.id, "write", request.dataset_ids if request.dataset_ids else None
)
return sync_results
# Execute new cloud sync operation for all datasets
sync_result = await cognee_sync(
datasets=datasets,
user=user,
)
return sync_result
except ValueError as e:
return JSONResponse(status_code=400, content={"error": str(e)})
@ -131,4 +154,88 @@ def get_sync_router() -> APIRouter:
logger.error(f"Cloud sync operation failed: {str(e)}")
return JSONResponse(status_code=409, content={"error": "Cloud sync operation failed."})
@router.get("/status")
async def get_sync_status_overview(
user: User = Depends(get_authenticated_user),
):
"""
Check if there are any running sync operations for the current user.
This endpoint provides a simple check to see if the user has any active sync operations
without needing to know specific run IDs.
## Response
Returns a simple status overview:
- **has_running_sync**: Boolean indicating if there are any running syncs
- **running_sync_count**: Number of currently running sync operations
- **latest_running_sync** (optional): Information about the most recent running sync if any exists
## Example Usage
```bash
curl -X GET "http://localhost:8000/api/v1/sync/status" \\
-H "Cookie: auth_token=your-token"
```
## Example Responses
**No running syncs:**
```json
{
"has_running_sync": false,
"running_sync_count": 0
}
```
**With running sync:**
```json
{
"has_running_sync": true,
"running_sync_count": 1,
"latest_running_sync": {
"run_id": "12345678-1234-5678-9012-123456789012",
"dataset_name": "My Dataset",
"progress_percentage": 45,
"created_at": "2025-01-01T00:00:00Z"
}
}
```
"""
send_telemetry(
"Sync Status Overview API Endpoint Invoked",
user.id,
additional_properties={
"endpoint": "GET /v1/sync/status",
},
)
try:
# Get any running sync operations for the user
running_syncs = await get_running_sync_operations_for_user(user.id)
response = {
"has_running_sync": len(running_syncs) > 0,
"running_sync_count": len(running_syncs),
}
# If there are running syncs, include info about the latest one
if running_syncs:
latest_sync = running_syncs[0] # Already ordered by created_at desc
response["latest_running_sync"] = {
"run_id": latest_sync.run_id,
"dataset_ids": latest_sync.dataset_ids,
"dataset_names": latest_sync.dataset_names,
"progress_percentage": latest_sync.progress_percentage,
"created_at": latest_sync.created_at.isoformat()
if latest_sync.created_at
else None,
}
return response
except Exception as e:
logger.error(f"Failed to get sync status overview: {str(e)}")
return JSONResponse(
status_code=500, content={"error": "Failed to get sync status overview"}
)
return router

View file

@ -1,3 +1,4 @@
import io
import os
import uuid
import asyncio
@ -5,8 +6,12 @@ import aiohttp
from pydantic import BaseModel
from typing import List, Optional
from datetime import datetime, timezone
from dataclasses import dataclass
from cognee.api.v1.cognify import cognify
from cognee.infrastructure.files.storage import get_file_storage
from cognee.tasks.ingestion.ingest_data import ingest_data
from cognee.shared.logging_utils import get_logger
from cognee.modules.users.models import User
from cognee.modules.data.models import Dataset
@ -19,7 +24,28 @@ from cognee.modules.sync.methods import (
mark_sync_failed,
)
logger = get_logger()
logger = get_logger("sync")
async def _safe_update_progress(run_id: str, stage: str, **kwargs):
"""
Safely update sync progress with better error handling and context.
Args:
run_id: Sync operation run ID
progress_percentage: Progress percentage (0-100)
stage: Description of current stage for logging
**kwargs: Additional fields to update (records_downloaded, records_uploaded, etc.)
"""
try:
await update_sync_operation(run_id, **kwargs)
logger.info(f"Sync {run_id}: Progress updated during {stage}")
except Exception as e:
# Log error but don't fail the sync - progress updates are nice-to-have
logger.warning(
f"Sync {run_id}: Non-critical progress update failed during {stage}: {str(e)}"
)
# Continue without raising - sync operation is more important than progress tracking
class LocalFileInfo(BaseModel):
@ -38,13 +64,16 @@ class LocalFileInfo(BaseModel):
class CheckMissingHashesRequest(BaseModel):
"""Request model for checking missing hashes in a dataset"""
dataset_id: str
dataset_name: str
hashes: List[str]
class CheckMissingHashesResponse(BaseModel):
class CheckHashesDiffResponse(BaseModel):
"""Response model for missing hashes check"""
missing: List[str]
missing_on_remote: List[str]
missing_on_local: List[str]
class PruneDatasetRequest(BaseModel):
@ -58,34 +87,34 @@ class SyncResponse(BaseModel):
run_id: str
status: str # "started" for immediate response
dataset_id: str
dataset_name: str
dataset_ids: List[str]
dataset_names: List[str]
message: str
timestamp: str
user_id: str
async def sync(
dataset: Dataset,
datasets: List[Dataset],
user: User,
) -> SyncResponse:
"""
Sync local Cognee data to Cognee Cloud.
This function handles synchronization of local datasets, knowledge graphs, and
This function handles synchronization of multiple datasets, knowledge graphs, and
processed data to the Cognee Cloud infrastructure. It uploads local data for
cloud-based processing, backup, and sharing.
Args:
dataset: Dataset object to sync (permissions already verified)
datasets: List of Dataset objects to sync (permissions already verified)
user: User object for authentication and permissions
Returns:
SyncResponse model with immediate response:
- run_id: Unique identifier for tracking this sync operation
- status: Always "started" (sync runs in background)
- dataset_id: ID of the dataset being synced
- dataset_name: Name of the dataset being synced
- dataset_ids: List of dataset IDs being synced
- dataset_names: List of dataset names being synced
- message: Description of what's happening
- timestamp: When the sync was initiated
- user_id: User who initiated the sync
@ -94,8 +123,8 @@ async def sync(
ConnectionError: If Cognee Cloud service is unreachable
Exception: For other sync-related errors
"""
if not dataset:
raise ValueError("Dataset must be provided for sync operation")
if not datasets:
raise ValueError("At least one dataset must be provided for sync operation")
# Generate a unique run ID
run_id = str(uuid.uuid4())
@ -103,12 +132,16 @@ async def sync(
# Get current timestamp
timestamp = datetime.now(timezone.utc).isoformat()
logger.info(f"Starting cloud sync operation {run_id}: dataset {dataset.name} ({dataset.id})")
dataset_info = ", ".join([f"{d.name} ({d.id})" for d in datasets])
logger.info(f"Starting cloud sync operation {run_id}: datasets {dataset_info}")
# Create sync operation record in database (total_records will be set during background sync)
try:
await create_sync_operation(
run_id=run_id, dataset_id=dataset.id, dataset_name=dataset.name, user_id=user.id
run_id=run_id,
dataset_ids=[d.id for d in datasets],
dataset_names=[d.name for d in datasets],
user_id=user.id,
)
logger.info(f"Created sync operation record for {run_id}")
except Exception as e:
@ -116,44 +149,74 @@ async def sync(
# Continue without database tracking if record creation fails
# Start the sync operation in the background
asyncio.create_task(_perform_background_sync(run_id, dataset, user))
asyncio.create_task(_perform_background_sync(run_id, datasets, user))
# Return immediately with run_id
return SyncResponse(
run_id=run_id,
status="started",
dataset_id=str(dataset.id),
dataset_name=dataset.name,
message=f"Sync operation started in background. Use run_id '{run_id}' to track progress.",
dataset_ids=[str(d.id) for d in datasets],
dataset_names=[d.name for d in datasets],
message=f"Sync operation started in background for {len(datasets)} datasets. Use run_id '{run_id}' to track progress.",
timestamp=timestamp,
user_id=str(user.id),
)
async def _perform_background_sync(run_id: str, dataset: Dataset, user: User) -> None:
"""Perform the actual sync operation in the background."""
async def _perform_background_sync(run_id: str, datasets: List[Dataset], user: User) -> None:
"""Perform the actual sync operation in the background for multiple datasets."""
start_time = datetime.now(timezone.utc)
try:
logger.info(
f"Background sync {run_id}: Starting sync for dataset {dataset.name} ({dataset.id})"
)
dataset_info = ", ".join([f"{d.name} ({d.id})" for d in datasets])
logger.info(f"Background sync {run_id}: Starting sync for datasets {dataset_info}")
# Mark sync as in progress
await mark_sync_started(run_id)
# Perform the actual sync operation
records_processed, bytes_transferred = await _sync_to_cognee_cloud(dataset, user, run_id)
MAX_RETRY_COUNT = 3
retry_count = 0
while retry_count < MAX_RETRY_COUNT:
try:
(
records_downloaded,
records_uploaded,
bytes_downloaded,
bytes_uploaded,
dataset_sync_hashes,
) = await _sync_to_cognee_cloud(datasets, user, run_id)
break
except Exception as e:
retry_count += 1
logger.error(
f"Background sync {run_id}: Failed after {retry_count} retries with error: {str(e)}"
)
await update_sync_operation(run_id, retry_count=retry_count)
await asyncio.sleep(2**retry_count)
continue
if retry_count == MAX_RETRY_COUNT:
logger.error(f"Background sync {run_id}: Failed after {MAX_RETRY_COUNT} retries")
await mark_sync_failed(run_id, "Failed after 3 retries")
return
end_time = datetime.now(timezone.utc)
duration = (end_time - start_time).total_seconds()
logger.info(
f"Background sync {run_id}: Completed successfully. Records: {records_processed}, Bytes: {bytes_transferred}, Duration: {duration}s"
f"Background sync {run_id}: Completed successfully. Downloaded: {records_downloaded} records/{bytes_downloaded} bytes, Uploaded: {records_uploaded} records/{bytes_uploaded} bytes, Duration: {duration}s"
)
# Mark sync as completed with final stats
await mark_sync_completed(run_id, records_processed, bytes_transferred)
# Mark sync as completed with final stats and data lineage
await mark_sync_completed(
run_id,
records_downloaded,
records_uploaded,
bytes_downloaded,
bytes_uploaded,
dataset_sync_hashes,
)
except Exception as e:
end_time = datetime.now(timezone.utc)
@ -165,89 +228,248 @@ async def _perform_background_sync(run_id: str, dataset: Dataset, user: User) ->
await mark_sync_failed(run_id, str(e))
async def _sync_to_cognee_cloud(dataset: Dataset, user: User, run_id: str) -> tuple[int, int]:
async def _sync_to_cognee_cloud(
datasets: List[Dataset], user: User, run_id: str
) -> tuple[int, int, int, int, dict]:
"""
Sync local data to Cognee Cloud using three-step idempotent process:
1. Extract local files with stored MD5 hashes and check what's missing on cloud
2. Upload missing files individually
3. Prune cloud dataset to match local state
"""
logger.info(f"Starting sync to Cognee Cloud: dataset {dataset.name} ({dataset.id})")
dataset_info = ", ".join([f"{d.name} ({d.id})" for d in datasets])
logger.info(f"Starting sync to Cognee Cloud: datasets {dataset_info}")
total_records_downloaded = 0
total_records_uploaded = 0
total_bytes_downloaded = 0
total_bytes_uploaded = 0
dataset_sync_hashes = {}
try:
# Get cloud configuration
cloud_base_url = await _get_cloud_base_url()
cloud_auth_token = await _get_cloud_auth_token(user)
logger.info(f"Cloud API URL: {cloud_base_url}")
# Step 1: Sync files for all datasets concurrently
sync_files_tasks = [
_sync_dataset_files(dataset, cloud_base_url, cloud_auth_token, user, run_id)
for dataset in datasets
]
# Step 1: Extract local file info with stored hashes
local_files = await _extract_local_files_with_hashes(dataset, user, run_id)
logger.info(f"Found {len(local_files)} local files to sync")
logger.info(f"Starting concurrent file sync for {len(datasets)} datasets")
# Update sync operation with total file count
try:
await update_sync_operation(run_id, processed_records=0)
except Exception as e:
logger.warning(f"Failed to initialize sync progress: {str(e)}")
has_any_uploads = False
has_any_downloads = False
processed_datasets = []
completed_datasets = 0
if not local_files:
logger.info("No files to sync - dataset is empty")
return 0, 0
# Process datasets concurrently and accumulate results
for completed_task in asyncio.as_completed(sync_files_tasks):
try:
dataset_result = await completed_task
completed_datasets += 1
# Step 2: Check what files are missing on cloud
local_hashes = [f.content_hash for f in local_files]
missing_hashes = await _check_missing_hashes(
cloud_base_url, cloud_auth_token, dataset.id, local_hashes, run_id
)
logger.info(f"Cloud is missing {len(missing_hashes)} out of {len(local_hashes)} files")
# Update progress based on completed datasets (0-80% for file sync)
file_sync_progress = int((completed_datasets / len(datasets)) * 80)
await _safe_update_progress(
run_id, "file_sync", progress_percentage=file_sync_progress
)
# Update progress
try:
await update_sync_operation(run_id, progress_percentage=25)
except Exception as e:
logger.warning(f"Failed to update progress: {str(e)}")
if dataset_result is None:
logger.info(
f"Progress: {completed_datasets}/{len(datasets)} datasets processed ({file_sync_progress}%)"
)
continue
# Step 3: Upload missing files
bytes_uploaded = await _upload_missing_files(
cloud_base_url, cloud_auth_token, dataset, local_files, missing_hashes, run_id
)
logger.info(f"Upload complete: {len(missing_hashes)} files, {bytes_uploaded} bytes")
total_records_downloaded += dataset_result.records_downloaded
total_records_uploaded += dataset_result.records_uploaded
total_bytes_downloaded += dataset_result.bytes_downloaded
total_bytes_uploaded += dataset_result.bytes_uploaded
# Update progress
try:
await update_sync_operation(run_id, progress_percentage=75)
except Exception as e:
logger.warning(f"Failed to update progress: {str(e)}")
# Build per-dataset hash tracking for data lineage
dataset_sync_hashes[dataset_result.dataset_id] = {
"uploaded": dataset_result.uploaded_hashes,
"downloaded": dataset_result.downloaded_hashes,
}
# Step 4: Trigger cognify processing on cloud dataset (only if new files were uploaded)
if missing_hashes:
await _trigger_remote_cognify(cloud_base_url, cloud_auth_token, dataset.id, run_id)
logger.info(f"Cognify processing triggered for dataset {dataset.id}")
if dataset_result.has_uploads:
has_any_uploads = True
if dataset_result.has_downloads:
has_any_downloads = True
processed_datasets.append(dataset_result.dataset_id)
logger.info(
f"Progress: {completed_datasets}/{len(datasets)} datasets processed ({file_sync_progress}%) - "
f"Completed file sync for dataset {dataset_result.dataset_name}: "
f"{dataset_result.records_uploaded} files ({dataset_result.bytes_uploaded} bytes), "
f"{dataset_result.records_downloaded} files ({dataset_result.bytes_downloaded} bytes)"
)
except Exception as e:
completed_datasets += 1
logger.error(f"Dataset file sync failed: {str(e)}")
# Update progress even for failed datasets
file_sync_progress = int((completed_datasets / len(datasets)) * 80)
await _safe_update_progress(
run_id, "file_sync", progress_percentage=file_sync_progress
)
# Continue with other datasets even if one fails
# Step 2: Trigger cognify processing once for all datasets (only if any files were uploaded)
# Update progress to 90% before cognify
await _safe_update_progress(run_id, "cognify", progress_percentage=90)
if has_any_uploads and processed_datasets:
logger.info(
f"Progress: 90% - Triggering cognify processing for {len(processed_datasets)} datasets with new files"
)
try:
# Trigger cognify for all datasets at once - use first dataset as reference point
await _trigger_remote_cognify(
cloud_base_url, cloud_auth_token, datasets[0].id, run_id
)
logger.info("Cognify processing triggered successfully for all datasets")
except Exception as e:
logger.warning(f"Failed to trigger cognify processing: {str(e)}")
# Don't fail the entire sync if cognify fails
else:
logger.info(
f"Skipping cognify processing - no new files were uploaded for dataset {dataset.id}"
"Progress: 90% - Skipping cognify processing - no new files were uploaded across any datasets"
)
# Final progress
try:
await update_sync_operation(run_id, progress_percentage=100)
except Exception as e:
logger.warning(f"Failed to update final progress: {str(e)}")
# Step 3: Trigger local cognify processing if any files were downloaded
if has_any_downloads and processed_datasets:
logger.info(
f"Progress: 95% - Triggering local cognify processing for {len(processed_datasets)} datasets with downloaded files"
)
try:
await cognify()
logger.info("Local cognify processing completed successfully for all datasets")
except Exception as e:
logger.warning(f"Failed to run local cognify processing: {str(e)}")
# Don't fail the entire sync if local cognify fails
else:
logger.info(
"Progress: 95% - Skipping local cognify processing - no new files were downloaded across any datasets"
)
records_processed = len(local_files)
# Update final progress
try:
await _safe_update_progress(
run_id,
"final",
progress_percentage=100,
total_records_to_sync=total_records_uploaded + total_records_downloaded,
total_records_to_download=total_records_downloaded,
total_records_to_upload=total_records_uploaded,
records_downloaded=total_records_downloaded,
records_uploaded=total_records_uploaded,
)
except Exception as e:
logger.warning(f"Failed to update final sync progress: {str(e)}")
logger.info(
f"Sync completed successfully: {records_processed} records, {bytes_uploaded} bytes uploaded"
f"Multi-dataset sync completed: {len(datasets)} datasets processed, downloaded {total_records_downloaded} records/{total_bytes_downloaded} bytes, uploaded {total_records_uploaded} records/{total_bytes_uploaded} bytes"
)
return records_processed, bytes_uploaded
return (
total_records_downloaded,
total_records_uploaded,
total_bytes_downloaded,
total_bytes_uploaded,
dataset_sync_hashes,
)
except Exception as e:
logger.error(f"Sync failed: {str(e)}")
raise ConnectionError(f"Cloud sync failed: {str(e)}")
@dataclass
class DatasetSyncResult:
"""Result of syncing files for a single dataset."""
dataset_name: str
dataset_id: str
records_downloaded: int
records_uploaded: int
bytes_downloaded: int
bytes_uploaded: int
has_uploads: bool # Whether any files were uploaded (for cognify decision)
has_downloads: bool # Whether any files were downloaded (for cognify decision)
uploaded_hashes: List[str] # Content hashes of files uploaded during sync
downloaded_hashes: List[str] # Content hashes of files downloaded during sync
async def _sync_dataset_files(
dataset: Dataset, cloud_base_url: str, cloud_auth_token: str, user: User, run_id: str
) -> Optional[DatasetSyncResult]:
"""
Sync files for a single dataset (2-way: upload to cloud, download from cloud).
Does NOT trigger cognify - that's done separately once for all datasets.
Returns:
DatasetSyncResult with sync results or None if dataset was empty
"""
logger.info(f"Syncing files for dataset: {dataset.name} ({dataset.id})")
try:
# Step 1: Extract local file info with stored hashes
local_files = await _extract_local_files_with_hashes(dataset, user, run_id)
logger.info(f"Found {len(local_files)} local files for dataset {dataset.name}")
if not local_files:
logger.info(f"No files to sync for dataset {dataset.name} - skipping")
return None
# Step 2: Check what files are missing on cloud
local_hashes = [f.content_hash for f in local_files]
hashes_diff_response = await _check_hashes_diff(
cloud_base_url, cloud_auth_token, dataset, local_hashes, run_id
)
hashes_missing_on_remote = hashes_diff_response.missing_on_remote
hashes_missing_on_local = hashes_diff_response.missing_on_local
logger.info(
f"Dataset {dataset.name}: {len(hashes_missing_on_remote)} files to upload, {len(hashes_missing_on_local)} files to download"
)
# Step 3: Upload files that are missing on cloud
bytes_uploaded = await _upload_missing_files(
cloud_base_url, cloud_auth_token, dataset, local_files, hashes_missing_on_remote, run_id
)
logger.info(
f"Dataset {dataset.name}: Upload complete - {len(hashes_missing_on_remote)} files, {bytes_uploaded} bytes"
)
# Step 4: Download files that are missing locally
bytes_downloaded = await _download_missing_files(
cloud_base_url, cloud_auth_token, dataset, hashes_missing_on_local, user
)
logger.info(
f"Dataset {dataset.name}: Download complete - {len(hashes_missing_on_local)} files, {bytes_downloaded} bytes"
)
return DatasetSyncResult(
dataset_name=dataset.name,
dataset_id=str(dataset.id),
records_downloaded=len(hashes_missing_on_local),
records_uploaded=len(hashes_missing_on_remote),
bytes_downloaded=bytes_downloaded,
bytes_uploaded=bytes_uploaded,
has_uploads=len(hashes_missing_on_remote) > 0,
has_downloads=len(hashes_missing_on_local) > 0,
uploaded_hashes=hashes_missing_on_remote,
downloaded_hashes=hashes_missing_on_local,
)
except Exception as e:
logger.error(f"Failed to sync files for dataset {dataset.name} ({dataset.id}): {str(e)}")
raise # Re-raise to be handled by the caller
async def _extract_local_files_with_hashes(
dataset: Dataset, user: User, run_id: str
) -> List[LocalFileInfo]:
@ -334,43 +556,42 @@ async def _get_file_size(file_path: str) -> int:
async def _get_cloud_base_url() -> str:
"""Get Cognee Cloud API base URL."""
# TODO: Make this configurable via environment variable or config
return os.getenv("COGNEE_CLOUD_API_URL", "http://localhost:8001")
async def _get_cloud_auth_token(user: User) -> str:
"""Get authentication token for Cognee Cloud API."""
# TODO: Implement proper authentication with Cognee Cloud
# This should get or refresh an API token for the user
return os.getenv("COGNEE_CLOUD_AUTH_TOKEN", "your-auth-token-here")
return os.getenv("COGNEE_CLOUD_AUTH_TOKEN", "your-auth-token")
async def _check_missing_hashes(
cloud_base_url: str, auth_token: str, dataset_id: str, local_hashes: List[str], run_id: str
) -> List[str]:
async def _check_hashes_diff(
cloud_base_url: str, auth_token: str, dataset: Dataset, local_hashes: List[str], run_id: str
) -> CheckHashesDiffResponse:
"""
Step 1: Check which hashes are missing on cloud.
Check which hashes are missing on cloud.
Returns:
List[str]: MD5 hashes that need to be uploaded
"""
url = f"{cloud_base_url}/api/sync/{dataset_id}/diff"
url = f"{cloud_base_url}/api/sync/{dataset.id}/diff"
headers = {"X-Api-Key": auth_token, "Content-Type": "application/json"}
payload = CheckMissingHashesRequest(hashes=local_hashes)
payload = CheckMissingHashesRequest(
dataset_id=str(dataset.id), dataset_name=dataset.name, hashes=local_hashes
)
logger.info(f"Checking missing hashes on cloud for dataset {dataset_id}")
logger.info(f"Checking missing hashes on cloud for dataset {dataset.id}")
try:
async with aiohttp.ClientSession() as session:
async with session.post(url, json=payload.dict(), headers=headers) as response:
if response.status == 200:
data = await response.json()
missing_response = CheckMissingHashesResponse(**data)
missing_response = CheckHashesDiffResponse(**data)
logger.info(
f"Cloud reports {len(missing_response.missing)} missing files out of {len(local_hashes)} total"
f"Cloud is missing {len(missing_response.missing_on_remote)} out of {len(local_hashes)} files, local is missing {len(missing_response.missing_on_local)} files"
)
return missing_response.missing
return missing_response
else:
error_text = await response.text()
logger.error(
@ -385,22 +606,137 @@ async def _check_missing_hashes(
raise ConnectionError(f"Failed to check missing hashes: {str(e)}")
async def _download_missing_files(
cloud_base_url: str,
auth_token: str,
dataset: Dataset,
hashes_missing_on_local: List[str],
user: User,
) -> int:
"""
Download files that are missing locally from the cloud.
Returns:
int: Total bytes downloaded
"""
logger.info(f"Downloading {len(hashes_missing_on_local)} missing files from cloud")
if not hashes_missing_on_local:
logger.info("No files need to be downloaded - all files already exist locally")
return 0
total_bytes_downloaded = 0
downloaded_count = 0
headers = {"X-Api-Key": auth_token}
async with aiohttp.ClientSession() as session:
for file_hash in hashes_missing_on_local:
try:
# Download file from cloud by hash
download_url = f"{cloud_base_url}/api/sync/{dataset.id}/data/{file_hash}"
logger.debug(f"Downloading file with hash: {file_hash}")
async with session.get(download_url, headers=headers) as response:
if response.status == 200:
file_content = await response.read()
file_size = len(file_content)
# Get file metadata from response headers
file_name = response.headers.get("X-File-Name", f"file_{file_hash}")
# Save file locally using ingestion pipeline
await _save_downloaded_file(
dataset, file_hash, file_name, file_content, user
)
total_bytes_downloaded += file_size
downloaded_count += 1
logger.debug(f"Successfully downloaded {file_name} ({file_size} bytes)")
elif response.status == 404:
logger.warning(f"File with hash {file_hash} not found on cloud")
continue
else:
error_text = await response.text()
logger.error(
f"Failed to download file {file_hash}: Status {response.status} - {error_text}"
)
continue
except Exception as e:
logger.error(f"Error downloading file {file_hash}: {str(e)}")
continue
logger.info(
f"Download summary: {downloaded_count}/{len(hashes_missing_on_local)} files downloaded, {total_bytes_downloaded} bytes total"
)
return total_bytes_downloaded
class InMemoryDownload:
def __init__(self, data: bytes, filename: str):
self.file = io.BufferedReader(io.BytesIO(data))
self.filename = filename
async def _save_downloaded_file(
dataset: Dataset,
file_hash: str,
file_name: str,
file_content: bytes,
user: User,
) -> None:
"""
Save a downloaded file to local storage and register it in the dataset.
Uses the existing ingest_data function for consistency with normal ingestion.
Args:
dataset: The dataset to add the file to
file_hash: MD5 hash of the file content
file_name: Original file name
file_content: Raw file content bytes
"""
try:
# Create a temporary file-like object from the bytes
file_obj = InMemoryDownload(file_content, file_name)
# User is injected as dependency
# Use the existing ingest_data function to properly handle the file
# This ensures consistency with normal file ingestion
await ingest_data(
data=file_obj,
dataset_name=dataset.name,
user=user,
dataset_id=dataset.id,
)
logger.debug(f"Successfully saved downloaded file: {file_name} (hash: {file_hash})")
except Exception as e:
logger.error(f"Failed to save downloaded file {file_name}: {str(e)}")
raise
async def _upload_missing_files(
cloud_base_url: str,
auth_token: str,
dataset: Dataset,
local_files: List[LocalFileInfo],
missing_hashes: List[str],
hashes_missing_on_remote: List[str],
run_id: str,
) -> int:
"""
Step 2: Upload files that are missing on cloud.
Upload files that are missing on cloud.
Returns:
int: Total bytes uploaded
"""
# Filter local files to only those with missing hashes
files_to_upload = [f for f in local_files if f.content_hash in missing_hashes]
files_to_upload = [f for f in local_files if f.content_hash in hashes_missing_on_remote]
logger.info(f"Uploading {len(files_to_upload)} missing files to cloud")
@ -442,13 +778,6 @@ async def _upload_missing_files(
if response.status in [200, 201]:
total_bytes_uploaded += len(file_content)
uploaded_count += 1
# Update progress periodically
if uploaded_count % 10 == 0:
progress = (
25 + (uploaded_count / len(files_to_upload)) * 50
) # 25-75% range
await update_sync_operation(run_id, progress_percentage=int(progress))
else:
error_text = await response.text()
logger.error(
@ -470,7 +799,7 @@ async def _prune_cloud_dataset(
cloud_base_url: str, auth_token: str, dataset_id: str, local_hashes: List[str], run_id: str
) -> None:
"""
Step 3: Prune cloud dataset to match local state.
Prune cloud dataset to match local state.
"""
url = f"{cloud_base_url}/api/sync/{dataset_id}?prune=true"
headers = {"X-Api-Key": auth_token, "Content-Type": "application/json"}
@ -506,7 +835,7 @@ async def _trigger_remote_cognify(
cloud_base_url: str, auth_token: str, dataset_id: str, run_id: str
) -> None:
"""
Step 4: Trigger cognify processing on the cloud dataset.
Trigger cognify processing on the cloud dataset.
This initiates knowledge graph processing on the synchronized dataset
using the cloud infrastructure.

View file

@ -1,5 +1,9 @@
from .create_sync_operation import create_sync_operation
from .get_sync_operation import get_sync_operation, get_user_sync_operations
from .get_sync_operation import (
get_sync_operation,
get_user_sync_operations,
get_running_sync_operations_for_user,
)
from .update_sync_operation import (
update_sync_operation,
mark_sync_started,
@ -11,6 +15,7 @@ __all__ = [
"create_sync_operation",
"get_sync_operation",
"get_user_sync_operations",
"get_running_sync_operations_for_user",
"update_sync_operation",
"mark_sync_started",
"mark_sync_completed",

View file

@ -1,5 +1,5 @@
from uuid import UUID
from typing import Optional
from typing import Optional, List
from datetime import datetime, timezone
from cognee.modules.sync.models import SyncOperation, SyncStatus
from cognee.infrastructure.databases.relational import get_relational_engine
@ -7,20 +7,24 @@ from cognee.infrastructure.databases.relational import get_relational_engine
async def create_sync_operation(
run_id: str,
dataset_id: UUID,
dataset_name: str,
dataset_ids: List[UUID],
dataset_names: List[str],
user_id: UUID,
total_records: Optional[int] = None,
total_records_to_sync: Optional[int] = None,
total_records_to_download: Optional[int] = None,
total_records_to_upload: Optional[int] = None,
) -> SyncOperation:
"""
Create a new sync operation record in the database.
Args:
run_id: Unique public identifier for this sync operation
dataset_id: UUID of the dataset being synced
dataset_name: Name of the dataset being synced
dataset_ids: List of dataset UUIDs being synced
dataset_names: List of dataset names being synced
user_id: UUID of the user who initiated the sync
total_records: Total number of records to sync (if known)
total_records_to_sync: Total number of records to sync (if known)
total_records_to_download: Total number of records to download (if known)
total_records_to_upload: Total number of records to upload (if known)
Returns:
SyncOperation: The created sync operation record
@ -29,11 +33,15 @@ async def create_sync_operation(
sync_operation = SyncOperation(
run_id=run_id,
dataset_id=dataset_id,
dataset_name=dataset_name,
dataset_ids=[
str(uuid) for uuid in dataset_ids
], # Convert UUIDs to strings for JSON storage
dataset_names=dataset_names,
user_id=user_id,
status=SyncStatus.STARTED,
total_records=total_records,
total_records_to_sync=total_records_to_sync,
total_records_to_download=total_records_to_download,
total_records_to_upload=total_records_to_upload,
created_at=datetime.now(timezone.utc),
)

View file

@ -1,7 +1,7 @@
from uuid import UUID
from typing import List, Optional
from sqlalchemy import select, desc
from cognee.modules.sync.models import SyncOperation
from sqlalchemy import select, desc, and_
from cognee.modules.sync.models import SyncOperation, SyncStatus
from cognee.infrastructure.databases.relational import get_relational_engine
@ -77,3 +77,31 @@ async def get_sync_operations_by_dataset(
)
result = await session.execute(query)
return list(result.scalars().all())
async def get_running_sync_operations_for_user(user_id: UUID) -> List[SyncOperation]:
"""
Get all currently running sync operations for a specific user.
Checks for operations with STARTED or IN_PROGRESS status.
Args:
user_id: UUID of the user
Returns:
List[SyncOperation]: List of running sync operations for the user
"""
db_engine = get_relational_engine()
async with db_engine.get_async_session() as session:
query = (
select(SyncOperation)
.where(
and_(
SyncOperation.user_id == user_id,
SyncOperation.status.in_([SyncStatus.STARTED, SyncStatus.IN_PROGRESS]),
)
)
.order_by(desc(SyncOperation.created_at))
)
result = await session.execute(query)
return list(result.scalars().all())

View file

@ -1,16 +1,74 @@
from typing import Optional
import asyncio
from typing import Optional, List
from datetime import datetime, timezone
from sqlalchemy import select
from sqlalchemy.exc import SQLAlchemyError, DisconnectionError, OperationalError, TimeoutError
from cognee.modules.sync.models import SyncOperation, SyncStatus
from cognee.infrastructure.databases.relational import get_relational_engine
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.utils.calculate_backoff import calculate_backoff
logger = get_logger("sync.db_operations")
async def _retry_db_operation(operation_func, run_id: str, max_retries: int = 3):
"""
Retry database operations with exponential backoff for transient failures.
Args:
operation_func: Async function to retry
run_id: Run ID for logging context
max_retries: Maximum number of retry attempts
Returns:
Result of the operation function
Raises:
Exception: Re-raises the last exception if all retries fail
"""
attempt = 0
last_exception = None
while attempt < max_retries:
try:
return await operation_func()
except (DisconnectionError, OperationalError, TimeoutError) as e:
attempt += 1
last_exception = e
if attempt >= max_retries:
logger.error(
f"Database operation failed after {max_retries} attempts for run_id {run_id}: {str(e)}"
)
break
backoff_time = calculate_backoff(attempt - 1) # calculate_backoff is 0-indexed
logger.warning(
f"Database operation failed for run_id {run_id}, retrying in {backoff_time:.2f}s (attempt {attempt}/{max_retries}): {str(e)}"
)
await asyncio.sleep(backoff_time)
except Exception as e:
# Non-transient errors should not be retried
logger.error(f"Non-retryable database error for run_id {run_id}: {str(e)}")
raise
# If we get here, all retries failed
raise last_exception
async def update_sync_operation(
run_id: str,
status: Optional[SyncStatus] = None,
progress_percentage: Optional[int] = None,
processed_records: Optional[int] = None,
bytes_transferred: Optional[int] = None,
records_downloaded: Optional[int] = None,
total_records_to_sync: Optional[int] = None,
total_records_to_download: Optional[int] = None,
total_records_to_upload: Optional[int] = None,
records_uploaded: Optional[int] = None,
bytes_downloaded: Optional[int] = None,
bytes_uploaded: Optional[int] = None,
dataset_sync_hashes: Optional[dict] = None,
error_message: Optional[str] = None,
retry_count: Optional[int] = None,
started_at: Optional[datetime] = None,
@ -23,8 +81,14 @@ async def update_sync_operation(
run_id: The public run_id of the sync operation to update
status: New status for the operation
progress_percentage: Progress percentage (0-100)
processed_records: Number of records processed so far
bytes_transferred: Total bytes transferred
records_downloaded: Number of records downloaded so far
total_records_to_sync: Total number of records that need to be synced
total_records_to_download: Total number of records to download from cloud
total_records_to_upload: Total number of records to upload to cloud
records_uploaded: Number of records uploaded so far
bytes_downloaded: Total bytes downloaded from cloud
bytes_uploaded: Total bytes uploaded to cloud
dataset_sync_hashes: Dict mapping dataset_id -> {uploaded: [hashes], downloaded: [hashes]}
error_message: Error message if operation failed
retry_count: Number of retry attempts
started_at: When the actual processing started
@ -33,57 +97,116 @@ async def update_sync_operation(
Returns:
SyncOperation: The updated sync operation record, or None if not found
"""
db_engine = get_relational_engine()
async with db_engine.get_async_session() as session:
# Find the sync operation
query = select(SyncOperation).where(SyncOperation.run_id == run_id)
result = await session.execute(query)
sync_operation = result.scalars().first()
async def _perform_update():
db_engine = get_relational_engine()
if not sync_operation:
return None
async with db_engine.get_async_session() as session:
try:
# Find the sync operation
query = select(SyncOperation).where(SyncOperation.run_id == run_id)
result = await session.execute(query)
sync_operation = result.scalars().first()
# Update fields that were provided
if status is not None:
sync_operation.status = status
if not sync_operation:
logger.warning(f"Sync operation not found for run_id: {run_id}")
return None
if progress_percentage is not None:
sync_operation.progress_percentage = max(0, min(100, progress_percentage))
# Log what we're updating for debugging
updates = []
if status is not None:
updates.append(f"status={status.value}")
if progress_percentage is not None:
updates.append(f"progress={progress_percentage}%")
if records_downloaded is not None:
updates.append(f"downloaded={records_downloaded}")
if records_uploaded is not None:
updates.append(f"uploaded={records_uploaded}")
if total_records_to_sync is not None:
updates.append(f"total_sync={total_records_to_sync}")
if total_records_to_download is not None:
updates.append(f"total_download={total_records_to_download}")
if total_records_to_upload is not None:
updates.append(f"total_upload={total_records_to_upload}")
if processed_records is not None:
sync_operation.processed_records = processed_records
if updates:
logger.debug(f"Updating sync operation {run_id}: {', '.join(updates)}")
if bytes_transferred is not None:
sync_operation.bytes_transferred = bytes_transferred
# Update fields that were provided
if status is not None:
sync_operation.status = status
if error_message is not None:
sync_operation.error_message = error_message
if progress_percentage is not None:
sync_operation.progress_percentage = max(0, min(100, progress_percentage))
if retry_count is not None:
sync_operation.retry_count = retry_count
if records_downloaded is not None:
sync_operation.records_downloaded = records_downloaded
if started_at is not None:
sync_operation.started_at = started_at
if records_uploaded is not None:
sync_operation.records_uploaded = records_uploaded
if completed_at is not None:
sync_operation.completed_at = completed_at
if total_records_to_sync is not None:
sync_operation.total_records_to_sync = total_records_to_sync
# Auto-set completion timestamp for terminal statuses
if (
status in [SyncStatus.COMPLETED, SyncStatus.FAILED, SyncStatus.CANCELLED]
and completed_at is None
):
sync_operation.completed_at = datetime.now(timezone.utc)
if total_records_to_download is not None:
sync_operation.total_records_to_download = total_records_to_download
# Auto-set started timestamp when moving to IN_PROGRESS
if status == SyncStatus.IN_PROGRESS and sync_operation.started_at is None:
sync_operation.started_at = datetime.now(timezone.utc)
if total_records_to_upload is not None:
sync_operation.total_records_to_upload = total_records_to_upload
await session.commit()
await session.refresh(sync_operation)
if bytes_downloaded is not None:
sync_operation.bytes_downloaded = bytes_downloaded
return sync_operation
if bytes_uploaded is not None:
sync_operation.bytes_uploaded = bytes_uploaded
if dataset_sync_hashes is not None:
sync_operation.dataset_sync_hashes = dataset_sync_hashes
if error_message is not None:
sync_operation.error_message = error_message
if retry_count is not None:
sync_operation.retry_count = retry_count
if started_at is not None:
sync_operation.started_at = started_at
if completed_at is not None:
sync_operation.completed_at = completed_at
# Auto-set completion timestamp for terminal statuses
if (
status in [SyncStatus.COMPLETED, SyncStatus.FAILED, SyncStatus.CANCELLED]
and completed_at is None
):
sync_operation.completed_at = datetime.now(timezone.utc)
# Auto-set started timestamp when moving to IN_PROGRESS
if status == SyncStatus.IN_PROGRESS and sync_operation.started_at is None:
sync_operation.started_at = datetime.now(timezone.utc)
await session.commit()
await session.refresh(sync_operation)
logger.debug(f"Successfully updated sync operation {run_id}")
return sync_operation
except SQLAlchemyError as e:
logger.error(
f"Database error updating sync operation {run_id}: {str(e)}", exc_info=True
)
await session.rollback()
raise
except Exception as e:
logger.error(
f"Unexpected error updating sync operation {run_id}: {str(e)}", exc_info=True
)
await session.rollback()
raise
# Use retry logic for the database operation
return await _retry_db_operation(_perform_update, run_id)
async def mark_sync_started(run_id: str) -> Optional[SyncOperation]:
@ -94,15 +217,23 @@ async def mark_sync_started(run_id: str) -> Optional[SyncOperation]:
async def mark_sync_completed(
run_id: str, processed_records: int, bytes_transferred: int
run_id: str,
records_downloaded: int = 0,
records_uploaded: int = 0,
bytes_downloaded: int = 0,
bytes_uploaded: int = 0,
dataset_sync_hashes: Optional[dict] = None,
) -> Optional[SyncOperation]:
"""Convenience method to mark a sync operation as completed successfully."""
return await update_sync_operation(
run_id=run_id,
status=SyncStatus.COMPLETED,
progress_percentage=100,
processed_records=processed_records,
bytes_transferred=bytes_transferred,
records_downloaded=records_downloaded,
records_uploaded=records_uploaded,
bytes_downloaded=bytes_downloaded,
bytes_uploaded=bytes_uploaded,
dataset_sync_hashes=dataset_sync_hashes,
completed_at=datetime.now(timezone.utc),
)

View file

@ -1,8 +1,16 @@
from uuid import uuid4
from enum import Enum
from typing import Optional
from typing import Optional, List
from datetime import datetime, timezone
from sqlalchemy import Column, Text, DateTime, UUID as SQLAlchemy_UUID, Integer, Enum as SQLEnum
from sqlalchemy import (
Column,
Text,
DateTime,
UUID as SQLAlchemy_UUID,
Integer,
Enum as SQLEnum,
JSON,
)
from cognee.infrastructure.databases.relational import Base
@ -38,8 +46,8 @@ class SyncOperation(Base):
progress_percentage = Column(Integer, default=0, doc="Progress percentage (0-100)")
# Operation metadata
dataset_id = Column(SQLAlchemy_UUID, index=True, doc="ID of the dataset being synced")
dataset_name = Column(Text, doc="Name of the dataset being synced")
dataset_ids = Column(JSON, doc="Array of dataset IDs being synced")
dataset_names = Column(JSON, doc="Array of dataset names being synced")
user_id = Column(SQLAlchemy_UUID, index=True, doc="ID of the user who initiated the sync")
# Timing information
@ -54,18 +62,24 @@ class SyncOperation(Base):
)
# Operation details
total_records = Column(Integer, doc="Total number of records to sync")
processed_records = Column(Integer, default=0, doc="Number of records successfully processed")
bytes_transferred = Column(Integer, default=0, doc="Total bytes transferred to cloud")
total_records_to_sync = Column(Integer, doc="Total number of records to sync")
total_records_to_download = Column(Integer, doc="Total number of records to download")
total_records_to_upload = Column(Integer, doc="Total number of records to upload")
records_downloaded = Column(Integer, default=0, doc="Number of records successfully downloaded")
records_uploaded = Column(Integer, default=0, doc="Number of records successfully uploaded")
bytes_downloaded = Column(Integer, default=0, doc="Total bytes downloaded from cloud")
bytes_uploaded = Column(Integer, default=0, doc="Total bytes uploaded to cloud")
# Data lineage tracking per dataset
dataset_sync_hashes = Column(
JSON, doc="Mapping of dataset_id -> {uploaded: [hashes], downloaded: [hashes]}"
)
# Error handling
error_message = Column(Text, doc="Error message if sync failed")
retry_count = Column(Integer, default=0, doc="Number of retry attempts")
# Additional metadata (can be added later when needed)
# cloud_endpoint = Column(Text, doc="Cloud endpoint used for sync")
# compression_enabled = Column(Text, doc="Whether compression was used")
def get_duration_seconds(self) -> Optional[float]:
"""Get the duration of the sync operation in seconds."""
if not self.created_at:
@ -76,11 +90,53 @@ class SyncOperation(Base):
def get_progress_info(self) -> dict:
"""Get comprehensive progress information."""
total_records_processed = (self.records_downloaded or 0) + (self.records_uploaded or 0)
total_bytes_transferred = (self.bytes_downloaded or 0) + (self.bytes_uploaded or 0)
return {
"status": self.status.value,
"progress_percentage": self.progress_percentage,
"records_processed": f"{self.processed_records or 0}/{self.total_records or 'unknown'}",
"bytes_transferred": self.bytes_transferred or 0,
"records_processed": f"{total_records_processed}/{self.total_records_to_sync or 'unknown'}",
"records_downloaded": self.records_downloaded or 0,
"records_uploaded": self.records_uploaded or 0,
"bytes_transferred": total_bytes_transferred,
"bytes_downloaded": self.bytes_downloaded or 0,
"bytes_uploaded": self.bytes_uploaded or 0,
"duration_seconds": self.get_duration_seconds(),
"error_message": self.error_message,
"dataset_sync_hashes": self.dataset_sync_hashes or {},
}
def _get_all_sync_hashes(self) -> List[str]:
"""Get all content hashes for data created/modified during this sync operation."""
all_hashes = set()
dataset_hashes = self.dataset_sync_hashes or {}
for dataset_id, operations in dataset_hashes.items():
if isinstance(operations, dict):
all_hashes.update(operations.get("uploaded", []))
all_hashes.update(operations.get("downloaded", []))
return list(all_hashes)
def _get_dataset_sync_hashes(self, dataset_id: str) -> dict:
"""Get uploaded/downloaded hashes for a specific dataset."""
dataset_hashes = self.dataset_sync_hashes or {}
return dataset_hashes.get(dataset_id, {"uploaded": [], "downloaded": []})
def was_data_synced(self, content_hash: str, dataset_id: str = None) -> bool:
"""
Check if a specific piece of data was part of this sync operation.
Args:
content_hash: The content hash to check for
dataset_id: Optional - check only within this dataset
"""
if dataset_id:
dataset_hashes = self.get_dataset_sync_hashes(dataset_id)
return content_hash in dataset_hashes.get(
"uploaded", []
) or content_hash in dataset_hashes.get("downloaded", [])
all_hashes = self.get_all_sync_hashes()
return content_hash in all_hashes