cognee/cognee/modules/sync/models/SyncOperation.py
Daulet Amirkhanov 47cb34e89c
feat: update sync to be two way (#1359)
<!-- .github/pull_request_template.md -->

## Description
<!-- Provide a clear description of the changes in this PR -->

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.
2025-09-11 15:34:43 +02:00

142 lines
5.5 KiB
Python

from uuid import uuid4
from enum import Enum
from typing import Optional, List
from datetime import datetime, timezone
from sqlalchemy import (
Column,
Text,
DateTime,
UUID as SQLAlchemy_UUID,
Integer,
Enum as SQLEnum,
JSON,
)
from cognee.infrastructure.databases.relational import Base
class SyncStatus(str, Enum):
"""Enumeration of possible sync operation statuses."""
STARTED = "started"
IN_PROGRESS = "in_progress"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
class SyncOperation(Base):
"""
Database model for tracking sync operations.
This model stores information about background sync operations,
allowing users to monitor progress and query the status of their sync requests.
"""
__tablename__ = "sync_operations"
# Primary identifiers
id = Column(SQLAlchemy_UUID, primary_key=True, default=uuid4, doc="Database primary key")
run_id = Column(Text, unique=True, index=True, doc="Public run ID returned to users")
# Status and progress tracking
status = Column(
SQLEnum(SyncStatus), default=SyncStatus.STARTED, doc="Current status of the sync operation"
)
progress_percentage = Column(Integer, default=0, doc="Progress percentage (0-100)")
# Operation metadata
dataset_ids = Column(JSON, doc="Array of dataset IDs being synced")
dataset_names = Column(JSON, doc="Array of dataset names being synced")
user_id = Column(SQLAlchemy_UUID, index=True, doc="ID of the user who initiated the sync")
# Timing information
created_at = Column(
DateTime(timezone=True),
default=lambda: datetime.now(timezone.utc),
doc="When the sync was initiated",
)
started_at = Column(DateTime(timezone=True), doc="When the actual sync processing began")
completed_at = Column(
DateTime(timezone=True), doc="When the sync finished (success or failure)"
)
# Operation details
total_records_to_sync = Column(Integer, doc="Total number of records to sync")
total_records_to_download = Column(Integer, doc="Total number of records to download")
total_records_to_upload = Column(Integer, doc="Total number of records to upload")
records_downloaded = Column(Integer, default=0, doc="Number of records successfully downloaded")
records_uploaded = Column(Integer, default=0, doc="Number of records successfully uploaded")
bytes_downloaded = Column(Integer, default=0, doc="Total bytes downloaded from cloud")
bytes_uploaded = Column(Integer, default=0, doc="Total bytes uploaded to cloud")
# Data lineage tracking per dataset
dataset_sync_hashes = Column(
JSON, doc="Mapping of dataset_id -> {uploaded: [hashes], downloaded: [hashes]}"
)
# Error handling
error_message = Column(Text, doc="Error message if sync failed")
retry_count = Column(Integer, default=0, doc="Number of retry attempts")
def get_duration_seconds(self) -> Optional[float]:
"""Get the duration of the sync operation in seconds."""
if not self.created_at:
return None
end_time = self.completed_at or datetime.now(timezone.utc)
return (end_time - self.created_at).total_seconds()
def get_progress_info(self) -> dict:
"""Get comprehensive progress information."""
total_records_processed = (self.records_downloaded or 0) + (self.records_uploaded or 0)
total_bytes_transferred = (self.bytes_downloaded or 0) + (self.bytes_uploaded or 0)
return {
"status": self.status.value,
"progress_percentage": self.progress_percentage,
"records_processed": f"{total_records_processed}/{self.total_records_to_sync or 'unknown'}",
"records_downloaded": self.records_downloaded or 0,
"records_uploaded": self.records_uploaded or 0,
"bytes_transferred": total_bytes_transferred,
"bytes_downloaded": self.bytes_downloaded or 0,
"bytes_uploaded": self.bytes_uploaded or 0,
"duration_seconds": self.get_duration_seconds(),
"error_message": self.error_message,
"dataset_sync_hashes": self.dataset_sync_hashes or {},
}
def _get_all_sync_hashes(self) -> List[str]:
"""Get all content hashes for data created/modified during this sync operation."""
all_hashes = set()
dataset_hashes = self.dataset_sync_hashes or {}
for dataset_id, operations in dataset_hashes.items():
if isinstance(operations, dict):
all_hashes.update(operations.get("uploaded", []))
all_hashes.update(operations.get("downloaded", []))
return list(all_hashes)
def _get_dataset_sync_hashes(self, dataset_id: str) -> dict:
"""Get uploaded/downloaded hashes for a specific dataset."""
dataset_hashes = self.dataset_sync_hashes or {}
return dataset_hashes.get(dataset_id, {"uploaded": [], "downloaded": []})
def was_data_synced(self, content_hash: str, dataset_id: str = None) -> bool:
"""
Check if a specific piece of data was part of this sync operation.
Args:
content_hash: The content hash to check for
dataset_id: Optional - check only within this dataset
"""
if dataset_id:
dataset_hashes = self.get_dataset_sync_hashes(dataset_id)
return content_hash in dataset_hashes.get(
"uploaded", []
) or content_hash in dataset_hashes.get("downloaded", [])
all_hashes = self.get_all_sync_hashes()
return content_hash in all_hashes