feat: update sync to be two way (#1359)
<!-- .github/pull_request_template.md --> ## Description <!-- Provide a clear description of the changes in this PR --> ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin.
This commit is contained in:
parent
a29ef8b2d3
commit
47cb34e89c
11 changed files with 970 additions and 199 deletions
|
|
@ -137,6 +137,14 @@ REQUIRE_AUTHENTICATION=False
|
|||
# It enforces LanceDB and KuzuDB use and uses them to create databases per Cognee user + dataset
|
||||
ENABLE_BACKEND_ACCESS_CONTROL=False
|
||||
|
||||
################################################################################
|
||||
# ☁️ Cloud Sync Settings
|
||||
################################################################################
|
||||
|
||||
# Cognee Cloud API settings for syncing data to/from cloud infrastructure
|
||||
COGNEE_CLOUD_API_URL="http://localhost:8001"
|
||||
COGNEE_CLOUD_AUTH_TOKEN="your-auth-token"
|
||||
|
||||
################################################################################
|
||||
# 🛠️ DEV Settings
|
||||
################################################################################
|
||||
|
|
|
|||
98
alembic/versions/211ab850ef3d_add_sync_operations_table.py
Normal file
98
alembic/versions/211ab850ef3d_add_sync_operations_table.py
Normal file
|
|
@ -0,0 +1,98 @@
|
|||
"""Add sync_operations table
|
||||
|
||||
Revision ID: 211ab850ef3d
|
||||
Revises: 9e7a3cb85175
|
||||
Create Date: 2025-09-10 20:11:13.534829
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "211ab850ef3d"
|
||||
down_revision: Union[str, None] = "9e7a3cb85175"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
|
||||
# Check if table already exists (it might be created by Base.metadata.create_all() in initial migration)
|
||||
connection = op.get_bind()
|
||||
inspector = sa.inspect(connection)
|
||||
|
||||
if "sync_operations" not in inspector.get_table_names():
|
||||
# Table doesn't exist, create it normally
|
||||
op.create_table(
|
||||
"sync_operations",
|
||||
sa.Column("id", sa.UUID(), nullable=False),
|
||||
sa.Column("run_id", sa.Text(), nullable=True),
|
||||
sa.Column(
|
||||
"status",
|
||||
sa.Enum(
|
||||
"STARTED",
|
||||
"IN_PROGRESS",
|
||||
"COMPLETED",
|
||||
"FAILED",
|
||||
"CANCELLED",
|
||||
name="syncstatus",
|
||||
create_type=False,
|
||||
),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("progress_percentage", sa.Integer(), nullable=True),
|
||||
sa.Column("dataset_ids", sa.JSON(), nullable=True),
|
||||
sa.Column("dataset_names", sa.JSON(), nullable=True),
|
||||
sa.Column("user_id", sa.UUID(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("started_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("completed_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("total_records_to_sync", sa.Integer(), nullable=True),
|
||||
sa.Column("total_records_to_download", sa.Integer(), nullable=True),
|
||||
sa.Column("total_records_to_upload", sa.Integer(), nullable=True),
|
||||
sa.Column("records_downloaded", sa.Integer(), nullable=True),
|
||||
sa.Column("records_uploaded", sa.Integer(), nullable=True),
|
||||
sa.Column("bytes_downloaded", sa.Integer(), nullable=True),
|
||||
sa.Column("bytes_uploaded", sa.Integer(), nullable=True),
|
||||
sa.Column("dataset_sync_hashes", sa.JSON(), nullable=True),
|
||||
sa.Column("error_message", sa.Text(), nullable=True),
|
||||
sa.Column("retry_count", sa.Integer(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_sync_operations_run_id"), "sync_operations", ["run_id"], unique=True
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_sync_operations_user_id"), "sync_operations", ["user_id"], unique=False
|
||||
)
|
||||
else:
|
||||
# Table already exists, but we might need to add missing columns or indexes
|
||||
# For now, just log that the table already exists
|
||||
print("sync_operations table already exists, skipping creation")
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
|
||||
# Only drop if table exists (might have been created by Base.metadata.create_all())
|
||||
connection = op.get_bind()
|
||||
inspector = sa.inspect(connection)
|
||||
|
||||
if "sync_operations" in inspector.get_table_names():
|
||||
op.drop_index(op.f("ix_sync_operations_user_id"), table_name="sync_operations")
|
||||
op.drop_index(op.f("ix_sync_operations_run_id"), table_name="sync_operations")
|
||||
op.drop_table("sync_operations")
|
||||
|
||||
# Drop the enum type that was created (only if no other tables are using it)
|
||||
sa.Enum(name="syncstatus").drop(op.get_bind(), checkfirst=True)
|
||||
else:
|
||||
print("sync_operations table doesn't exist, skipping downgrade")
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
|
@ -19,6 +19,7 @@ depends_on: Union[str, Sequence[str], None] = None
|
|||
|
||||
def upgrade() -> None:
|
||||
db_engine = get_relational_engine()
|
||||
# we might want to delete this
|
||||
await_only(db_engine.create_database())
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from .sync import (
|
|||
SyncResponse,
|
||||
LocalFileInfo,
|
||||
CheckMissingHashesRequest,
|
||||
CheckMissingHashesResponse,
|
||||
CheckHashesDiffResponse,
|
||||
PruneDatasetRequest,
|
||||
)
|
||||
|
||||
|
|
@ -12,6 +12,6 @@ __all__ = [
|
|||
"SyncResponse",
|
||||
"LocalFileInfo",
|
||||
"CheckMissingHashesRequest",
|
||||
"CheckMissingHashesResponse",
|
||||
"CheckHashesDiffResponse",
|
||||
"PruneDatasetRequest",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from uuid import UUID
|
||||
from typing import Optional
|
||||
from typing import Optional, List
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
|
|
@ -8,6 +8,7 @@ from cognee.api.DTO import InDTO
|
|||
from cognee.modules.users.models import User
|
||||
from cognee.modules.users.methods import get_authenticated_user
|
||||
from cognee.modules.users.permissions.methods import get_specific_user_permission_datasets
|
||||
from cognee.modules.sync.methods import get_running_sync_operations_for_user, get_sync_operation
|
||||
from cognee.shared.utils import send_telemetry
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.api.v1.sync import SyncResponse
|
||||
|
|
@ -19,7 +20,7 @@ logger = get_logger()
|
|||
class SyncRequest(InDTO):
|
||||
"""Request model for sync operations."""
|
||||
|
||||
dataset_id: Optional[UUID] = None
|
||||
dataset_ids: Optional[List[UUID]] = None
|
||||
|
||||
|
||||
def get_sync_router() -> APIRouter:
|
||||
|
|
@ -40,7 +41,7 @@ def get_sync_router() -> APIRouter:
|
|||
## Request Body (JSON)
|
||||
```json
|
||||
{
|
||||
"dataset_id": "123e4567-e89b-12d3-a456-426614174000"
|
||||
"dataset_ids": ["123e4567-e89b-12d3-a456-426614174000", "456e7890-e12b-34c5-d678-901234567000"]
|
||||
}
|
||||
```
|
||||
|
||||
|
|
@ -48,8 +49,8 @@ def get_sync_router() -> APIRouter:
|
|||
Returns immediate response for the sync operation:
|
||||
- **run_id**: Unique identifier for tracking the background sync operation
|
||||
- **status**: Always "started" (operation runs in background)
|
||||
- **dataset_id**: ID of the dataset being synced
|
||||
- **dataset_name**: Name of the dataset being synced
|
||||
- **dataset_ids**: List of dataset IDs being synced
|
||||
- **dataset_names**: List of dataset names being synced
|
||||
- **message**: Description of the background operation
|
||||
- **timestamp**: When the sync was initiated
|
||||
- **user_id**: User who initiated the sync
|
||||
|
|
@ -64,15 +65,21 @@ def get_sync_router() -> APIRouter:
|
|||
|
||||
## Example Usage
|
||||
```bash
|
||||
# Sync dataset to cloud by ID (JSON request)
|
||||
# Sync multiple datasets to cloud by IDs (JSON request)
|
||||
curl -X POST "http://localhost:8000/api/v1/sync" \\
|
||||
-H "Content-Type: application/json" \\
|
||||
-H "Cookie: auth_token=your-token" \\
|
||||
-d '{"dataset_id": "123e4567-e89b-12d3-a456-426614174000"}'
|
||||
-d '{"dataset_ids": ["123e4567-e89b-12d3-a456-426614174000", "456e7890-e12b-34c5-d678-901234567000"]}'
|
||||
|
||||
# Sync all user datasets (empty request body or null dataset_ids)
|
||||
curl -X POST "http://localhost:8000/api/v1/sync" \\
|
||||
-H "Content-Type: application/json" \\
|
||||
-H "Cookie: auth_token=your-token" \\
|
||||
-d '{}'
|
||||
```
|
||||
|
||||
## Error Codes
|
||||
- **400 Bad Request**: Invalid dataset_id format
|
||||
- **400 Bad Request**: Invalid dataset_ids format
|
||||
- **401 Unauthorized**: Invalid or missing authentication
|
||||
- **403 Forbidden**: User doesn't have permission to access dataset
|
||||
- **404 Not Found**: Dataset not found
|
||||
|
|
@ -92,32 +99,48 @@ def get_sync_router() -> APIRouter:
|
|||
user.id,
|
||||
additional_properties={
|
||||
"endpoint": "POST /v1/sync",
|
||||
"dataset_id": str(request.dataset_id) if request.dataset_id else "*",
|
||||
"dataset_ids": [str(id) for id in request.dataset_ids]
|
||||
if request.dataset_ids
|
||||
else "*",
|
||||
},
|
||||
)
|
||||
|
||||
from cognee.api.v1.sync import sync as cognee_sync
|
||||
|
||||
try:
|
||||
# Retrieve existing dataset and check permissions
|
||||
datasets = await get_specific_user_permission_datasets(
|
||||
user.id, "write", [request.dataset_id] if request.dataset_id else None
|
||||
)
|
||||
|
||||
sync_results = {}
|
||||
|
||||
for dataset in datasets:
|
||||
await set_database_global_context_variables(dataset.id, dataset.owner_id)
|
||||
|
||||
# Execute cloud sync operation
|
||||
sync_result = await cognee_sync(
|
||||
dataset=dataset,
|
||||
user=user,
|
||||
# Check if user has any running sync operations
|
||||
running_syncs = await get_running_sync_operations_for_user(user.id)
|
||||
if running_syncs:
|
||||
# Return information about the existing sync operation
|
||||
existing_sync = running_syncs[0] # Get the most recent running sync
|
||||
return JSONResponse(
|
||||
status_code=409,
|
||||
content={
|
||||
"error": "Sync operation already in progress",
|
||||
"details": {
|
||||
"run_id": existing_sync.run_id,
|
||||
"status": "already_running",
|
||||
"dataset_ids": existing_sync.dataset_ids,
|
||||
"dataset_names": existing_sync.dataset_names,
|
||||
"message": f"You have a sync operation already in progress with run_id '{existing_sync.run_id}'. Use the status endpoint to monitor progress, or wait for it to complete before starting a new sync.",
|
||||
"timestamp": existing_sync.created_at.isoformat(),
|
||||
"progress_percentage": existing_sync.progress_percentage,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
sync_results[str(dataset.id)] = sync_result
|
||||
# Retrieve existing dataset and check permissions
|
||||
datasets = await get_specific_user_permission_datasets(
|
||||
user.id, "write", request.dataset_ids if request.dataset_ids else None
|
||||
)
|
||||
|
||||
return sync_results
|
||||
# Execute new cloud sync operation for all datasets
|
||||
sync_result = await cognee_sync(
|
||||
datasets=datasets,
|
||||
user=user,
|
||||
)
|
||||
|
||||
return sync_result
|
||||
|
||||
except ValueError as e:
|
||||
return JSONResponse(status_code=400, content={"error": str(e)})
|
||||
|
|
@ -131,4 +154,88 @@ def get_sync_router() -> APIRouter:
|
|||
logger.error(f"Cloud sync operation failed: {str(e)}")
|
||||
return JSONResponse(status_code=409, content={"error": "Cloud sync operation failed."})
|
||||
|
||||
@router.get("/status")
|
||||
async def get_sync_status_overview(
|
||||
user: User = Depends(get_authenticated_user),
|
||||
):
|
||||
"""
|
||||
Check if there are any running sync operations for the current user.
|
||||
|
||||
This endpoint provides a simple check to see if the user has any active sync operations
|
||||
without needing to know specific run IDs.
|
||||
|
||||
## Response
|
||||
Returns a simple status overview:
|
||||
- **has_running_sync**: Boolean indicating if there are any running syncs
|
||||
- **running_sync_count**: Number of currently running sync operations
|
||||
- **latest_running_sync** (optional): Information about the most recent running sync if any exists
|
||||
|
||||
## Example Usage
|
||||
```bash
|
||||
curl -X GET "http://localhost:8000/api/v1/sync/status" \\
|
||||
-H "Cookie: auth_token=your-token"
|
||||
```
|
||||
|
||||
## Example Responses
|
||||
|
||||
**No running syncs:**
|
||||
```json
|
||||
{
|
||||
"has_running_sync": false,
|
||||
"running_sync_count": 0
|
||||
}
|
||||
```
|
||||
|
||||
**With running sync:**
|
||||
```json
|
||||
{
|
||||
"has_running_sync": true,
|
||||
"running_sync_count": 1,
|
||||
"latest_running_sync": {
|
||||
"run_id": "12345678-1234-5678-9012-123456789012",
|
||||
"dataset_name": "My Dataset",
|
||||
"progress_percentage": 45,
|
||||
"created_at": "2025-01-01T00:00:00Z"
|
||||
}
|
||||
}
|
||||
```
|
||||
"""
|
||||
send_telemetry(
|
||||
"Sync Status Overview API Endpoint Invoked",
|
||||
user.id,
|
||||
additional_properties={
|
||||
"endpoint": "GET /v1/sync/status",
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
# Get any running sync operations for the user
|
||||
running_syncs = await get_running_sync_operations_for_user(user.id)
|
||||
|
||||
response = {
|
||||
"has_running_sync": len(running_syncs) > 0,
|
||||
"running_sync_count": len(running_syncs),
|
||||
}
|
||||
|
||||
# If there are running syncs, include info about the latest one
|
||||
if running_syncs:
|
||||
latest_sync = running_syncs[0] # Already ordered by created_at desc
|
||||
response["latest_running_sync"] = {
|
||||
"run_id": latest_sync.run_id,
|
||||
"dataset_ids": latest_sync.dataset_ids,
|
||||
"dataset_names": latest_sync.dataset_names,
|
||||
"progress_percentage": latest_sync.progress_percentage,
|
||||
"created_at": latest_sync.created_at.isoformat()
|
||||
if latest_sync.created_at
|
||||
else None,
|
||||
}
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get sync status overview: {str(e)}")
|
||||
return JSONResponse(
|
||||
status_code=500, content={"error": "Failed to get sync status overview"}
|
||||
)
|
||||
|
||||
return router
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import io
|
||||
import os
|
||||
import uuid
|
||||
import asyncio
|
||||
|
|
@ -5,8 +6,12 @@ import aiohttp
|
|||
from pydantic import BaseModel
|
||||
from typing import List, Optional
|
||||
from datetime import datetime, timezone
|
||||
from dataclasses import dataclass
|
||||
|
||||
from cognee.api.v1.cognify import cognify
|
||||
|
||||
from cognee.infrastructure.files.storage import get_file_storage
|
||||
from cognee.tasks.ingestion.ingest_data import ingest_data
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.data.models import Dataset
|
||||
|
|
@ -19,7 +24,28 @@ from cognee.modules.sync.methods import (
|
|||
mark_sync_failed,
|
||||
)
|
||||
|
||||
logger = get_logger()
|
||||
logger = get_logger("sync")
|
||||
|
||||
|
||||
async def _safe_update_progress(run_id: str, stage: str, **kwargs):
|
||||
"""
|
||||
Safely update sync progress with better error handling and context.
|
||||
|
||||
Args:
|
||||
run_id: Sync operation run ID
|
||||
progress_percentage: Progress percentage (0-100)
|
||||
stage: Description of current stage for logging
|
||||
**kwargs: Additional fields to update (records_downloaded, records_uploaded, etc.)
|
||||
"""
|
||||
try:
|
||||
await update_sync_operation(run_id, **kwargs)
|
||||
logger.info(f"Sync {run_id}: Progress updated during {stage}")
|
||||
except Exception as e:
|
||||
# Log error but don't fail the sync - progress updates are nice-to-have
|
||||
logger.warning(
|
||||
f"Sync {run_id}: Non-critical progress update failed during {stage}: {str(e)}"
|
||||
)
|
||||
# Continue without raising - sync operation is more important than progress tracking
|
||||
|
||||
|
||||
class LocalFileInfo(BaseModel):
|
||||
|
|
@ -38,13 +64,16 @@ class LocalFileInfo(BaseModel):
|
|||
class CheckMissingHashesRequest(BaseModel):
|
||||
"""Request model for checking missing hashes in a dataset"""
|
||||
|
||||
dataset_id: str
|
||||
dataset_name: str
|
||||
hashes: List[str]
|
||||
|
||||
|
||||
class CheckMissingHashesResponse(BaseModel):
|
||||
class CheckHashesDiffResponse(BaseModel):
|
||||
"""Response model for missing hashes check"""
|
||||
|
||||
missing: List[str]
|
||||
missing_on_remote: List[str]
|
||||
missing_on_local: List[str]
|
||||
|
||||
|
||||
class PruneDatasetRequest(BaseModel):
|
||||
|
|
@ -58,34 +87,34 @@ class SyncResponse(BaseModel):
|
|||
|
||||
run_id: str
|
||||
status: str # "started" for immediate response
|
||||
dataset_id: str
|
||||
dataset_name: str
|
||||
dataset_ids: List[str]
|
||||
dataset_names: List[str]
|
||||
message: str
|
||||
timestamp: str
|
||||
user_id: str
|
||||
|
||||
|
||||
async def sync(
|
||||
dataset: Dataset,
|
||||
datasets: List[Dataset],
|
||||
user: User,
|
||||
) -> SyncResponse:
|
||||
"""
|
||||
Sync local Cognee data to Cognee Cloud.
|
||||
|
||||
This function handles synchronization of local datasets, knowledge graphs, and
|
||||
This function handles synchronization of multiple datasets, knowledge graphs, and
|
||||
processed data to the Cognee Cloud infrastructure. It uploads local data for
|
||||
cloud-based processing, backup, and sharing.
|
||||
|
||||
Args:
|
||||
dataset: Dataset object to sync (permissions already verified)
|
||||
datasets: List of Dataset objects to sync (permissions already verified)
|
||||
user: User object for authentication and permissions
|
||||
|
||||
Returns:
|
||||
SyncResponse model with immediate response:
|
||||
- run_id: Unique identifier for tracking this sync operation
|
||||
- status: Always "started" (sync runs in background)
|
||||
- dataset_id: ID of the dataset being synced
|
||||
- dataset_name: Name of the dataset being synced
|
||||
- dataset_ids: List of dataset IDs being synced
|
||||
- dataset_names: List of dataset names being synced
|
||||
- message: Description of what's happening
|
||||
- timestamp: When the sync was initiated
|
||||
- user_id: User who initiated the sync
|
||||
|
|
@ -94,8 +123,8 @@ async def sync(
|
|||
ConnectionError: If Cognee Cloud service is unreachable
|
||||
Exception: For other sync-related errors
|
||||
"""
|
||||
if not dataset:
|
||||
raise ValueError("Dataset must be provided for sync operation")
|
||||
if not datasets:
|
||||
raise ValueError("At least one dataset must be provided for sync operation")
|
||||
|
||||
# Generate a unique run ID
|
||||
run_id = str(uuid.uuid4())
|
||||
|
|
@ -103,12 +132,16 @@ async def sync(
|
|||
# Get current timestamp
|
||||
timestamp = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
logger.info(f"Starting cloud sync operation {run_id}: dataset {dataset.name} ({dataset.id})")
|
||||
dataset_info = ", ".join([f"{d.name} ({d.id})" for d in datasets])
|
||||
logger.info(f"Starting cloud sync operation {run_id}: datasets {dataset_info}")
|
||||
|
||||
# Create sync operation record in database (total_records will be set during background sync)
|
||||
try:
|
||||
await create_sync_operation(
|
||||
run_id=run_id, dataset_id=dataset.id, dataset_name=dataset.name, user_id=user.id
|
||||
run_id=run_id,
|
||||
dataset_ids=[d.id for d in datasets],
|
||||
dataset_names=[d.name for d in datasets],
|
||||
user_id=user.id,
|
||||
)
|
||||
logger.info(f"Created sync operation record for {run_id}")
|
||||
except Exception as e:
|
||||
|
|
@ -116,44 +149,74 @@ async def sync(
|
|||
# Continue without database tracking if record creation fails
|
||||
|
||||
# Start the sync operation in the background
|
||||
asyncio.create_task(_perform_background_sync(run_id, dataset, user))
|
||||
asyncio.create_task(_perform_background_sync(run_id, datasets, user))
|
||||
|
||||
# Return immediately with run_id
|
||||
return SyncResponse(
|
||||
run_id=run_id,
|
||||
status="started",
|
||||
dataset_id=str(dataset.id),
|
||||
dataset_name=dataset.name,
|
||||
message=f"Sync operation started in background. Use run_id '{run_id}' to track progress.",
|
||||
dataset_ids=[str(d.id) for d in datasets],
|
||||
dataset_names=[d.name for d in datasets],
|
||||
message=f"Sync operation started in background for {len(datasets)} datasets. Use run_id '{run_id}' to track progress.",
|
||||
timestamp=timestamp,
|
||||
user_id=str(user.id),
|
||||
)
|
||||
|
||||
|
||||
async def _perform_background_sync(run_id: str, dataset: Dataset, user: User) -> None:
|
||||
"""Perform the actual sync operation in the background."""
|
||||
async def _perform_background_sync(run_id: str, datasets: List[Dataset], user: User) -> None:
|
||||
"""Perform the actual sync operation in the background for multiple datasets."""
|
||||
start_time = datetime.now(timezone.utc)
|
||||
|
||||
try:
|
||||
logger.info(
|
||||
f"Background sync {run_id}: Starting sync for dataset {dataset.name} ({dataset.id})"
|
||||
)
|
||||
dataset_info = ", ".join([f"{d.name} ({d.id})" for d in datasets])
|
||||
logger.info(f"Background sync {run_id}: Starting sync for datasets {dataset_info}")
|
||||
|
||||
# Mark sync as in progress
|
||||
await mark_sync_started(run_id)
|
||||
|
||||
# Perform the actual sync operation
|
||||
records_processed, bytes_transferred = await _sync_to_cognee_cloud(dataset, user, run_id)
|
||||
MAX_RETRY_COUNT = 3
|
||||
retry_count = 0
|
||||
while retry_count < MAX_RETRY_COUNT:
|
||||
try:
|
||||
(
|
||||
records_downloaded,
|
||||
records_uploaded,
|
||||
bytes_downloaded,
|
||||
bytes_uploaded,
|
||||
dataset_sync_hashes,
|
||||
) = await _sync_to_cognee_cloud(datasets, user, run_id)
|
||||
break
|
||||
except Exception as e:
|
||||
retry_count += 1
|
||||
logger.error(
|
||||
f"Background sync {run_id}: Failed after {retry_count} retries with error: {str(e)}"
|
||||
)
|
||||
await update_sync_operation(run_id, retry_count=retry_count)
|
||||
await asyncio.sleep(2**retry_count)
|
||||
continue
|
||||
|
||||
if retry_count == MAX_RETRY_COUNT:
|
||||
logger.error(f"Background sync {run_id}: Failed after {MAX_RETRY_COUNT} retries")
|
||||
await mark_sync_failed(run_id, "Failed after 3 retries")
|
||||
return
|
||||
|
||||
end_time = datetime.now(timezone.utc)
|
||||
duration = (end_time - start_time).total_seconds()
|
||||
|
||||
logger.info(
|
||||
f"Background sync {run_id}: Completed successfully. Records: {records_processed}, Bytes: {bytes_transferred}, Duration: {duration}s"
|
||||
f"Background sync {run_id}: Completed successfully. Downloaded: {records_downloaded} records/{bytes_downloaded} bytes, Uploaded: {records_uploaded} records/{bytes_uploaded} bytes, Duration: {duration}s"
|
||||
)
|
||||
|
||||
# Mark sync as completed with final stats
|
||||
await mark_sync_completed(run_id, records_processed, bytes_transferred)
|
||||
# Mark sync as completed with final stats and data lineage
|
||||
await mark_sync_completed(
|
||||
run_id,
|
||||
records_downloaded,
|
||||
records_uploaded,
|
||||
bytes_downloaded,
|
||||
bytes_uploaded,
|
||||
dataset_sync_hashes,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
end_time = datetime.now(timezone.utc)
|
||||
|
|
@ -165,89 +228,248 @@ async def _perform_background_sync(run_id: str, dataset: Dataset, user: User) ->
|
|||
await mark_sync_failed(run_id, str(e))
|
||||
|
||||
|
||||
async def _sync_to_cognee_cloud(dataset: Dataset, user: User, run_id: str) -> tuple[int, int]:
|
||||
async def _sync_to_cognee_cloud(
|
||||
datasets: List[Dataset], user: User, run_id: str
|
||||
) -> tuple[int, int, int, int, dict]:
|
||||
"""
|
||||
Sync local data to Cognee Cloud using three-step idempotent process:
|
||||
1. Extract local files with stored MD5 hashes and check what's missing on cloud
|
||||
2. Upload missing files individually
|
||||
3. Prune cloud dataset to match local state
|
||||
"""
|
||||
logger.info(f"Starting sync to Cognee Cloud: dataset {dataset.name} ({dataset.id})")
|
||||
dataset_info = ", ".join([f"{d.name} ({d.id})" for d in datasets])
|
||||
logger.info(f"Starting sync to Cognee Cloud: datasets {dataset_info}")
|
||||
|
||||
total_records_downloaded = 0
|
||||
total_records_uploaded = 0
|
||||
total_bytes_downloaded = 0
|
||||
total_bytes_uploaded = 0
|
||||
dataset_sync_hashes = {}
|
||||
|
||||
try:
|
||||
# Get cloud configuration
|
||||
cloud_base_url = await _get_cloud_base_url()
|
||||
cloud_auth_token = await _get_cloud_auth_token(user)
|
||||
|
||||
logger.info(f"Cloud API URL: {cloud_base_url}")
|
||||
# Step 1: Sync files for all datasets concurrently
|
||||
sync_files_tasks = [
|
||||
_sync_dataset_files(dataset, cloud_base_url, cloud_auth_token, user, run_id)
|
||||
for dataset in datasets
|
||||
]
|
||||
|
||||
# Step 1: Extract local file info with stored hashes
|
||||
local_files = await _extract_local_files_with_hashes(dataset, user, run_id)
|
||||
logger.info(f"Found {len(local_files)} local files to sync")
|
||||
logger.info(f"Starting concurrent file sync for {len(datasets)} datasets")
|
||||
|
||||
# Update sync operation with total file count
|
||||
try:
|
||||
await update_sync_operation(run_id, processed_records=0)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize sync progress: {str(e)}")
|
||||
has_any_uploads = False
|
||||
has_any_downloads = False
|
||||
processed_datasets = []
|
||||
completed_datasets = 0
|
||||
|
||||
if not local_files:
|
||||
logger.info("No files to sync - dataset is empty")
|
||||
return 0, 0
|
||||
# Process datasets concurrently and accumulate results
|
||||
for completed_task in asyncio.as_completed(sync_files_tasks):
|
||||
try:
|
||||
dataset_result = await completed_task
|
||||
completed_datasets += 1
|
||||
|
||||
# Step 2: Check what files are missing on cloud
|
||||
local_hashes = [f.content_hash for f in local_files]
|
||||
missing_hashes = await _check_missing_hashes(
|
||||
cloud_base_url, cloud_auth_token, dataset.id, local_hashes, run_id
|
||||
)
|
||||
logger.info(f"Cloud is missing {len(missing_hashes)} out of {len(local_hashes)} files")
|
||||
# Update progress based on completed datasets (0-80% for file sync)
|
||||
file_sync_progress = int((completed_datasets / len(datasets)) * 80)
|
||||
await _safe_update_progress(
|
||||
run_id, "file_sync", progress_percentage=file_sync_progress
|
||||
)
|
||||
|
||||
# Update progress
|
||||
try:
|
||||
await update_sync_operation(run_id, progress_percentage=25)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to update progress: {str(e)}")
|
||||
if dataset_result is None:
|
||||
logger.info(
|
||||
f"Progress: {completed_datasets}/{len(datasets)} datasets processed ({file_sync_progress}%)"
|
||||
)
|
||||
continue
|
||||
|
||||
# Step 3: Upload missing files
|
||||
bytes_uploaded = await _upload_missing_files(
|
||||
cloud_base_url, cloud_auth_token, dataset, local_files, missing_hashes, run_id
|
||||
)
|
||||
logger.info(f"Upload complete: {len(missing_hashes)} files, {bytes_uploaded} bytes")
|
||||
total_records_downloaded += dataset_result.records_downloaded
|
||||
total_records_uploaded += dataset_result.records_uploaded
|
||||
total_bytes_downloaded += dataset_result.bytes_downloaded
|
||||
total_bytes_uploaded += dataset_result.bytes_uploaded
|
||||
|
||||
# Update progress
|
||||
try:
|
||||
await update_sync_operation(run_id, progress_percentage=75)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to update progress: {str(e)}")
|
||||
# Build per-dataset hash tracking for data lineage
|
||||
dataset_sync_hashes[dataset_result.dataset_id] = {
|
||||
"uploaded": dataset_result.uploaded_hashes,
|
||||
"downloaded": dataset_result.downloaded_hashes,
|
||||
}
|
||||
|
||||
# Step 4: Trigger cognify processing on cloud dataset (only if new files were uploaded)
|
||||
if missing_hashes:
|
||||
await _trigger_remote_cognify(cloud_base_url, cloud_auth_token, dataset.id, run_id)
|
||||
logger.info(f"Cognify processing triggered for dataset {dataset.id}")
|
||||
if dataset_result.has_uploads:
|
||||
has_any_uploads = True
|
||||
if dataset_result.has_downloads:
|
||||
has_any_downloads = True
|
||||
|
||||
processed_datasets.append(dataset_result.dataset_id)
|
||||
|
||||
logger.info(
|
||||
f"Progress: {completed_datasets}/{len(datasets)} datasets processed ({file_sync_progress}%) - "
|
||||
f"Completed file sync for dataset {dataset_result.dataset_name}: "
|
||||
f"↑{dataset_result.records_uploaded} files ({dataset_result.bytes_uploaded} bytes), "
|
||||
f"↓{dataset_result.records_downloaded} files ({dataset_result.bytes_downloaded} bytes)"
|
||||
)
|
||||
except Exception as e:
|
||||
completed_datasets += 1
|
||||
logger.error(f"Dataset file sync failed: {str(e)}")
|
||||
# Update progress even for failed datasets
|
||||
file_sync_progress = int((completed_datasets / len(datasets)) * 80)
|
||||
await _safe_update_progress(
|
||||
run_id, "file_sync", progress_percentage=file_sync_progress
|
||||
)
|
||||
# Continue with other datasets even if one fails
|
||||
|
||||
# Step 2: Trigger cognify processing once for all datasets (only if any files were uploaded)
|
||||
# Update progress to 90% before cognify
|
||||
await _safe_update_progress(run_id, "cognify", progress_percentage=90)
|
||||
|
||||
if has_any_uploads and processed_datasets:
|
||||
logger.info(
|
||||
f"Progress: 90% - Triggering cognify processing for {len(processed_datasets)} datasets with new files"
|
||||
)
|
||||
try:
|
||||
# Trigger cognify for all datasets at once - use first dataset as reference point
|
||||
await _trigger_remote_cognify(
|
||||
cloud_base_url, cloud_auth_token, datasets[0].id, run_id
|
||||
)
|
||||
logger.info("Cognify processing triggered successfully for all datasets")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to trigger cognify processing: {str(e)}")
|
||||
# Don't fail the entire sync if cognify fails
|
||||
else:
|
||||
logger.info(
|
||||
f"Skipping cognify processing - no new files were uploaded for dataset {dataset.id}"
|
||||
"Progress: 90% - Skipping cognify processing - no new files were uploaded across any datasets"
|
||||
)
|
||||
|
||||
# Final progress
|
||||
try:
|
||||
await update_sync_operation(run_id, progress_percentage=100)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to update final progress: {str(e)}")
|
||||
# Step 3: Trigger local cognify processing if any files were downloaded
|
||||
if has_any_downloads and processed_datasets:
|
||||
logger.info(
|
||||
f"Progress: 95% - Triggering local cognify processing for {len(processed_datasets)} datasets with downloaded files"
|
||||
)
|
||||
try:
|
||||
await cognify()
|
||||
logger.info("Local cognify processing completed successfully for all datasets")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to run local cognify processing: {str(e)}")
|
||||
# Don't fail the entire sync if local cognify fails
|
||||
else:
|
||||
logger.info(
|
||||
"Progress: 95% - Skipping local cognify processing - no new files were downloaded across any datasets"
|
||||
)
|
||||
|
||||
records_processed = len(local_files)
|
||||
# Update final progress
|
||||
try:
|
||||
await _safe_update_progress(
|
||||
run_id,
|
||||
"final",
|
||||
progress_percentage=100,
|
||||
total_records_to_sync=total_records_uploaded + total_records_downloaded,
|
||||
total_records_to_download=total_records_downloaded,
|
||||
total_records_to_upload=total_records_uploaded,
|
||||
records_downloaded=total_records_downloaded,
|
||||
records_uploaded=total_records_uploaded,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to update final sync progress: {str(e)}")
|
||||
|
||||
logger.info(
|
||||
f"Sync completed successfully: {records_processed} records, {bytes_uploaded} bytes uploaded"
|
||||
f"Multi-dataset sync completed: {len(datasets)} datasets processed, downloaded {total_records_downloaded} records/{total_bytes_downloaded} bytes, uploaded {total_records_uploaded} records/{total_bytes_uploaded} bytes"
|
||||
)
|
||||
|
||||
return records_processed, bytes_uploaded
|
||||
return (
|
||||
total_records_downloaded,
|
||||
total_records_uploaded,
|
||||
total_bytes_downloaded,
|
||||
total_bytes_uploaded,
|
||||
dataset_sync_hashes,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Sync failed: {str(e)}")
|
||||
raise ConnectionError(f"Cloud sync failed: {str(e)}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetSyncResult:
|
||||
"""Result of syncing files for a single dataset."""
|
||||
|
||||
dataset_name: str
|
||||
dataset_id: str
|
||||
records_downloaded: int
|
||||
records_uploaded: int
|
||||
bytes_downloaded: int
|
||||
bytes_uploaded: int
|
||||
has_uploads: bool # Whether any files were uploaded (for cognify decision)
|
||||
has_downloads: bool # Whether any files were downloaded (for cognify decision)
|
||||
uploaded_hashes: List[str] # Content hashes of files uploaded during sync
|
||||
downloaded_hashes: List[str] # Content hashes of files downloaded during sync
|
||||
|
||||
|
||||
async def _sync_dataset_files(
|
||||
dataset: Dataset, cloud_base_url: str, cloud_auth_token: str, user: User, run_id: str
|
||||
) -> Optional[DatasetSyncResult]:
|
||||
"""
|
||||
Sync files for a single dataset (2-way: upload to cloud, download from cloud).
|
||||
Does NOT trigger cognify - that's done separately once for all datasets.
|
||||
|
||||
Returns:
|
||||
DatasetSyncResult with sync results or None if dataset was empty
|
||||
"""
|
||||
logger.info(f"Syncing files for dataset: {dataset.name} ({dataset.id})")
|
||||
|
||||
try:
|
||||
# Step 1: Extract local file info with stored hashes
|
||||
local_files = await _extract_local_files_with_hashes(dataset, user, run_id)
|
||||
logger.info(f"Found {len(local_files)} local files for dataset {dataset.name}")
|
||||
|
||||
if not local_files:
|
||||
logger.info(f"No files to sync for dataset {dataset.name} - skipping")
|
||||
return None
|
||||
|
||||
# Step 2: Check what files are missing on cloud
|
||||
local_hashes = [f.content_hash for f in local_files]
|
||||
hashes_diff_response = await _check_hashes_diff(
|
||||
cloud_base_url, cloud_auth_token, dataset, local_hashes, run_id
|
||||
)
|
||||
|
||||
hashes_missing_on_remote = hashes_diff_response.missing_on_remote
|
||||
hashes_missing_on_local = hashes_diff_response.missing_on_local
|
||||
|
||||
logger.info(
|
||||
f"Dataset {dataset.name}: {len(hashes_missing_on_remote)} files to upload, {len(hashes_missing_on_local)} files to download"
|
||||
)
|
||||
|
||||
# Step 3: Upload files that are missing on cloud
|
||||
bytes_uploaded = await _upload_missing_files(
|
||||
cloud_base_url, cloud_auth_token, dataset, local_files, hashes_missing_on_remote, run_id
|
||||
)
|
||||
logger.info(
|
||||
f"Dataset {dataset.name}: Upload complete - {len(hashes_missing_on_remote)} files, {bytes_uploaded} bytes"
|
||||
)
|
||||
|
||||
# Step 4: Download files that are missing locally
|
||||
bytes_downloaded = await _download_missing_files(
|
||||
cloud_base_url, cloud_auth_token, dataset, hashes_missing_on_local, user
|
||||
)
|
||||
logger.info(
|
||||
f"Dataset {dataset.name}: Download complete - {len(hashes_missing_on_local)} files, {bytes_downloaded} bytes"
|
||||
)
|
||||
|
||||
return DatasetSyncResult(
|
||||
dataset_name=dataset.name,
|
||||
dataset_id=str(dataset.id),
|
||||
records_downloaded=len(hashes_missing_on_local),
|
||||
records_uploaded=len(hashes_missing_on_remote),
|
||||
bytes_downloaded=bytes_downloaded,
|
||||
bytes_uploaded=bytes_uploaded,
|
||||
has_uploads=len(hashes_missing_on_remote) > 0,
|
||||
has_downloads=len(hashes_missing_on_local) > 0,
|
||||
uploaded_hashes=hashes_missing_on_remote,
|
||||
downloaded_hashes=hashes_missing_on_local,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to sync files for dataset {dataset.name} ({dataset.id}): {str(e)}")
|
||||
raise # Re-raise to be handled by the caller
|
||||
|
||||
|
||||
async def _extract_local_files_with_hashes(
|
||||
dataset: Dataset, user: User, run_id: str
|
||||
) -> List[LocalFileInfo]:
|
||||
|
|
@ -334,43 +556,42 @@ async def _get_file_size(file_path: str) -> int:
|
|||
|
||||
async def _get_cloud_base_url() -> str:
|
||||
"""Get Cognee Cloud API base URL."""
|
||||
# TODO: Make this configurable via environment variable or config
|
||||
return os.getenv("COGNEE_CLOUD_API_URL", "http://localhost:8001")
|
||||
|
||||
|
||||
async def _get_cloud_auth_token(user: User) -> str:
|
||||
"""Get authentication token for Cognee Cloud API."""
|
||||
# TODO: Implement proper authentication with Cognee Cloud
|
||||
# This should get or refresh an API token for the user
|
||||
return os.getenv("COGNEE_CLOUD_AUTH_TOKEN", "your-auth-token-here")
|
||||
return os.getenv("COGNEE_CLOUD_AUTH_TOKEN", "your-auth-token")
|
||||
|
||||
|
||||
async def _check_missing_hashes(
|
||||
cloud_base_url: str, auth_token: str, dataset_id: str, local_hashes: List[str], run_id: str
|
||||
) -> List[str]:
|
||||
async def _check_hashes_diff(
|
||||
cloud_base_url: str, auth_token: str, dataset: Dataset, local_hashes: List[str], run_id: str
|
||||
) -> CheckHashesDiffResponse:
|
||||
"""
|
||||
Step 1: Check which hashes are missing on cloud.
|
||||
Check which hashes are missing on cloud.
|
||||
|
||||
Returns:
|
||||
List[str]: MD5 hashes that need to be uploaded
|
||||
"""
|
||||
url = f"{cloud_base_url}/api/sync/{dataset_id}/diff"
|
||||
url = f"{cloud_base_url}/api/sync/{dataset.id}/diff"
|
||||
headers = {"X-Api-Key": auth_token, "Content-Type": "application/json"}
|
||||
|
||||
payload = CheckMissingHashesRequest(hashes=local_hashes)
|
||||
payload = CheckMissingHashesRequest(
|
||||
dataset_id=str(dataset.id), dataset_name=dataset.name, hashes=local_hashes
|
||||
)
|
||||
|
||||
logger.info(f"Checking missing hashes on cloud for dataset {dataset_id}")
|
||||
logger.info(f"Checking missing hashes on cloud for dataset {dataset.id}")
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, json=payload.dict(), headers=headers) as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
missing_response = CheckMissingHashesResponse(**data)
|
||||
missing_response = CheckHashesDiffResponse(**data)
|
||||
logger.info(
|
||||
f"Cloud reports {len(missing_response.missing)} missing files out of {len(local_hashes)} total"
|
||||
f"Cloud is missing {len(missing_response.missing_on_remote)} out of {len(local_hashes)} files, local is missing {len(missing_response.missing_on_local)} files"
|
||||
)
|
||||
return missing_response.missing
|
||||
return missing_response
|
||||
else:
|
||||
error_text = await response.text()
|
||||
logger.error(
|
||||
|
|
@ -385,22 +606,137 @@ async def _check_missing_hashes(
|
|||
raise ConnectionError(f"Failed to check missing hashes: {str(e)}")
|
||||
|
||||
|
||||
async def _download_missing_files(
|
||||
cloud_base_url: str,
|
||||
auth_token: str,
|
||||
dataset: Dataset,
|
||||
hashes_missing_on_local: List[str],
|
||||
user: User,
|
||||
) -> int:
|
||||
"""
|
||||
Download files that are missing locally from the cloud.
|
||||
|
||||
Returns:
|
||||
int: Total bytes downloaded
|
||||
"""
|
||||
logger.info(f"Downloading {len(hashes_missing_on_local)} missing files from cloud")
|
||||
|
||||
if not hashes_missing_on_local:
|
||||
logger.info("No files need to be downloaded - all files already exist locally")
|
||||
return 0
|
||||
|
||||
total_bytes_downloaded = 0
|
||||
downloaded_count = 0
|
||||
|
||||
headers = {"X-Api-Key": auth_token}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
for file_hash in hashes_missing_on_local:
|
||||
try:
|
||||
# Download file from cloud by hash
|
||||
download_url = f"{cloud_base_url}/api/sync/{dataset.id}/data/{file_hash}"
|
||||
|
||||
logger.debug(f"Downloading file with hash: {file_hash}")
|
||||
|
||||
async with session.get(download_url, headers=headers) as response:
|
||||
if response.status == 200:
|
||||
file_content = await response.read()
|
||||
file_size = len(file_content)
|
||||
|
||||
# Get file metadata from response headers
|
||||
file_name = response.headers.get("X-File-Name", f"file_{file_hash}")
|
||||
|
||||
# Save file locally using ingestion pipeline
|
||||
await _save_downloaded_file(
|
||||
dataset, file_hash, file_name, file_content, user
|
||||
)
|
||||
|
||||
total_bytes_downloaded += file_size
|
||||
downloaded_count += 1
|
||||
|
||||
logger.debug(f"Successfully downloaded {file_name} ({file_size} bytes)")
|
||||
|
||||
elif response.status == 404:
|
||||
logger.warning(f"File with hash {file_hash} not found on cloud")
|
||||
continue
|
||||
else:
|
||||
error_text = await response.text()
|
||||
logger.error(
|
||||
f"Failed to download file {file_hash}: Status {response.status} - {error_text}"
|
||||
)
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error downloading file {file_hash}: {str(e)}")
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
f"Download summary: {downloaded_count}/{len(hashes_missing_on_local)} files downloaded, {total_bytes_downloaded} bytes total"
|
||||
)
|
||||
return total_bytes_downloaded
|
||||
|
||||
|
||||
class InMemoryDownload:
|
||||
def __init__(self, data: bytes, filename: str):
|
||||
self.file = io.BufferedReader(io.BytesIO(data))
|
||||
self.filename = filename
|
||||
|
||||
|
||||
async def _save_downloaded_file(
|
||||
dataset: Dataset,
|
||||
file_hash: str,
|
||||
file_name: str,
|
||||
file_content: bytes,
|
||||
user: User,
|
||||
) -> None:
|
||||
"""
|
||||
Save a downloaded file to local storage and register it in the dataset.
|
||||
Uses the existing ingest_data function for consistency with normal ingestion.
|
||||
|
||||
Args:
|
||||
dataset: The dataset to add the file to
|
||||
file_hash: MD5 hash of the file content
|
||||
file_name: Original file name
|
||||
file_content: Raw file content bytes
|
||||
"""
|
||||
try:
|
||||
# Create a temporary file-like object from the bytes
|
||||
file_obj = InMemoryDownload(file_content, file_name)
|
||||
|
||||
# User is injected as dependency
|
||||
|
||||
# Use the existing ingest_data function to properly handle the file
|
||||
# This ensures consistency with normal file ingestion
|
||||
await ingest_data(
|
||||
data=file_obj,
|
||||
dataset_name=dataset.name,
|
||||
user=user,
|
||||
dataset_id=dataset.id,
|
||||
)
|
||||
|
||||
logger.debug(f"Successfully saved downloaded file: {file_name} (hash: {file_hash})")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save downloaded file {file_name}: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
async def _upload_missing_files(
|
||||
cloud_base_url: str,
|
||||
auth_token: str,
|
||||
dataset: Dataset,
|
||||
local_files: List[LocalFileInfo],
|
||||
missing_hashes: List[str],
|
||||
hashes_missing_on_remote: List[str],
|
||||
run_id: str,
|
||||
) -> int:
|
||||
"""
|
||||
Step 2: Upload files that are missing on cloud.
|
||||
Upload files that are missing on cloud.
|
||||
|
||||
Returns:
|
||||
int: Total bytes uploaded
|
||||
"""
|
||||
# Filter local files to only those with missing hashes
|
||||
files_to_upload = [f for f in local_files if f.content_hash in missing_hashes]
|
||||
files_to_upload = [f for f in local_files if f.content_hash in hashes_missing_on_remote]
|
||||
|
||||
logger.info(f"Uploading {len(files_to_upload)} missing files to cloud")
|
||||
|
||||
|
|
@ -442,13 +778,6 @@ async def _upload_missing_files(
|
|||
if response.status in [200, 201]:
|
||||
total_bytes_uploaded += len(file_content)
|
||||
uploaded_count += 1
|
||||
|
||||
# Update progress periodically
|
||||
if uploaded_count % 10 == 0:
|
||||
progress = (
|
||||
25 + (uploaded_count / len(files_to_upload)) * 50
|
||||
) # 25-75% range
|
||||
await update_sync_operation(run_id, progress_percentage=int(progress))
|
||||
else:
|
||||
error_text = await response.text()
|
||||
logger.error(
|
||||
|
|
@ -470,7 +799,7 @@ async def _prune_cloud_dataset(
|
|||
cloud_base_url: str, auth_token: str, dataset_id: str, local_hashes: List[str], run_id: str
|
||||
) -> None:
|
||||
"""
|
||||
Step 3: Prune cloud dataset to match local state.
|
||||
Prune cloud dataset to match local state.
|
||||
"""
|
||||
url = f"{cloud_base_url}/api/sync/{dataset_id}?prune=true"
|
||||
headers = {"X-Api-Key": auth_token, "Content-Type": "application/json"}
|
||||
|
|
@ -506,7 +835,7 @@ async def _trigger_remote_cognify(
|
|||
cloud_base_url: str, auth_token: str, dataset_id: str, run_id: str
|
||||
) -> None:
|
||||
"""
|
||||
Step 4: Trigger cognify processing on the cloud dataset.
|
||||
Trigger cognify processing on the cloud dataset.
|
||||
|
||||
This initiates knowledge graph processing on the synchronized dataset
|
||||
using the cloud infrastructure.
|
||||
|
|
|
|||
|
|
@ -1,5 +1,9 @@
|
|||
from .create_sync_operation import create_sync_operation
|
||||
from .get_sync_operation import get_sync_operation, get_user_sync_operations
|
||||
from .get_sync_operation import (
|
||||
get_sync_operation,
|
||||
get_user_sync_operations,
|
||||
get_running_sync_operations_for_user,
|
||||
)
|
||||
from .update_sync_operation import (
|
||||
update_sync_operation,
|
||||
mark_sync_started,
|
||||
|
|
@ -11,6 +15,7 @@ __all__ = [
|
|||
"create_sync_operation",
|
||||
"get_sync_operation",
|
||||
"get_user_sync_operations",
|
||||
"get_running_sync_operations_for_user",
|
||||
"update_sync_operation",
|
||||
"mark_sync_started",
|
||||
"mark_sync_completed",
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from uuid import UUID
|
||||
from typing import Optional
|
||||
from typing import Optional, List
|
||||
from datetime import datetime, timezone
|
||||
from cognee.modules.sync.models import SyncOperation, SyncStatus
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
|
|
@ -7,20 +7,24 @@ from cognee.infrastructure.databases.relational import get_relational_engine
|
|||
|
||||
async def create_sync_operation(
|
||||
run_id: str,
|
||||
dataset_id: UUID,
|
||||
dataset_name: str,
|
||||
dataset_ids: List[UUID],
|
||||
dataset_names: List[str],
|
||||
user_id: UUID,
|
||||
total_records: Optional[int] = None,
|
||||
total_records_to_sync: Optional[int] = None,
|
||||
total_records_to_download: Optional[int] = None,
|
||||
total_records_to_upload: Optional[int] = None,
|
||||
) -> SyncOperation:
|
||||
"""
|
||||
Create a new sync operation record in the database.
|
||||
|
||||
Args:
|
||||
run_id: Unique public identifier for this sync operation
|
||||
dataset_id: UUID of the dataset being synced
|
||||
dataset_name: Name of the dataset being synced
|
||||
dataset_ids: List of dataset UUIDs being synced
|
||||
dataset_names: List of dataset names being synced
|
||||
user_id: UUID of the user who initiated the sync
|
||||
total_records: Total number of records to sync (if known)
|
||||
total_records_to_sync: Total number of records to sync (if known)
|
||||
total_records_to_download: Total number of records to download (if known)
|
||||
total_records_to_upload: Total number of records to upload (if known)
|
||||
|
||||
Returns:
|
||||
SyncOperation: The created sync operation record
|
||||
|
|
@ -29,11 +33,15 @@ async def create_sync_operation(
|
|||
|
||||
sync_operation = SyncOperation(
|
||||
run_id=run_id,
|
||||
dataset_id=dataset_id,
|
||||
dataset_name=dataset_name,
|
||||
dataset_ids=[
|
||||
str(uuid) for uuid in dataset_ids
|
||||
], # Convert UUIDs to strings for JSON storage
|
||||
dataset_names=dataset_names,
|
||||
user_id=user_id,
|
||||
status=SyncStatus.STARTED,
|
||||
total_records=total_records,
|
||||
total_records_to_sync=total_records_to_sync,
|
||||
total_records_to_download=total_records_to_download,
|
||||
total_records_to_upload=total_records_to_upload,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from uuid import UUID
|
||||
from typing import List, Optional
|
||||
from sqlalchemy import select, desc
|
||||
from cognee.modules.sync.models import SyncOperation
|
||||
from sqlalchemy import select, desc, and_
|
||||
from cognee.modules.sync.models import SyncOperation, SyncStatus
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
|
||||
|
||||
|
|
@ -77,3 +77,31 @@ async def get_sync_operations_by_dataset(
|
|||
)
|
||||
result = await session.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
async def get_running_sync_operations_for_user(user_id: UUID) -> List[SyncOperation]:
|
||||
"""
|
||||
Get all currently running sync operations for a specific user.
|
||||
Checks for operations with STARTED or IN_PROGRESS status.
|
||||
|
||||
Args:
|
||||
user_id: UUID of the user
|
||||
|
||||
Returns:
|
||||
List[SyncOperation]: List of running sync operations for the user
|
||||
"""
|
||||
db_engine = get_relational_engine()
|
||||
|
||||
async with db_engine.get_async_session() as session:
|
||||
query = (
|
||||
select(SyncOperation)
|
||||
.where(
|
||||
and_(
|
||||
SyncOperation.user_id == user_id,
|
||||
SyncOperation.status.in_([SyncStatus.STARTED, SyncStatus.IN_PROGRESS]),
|
||||
)
|
||||
)
|
||||
.order_by(desc(SyncOperation.created_at))
|
||||
)
|
||||
result = await session.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
|
|
|||
|
|
@ -1,16 +1,74 @@
|
|||
from typing import Optional
|
||||
import asyncio
|
||||
from typing import Optional, List
|
||||
from datetime import datetime, timezone
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import SQLAlchemyError, DisconnectionError, OperationalError, TimeoutError
|
||||
from cognee.modules.sync.models import SyncOperation, SyncStatus
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.infrastructure.utils.calculate_backoff import calculate_backoff
|
||||
|
||||
logger = get_logger("sync.db_operations")
|
||||
|
||||
|
||||
async def _retry_db_operation(operation_func, run_id: str, max_retries: int = 3):
|
||||
"""
|
||||
Retry database operations with exponential backoff for transient failures.
|
||||
|
||||
Args:
|
||||
operation_func: Async function to retry
|
||||
run_id: Run ID for logging context
|
||||
max_retries: Maximum number of retry attempts
|
||||
|
||||
Returns:
|
||||
Result of the operation function
|
||||
|
||||
Raises:
|
||||
Exception: Re-raises the last exception if all retries fail
|
||||
"""
|
||||
attempt = 0
|
||||
last_exception = None
|
||||
|
||||
while attempt < max_retries:
|
||||
try:
|
||||
return await operation_func()
|
||||
except (DisconnectionError, OperationalError, TimeoutError) as e:
|
||||
attempt += 1
|
||||
last_exception = e
|
||||
|
||||
if attempt >= max_retries:
|
||||
logger.error(
|
||||
f"Database operation failed after {max_retries} attempts for run_id {run_id}: {str(e)}"
|
||||
)
|
||||
break
|
||||
|
||||
backoff_time = calculate_backoff(attempt - 1) # calculate_backoff is 0-indexed
|
||||
logger.warning(
|
||||
f"Database operation failed for run_id {run_id}, retrying in {backoff_time:.2f}s (attempt {attempt}/{max_retries}): {str(e)}"
|
||||
)
|
||||
await asyncio.sleep(backoff_time)
|
||||
|
||||
except Exception as e:
|
||||
# Non-transient errors should not be retried
|
||||
logger.error(f"Non-retryable database error for run_id {run_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
# If we get here, all retries failed
|
||||
raise last_exception
|
||||
|
||||
|
||||
async def update_sync_operation(
|
||||
run_id: str,
|
||||
status: Optional[SyncStatus] = None,
|
||||
progress_percentage: Optional[int] = None,
|
||||
processed_records: Optional[int] = None,
|
||||
bytes_transferred: Optional[int] = None,
|
||||
records_downloaded: Optional[int] = None,
|
||||
total_records_to_sync: Optional[int] = None,
|
||||
total_records_to_download: Optional[int] = None,
|
||||
total_records_to_upload: Optional[int] = None,
|
||||
records_uploaded: Optional[int] = None,
|
||||
bytes_downloaded: Optional[int] = None,
|
||||
bytes_uploaded: Optional[int] = None,
|
||||
dataset_sync_hashes: Optional[dict] = None,
|
||||
error_message: Optional[str] = None,
|
||||
retry_count: Optional[int] = None,
|
||||
started_at: Optional[datetime] = None,
|
||||
|
|
@ -23,8 +81,14 @@ async def update_sync_operation(
|
|||
run_id: The public run_id of the sync operation to update
|
||||
status: New status for the operation
|
||||
progress_percentage: Progress percentage (0-100)
|
||||
processed_records: Number of records processed so far
|
||||
bytes_transferred: Total bytes transferred
|
||||
records_downloaded: Number of records downloaded so far
|
||||
total_records_to_sync: Total number of records that need to be synced
|
||||
total_records_to_download: Total number of records to download from cloud
|
||||
total_records_to_upload: Total number of records to upload to cloud
|
||||
records_uploaded: Number of records uploaded so far
|
||||
bytes_downloaded: Total bytes downloaded from cloud
|
||||
bytes_uploaded: Total bytes uploaded to cloud
|
||||
dataset_sync_hashes: Dict mapping dataset_id -> {uploaded: [hashes], downloaded: [hashes]}
|
||||
error_message: Error message if operation failed
|
||||
retry_count: Number of retry attempts
|
||||
started_at: When the actual processing started
|
||||
|
|
@ -33,57 +97,116 @@ async def update_sync_operation(
|
|||
Returns:
|
||||
SyncOperation: The updated sync operation record, or None if not found
|
||||
"""
|
||||
db_engine = get_relational_engine()
|
||||
|
||||
async with db_engine.get_async_session() as session:
|
||||
# Find the sync operation
|
||||
query = select(SyncOperation).where(SyncOperation.run_id == run_id)
|
||||
result = await session.execute(query)
|
||||
sync_operation = result.scalars().first()
|
||||
async def _perform_update():
|
||||
db_engine = get_relational_engine()
|
||||
|
||||
if not sync_operation:
|
||||
return None
|
||||
async with db_engine.get_async_session() as session:
|
||||
try:
|
||||
# Find the sync operation
|
||||
query = select(SyncOperation).where(SyncOperation.run_id == run_id)
|
||||
result = await session.execute(query)
|
||||
sync_operation = result.scalars().first()
|
||||
|
||||
# Update fields that were provided
|
||||
if status is not None:
|
||||
sync_operation.status = status
|
||||
if not sync_operation:
|
||||
logger.warning(f"Sync operation not found for run_id: {run_id}")
|
||||
return None
|
||||
|
||||
if progress_percentage is not None:
|
||||
sync_operation.progress_percentage = max(0, min(100, progress_percentage))
|
||||
# Log what we're updating for debugging
|
||||
updates = []
|
||||
if status is not None:
|
||||
updates.append(f"status={status.value}")
|
||||
if progress_percentage is not None:
|
||||
updates.append(f"progress={progress_percentage}%")
|
||||
if records_downloaded is not None:
|
||||
updates.append(f"downloaded={records_downloaded}")
|
||||
if records_uploaded is not None:
|
||||
updates.append(f"uploaded={records_uploaded}")
|
||||
if total_records_to_sync is not None:
|
||||
updates.append(f"total_sync={total_records_to_sync}")
|
||||
if total_records_to_download is not None:
|
||||
updates.append(f"total_download={total_records_to_download}")
|
||||
if total_records_to_upload is not None:
|
||||
updates.append(f"total_upload={total_records_to_upload}")
|
||||
|
||||
if processed_records is not None:
|
||||
sync_operation.processed_records = processed_records
|
||||
if updates:
|
||||
logger.debug(f"Updating sync operation {run_id}: {', '.join(updates)}")
|
||||
|
||||
if bytes_transferred is not None:
|
||||
sync_operation.bytes_transferred = bytes_transferred
|
||||
# Update fields that were provided
|
||||
if status is not None:
|
||||
sync_operation.status = status
|
||||
|
||||
if error_message is not None:
|
||||
sync_operation.error_message = error_message
|
||||
if progress_percentage is not None:
|
||||
sync_operation.progress_percentage = max(0, min(100, progress_percentage))
|
||||
|
||||
if retry_count is not None:
|
||||
sync_operation.retry_count = retry_count
|
||||
if records_downloaded is not None:
|
||||
sync_operation.records_downloaded = records_downloaded
|
||||
|
||||
if started_at is not None:
|
||||
sync_operation.started_at = started_at
|
||||
if records_uploaded is not None:
|
||||
sync_operation.records_uploaded = records_uploaded
|
||||
|
||||
if completed_at is not None:
|
||||
sync_operation.completed_at = completed_at
|
||||
if total_records_to_sync is not None:
|
||||
sync_operation.total_records_to_sync = total_records_to_sync
|
||||
|
||||
# Auto-set completion timestamp for terminal statuses
|
||||
if (
|
||||
status in [SyncStatus.COMPLETED, SyncStatus.FAILED, SyncStatus.CANCELLED]
|
||||
and completed_at is None
|
||||
):
|
||||
sync_operation.completed_at = datetime.now(timezone.utc)
|
||||
if total_records_to_download is not None:
|
||||
sync_operation.total_records_to_download = total_records_to_download
|
||||
|
||||
# Auto-set started timestamp when moving to IN_PROGRESS
|
||||
if status == SyncStatus.IN_PROGRESS and sync_operation.started_at is None:
|
||||
sync_operation.started_at = datetime.now(timezone.utc)
|
||||
if total_records_to_upload is not None:
|
||||
sync_operation.total_records_to_upload = total_records_to_upload
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(sync_operation)
|
||||
if bytes_downloaded is not None:
|
||||
sync_operation.bytes_downloaded = bytes_downloaded
|
||||
|
||||
return sync_operation
|
||||
if bytes_uploaded is not None:
|
||||
sync_operation.bytes_uploaded = bytes_uploaded
|
||||
|
||||
if dataset_sync_hashes is not None:
|
||||
sync_operation.dataset_sync_hashes = dataset_sync_hashes
|
||||
|
||||
if error_message is not None:
|
||||
sync_operation.error_message = error_message
|
||||
|
||||
if retry_count is not None:
|
||||
sync_operation.retry_count = retry_count
|
||||
|
||||
if started_at is not None:
|
||||
sync_operation.started_at = started_at
|
||||
|
||||
if completed_at is not None:
|
||||
sync_operation.completed_at = completed_at
|
||||
|
||||
# Auto-set completion timestamp for terminal statuses
|
||||
if (
|
||||
status in [SyncStatus.COMPLETED, SyncStatus.FAILED, SyncStatus.CANCELLED]
|
||||
and completed_at is None
|
||||
):
|
||||
sync_operation.completed_at = datetime.now(timezone.utc)
|
||||
|
||||
# Auto-set started timestamp when moving to IN_PROGRESS
|
||||
if status == SyncStatus.IN_PROGRESS and sync_operation.started_at is None:
|
||||
sync_operation.started_at = datetime.now(timezone.utc)
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(sync_operation)
|
||||
|
||||
logger.debug(f"Successfully updated sync operation {run_id}")
|
||||
return sync_operation
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(
|
||||
f"Database error updating sync operation {run_id}: {str(e)}", exc_info=True
|
||||
)
|
||||
await session.rollback()
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Unexpected error updating sync operation {run_id}: {str(e)}", exc_info=True
|
||||
)
|
||||
await session.rollback()
|
||||
raise
|
||||
|
||||
# Use retry logic for the database operation
|
||||
return await _retry_db_operation(_perform_update, run_id)
|
||||
|
||||
|
||||
async def mark_sync_started(run_id: str) -> Optional[SyncOperation]:
|
||||
|
|
@ -94,15 +217,23 @@ async def mark_sync_started(run_id: str) -> Optional[SyncOperation]:
|
|||
|
||||
|
||||
async def mark_sync_completed(
|
||||
run_id: str, processed_records: int, bytes_transferred: int
|
||||
run_id: str,
|
||||
records_downloaded: int = 0,
|
||||
records_uploaded: int = 0,
|
||||
bytes_downloaded: int = 0,
|
||||
bytes_uploaded: int = 0,
|
||||
dataset_sync_hashes: Optional[dict] = None,
|
||||
) -> Optional[SyncOperation]:
|
||||
"""Convenience method to mark a sync operation as completed successfully."""
|
||||
return await update_sync_operation(
|
||||
run_id=run_id,
|
||||
status=SyncStatus.COMPLETED,
|
||||
progress_percentage=100,
|
||||
processed_records=processed_records,
|
||||
bytes_transferred=bytes_transferred,
|
||||
records_downloaded=records_downloaded,
|
||||
records_uploaded=records_uploaded,
|
||||
bytes_downloaded=bytes_downloaded,
|
||||
bytes_uploaded=bytes_uploaded,
|
||||
dataset_sync_hashes=dataset_sync_hashes,
|
||||
completed_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,16 @@
|
|||
from uuid import uuid4
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from typing import Optional, List
|
||||
from datetime import datetime, timezone
|
||||
from sqlalchemy import Column, Text, DateTime, UUID as SQLAlchemy_UUID, Integer, Enum as SQLEnum
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
Text,
|
||||
DateTime,
|
||||
UUID as SQLAlchemy_UUID,
|
||||
Integer,
|
||||
Enum as SQLEnum,
|
||||
JSON,
|
||||
)
|
||||
|
||||
from cognee.infrastructure.databases.relational import Base
|
||||
|
||||
|
|
@ -38,8 +46,8 @@ class SyncOperation(Base):
|
|||
progress_percentage = Column(Integer, default=0, doc="Progress percentage (0-100)")
|
||||
|
||||
# Operation metadata
|
||||
dataset_id = Column(SQLAlchemy_UUID, index=True, doc="ID of the dataset being synced")
|
||||
dataset_name = Column(Text, doc="Name of the dataset being synced")
|
||||
dataset_ids = Column(JSON, doc="Array of dataset IDs being synced")
|
||||
dataset_names = Column(JSON, doc="Array of dataset names being synced")
|
||||
user_id = Column(SQLAlchemy_UUID, index=True, doc="ID of the user who initiated the sync")
|
||||
|
||||
# Timing information
|
||||
|
|
@ -54,18 +62,24 @@ class SyncOperation(Base):
|
|||
)
|
||||
|
||||
# Operation details
|
||||
total_records = Column(Integer, doc="Total number of records to sync")
|
||||
processed_records = Column(Integer, default=0, doc="Number of records successfully processed")
|
||||
bytes_transferred = Column(Integer, default=0, doc="Total bytes transferred to cloud")
|
||||
total_records_to_sync = Column(Integer, doc="Total number of records to sync")
|
||||
total_records_to_download = Column(Integer, doc="Total number of records to download")
|
||||
total_records_to_upload = Column(Integer, doc="Total number of records to upload")
|
||||
|
||||
records_downloaded = Column(Integer, default=0, doc="Number of records successfully downloaded")
|
||||
records_uploaded = Column(Integer, default=0, doc="Number of records successfully uploaded")
|
||||
bytes_downloaded = Column(Integer, default=0, doc="Total bytes downloaded from cloud")
|
||||
bytes_uploaded = Column(Integer, default=0, doc="Total bytes uploaded to cloud")
|
||||
|
||||
# Data lineage tracking per dataset
|
||||
dataset_sync_hashes = Column(
|
||||
JSON, doc="Mapping of dataset_id -> {uploaded: [hashes], downloaded: [hashes]}"
|
||||
)
|
||||
|
||||
# Error handling
|
||||
error_message = Column(Text, doc="Error message if sync failed")
|
||||
retry_count = Column(Integer, default=0, doc="Number of retry attempts")
|
||||
|
||||
# Additional metadata (can be added later when needed)
|
||||
# cloud_endpoint = Column(Text, doc="Cloud endpoint used for sync")
|
||||
# compression_enabled = Column(Text, doc="Whether compression was used")
|
||||
|
||||
def get_duration_seconds(self) -> Optional[float]:
|
||||
"""Get the duration of the sync operation in seconds."""
|
||||
if not self.created_at:
|
||||
|
|
@ -76,11 +90,53 @@ class SyncOperation(Base):
|
|||
|
||||
def get_progress_info(self) -> dict:
|
||||
"""Get comprehensive progress information."""
|
||||
total_records_processed = (self.records_downloaded or 0) + (self.records_uploaded or 0)
|
||||
total_bytes_transferred = (self.bytes_downloaded or 0) + (self.bytes_uploaded or 0)
|
||||
|
||||
return {
|
||||
"status": self.status.value,
|
||||
"progress_percentage": self.progress_percentage,
|
||||
"records_processed": f"{self.processed_records or 0}/{self.total_records or 'unknown'}",
|
||||
"bytes_transferred": self.bytes_transferred or 0,
|
||||
"records_processed": f"{total_records_processed}/{self.total_records_to_sync or 'unknown'}",
|
||||
"records_downloaded": self.records_downloaded or 0,
|
||||
"records_uploaded": self.records_uploaded or 0,
|
||||
"bytes_transferred": total_bytes_transferred,
|
||||
"bytes_downloaded": self.bytes_downloaded or 0,
|
||||
"bytes_uploaded": self.bytes_uploaded or 0,
|
||||
"duration_seconds": self.get_duration_seconds(),
|
||||
"error_message": self.error_message,
|
||||
"dataset_sync_hashes": self.dataset_sync_hashes or {},
|
||||
}
|
||||
|
||||
def _get_all_sync_hashes(self) -> List[str]:
|
||||
"""Get all content hashes for data created/modified during this sync operation."""
|
||||
all_hashes = set()
|
||||
dataset_hashes = self.dataset_sync_hashes or {}
|
||||
|
||||
for dataset_id, operations in dataset_hashes.items():
|
||||
if isinstance(operations, dict):
|
||||
all_hashes.update(operations.get("uploaded", []))
|
||||
all_hashes.update(operations.get("downloaded", []))
|
||||
|
||||
return list(all_hashes)
|
||||
|
||||
def _get_dataset_sync_hashes(self, dataset_id: str) -> dict:
|
||||
"""Get uploaded/downloaded hashes for a specific dataset."""
|
||||
dataset_hashes = self.dataset_sync_hashes or {}
|
||||
return dataset_hashes.get(dataset_id, {"uploaded": [], "downloaded": []})
|
||||
|
||||
def was_data_synced(self, content_hash: str, dataset_id: str = None) -> bool:
|
||||
"""
|
||||
Check if a specific piece of data was part of this sync operation.
|
||||
|
||||
Args:
|
||||
content_hash: The content hash to check for
|
||||
dataset_id: Optional - check only within this dataset
|
||||
"""
|
||||
if dataset_id:
|
||||
dataset_hashes = self.get_dataset_sync_hashes(dataset_id)
|
||||
return content_hash in dataset_hashes.get(
|
||||
"uploaded", []
|
||||
) or content_hash in dataset_hashes.get("downloaded", [])
|
||||
|
||||
all_hashes = self.get_all_sync_hashes()
|
||||
return content_hash in all_hashes
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue