<!-- .github/pull_request_template.md --> ## Description <!-- Provide a clear description of the changes in this PR --> ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin.
142 lines
5.5 KiB
Python
142 lines
5.5 KiB
Python
from uuid import uuid4
|
|
from enum import Enum
|
|
from typing import Optional, List
|
|
from datetime import datetime, timezone
|
|
from sqlalchemy import (
|
|
Column,
|
|
Text,
|
|
DateTime,
|
|
UUID as SQLAlchemy_UUID,
|
|
Integer,
|
|
Enum as SQLEnum,
|
|
JSON,
|
|
)
|
|
|
|
from cognee.infrastructure.databases.relational import Base
|
|
|
|
|
|
class SyncStatus(str, Enum):
|
|
"""Enumeration of possible sync operation statuses."""
|
|
|
|
STARTED = "started"
|
|
IN_PROGRESS = "in_progress"
|
|
COMPLETED = "completed"
|
|
FAILED = "failed"
|
|
CANCELLED = "cancelled"
|
|
|
|
|
|
class SyncOperation(Base):
|
|
"""
|
|
Database model for tracking sync operations.
|
|
|
|
This model stores information about background sync operations,
|
|
allowing users to monitor progress and query the status of their sync requests.
|
|
"""
|
|
|
|
__tablename__ = "sync_operations"
|
|
|
|
# Primary identifiers
|
|
id = Column(SQLAlchemy_UUID, primary_key=True, default=uuid4, doc="Database primary key")
|
|
run_id = Column(Text, unique=True, index=True, doc="Public run ID returned to users")
|
|
|
|
# Status and progress tracking
|
|
status = Column(
|
|
SQLEnum(SyncStatus), default=SyncStatus.STARTED, doc="Current status of the sync operation"
|
|
)
|
|
progress_percentage = Column(Integer, default=0, doc="Progress percentage (0-100)")
|
|
|
|
# Operation metadata
|
|
dataset_ids = Column(JSON, doc="Array of dataset IDs being synced")
|
|
dataset_names = Column(JSON, doc="Array of dataset names being synced")
|
|
user_id = Column(SQLAlchemy_UUID, index=True, doc="ID of the user who initiated the sync")
|
|
|
|
# Timing information
|
|
created_at = Column(
|
|
DateTime(timezone=True),
|
|
default=lambda: datetime.now(timezone.utc),
|
|
doc="When the sync was initiated",
|
|
)
|
|
started_at = Column(DateTime(timezone=True), doc="When the actual sync processing began")
|
|
completed_at = Column(
|
|
DateTime(timezone=True), doc="When the sync finished (success or failure)"
|
|
)
|
|
|
|
# Operation details
|
|
total_records_to_sync = Column(Integer, doc="Total number of records to sync")
|
|
total_records_to_download = Column(Integer, doc="Total number of records to download")
|
|
total_records_to_upload = Column(Integer, doc="Total number of records to upload")
|
|
|
|
records_downloaded = Column(Integer, default=0, doc="Number of records successfully downloaded")
|
|
records_uploaded = Column(Integer, default=0, doc="Number of records successfully uploaded")
|
|
bytes_downloaded = Column(Integer, default=0, doc="Total bytes downloaded from cloud")
|
|
bytes_uploaded = Column(Integer, default=0, doc="Total bytes uploaded to cloud")
|
|
|
|
# Data lineage tracking per dataset
|
|
dataset_sync_hashes = Column(
|
|
JSON, doc="Mapping of dataset_id -> {uploaded: [hashes], downloaded: [hashes]}"
|
|
)
|
|
|
|
# Error handling
|
|
error_message = Column(Text, doc="Error message if sync failed")
|
|
retry_count = Column(Integer, default=0, doc="Number of retry attempts")
|
|
|
|
def get_duration_seconds(self) -> Optional[float]:
|
|
"""Get the duration of the sync operation in seconds."""
|
|
if not self.created_at:
|
|
return None
|
|
|
|
end_time = self.completed_at or datetime.now(timezone.utc)
|
|
return (end_time - self.created_at).total_seconds()
|
|
|
|
def get_progress_info(self) -> dict:
|
|
"""Get comprehensive progress information."""
|
|
total_records_processed = (self.records_downloaded or 0) + (self.records_uploaded or 0)
|
|
total_bytes_transferred = (self.bytes_downloaded or 0) + (self.bytes_uploaded or 0)
|
|
|
|
return {
|
|
"status": self.status.value,
|
|
"progress_percentage": self.progress_percentage,
|
|
"records_processed": f"{total_records_processed}/{self.total_records_to_sync or 'unknown'}",
|
|
"records_downloaded": self.records_downloaded or 0,
|
|
"records_uploaded": self.records_uploaded or 0,
|
|
"bytes_transferred": total_bytes_transferred,
|
|
"bytes_downloaded": self.bytes_downloaded or 0,
|
|
"bytes_uploaded": self.bytes_uploaded or 0,
|
|
"duration_seconds": self.get_duration_seconds(),
|
|
"error_message": self.error_message,
|
|
"dataset_sync_hashes": self.dataset_sync_hashes or {},
|
|
}
|
|
|
|
def _get_all_sync_hashes(self) -> List[str]:
|
|
"""Get all content hashes for data created/modified during this sync operation."""
|
|
all_hashes = set()
|
|
dataset_hashes = self.dataset_sync_hashes or {}
|
|
|
|
for dataset_id, operations in dataset_hashes.items():
|
|
if isinstance(operations, dict):
|
|
all_hashes.update(operations.get("uploaded", []))
|
|
all_hashes.update(operations.get("downloaded", []))
|
|
|
|
return list(all_hashes)
|
|
|
|
def _get_dataset_sync_hashes(self, dataset_id: str) -> dict:
|
|
"""Get uploaded/downloaded hashes for a specific dataset."""
|
|
dataset_hashes = self.dataset_sync_hashes or {}
|
|
return dataset_hashes.get(dataset_id, {"uploaded": [], "downloaded": []})
|
|
|
|
def was_data_synced(self, content_hash: str, dataset_id: str = None) -> bool:
|
|
"""
|
|
Check if a specific piece of data was part of this sync operation.
|
|
|
|
Args:
|
|
content_hash: The content hash to check for
|
|
dataset_id: Optional - check only within this dataset
|
|
"""
|
|
if dataset_id:
|
|
dataset_hashes = self.get_dataset_sync_hashes(dataset_id)
|
|
return content_hash in dataset_hashes.get(
|
|
"uploaded", []
|
|
) or content_hash in dataset_hashes.get("downloaded", [])
|
|
|
|
all_hashes = self.get_all_sync_hashes()
|
|
return content_hash in all_hashes
|