feat: update sync to be two way (#1359)
<!-- .github/pull_request_template.md --> ## Description <!-- Provide a clear description of the changes in this PR --> ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin.
This commit is contained in:
parent
a29ef8b2d3
commit
47cb34e89c
11 changed files with 970 additions and 199 deletions
|
|
@ -137,6 +137,14 @@ REQUIRE_AUTHENTICATION=False
|
||||||
# It enforces LanceDB and KuzuDB use and uses them to create databases per Cognee user + dataset
|
# It enforces LanceDB and KuzuDB use and uses them to create databases per Cognee user + dataset
|
||||||
ENABLE_BACKEND_ACCESS_CONTROL=False
|
ENABLE_BACKEND_ACCESS_CONTROL=False
|
||||||
|
|
||||||
|
################################################################################
|
||||||
|
# ☁️ Cloud Sync Settings
|
||||||
|
################################################################################
|
||||||
|
|
||||||
|
# Cognee Cloud API settings for syncing data to/from cloud infrastructure
|
||||||
|
COGNEE_CLOUD_API_URL="http://localhost:8001"
|
||||||
|
COGNEE_CLOUD_AUTH_TOKEN="your-auth-token"
|
||||||
|
|
||||||
################################################################################
|
################################################################################
|
||||||
# 🛠️ DEV Settings
|
# 🛠️ DEV Settings
|
||||||
################################################################################
|
################################################################################
|
||||||
|
|
|
||||||
98
alembic/versions/211ab850ef3d_add_sync_operations_table.py
Normal file
98
alembic/versions/211ab850ef3d_add_sync_operations_table.py
Normal file
|
|
@ -0,0 +1,98 @@
|
||||||
|
"""Add sync_operations table
|
||||||
|
|
||||||
|
Revision ID: 211ab850ef3d
|
||||||
|
Revises: 9e7a3cb85175
|
||||||
|
Create Date: 2025-09-10 20:11:13.534829
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "211ab850ef3d"
|
||||||
|
down_revision: Union[str, None] = "9e7a3cb85175"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
|
||||||
|
# Check if table already exists (it might be created by Base.metadata.create_all() in initial migration)
|
||||||
|
connection = op.get_bind()
|
||||||
|
inspector = sa.inspect(connection)
|
||||||
|
|
||||||
|
if "sync_operations" not in inspector.get_table_names():
|
||||||
|
# Table doesn't exist, create it normally
|
||||||
|
op.create_table(
|
||||||
|
"sync_operations",
|
||||||
|
sa.Column("id", sa.UUID(), nullable=False),
|
||||||
|
sa.Column("run_id", sa.Text(), nullable=True),
|
||||||
|
sa.Column(
|
||||||
|
"status",
|
||||||
|
sa.Enum(
|
||||||
|
"STARTED",
|
||||||
|
"IN_PROGRESS",
|
||||||
|
"COMPLETED",
|
||||||
|
"FAILED",
|
||||||
|
"CANCELLED",
|
||||||
|
name="syncstatus",
|
||||||
|
create_type=False,
|
||||||
|
),
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
sa.Column("progress_percentage", sa.Integer(), nullable=True),
|
||||||
|
sa.Column("dataset_ids", sa.JSON(), nullable=True),
|
||||||
|
sa.Column("dataset_names", sa.JSON(), nullable=True),
|
||||||
|
sa.Column("user_id", sa.UUID(), nullable=True),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column("started_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column("completed_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column("total_records_to_sync", sa.Integer(), nullable=True),
|
||||||
|
sa.Column("total_records_to_download", sa.Integer(), nullable=True),
|
||||||
|
sa.Column("total_records_to_upload", sa.Integer(), nullable=True),
|
||||||
|
sa.Column("records_downloaded", sa.Integer(), nullable=True),
|
||||||
|
sa.Column("records_uploaded", sa.Integer(), nullable=True),
|
||||||
|
sa.Column("bytes_downloaded", sa.Integer(), nullable=True),
|
||||||
|
sa.Column("bytes_uploaded", sa.Integer(), nullable=True),
|
||||||
|
sa.Column("dataset_sync_hashes", sa.JSON(), nullable=True),
|
||||||
|
sa.Column("error_message", sa.Text(), nullable=True),
|
||||||
|
sa.Column("retry_count", sa.Integer(), nullable=True),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_sync_operations_run_id"), "sync_operations", ["run_id"], unique=True
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_sync_operations_user_id"), "sync_operations", ["user_id"], unique=False
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Table already exists, but we might need to add missing columns or indexes
|
||||||
|
# For now, just log that the table already exists
|
||||||
|
print("sync_operations table already exists, skipping creation")
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
|
||||||
|
# Only drop if table exists (might have been created by Base.metadata.create_all())
|
||||||
|
connection = op.get_bind()
|
||||||
|
inspector = sa.inspect(connection)
|
||||||
|
|
||||||
|
if "sync_operations" in inspector.get_table_names():
|
||||||
|
op.drop_index(op.f("ix_sync_operations_user_id"), table_name="sync_operations")
|
||||||
|
op.drop_index(op.f("ix_sync_operations_run_id"), table_name="sync_operations")
|
||||||
|
op.drop_table("sync_operations")
|
||||||
|
|
||||||
|
# Drop the enum type that was created (only if no other tables are using it)
|
||||||
|
sa.Enum(name="syncstatus").drop(op.get_bind(), checkfirst=True)
|
||||||
|
else:
|
||||||
|
print("sync_operations table doesn't exist, skipping downgrade")
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
@ -19,6 +19,7 @@ depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
def upgrade() -> None:
|
def upgrade() -> None:
|
||||||
db_engine = get_relational_engine()
|
db_engine = get_relational_engine()
|
||||||
|
# we might want to delete this
|
||||||
await_only(db_engine.create_database())
|
await_only(db_engine.create_database())
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ from .sync import (
|
||||||
SyncResponse,
|
SyncResponse,
|
||||||
LocalFileInfo,
|
LocalFileInfo,
|
||||||
CheckMissingHashesRequest,
|
CheckMissingHashesRequest,
|
||||||
CheckMissingHashesResponse,
|
CheckHashesDiffResponse,
|
||||||
PruneDatasetRequest,
|
PruneDatasetRequest,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -12,6 +12,6 @@ __all__ = [
|
||||||
"SyncResponse",
|
"SyncResponse",
|
||||||
"LocalFileInfo",
|
"LocalFileInfo",
|
||||||
"CheckMissingHashesRequest",
|
"CheckMissingHashesRequest",
|
||||||
"CheckMissingHashesResponse",
|
"CheckHashesDiffResponse",
|
||||||
"PruneDatasetRequest",
|
"PruneDatasetRequest",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from typing import Optional
|
from typing import Optional, List
|
||||||
from fastapi import APIRouter, Depends
|
from fastapi import APIRouter, Depends
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
|
|
@ -8,6 +8,7 @@ from cognee.api.DTO import InDTO
|
||||||
from cognee.modules.users.models import User
|
from cognee.modules.users.models import User
|
||||||
from cognee.modules.users.methods import get_authenticated_user
|
from cognee.modules.users.methods import get_authenticated_user
|
||||||
from cognee.modules.users.permissions.methods import get_specific_user_permission_datasets
|
from cognee.modules.users.permissions.methods import get_specific_user_permission_datasets
|
||||||
|
from cognee.modules.sync.methods import get_running_sync_operations_for_user, get_sync_operation
|
||||||
from cognee.shared.utils import send_telemetry
|
from cognee.shared.utils import send_telemetry
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
from cognee.api.v1.sync import SyncResponse
|
from cognee.api.v1.sync import SyncResponse
|
||||||
|
|
@ -19,7 +20,7 @@ logger = get_logger()
|
||||||
class SyncRequest(InDTO):
|
class SyncRequest(InDTO):
|
||||||
"""Request model for sync operations."""
|
"""Request model for sync operations."""
|
||||||
|
|
||||||
dataset_id: Optional[UUID] = None
|
dataset_ids: Optional[List[UUID]] = None
|
||||||
|
|
||||||
|
|
||||||
def get_sync_router() -> APIRouter:
|
def get_sync_router() -> APIRouter:
|
||||||
|
|
@ -40,7 +41,7 @@ def get_sync_router() -> APIRouter:
|
||||||
## Request Body (JSON)
|
## Request Body (JSON)
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"dataset_id": "123e4567-e89b-12d3-a456-426614174000"
|
"dataset_ids": ["123e4567-e89b-12d3-a456-426614174000", "456e7890-e12b-34c5-d678-901234567000"]
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
@ -48,8 +49,8 @@ def get_sync_router() -> APIRouter:
|
||||||
Returns immediate response for the sync operation:
|
Returns immediate response for the sync operation:
|
||||||
- **run_id**: Unique identifier for tracking the background sync operation
|
- **run_id**: Unique identifier for tracking the background sync operation
|
||||||
- **status**: Always "started" (operation runs in background)
|
- **status**: Always "started" (operation runs in background)
|
||||||
- **dataset_id**: ID of the dataset being synced
|
- **dataset_ids**: List of dataset IDs being synced
|
||||||
- **dataset_name**: Name of the dataset being synced
|
- **dataset_names**: List of dataset names being synced
|
||||||
- **message**: Description of the background operation
|
- **message**: Description of the background operation
|
||||||
- **timestamp**: When the sync was initiated
|
- **timestamp**: When the sync was initiated
|
||||||
- **user_id**: User who initiated the sync
|
- **user_id**: User who initiated the sync
|
||||||
|
|
@ -64,15 +65,21 @@ def get_sync_router() -> APIRouter:
|
||||||
|
|
||||||
## Example Usage
|
## Example Usage
|
||||||
```bash
|
```bash
|
||||||
# Sync dataset to cloud by ID (JSON request)
|
# Sync multiple datasets to cloud by IDs (JSON request)
|
||||||
curl -X POST "http://localhost:8000/api/v1/sync" \\
|
curl -X POST "http://localhost:8000/api/v1/sync" \\
|
||||||
-H "Content-Type: application/json" \\
|
-H "Content-Type: application/json" \\
|
||||||
-H "Cookie: auth_token=your-token" \\
|
-H "Cookie: auth_token=your-token" \\
|
||||||
-d '{"dataset_id": "123e4567-e89b-12d3-a456-426614174000"}'
|
-d '{"dataset_ids": ["123e4567-e89b-12d3-a456-426614174000", "456e7890-e12b-34c5-d678-901234567000"]}'
|
||||||
|
|
||||||
|
# Sync all user datasets (empty request body or null dataset_ids)
|
||||||
|
curl -X POST "http://localhost:8000/api/v1/sync" \\
|
||||||
|
-H "Content-Type: application/json" \\
|
||||||
|
-H "Cookie: auth_token=your-token" \\
|
||||||
|
-d '{}'
|
||||||
```
|
```
|
||||||
|
|
||||||
## Error Codes
|
## Error Codes
|
||||||
- **400 Bad Request**: Invalid dataset_id format
|
- **400 Bad Request**: Invalid dataset_ids format
|
||||||
- **401 Unauthorized**: Invalid or missing authentication
|
- **401 Unauthorized**: Invalid or missing authentication
|
||||||
- **403 Forbidden**: User doesn't have permission to access dataset
|
- **403 Forbidden**: User doesn't have permission to access dataset
|
||||||
- **404 Not Found**: Dataset not found
|
- **404 Not Found**: Dataset not found
|
||||||
|
|
@ -92,32 +99,48 @@ def get_sync_router() -> APIRouter:
|
||||||
user.id,
|
user.id,
|
||||||
additional_properties={
|
additional_properties={
|
||||||
"endpoint": "POST /v1/sync",
|
"endpoint": "POST /v1/sync",
|
||||||
"dataset_id": str(request.dataset_id) if request.dataset_id else "*",
|
"dataset_ids": [str(id) for id in request.dataset_ids]
|
||||||
|
if request.dataset_ids
|
||||||
|
else "*",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
from cognee.api.v1.sync import sync as cognee_sync
|
from cognee.api.v1.sync import sync as cognee_sync
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Retrieve existing dataset and check permissions
|
# Check if user has any running sync operations
|
||||||
datasets = await get_specific_user_permission_datasets(
|
running_syncs = await get_running_sync_operations_for_user(user.id)
|
||||||
user.id, "write", [request.dataset_id] if request.dataset_id else None
|
if running_syncs:
|
||||||
)
|
# Return information about the existing sync operation
|
||||||
|
existing_sync = running_syncs[0] # Get the most recent running sync
|
||||||
sync_results = {}
|
return JSONResponse(
|
||||||
|
status_code=409,
|
||||||
for dataset in datasets:
|
content={
|
||||||
await set_database_global_context_variables(dataset.id, dataset.owner_id)
|
"error": "Sync operation already in progress",
|
||||||
|
"details": {
|
||||||
# Execute cloud sync operation
|
"run_id": existing_sync.run_id,
|
||||||
sync_result = await cognee_sync(
|
"status": "already_running",
|
||||||
dataset=dataset,
|
"dataset_ids": existing_sync.dataset_ids,
|
||||||
user=user,
|
"dataset_names": existing_sync.dataset_names,
|
||||||
|
"message": f"You have a sync operation already in progress with run_id '{existing_sync.run_id}'. Use the status endpoint to monitor progress, or wait for it to complete before starting a new sync.",
|
||||||
|
"timestamp": existing_sync.created_at.isoformat(),
|
||||||
|
"progress_percentage": existing_sync.progress_percentage,
|
||||||
|
},
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
sync_results[str(dataset.id)] = sync_result
|
# Retrieve existing dataset and check permissions
|
||||||
|
datasets = await get_specific_user_permission_datasets(
|
||||||
|
user.id, "write", request.dataset_ids if request.dataset_ids else None
|
||||||
|
)
|
||||||
|
|
||||||
return sync_results
|
# Execute new cloud sync operation for all datasets
|
||||||
|
sync_result = await cognee_sync(
|
||||||
|
datasets=datasets,
|
||||||
|
user=user,
|
||||||
|
)
|
||||||
|
|
||||||
|
return sync_result
|
||||||
|
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
return JSONResponse(status_code=400, content={"error": str(e)})
|
return JSONResponse(status_code=400, content={"error": str(e)})
|
||||||
|
|
@ -131,4 +154,88 @@ def get_sync_router() -> APIRouter:
|
||||||
logger.error(f"Cloud sync operation failed: {str(e)}")
|
logger.error(f"Cloud sync operation failed: {str(e)}")
|
||||||
return JSONResponse(status_code=409, content={"error": "Cloud sync operation failed."})
|
return JSONResponse(status_code=409, content={"error": "Cloud sync operation failed."})
|
||||||
|
|
||||||
|
@router.get("/status")
|
||||||
|
async def get_sync_status_overview(
|
||||||
|
user: User = Depends(get_authenticated_user),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Check if there are any running sync operations for the current user.
|
||||||
|
|
||||||
|
This endpoint provides a simple check to see if the user has any active sync operations
|
||||||
|
without needing to know specific run IDs.
|
||||||
|
|
||||||
|
## Response
|
||||||
|
Returns a simple status overview:
|
||||||
|
- **has_running_sync**: Boolean indicating if there are any running syncs
|
||||||
|
- **running_sync_count**: Number of currently running sync operations
|
||||||
|
- **latest_running_sync** (optional): Information about the most recent running sync if any exists
|
||||||
|
|
||||||
|
## Example Usage
|
||||||
|
```bash
|
||||||
|
curl -X GET "http://localhost:8000/api/v1/sync/status" \\
|
||||||
|
-H "Cookie: auth_token=your-token"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Example Responses
|
||||||
|
|
||||||
|
**No running syncs:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"has_running_sync": false,
|
||||||
|
"running_sync_count": 0
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**With running sync:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"has_running_sync": true,
|
||||||
|
"running_sync_count": 1,
|
||||||
|
"latest_running_sync": {
|
||||||
|
"run_id": "12345678-1234-5678-9012-123456789012",
|
||||||
|
"dataset_name": "My Dataset",
|
||||||
|
"progress_percentage": 45,
|
||||||
|
"created_at": "2025-01-01T00:00:00Z"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
send_telemetry(
|
||||||
|
"Sync Status Overview API Endpoint Invoked",
|
||||||
|
user.id,
|
||||||
|
additional_properties={
|
||||||
|
"endpoint": "GET /v1/sync/status",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get any running sync operations for the user
|
||||||
|
running_syncs = await get_running_sync_operations_for_user(user.id)
|
||||||
|
|
||||||
|
response = {
|
||||||
|
"has_running_sync": len(running_syncs) > 0,
|
||||||
|
"running_sync_count": len(running_syncs),
|
||||||
|
}
|
||||||
|
|
||||||
|
# If there are running syncs, include info about the latest one
|
||||||
|
if running_syncs:
|
||||||
|
latest_sync = running_syncs[0] # Already ordered by created_at desc
|
||||||
|
response["latest_running_sync"] = {
|
||||||
|
"run_id": latest_sync.run_id,
|
||||||
|
"dataset_ids": latest_sync.dataset_ids,
|
||||||
|
"dataset_names": latest_sync.dataset_names,
|
||||||
|
"progress_percentage": latest_sync.progress_percentage,
|
||||||
|
"created_at": latest_sync.created_at.isoformat()
|
||||||
|
if latest_sync.created_at
|
||||||
|
else None,
|
||||||
|
}
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to get sync status overview: {str(e)}")
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=500, content={"error": "Failed to get sync status overview"}
|
||||||
|
)
|
||||||
|
|
||||||
return router
|
return router
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
import io
|
||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
@ -5,8 +6,12 @@ import aiohttp
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from cognee.api.v1.cognify import cognify
|
||||||
|
|
||||||
from cognee.infrastructure.files.storage import get_file_storage
|
from cognee.infrastructure.files.storage import get_file_storage
|
||||||
|
from cognee.tasks.ingestion.ingest_data import ingest_data
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
from cognee.modules.users.models import User
|
from cognee.modules.users.models import User
|
||||||
from cognee.modules.data.models import Dataset
|
from cognee.modules.data.models import Dataset
|
||||||
|
|
@ -19,7 +24,28 @@ from cognee.modules.sync.methods import (
|
||||||
mark_sync_failed,
|
mark_sync_failed,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger("sync")
|
||||||
|
|
||||||
|
|
||||||
|
async def _safe_update_progress(run_id: str, stage: str, **kwargs):
|
||||||
|
"""
|
||||||
|
Safely update sync progress with better error handling and context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
run_id: Sync operation run ID
|
||||||
|
progress_percentage: Progress percentage (0-100)
|
||||||
|
stage: Description of current stage for logging
|
||||||
|
**kwargs: Additional fields to update (records_downloaded, records_uploaded, etc.)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
await update_sync_operation(run_id, **kwargs)
|
||||||
|
logger.info(f"Sync {run_id}: Progress updated during {stage}")
|
||||||
|
except Exception as e:
|
||||||
|
# Log error but don't fail the sync - progress updates are nice-to-have
|
||||||
|
logger.warning(
|
||||||
|
f"Sync {run_id}: Non-critical progress update failed during {stage}: {str(e)}"
|
||||||
|
)
|
||||||
|
# Continue without raising - sync operation is more important than progress tracking
|
||||||
|
|
||||||
|
|
||||||
class LocalFileInfo(BaseModel):
|
class LocalFileInfo(BaseModel):
|
||||||
|
|
@ -38,13 +64,16 @@ class LocalFileInfo(BaseModel):
|
||||||
class CheckMissingHashesRequest(BaseModel):
|
class CheckMissingHashesRequest(BaseModel):
|
||||||
"""Request model for checking missing hashes in a dataset"""
|
"""Request model for checking missing hashes in a dataset"""
|
||||||
|
|
||||||
|
dataset_id: str
|
||||||
|
dataset_name: str
|
||||||
hashes: List[str]
|
hashes: List[str]
|
||||||
|
|
||||||
|
|
||||||
class CheckMissingHashesResponse(BaseModel):
|
class CheckHashesDiffResponse(BaseModel):
|
||||||
"""Response model for missing hashes check"""
|
"""Response model for missing hashes check"""
|
||||||
|
|
||||||
missing: List[str]
|
missing_on_remote: List[str]
|
||||||
|
missing_on_local: List[str]
|
||||||
|
|
||||||
|
|
||||||
class PruneDatasetRequest(BaseModel):
|
class PruneDatasetRequest(BaseModel):
|
||||||
|
|
@ -58,34 +87,34 @@ class SyncResponse(BaseModel):
|
||||||
|
|
||||||
run_id: str
|
run_id: str
|
||||||
status: str # "started" for immediate response
|
status: str # "started" for immediate response
|
||||||
dataset_id: str
|
dataset_ids: List[str]
|
||||||
dataset_name: str
|
dataset_names: List[str]
|
||||||
message: str
|
message: str
|
||||||
timestamp: str
|
timestamp: str
|
||||||
user_id: str
|
user_id: str
|
||||||
|
|
||||||
|
|
||||||
async def sync(
|
async def sync(
|
||||||
dataset: Dataset,
|
datasets: List[Dataset],
|
||||||
user: User,
|
user: User,
|
||||||
) -> SyncResponse:
|
) -> SyncResponse:
|
||||||
"""
|
"""
|
||||||
Sync local Cognee data to Cognee Cloud.
|
Sync local Cognee data to Cognee Cloud.
|
||||||
|
|
||||||
This function handles synchronization of local datasets, knowledge graphs, and
|
This function handles synchronization of multiple datasets, knowledge graphs, and
|
||||||
processed data to the Cognee Cloud infrastructure. It uploads local data for
|
processed data to the Cognee Cloud infrastructure. It uploads local data for
|
||||||
cloud-based processing, backup, and sharing.
|
cloud-based processing, backup, and sharing.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dataset: Dataset object to sync (permissions already verified)
|
datasets: List of Dataset objects to sync (permissions already verified)
|
||||||
user: User object for authentication and permissions
|
user: User object for authentication and permissions
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
SyncResponse model with immediate response:
|
SyncResponse model with immediate response:
|
||||||
- run_id: Unique identifier for tracking this sync operation
|
- run_id: Unique identifier for tracking this sync operation
|
||||||
- status: Always "started" (sync runs in background)
|
- status: Always "started" (sync runs in background)
|
||||||
- dataset_id: ID of the dataset being synced
|
- dataset_ids: List of dataset IDs being synced
|
||||||
- dataset_name: Name of the dataset being synced
|
- dataset_names: List of dataset names being synced
|
||||||
- message: Description of what's happening
|
- message: Description of what's happening
|
||||||
- timestamp: When the sync was initiated
|
- timestamp: When the sync was initiated
|
||||||
- user_id: User who initiated the sync
|
- user_id: User who initiated the sync
|
||||||
|
|
@ -94,8 +123,8 @@ async def sync(
|
||||||
ConnectionError: If Cognee Cloud service is unreachable
|
ConnectionError: If Cognee Cloud service is unreachable
|
||||||
Exception: For other sync-related errors
|
Exception: For other sync-related errors
|
||||||
"""
|
"""
|
||||||
if not dataset:
|
if not datasets:
|
||||||
raise ValueError("Dataset must be provided for sync operation")
|
raise ValueError("At least one dataset must be provided for sync operation")
|
||||||
|
|
||||||
# Generate a unique run ID
|
# Generate a unique run ID
|
||||||
run_id = str(uuid.uuid4())
|
run_id = str(uuid.uuid4())
|
||||||
|
|
@ -103,12 +132,16 @@ async def sync(
|
||||||
# Get current timestamp
|
# Get current timestamp
|
||||||
timestamp = datetime.now(timezone.utc).isoformat()
|
timestamp = datetime.now(timezone.utc).isoformat()
|
||||||
|
|
||||||
logger.info(f"Starting cloud sync operation {run_id}: dataset {dataset.name} ({dataset.id})")
|
dataset_info = ", ".join([f"{d.name} ({d.id})" for d in datasets])
|
||||||
|
logger.info(f"Starting cloud sync operation {run_id}: datasets {dataset_info}")
|
||||||
|
|
||||||
# Create sync operation record in database (total_records will be set during background sync)
|
# Create sync operation record in database (total_records will be set during background sync)
|
||||||
try:
|
try:
|
||||||
await create_sync_operation(
|
await create_sync_operation(
|
||||||
run_id=run_id, dataset_id=dataset.id, dataset_name=dataset.name, user_id=user.id
|
run_id=run_id,
|
||||||
|
dataset_ids=[d.id for d in datasets],
|
||||||
|
dataset_names=[d.name for d in datasets],
|
||||||
|
user_id=user.id,
|
||||||
)
|
)
|
||||||
logger.info(f"Created sync operation record for {run_id}")
|
logger.info(f"Created sync operation record for {run_id}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -116,44 +149,74 @@ async def sync(
|
||||||
# Continue without database tracking if record creation fails
|
# Continue without database tracking if record creation fails
|
||||||
|
|
||||||
# Start the sync operation in the background
|
# Start the sync operation in the background
|
||||||
asyncio.create_task(_perform_background_sync(run_id, dataset, user))
|
asyncio.create_task(_perform_background_sync(run_id, datasets, user))
|
||||||
|
|
||||||
# Return immediately with run_id
|
# Return immediately with run_id
|
||||||
return SyncResponse(
|
return SyncResponse(
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
status="started",
|
status="started",
|
||||||
dataset_id=str(dataset.id),
|
dataset_ids=[str(d.id) for d in datasets],
|
||||||
dataset_name=dataset.name,
|
dataset_names=[d.name for d in datasets],
|
||||||
message=f"Sync operation started in background. Use run_id '{run_id}' to track progress.",
|
message=f"Sync operation started in background for {len(datasets)} datasets. Use run_id '{run_id}' to track progress.",
|
||||||
timestamp=timestamp,
|
timestamp=timestamp,
|
||||||
user_id=str(user.id),
|
user_id=str(user.id),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def _perform_background_sync(run_id: str, dataset: Dataset, user: User) -> None:
|
async def _perform_background_sync(run_id: str, datasets: List[Dataset], user: User) -> None:
|
||||||
"""Perform the actual sync operation in the background."""
|
"""Perform the actual sync operation in the background for multiple datasets."""
|
||||||
start_time = datetime.now(timezone.utc)
|
start_time = datetime.now(timezone.utc)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logger.info(
|
dataset_info = ", ".join([f"{d.name} ({d.id})" for d in datasets])
|
||||||
f"Background sync {run_id}: Starting sync for dataset {dataset.name} ({dataset.id})"
|
logger.info(f"Background sync {run_id}: Starting sync for datasets {dataset_info}")
|
||||||
)
|
|
||||||
|
|
||||||
# Mark sync as in progress
|
# Mark sync as in progress
|
||||||
await mark_sync_started(run_id)
|
await mark_sync_started(run_id)
|
||||||
|
|
||||||
# Perform the actual sync operation
|
# Perform the actual sync operation
|
||||||
records_processed, bytes_transferred = await _sync_to_cognee_cloud(dataset, user, run_id)
|
MAX_RETRY_COUNT = 3
|
||||||
|
retry_count = 0
|
||||||
|
while retry_count < MAX_RETRY_COUNT:
|
||||||
|
try:
|
||||||
|
(
|
||||||
|
records_downloaded,
|
||||||
|
records_uploaded,
|
||||||
|
bytes_downloaded,
|
||||||
|
bytes_uploaded,
|
||||||
|
dataset_sync_hashes,
|
||||||
|
) = await _sync_to_cognee_cloud(datasets, user, run_id)
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
retry_count += 1
|
||||||
|
logger.error(
|
||||||
|
f"Background sync {run_id}: Failed after {retry_count} retries with error: {str(e)}"
|
||||||
|
)
|
||||||
|
await update_sync_operation(run_id, retry_count=retry_count)
|
||||||
|
await asyncio.sleep(2**retry_count)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if retry_count == MAX_RETRY_COUNT:
|
||||||
|
logger.error(f"Background sync {run_id}: Failed after {MAX_RETRY_COUNT} retries")
|
||||||
|
await mark_sync_failed(run_id, "Failed after 3 retries")
|
||||||
|
return
|
||||||
|
|
||||||
end_time = datetime.now(timezone.utc)
|
end_time = datetime.now(timezone.utc)
|
||||||
duration = (end_time - start_time).total_seconds()
|
duration = (end_time - start_time).total_seconds()
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Background sync {run_id}: Completed successfully. Records: {records_processed}, Bytes: {bytes_transferred}, Duration: {duration}s"
|
f"Background sync {run_id}: Completed successfully. Downloaded: {records_downloaded} records/{bytes_downloaded} bytes, Uploaded: {records_uploaded} records/{bytes_uploaded} bytes, Duration: {duration}s"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Mark sync as completed with final stats
|
# Mark sync as completed with final stats and data lineage
|
||||||
await mark_sync_completed(run_id, records_processed, bytes_transferred)
|
await mark_sync_completed(
|
||||||
|
run_id,
|
||||||
|
records_downloaded,
|
||||||
|
records_uploaded,
|
||||||
|
bytes_downloaded,
|
||||||
|
bytes_uploaded,
|
||||||
|
dataset_sync_hashes,
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
end_time = datetime.now(timezone.utc)
|
end_time = datetime.now(timezone.utc)
|
||||||
|
|
@ -165,89 +228,248 @@ async def _perform_background_sync(run_id: str, dataset: Dataset, user: User) ->
|
||||||
await mark_sync_failed(run_id, str(e))
|
await mark_sync_failed(run_id, str(e))
|
||||||
|
|
||||||
|
|
||||||
async def _sync_to_cognee_cloud(dataset: Dataset, user: User, run_id: str) -> tuple[int, int]:
|
async def _sync_to_cognee_cloud(
|
||||||
|
datasets: List[Dataset], user: User, run_id: str
|
||||||
|
) -> tuple[int, int, int, int, dict]:
|
||||||
"""
|
"""
|
||||||
Sync local data to Cognee Cloud using three-step idempotent process:
|
Sync local data to Cognee Cloud using three-step idempotent process:
|
||||||
1. Extract local files with stored MD5 hashes and check what's missing on cloud
|
1. Extract local files with stored MD5 hashes and check what's missing on cloud
|
||||||
2. Upload missing files individually
|
2. Upload missing files individually
|
||||||
3. Prune cloud dataset to match local state
|
3. Prune cloud dataset to match local state
|
||||||
"""
|
"""
|
||||||
logger.info(f"Starting sync to Cognee Cloud: dataset {dataset.name} ({dataset.id})")
|
dataset_info = ", ".join([f"{d.name} ({d.id})" for d in datasets])
|
||||||
|
logger.info(f"Starting sync to Cognee Cloud: datasets {dataset_info}")
|
||||||
|
|
||||||
|
total_records_downloaded = 0
|
||||||
|
total_records_uploaded = 0
|
||||||
|
total_bytes_downloaded = 0
|
||||||
|
total_bytes_uploaded = 0
|
||||||
|
dataset_sync_hashes = {}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Get cloud configuration
|
# Get cloud configuration
|
||||||
cloud_base_url = await _get_cloud_base_url()
|
cloud_base_url = await _get_cloud_base_url()
|
||||||
cloud_auth_token = await _get_cloud_auth_token(user)
|
cloud_auth_token = await _get_cloud_auth_token(user)
|
||||||
|
|
||||||
logger.info(f"Cloud API URL: {cloud_base_url}")
|
# Step 1: Sync files for all datasets concurrently
|
||||||
|
sync_files_tasks = [
|
||||||
|
_sync_dataset_files(dataset, cloud_base_url, cloud_auth_token, user, run_id)
|
||||||
|
for dataset in datasets
|
||||||
|
]
|
||||||
|
|
||||||
# Step 1: Extract local file info with stored hashes
|
logger.info(f"Starting concurrent file sync for {len(datasets)} datasets")
|
||||||
local_files = await _extract_local_files_with_hashes(dataset, user, run_id)
|
|
||||||
logger.info(f"Found {len(local_files)} local files to sync")
|
|
||||||
|
|
||||||
# Update sync operation with total file count
|
has_any_uploads = False
|
||||||
try:
|
has_any_downloads = False
|
||||||
await update_sync_operation(run_id, processed_records=0)
|
processed_datasets = []
|
||||||
except Exception as e:
|
completed_datasets = 0
|
||||||
logger.warning(f"Failed to initialize sync progress: {str(e)}")
|
|
||||||
|
|
||||||
if not local_files:
|
# Process datasets concurrently and accumulate results
|
||||||
logger.info("No files to sync - dataset is empty")
|
for completed_task in asyncio.as_completed(sync_files_tasks):
|
||||||
return 0, 0
|
try:
|
||||||
|
dataset_result = await completed_task
|
||||||
|
completed_datasets += 1
|
||||||
|
|
||||||
# Step 2: Check what files are missing on cloud
|
# Update progress based on completed datasets (0-80% for file sync)
|
||||||
local_hashes = [f.content_hash for f in local_files]
|
file_sync_progress = int((completed_datasets / len(datasets)) * 80)
|
||||||
missing_hashes = await _check_missing_hashes(
|
await _safe_update_progress(
|
||||||
cloud_base_url, cloud_auth_token, dataset.id, local_hashes, run_id
|
run_id, "file_sync", progress_percentage=file_sync_progress
|
||||||
)
|
)
|
||||||
logger.info(f"Cloud is missing {len(missing_hashes)} out of {len(local_hashes)} files")
|
|
||||||
|
|
||||||
# Update progress
|
if dataset_result is None:
|
||||||
try:
|
logger.info(
|
||||||
await update_sync_operation(run_id, progress_percentage=25)
|
f"Progress: {completed_datasets}/{len(datasets)} datasets processed ({file_sync_progress}%)"
|
||||||
except Exception as e:
|
)
|
||||||
logger.warning(f"Failed to update progress: {str(e)}")
|
continue
|
||||||
|
|
||||||
# Step 3: Upload missing files
|
total_records_downloaded += dataset_result.records_downloaded
|
||||||
bytes_uploaded = await _upload_missing_files(
|
total_records_uploaded += dataset_result.records_uploaded
|
||||||
cloud_base_url, cloud_auth_token, dataset, local_files, missing_hashes, run_id
|
total_bytes_downloaded += dataset_result.bytes_downloaded
|
||||||
)
|
total_bytes_uploaded += dataset_result.bytes_uploaded
|
||||||
logger.info(f"Upload complete: {len(missing_hashes)} files, {bytes_uploaded} bytes")
|
|
||||||
|
|
||||||
# Update progress
|
# Build per-dataset hash tracking for data lineage
|
||||||
try:
|
dataset_sync_hashes[dataset_result.dataset_id] = {
|
||||||
await update_sync_operation(run_id, progress_percentage=75)
|
"uploaded": dataset_result.uploaded_hashes,
|
||||||
except Exception as e:
|
"downloaded": dataset_result.downloaded_hashes,
|
||||||
logger.warning(f"Failed to update progress: {str(e)}")
|
}
|
||||||
|
|
||||||
# Step 4: Trigger cognify processing on cloud dataset (only if new files were uploaded)
|
if dataset_result.has_uploads:
|
||||||
if missing_hashes:
|
has_any_uploads = True
|
||||||
await _trigger_remote_cognify(cloud_base_url, cloud_auth_token, dataset.id, run_id)
|
if dataset_result.has_downloads:
|
||||||
logger.info(f"Cognify processing triggered for dataset {dataset.id}")
|
has_any_downloads = True
|
||||||
|
|
||||||
|
processed_datasets.append(dataset_result.dataset_id)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Progress: {completed_datasets}/{len(datasets)} datasets processed ({file_sync_progress}%) - "
|
||||||
|
f"Completed file sync for dataset {dataset_result.dataset_name}: "
|
||||||
|
f"↑{dataset_result.records_uploaded} files ({dataset_result.bytes_uploaded} bytes), "
|
||||||
|
f"↓{dataset_result.records_downloaded} files ({dataset_result.bytes_downloaded} bytes)"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
completed_datasets += 1
|
||||||
|
logger.error(f"Dataset file sync failed: {str(e)}")
|
||||||
|
# Update progress even for failed datasets
|
||||||
|
file_sync_progress = int((completed_datasets / len(datasets)) * 80)
|
||||||
|
await _safe_update_progress(
|
||||||
|
run_id, "file_sync", progress_percentage=file_sync_progress
|
||||||
|
)
|
||||||
|
# Continue with other datasets even if one fails
|
||||||
|
|
||||||
|
# Step 2: Trigger cognify processing once for all datasets (only if any files were uploaded)
|
||||||
|
# Update progress to 90% before cognify
|
||||||
|
await _safe_update_progress(run_id, "cognify", progress_percentage=90)
|
||||||
|
|
||||||
|
if has_any_uploads and processed_datasets:
|
||||||
|
logger.info(
|
||||||
|
f"Progress: 90% - Triggering cognify processing for {len(processed_datasets)} datasets with new files"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
# Trigger cognify for all datasets at once - use first dataset as reference point
|
||||||
|
await _trigger_remote_cognify(
|
||||||
|
cloud_base_url, cloud_auth_token, datasets[0].id, run_id
|
||||||
|
)
|
||||||
|
logger.info("Cognify processing triggered successfully for all datasets")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to trigger cognify processing: {str(e)}")
|
||||||
|
# Don't fail the entire sync if cognify fails
|
||||||
else:
|
else:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Skipping cognify processing - no new files were uploaded for dataset {dataset.id}"
|
"Progress: 90% - Skipping cognify processing - no new files were uploaded across any datasets"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Final progress
|
# Step 3: Trigger local cognify processing if any files were downloaded
|
||||||
try:
|
if has_any_downloads and processed_datasets:
|
||||||
await update_sync_operation(run_id, progress_percentage=100)
|
logger.info(
|
||||||
except Exception as e:
|
f"Progress: 95% - Triggering local cognify processing for {len(processed_datasets)} datasets with downloaded files"
|
||||||
logger.warning(f"Failed to update final progress: {str(e)}")
|
)
|
||||||
|
try:
|
||||||
|
await cognify()
|
||||||
|
logger.info("Local cognify processing completed successfully for all datasets")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to run local cognify processing: {str(e)}")
|
||||||
|
# Don't fail the entire sync if local cognify fails
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
"Progress: 95% - Skipping local cognify processing - no new files were downloaded across any datasets"
|
||||||
|
)
|
||||||
|
|
||||||
records_processed = len(local_files)
|
# Update final progress
|
||||||
|
try:
|
||||||
|
await _safe_update_progress(
|
||||||
|
run_id,
|
||||||
|
"final",
|
||||||
|
progress_percentage=100,
|
||||||
|
total_records_to_sync=total_records_uploaded + total_records_downloaded,
|
||||||
|
total_records_to_download=total_records_downloaded,
|
||||||
|
total_records_to_upload=total_records_uploaded,
|
||||||
|
records_downloaded=total_records_downloaded,
|
||||||
|
records_uploaded=total_records_uploaded,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to update final sync progress: {str(e)}")
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Sync completed successfully: {records_processed} records, {bytes_uploaded} bytes uploaded"
|
f"Multi-dataset sync completed: {len(datasets)} datasets processed, downloaded {total_records_downloaded} records/{total_bytes_downloaded} bytes, uploaded {total_records_uploaded} records/{total_bytes_uploaded} bytes"
|
||||||
)
|
)
|
||||||
|
|
||||||
return records_processed, bytes_uploaded
|
return (
|
||||||
|
total_records_downloaded,
|
||||||
|
total_records_uploaded,
|
||||||
|
total_bytes_downloaded,
|
||||||
|
total_bytes_uploaded,
|
||||||
|
dataset_sync_hashes,
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Sync failed: {str(e)}")
|
logger.error(f"Sync failed: {str(e)}")
|
||||||
raise ConnectionError(f"Cloud sync failed: {str(e)}")
|
raise ConnectionError(f"Cloud sync failed: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DatasetSyncResult:
|
||||||
|
"""Result of syncing files for a single dataset."""
|
||||||
|
|
||||||
|
dataset_name: str
|
||||||
|
dataset_id: str
|
||||||
|
records_downloaded: int
|
||||||
|
records_uploaded: int
|
||||||
|
bytes_downloaded: int
|
||||||
|
bytes_uploaded: int
|
||||||
|
has_uploads: bool # Whether any files were uploaded (for cognify decision)
|
||||||
|
has_downloads: bool # Whether any files were downloaded (for cognify decision)
|
||||||
|
uploaded_hashes: List[str] # Content hashes of files uploaded during sync
|
||||||
|
downloaded_hashes: List[str] # Content hashes of files downloaded during sync
|
||||||
|
|
||||||
|
|
||||||
|
async def _sync_dataset_files(
|
||||||
|
dataset: Dataset, cloud_base_url: str, cloud_auth_token: str, user: User, run_id: str
|
||||||
|
) -> Optional[DatasetSyncResult]:
|
||||||
|
"""
|
||||||
|
Sync files for a single dataset (2-way: upload to cloud, download from cloud).
|
||||||
|
Does NOT trigger cognify - that's done separately once for all datasets.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DatasetSyncResult with sync results or None if dataset was empty
|
||||||
|
"""
|
||||||
|
logger.info(f"Syncing files for dataset: {dataset.name} ({dataset.id})")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Step 1: Extract local file info with stored hashes
|
||||||
|
local_files = await _extract_local_files_with_hashes(dataset, user, run_id)
|
||||||
|
logger.info(f"Found {len(local_files)} local files for dataset {dataset.name}")
|
||||||
|
|
||||||
|
if not local_files:
|
||||||
|
logger.info(f"No files to sync for dataset {dataset.name} - skipping")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Step 2: Check what files are missing on cloud
|
||||||
|
local_hashes = [f.content_hash for f in local_files]
|
||||||
|
hashes_diff_response = await _check_hashes_diff(
|
||||||
|
cloud_base_url, cloud_auth_token, dataset, local_hashes, run_id
|
||||||
|
)
|
||||||
|
|
||||||
|
hashes_missing_on_remote = hashes_diff_response.missing_on_remote
|
||||||
|
hashes_missing_on_local = hashes_diff_response.missing_on_local
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Dataset {dataset.name}: {len(hashes_missing_on_remote)} files to upload, {len(hashes_missing_on_local)} files to download"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 3: Upload files that are missing on cloud
|
||||||
|
bytes_uploaded = await _upload_missing_files(
|
||||||
|
cloud_base_url, cloud_auth_token, dataset, local_files, hashes_missing_on_remote, run_id
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Dataset {dataset.name}: Upload complete - {len(hashes_missing_on_remote)} files, {bytes_uploaded} bytes"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 4: Download files that are missing locally
|
||||||
|
bytes_downloaded = await _download_missing_files(
|
||||||
|
cloud_base_url, cloud_auth_token, dataset, hashes_missing_on_local, user
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Dataset {dataset.name}: Download complete - {len(hashes_missing_on_local)} files, {bytes_downloaded} bytes"
|
||||||
|
)
|
||||||
|
|
||||||
|
return DatasetSyncResult(
|
||||||
|
dataset_name=dataset.name,
|
||||||
|
dataset_id=str(dataset.id),
|
||||||
|
records_downloaded=len(hashes_missing_on_local),
|
||||||
|
records_uploaded=len(hashes_missing_on_remote),
|
||||||
|
bytes_downloaded=bytes_downloaded,
|
||||||
|
bytes_uploaded=bytes_uploaded,
|
||||||
|
has_uploads=len(hashes_missing_on_remote) > 0,
|
||||||
|
has_downloads=len(hashes_missing_on_local) > 0,
|
||||||
|
uploaded_hashes=hashes_missing_on_remote,
|
||||||
|
downloaded_hashes=hashes_missing_on_local,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to sync files for dataset {dataset.name} ({dataset.id}): {str(e)}")
|
||||||
|
raise # Re-raise to be handled by the caller
|
||||||
|
|
||||||
|
|
||||||
async def _extract_local_files_with_hashes(
|
async def _extract_local_files_with_hashes(
|
||||||
dataset: Dataset, user: User, run_id: str
|
dataset: Dataset, user: User, run_id: str
|
||||||
) -> List[LocalFileInfo]:
|
) -> List[LocalFileInfo]:
|
||||||
|
|
@ -334,43 +556,42 @@ async def _get_file_size(file_path: str) -> int:
|
||||||
|
|
||||||
async def _get_cloud_base_url() -> str:
|
async def _get_cloud_base_url() -> str:
|
||||||
"""Get Cognee Cloud API base URL."""
|
"""Get Cognee Cloud API base URL."""
|
||||||
# TODO: Make this configurable via environment variable or config
|
|
||||||
return os.getenv("COGNEE_CLOUD_API_URL", "http://localhost:8001")
|
return os.getenv("COGNEE_CLOUD_API_URL", "http://localhost:8001")
|
||||||
|
|
||||||
|
|
||||||
async def _get_cloud_auth_token(user: User) -> str:
|
async def _get_cloud_auth_token(user: User) -> str:
|
||||||
"""Get authentication token for Cognee Cloud API."""
|
"""Get authentication token for Cognee Cloud API."""
|
||||||
# TODO: Implement proper authentication with Cognee Cloud
|
return os.getenv("COGNEE_CLOUD_AUTH_TOKEN", "your-auth-token")
|
||||||
# This should get or refresh an API token for the user
|
|
||||||
return os.getenv("COGNEE_CLOUD_AUTH_TOKEN", "your-auth-token-here")
|
|
||||||
|
|
||||||
|
|
||||||
async def _check_missing_hashes(
|
async def _check_hashes_diff(
|
||||||
cloud_base_url: str, auth_token: str, dataset_id: str, local_hashes: List[str], run_id: str
|
cloud_base_url: str, auth_token: str, dataset: Dataset, local_hashes: List[str], run_id: str
|
||||||
) -> List[str]:
|
) -> CheckHashesDiffResponse:
|
||||||
"""
|
"""
|
||||||
Step 1: Check which hashes are missing on cloud.
|
Check which hashes are missing on cloud.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[str]: MD5 hashes that need to be uploaded
|
List[str]: MD5 hashes that need to be uploaded
|
||||||
"""
|
"""
|
||||||
url = f"{cloud_base_url}/api/sync/{dataset_id}/diff"
|
url = f"{cloud_base_url}/api/sync/{dataset.id}/diff"
|
||||||
headers = {"X-Api-Key": auth_token, "Content-Type": "application/json"}
|
headers = {"X-Api-Key": auth_token, "Content-Type": "application/json"}
|
||||||
|
|
||||||
payload = CheckMissingHashesRequest(hashes=local_hashes)
|
payload = CheckMissingHashesRequest(
|
||||||
|
dataset_id=str(dataset.id), dataset_name=dataset.name, hashes=local_hashes
|
||||||
|
)
|
||||||
|
|
||||||
logger.info(f"Checking missing hashes on cloud for dataset {dataset_id}")
|
logger.info(f"Checking missing hashes on cloud for dataset {dataset.id}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
async with session.post(url, json=payload.dict(), headers=headers) as response:
|
async with session.post(url, json=payload.dict(), headers=headers) as response:
|
||||||
if response.status == 200:
|
if response.status == 200:
|
||||||
data = await response.json()
|
data = await response.json()
|
||||||
missing_response = CheckMissingHashesResponse(**data)
|
missing_response = CheckHashesDiffResponse(**data)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Cloud reports {len(missing_response.missing)} missing files out of {len(local_hashes)} total"
|
f"Cloud is missing {len(missing_response.missing_on_remote)} out of {len(local_hashes)} files, local is missing {len(missing_response.missing_on_local)} files"
|
||||||
)
|
)
|
||||||
return missing_response.missing
|
return missing_response
|
||||||
else:
|
else:
|
||||||
error_text = await response.text()
|
error_text = await response.text()
|
||||||
logger.error(
|
logger.error(
|
||||||
|
|
@ -385,22 +606,137 @@ async def _check_missing_hashes(
|
||||||
raise ConnectionError(f"Failed to check missing hashes: {str(e)}")
|
raise ConnectionError(f"Failed to check missing hashes: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
async def _download_missing_files(
|
||||||
|
cloud_base_url: str,
|
||||||
|
auth_token: str,
|
||||||
|
dataset: Dataset,
|
||||||
|
hashes_missing_on_local: List[str],
|
||||||
|
user: User,
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Download files that are missing locally from the cloud.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: Total bytes downloaded
|
||||||
|
"""
|
||||||
|
logger.info(f"Downloading {len(hashes_missing_on_local)} missing files from cloud")
|
||||||
|
|
||||||
|
if not hashes_missing_on_local:
|
||||||
|
logger.info("No files need to be downloaded - all files already exist locally")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
total_bytes_downloaded = 0
|
||||||
|
downloaded_count = 0
|
||||||
|
|
||||||
|
headers = {"X-Api-Key": auth_token}
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
for file_hash in hashes_missing_on_local:
|
||||||
|
try:
|
||||||
|
# Download file from cloud by hash
|
||||||
|
download_url = f"{cloud_base_url}/api/sync/{dataset.id}/data/{file_hash}"
|
||||||
|
|
||||||
|
logger.debug(f"Downloading file with hash: {file_hash}")
|
||||||
|
|
||||||
|
async with session.get(download_url, headers=headers) as response:
|
||||||
|
if response.status == 200:
|
||||||
|
file_content = await response.read()
|
||||||
|
file_size = len(file_content)
|
||||||
|
|
||||||
|
# Get file metadata from response headers
|
||||||
|
file_name = response.headers.get("X-File-Name", f"file_{file_hash}")
|
||||||
|
|
||||||
|
# Save file locally using ingestion pipeline
|
||||||
|
await _save_downloaded_file(
|
||||||
|
dataset, file_hash, file_name, file_content, user
|
||||||
|
)
|
||||||
|
|
||||||
|
total_bytes_downloaded += file_size
|
||||||
|
downloaded_count += 1
|
||||||
|
|
||||||
|
logger.debug(f"Successfully downloaded {file_name} ({file_size} bytes)")
|
||||||
|
|
||||||
|
elif response.status == 404:
|
||||||
|
logger.warning(f"File with hash {file_hash} not found on cloud")
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
error_text = await response.text()
|
||||||
|
logger.error(
|
||||||
|
f"Failed to download file {file_hash}: Status {response.status} - {error_text}"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error downloading file {file_hash}: {str(e)}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Download summary: {downloaded_count}/{len(hashes_missing_on_local)} files downloaded, {total_bytes_downloaded} bytes total"
|
||||||
|
)
|
||||||
|
return total_bytes_downloaded
|
||||||
|
|
||||||
|
|
||||||
|
class InMemoryDownload:
|
||||||
|
def __init__(self, data: bytes, filename: str):
|
||||||
|
self.file = io.BufferedReader(io.BytesIO(data))
|
||||||
|
self.filename = filename
|
||||||
|
|
||||||
|
|
||||||
|
async def _save_downloaded_file(
|
||||||
|
dataset: Dataset,
|
||||||
|
file_hash: str,
|
||||||
|
file_name: str,
|
||||||
|
file_content: bytes,
|
||||||
|
user: User,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Save a downloaded file to local storage and register it in the dataset.
|
||||||
|
Uses the existing ingest_data function for consistency with normal ingestion.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset: The dataset to add the file to
|
||||||
|
file_hash: MD5 hash of the file content
|
||||||
|
file_name: Original file name
|
||||||
|
file_content: Raw file content bytes
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Create a temporary file-like object from the bytes
|
||||||
|
file_obj = InMemoryDownload(file_content, file_name)
|
||||||
|
|
||||||
|
# User is injected as dependency
|
||||||
|
|
||||||
|
# Use the existing ingest_data function to properly handle the file
|
||||||
|
# This ensures consistency with normal file ingestion
|
||||||
|
await ingest_data(
|
||||||
|
data=file_obj,
|
||||||
|
dataset_name=dataset.name,
|
||||||
|
user=user,
|
||||||
|
dataset_id=dataset.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(f"Successfully saved downloaded file: {file_name} (hash: {file_hash})")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to save downloaded file {file_name}: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
async def _upload_missing_files(
|
async def _upload_missing_files(
|
||||||
cloud_base_url: str,
|
cloud_base_url: str,
|
||||||
auth_token: str,
|
auth_token: str,
|
||||||
dataset: Dataset,
|
dataset: Dataset,
|
||||||
local_files: List[LocalFileInfo],
|
local_files: List[LocalFileInfo],
|
||||||
missing_hashes: List[str],
|
hashes_missing_on_remote: List[str],
|
||||||
run_id: str,
|
run_id: str,
|
||||||
) -> int:
|
) -> int:
|
||||||
"""
|
"""
|
||||||
Step 2: Upload files that are missing on cloud.
|
Upload files that are missing on cloud.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
int: Total bytes uploaded
|
int: Total bytes uploaded
|
||||||
"""
|
"""
|
||||||
# Filter local files to only those with missing hashes
|
# Filter local files to only those with missing hashes
|
||||||
files_to_upload = [f for f in local_files if f.content_hash in missing_hashes]
|
files_to_upload = [f for f in local_files if f.content_hash in hashes_missing_on_remote]
|
||||||
|
|
||||||
logger.info(f"Uploading {len(files_to_upload)} missing files to cloud")
|
logger.info(f"Uploading {len(files_to_upload)} missing files to cloud")
|
||||||
|
|
||||||
|
|
@ -442,13 +778,6 @@ async def _upload_missing_files(
|
||||||
if response.status in [200, 201]:
|
if response.status in [200, 201]:
|
||||||
total_bytes_uploaded += len(file_content)
|
total_bytes_uploaded += len(file_content)
|
||||||
uploaded_count += 1
|
uploaded_count += 1
|
||||||
|
|
||||||
# Update progress periodically
|
|
||||||
if uploaded_count % 10 == 0:
|
|
||||||
progress = (
|
|
||||||
25 + (uploaded_count / len(files_to_upload)) * 50
|
|
||||||
) # 25-75% range
|
|
||||||
await update_sync_operation(run_id, progress_percentage=int(progress))
|
|
||||||
else:
|
else:
|
||||||
error_text = await response.text()
|
error_text = await response.text()
|
||||||
logger.error(
|
logger.error(
|
||||||
|
|
@ -470,7 +799,7 @@ async def _prune_cloud_dataset(
|
||||||
cloud_base_url: str, auth_token: str, dataset_id: str, local_hashes: List[str], run_id: str
|
cloud_base_url: str, auth_token: str, dataset_id: str, local_hashes: List[str], run_id: str
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Step 3: Prune cloud dataset to match local state.
|
Prune cloud dataset to match local state.
|
||||||
"""
|
"""
|
||||||
url = f"{cloud_base_url}/api/sync/{dataset_id}?prune=true"
|
url = f"{cloud_base_url}/api/sync/{dataset_id}?prune=true"
|
||||||
headers = {"X-Api-Key": auth_token, "Content-Type": "application/json"}
|
headers = {"X-Api-Key": auth_token, "Content-Type": "application/json"}
|
||||||
|
|
@ -506,7 +835,7 @@ async def _trigger_remote_cognify(
|
||||||
cloud_base_url: str, auth_token: str, dataset_id: str, run_id: str
|
cloud_base_url: str, auth_token: str, dataset_id: str, run_id: str
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Step 4: Trigger cognify processing on the cloud dataset.
|
Trigger cognify processing on the cloud dataset.
|
||||||
|
|
||||||
This initiates knowledge graph processing on the synchronized dataset
|
This initiates knowledge graph processing on the synchronized dataset
|
||||||
using the cloud infrastructure.
|
using the cloud infrastructure.
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,9 @@
|
||||||
from .create_sync_operation import create_sync_operation
|
from .create_sync_operation import create_sync_operation
|
||||||
from .get_sync_operation import get_sync_operation, get_user_sync_operations
|
from .get_sync_operation import (
|
||||||
|
get_sync_operation,
|
||||||
|
get_user_sync_operations,
|
||||||
|
get_running_sync_operations_for_user,
|
||||||
|
)
|
||||||
from .update_sync_operation import (
|
from .update_sync_operation import (
|
||||||
update_sync_operation,
|
update_sync_operation,
|
||||||
mark_sync_started,
|
mark_sync_started,
|
||||||
|
|
@ -11,6 +15,7 @@ __all__ = [
|
||||||
"create_sync_operation",
|
"create_sync_operation",
|
||||||
"get_sync_operation",
|
"get_sync_operation",
|
||||||
"get_user_sync_operations",
|
"get_user_sync_operations",
|
||||||
|
"get_running_sync_operations_for_user",
|
||||||
"update_sync_operation",
|
"update_sync_operation",
|
||||||
"mark_sync_started",
|
"mark_sync_started",
|
||||||
"mark_sync_completed",
|
"mark_sync_completed",
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from typing import Optional
|
from typing import Optional, List
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from cognee.modules.sync.models import SyncOperation, SyncStatus
|
from cognee.modules.sync.models import SyncOperation, SyncStatus
|
||||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||||
|
|
@ -7,20 +7,24 @@ from cognee.infrastructure.databases.relational import get_relational_engine
|
||||||
|
|
||||||
async def create_sync_operation(
|
async def create_sync_operation(
|
||||||
run_id: str,
|
run_id: str,
|
||||||
dataset_id: UUID,
|
dataset_ids: List[UUID],
|
||||||
dataset_name: str,
|
dataset_names: List[str],
|
||||||
user_id: UUID,
|
user_id: UUID,
|
||||||
total_records: Optional[int] = None,
|
total_records_to_sync: Optional[int] = None,
|
||||||
|
total_records_to_download: Optional[int] = None,
|
||||||
|
total_records_to_upload: Optional[int] = None,
|
||||||
) -> SyncOperation:
|
) -> SyncOperation:
|
||||||
"""
|
"""
|
||||||
Create a new sync operation record in the database.
|
Create a new sync operation record in the database.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
run_id: Unique public identifier for this sync operation
|
run_id: Unique public identifier for this sync operation
|
||||||
dataset_id: UUID of the dataset being synced
|
dataset_ids: List of dataset UUIDs being synced
|
||||||
dataset_name: Name of the dataset being synced
|
dataset_names: List of dataset names being synced
|
||||||
user_id: UUID of the user who initiated the sync
|
user_id: UUID of the user who initiated the sync
|
||||||
total_records: Total number of records to sync (if known)
|
total_records_to_sync: Total number of records to sync (if known)
|
||||||
|
total_records_to_download: Total number of records to download (if known)
|
||||||
|
total_records_to_upload: Total number of records to upload (if known)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
SyncOperation: The created sync operation record
|
SyncOperation: The created sync operation record
|
||||||
|
|
@ -29,11 +33,15 @@ async def create_sync_operation(
|
||||||
|
|
||||||
sync_operation = SyncOperation(
|
sync_operation = SyncOperation(
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
dataset_id=dataset_id,
|
dataset_ids=[
|
||||||
dataset_name=dataset_name,
|
str(uuid) for uuid in dataset_ids
|
||||||
|
], # Convert UUIDs to strings for JSON storage
|
||||||
|
dataset_names=dataset_names,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
status=SyncStatus.STARTED,
|
status=SyncStatus.STARTED,
|
||||||
total_records=total_records,
|
total_records_to_sync=total_records_to_sync,
|
||||||
|
total_records_to_download=total_records_to_download,
|
||||||
|
total_records_to_upload=total_records_to_upload,
|
||||||
created_at=datetime.now(timezone.utc),
|
created_at=datetime.now(timezone.utc),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
from sqlalchemy import select, desc
|
from sqlalchemy import select, desc, and_
|
||||||
from cognee.modules.sync.models import SyncOperation
|
from cognee.modules.sync.models import SyncOperation, SyncStatus
|
||||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -77,3 +77,31 @@ async def get_sync_operations_by_dataset(
|
||||||
)
|
)
|
||||||
result = await session.execute(query)
|
result = await session.execute(query)
|
||||||
return list(result.scalars().all())
|
return list(result.scalars().all())
|
||||||
|
|
||||||
|
|
||||||
|
async def get_running_sync_operations_for_user(user_id: UUID) -> List[SyncOperation]:
|
||||||
|
"""
|
||||||
|
Get all currently running sync operations for a specific user.
|
||||||
|
Checks for operations with STARTED or IN_PROGRESS status.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: UUID of the user
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[SyncOperation]: List of running sync operations for the user
|
||||||
|
"""
|
||||||
|
db_engine = get_relational_engine()
|
||||||
|
|
||||||
|
async with db_engine.get_async_session() as session:
|
||||||
|
query = (
|
||||||
|
select(SyncOperation)
|
||||||
|
.where(
|
||||||
|
and_(
|
||||||
|
SyncOperation.user_id == user_id,
|
||||||
|
SyncOperation.status.in_([SyncStatus.STARTED, SyncStatus.IN_PROGRESS]),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.order_by(desc(SyncOperation.created_at))
|
||||||
|
)
|
||||||
|
result = await session.execute(query)
|
||||||
|
return list(result.scalars().all())
|
||||||
|
|
|
||||||
|
|
@ -1,16 +1,74 @@
|
||||||
from typing import Optional
|
import asyncio
|
||||||
|
from typing import Optional, List
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.exc import SQLAlchemyError, DisconnectionError, OperationalError, TimeoutError
|
||||||
from cognee.modules.sync.models import SyncOperation, SyncStatus
|
from cognee.modules.sync.models import SyncOperation, SyncStatus
|
||||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||||
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
from cognee.infrastructure.utils.calculate_backoff import calculate_backoff
|
||||||
|
|
||||||
|
logger = get_logger("sync.db_operations")
|
||||||
|
|
||||||
|
|
||||||
|
async def _retry_db_operation(operation_func, run_id: str, max_retries: int = 3):
|
||||||
|
"""
|
||||||
|
Retry database operations with exponential backoff for transient failures.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
operation_func: Async function to retry
|
||||||
|
run_id: Run ID for logging context
|
||||||
|
max_retries: Maximum number of retry attempts
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Result of the operation function
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: Re-raises the last exception if all retries fail
|
||||||
|
"""
|
||||||
|
attempt = 0
|
||||||
|
last_exception = None
|
||||||
|
|
||||||
|
while attempt < max_retries:
|
||||||
|
try:
|
||||||
|
return await operation_func()
|
||||||
|
except (DisconnectionError, OperationalError, TimeoutError) as e:
|
||||||
|
attempt += 1
|
||||||
|
last_exception = e
|
||||||
|
|
||||||
|
if attempt >= max_retries:
|
||||||
|
logger.error(
|
||||||
|
f"Database operation failed after {max_retries} attempts for run_id {run_id}: {str(e)}"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
backoff_time = calculate_backoff(attempt - 1) # calculate_backoff is 0-indexed
|
||||||
|
logger.warning(
|
||||||
|
f"Database operation failed for run_id {run_id}, retrying in {backoff_time:.2f}s (attempt {attempt}/{max_retries}): {str(e)}"
|
||||||
|
)
|
||||||
|
await asyncio.sleep(backoff_time)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Non-transient errors should not be retried
|
||||||
|
logger.error(f"Non-retryable database error for run_id {run_id}: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
# If we get here, all retries failed
|
||||||
|
raise last_exception
|
||||||
|
|
||||||
|
|
||||||
async def update_sync_operation(
|
async def update_sync_operation(
|
||||||
run_id: str,
|
run_id: str,
|
||||||
status: Optional[SyncStatus] = None,
|
status: Optional[SyncStatus] = None,
|
||||||
progress_percentage: Optional[int] = None,
|
progress_percentage: Optional[int] = None,
|
||||||
processed_records: Optional[int] = None,
|
records_downloaded: Optional[int] = None,
|
||||||
bytes_transferred: Optional[int] = None,
|
total_records_to_sync: Optional[int] = None,
|
||||||
|
total_records_to_download: Optional[int] = None,
|
||||||
|
total_records_to_upload: Optional[int] = None,
|
||||||
|
records_uploaded: Optional[int] = None,
|
||||||
|
bytes_downloaded: Optional[int] = None,
|
||||||
|
bytes_uploaded: Optional[int] = None,
|
||||||
|
dataset_sync_hashes: Optional[dict] = None,
|
||||||
error_message: Optional[str] = None,
|
error_message: Optional[str] = None,
|
||||||
retry_count: Optional[int] = None,
|
retry_count: Optional[int] = None,
|
||||||
started_at: Optional[datetime] = None,
|
started_at: Optional[datetime] = None,
|
||||||
|
|
@ -23,8 +81,14 @@ async def update_sync_operation(
|
||||||
run_id: The public run_id of the sync operation to update
|
run_id: The public run_id of the sync operation to update
|
||||||
status: New status for the operation
|
status: New status for the operation
|
||||||
progress_percentage: Progress percentage (0-100)
|
progress_percentage: Progress percentage (0-100)
|
||||||
processed_records: Number of records processed so far
|
records_downloaded: Number of records downloaded so far
|
||||||
bytes_transferred: Total bytes transferred
|
total_records_to_sync: Total number of records that need to be synced
|
||||||
|
total_records_to_download: Total number of records to download from cloud
|
||||||
|
total_records_to_upload: Total number of records to upload to cloud
|
||||||
|
records_uploaded: Number of records uploaded so far
|
||||||
|
bytes_downloaded: Total bytes downloaded from cloud
|
||||||
|
bytes_uploaded: Total bytes uploaded to cloud
|
||||||
|
dataset_sync_hashes: Dict mapping dataset_id -> {uploaded: [hashes], downloaded: [hashes]}
|
||||||
error_message: Error message if operation failed
|
error_message: Error message if operation failed
|
||||||
retry_count: Number of retry attempts
|
retry_count: Number of retry attempts
|
||||||
started_at: When the actual processing started
|
started_at: When the actual processing started
|
||||||
|
|
@ -33,57 +97,116 @@ async def update_sync_operation(
|
||||||
Returns:
|
Returns:
|
||||||
SyncOperation: The updated sync operation record, or None if not found
|
SyncOperation: The updated sync operation record, or None if not found
|
||||||
"""
|
"""
|
||||||
db_engine = get_relational_engine()
|
|
||||||
|
|
||||||
async with db_engine.get_async_session() as session:
|
async def _perform_update():
|
||||||
# Find the sync operation
|
db_engine = get_relational_engine()
|
||||||
query = select(SyncOperation).where(SyncOperation.run_id == run_id)
|
|
||||||
result = await session.execute(query)
|
|
||||||
sync_operation = result.scalars().first()
|
|
||||||
|
|
||||||
if not sync_operation:
|
async with db_engine.get_async_session() as session:
|
||||||
return None
|
try:
|
||||||
|
# Find the sync operation
|
||||||
|
query = select(SyncOperation).where(SyncOperation.run_id == run_id)
|
||||||
|
result = await session.execute(query)
|
||||||
|
sync_operation = result.scalars().first()
|
||||||
|
|
||||||
# Update fields that were provided
|
if not sync_operation:
|
||||||
if status is not None:
|
logger.warning(f"Sync operation not found for run_id: {run_id}")
|
||||||
sync_operation.status = status
|
return None
|
||||||
|
|
||||||
if progress_percentage is not None:
|
# Log what we're updating for debugging
|
||||||
sync_operation.progress_percentage = max(0, min(100, progress_percentage))
|
updates = []
|
||||||
|
if status is not None:
|
||||||
|
updates.append(f"status={status.value}")
|
||||||
|
if progress_percentage is not None:
|
||||||
|
updates.append(f"progress={progress_percentage}%")
|
||||||
|
if records_downloaded is not None:
|
||||||
|
updates.append(f"downloaded={records_downloaded}")
|
||||||
|
if records_uploaded is not None:
|
||||||
|
updates.append(f"uploaded={records_uploaded}")
|
||||||
|
if total_records_to_sync is not None:
|
||||||
|
updates.append(f"total_sync={total_records_to_sync}")
|
||||||
|
if total_records_to_download is not None:
|
||||||
|
updates.append(f"total_download={total_records_to_download}")
|
||||||
|
if total_records_to_upload is not None:
|
||||||
|
updates.append(f"total_upload={total_records_to_upload}")
|
||||||
|
|
||||||
if processed_records is not None:
|
if updates:
|
||||||
sync_operation.processed_records = processed_records
|
logger.debug(f"Updating sync operation {run_id}: {', '.join(updates)}")
|
||||||
|
|
||||||
if bytes_transferred is not None:
|
# Update fields that were provided
|
||||||
sync_operation.bytes_transferred = bytes_transferred
|
if status is not None:
|
||||||
|
sync_operation.status = status
|
||||||
|
|
||||||
if error_message is not None:
|
if progress_percentage is not None:
|
||||||
sync_operation.error_message = error_message
|
sync_operation.progress_percentage = max(0, min(100, progress_percentage))
|
||||||
|
|
||||||
if retry_count is not None:
|
if records_downloaded is not None:
|
||||||
sync_operation.retry_count = retry_count
|
sync_operation.records_downloaded = records_downloaded
|
||||||
|
|
||||||
if started_at is not None:
|
if records_uploaded is not None:
|
||||||
sync_operation.started_at = started_at
|
sync_operation.records_uploaded = records_uploaded
|
||||||
|
|
||||||
if completed_at is not None:
|
if total_records_to_sync is not None:
|
||||||
sync_operation.completed_at = completed_at
|
sync_operation.total_records_to_sync = total_records_to_sync
|
||||||
|
|
||||||
# Auto-set completion timestamp for terminal statuses
|
if total_records_to_download is not None:
|
||||||
if (
|
sync_operation.total_records_to_download = total_records_to_download
|
||||||
status in [SyncStatus.COMPLETED, SyncStatus.FAILED, SyncStatus.CANCELLED]
|
|
||||||
and completed_at is None
|
|
||||||
):
|
|
||||||
sync_operation.completed_at = datetime.now(timezone.utc)
|
|
||||||
|
|
||||||
# Auto-set started timestamp when moving to IN_PROGRESS
|
if total_records_to_upload is not None:
|
||||||
if status == SyncStatus.IN_PROGRESS and sync_operation.started_at is None:
|
sync_operation.total_records_to_upload = total_records_to_upload
|
||||||
sync_operation.started_at = datetime.now(timezone.utc)
|
|
||||||
|
|
||||||
await session.commit()
|
if bytes_downloaded is not None:
|
||||||
await session.refresh(sync_operation)
|
sync_operation.bytes_downloaded = bytes_downloaded
|
||||||
|
|
||||||
return sync_operation
|
if bytes_uploaded is not None:
|
||||||
|
sync_operation.bytes_uploaded = bytes_uploaded
|
||||||
|
|
||||||
|
if dataset_sync_hashes is not None:
|
||||||
|
sync_operation.dataset_sync_hashes = dataset_sync_hashes
|
||||||
|
|
||||||
|
if error_message is not None:
|
||||||
|
sync_operation.error_message = error_message
|
||||||
|
|
||||||
|
if retry_count is not None:
|
||||||
|
sync_operation.retry_count = retry_count
|
||||||
|
|
||||||
|
if started_at is not None:
|
||||||
|
sync_operation.started_at = started_at
|
||||||
|
|
||||||
|
if completed_at is not None:
|
||||||
|
sync_operation.completed_at = completed_at
|
||||||
|
|
||||||
|
# Auto-set completion timestamp for terminal statuses
|
||||||
|
if (
|
||||||
|
status in [SyncStatus.COMPLETED, SyncStatus.FAILED, SyncStatus.CANCELLED]
|
||||||
|
and completed_at is None
|
||||||
|
):
|
||||||
|
sync_operation.completed_at = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
# Auto-set started timestamp when moving to IN_PROGRESS
|
||||||
|
if status == SyncStatus.IN_PROGRESS and sync_operation.started_at is None:
|
||||||
|
sync_operation.started_at = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(sync_operation)
|
||||||
|
|
||||||
|
logger.debug(f"Successfully updated sync operation {run_id}")
|
||||||
|
return sync_operation
|
||||||
|
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
logger.error(
|
||||||
|
f"Database error updating sync operation {run_id}: {str(e)}", exc_info=True
|
||||||
|
)
|
||||||
|
await session.rollback()
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Unexpected error updating sync operation {run_id}: {str(e)}", exc_info=True
|
||||||
|
)
|
||||||
|
await session.rollback()
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Use retry logic for the database operation
|
||||||
|
return await _retry_db_operation(_perform_update, run_id)
|
||||||
|
|
||||||
|
|
||||||
async def mark_sync_started(run_id: str) -> Optional[SyncOperation]:
|
async def mark_sync_started(run_id: str) -> Optional[SyncOperation]:
|
||||||
|
|
@ -94,15 +217,23 @@ async def mark_sync_started(run_id: str) -> Optional[SyncOperation]:
|
||||||
|
|
||||||
|
|
||||||
async def mark_sync_completed(
|
async def mark_sync_completed(
|
||||||
run_id: str, processed_records: int, bytes_transferred: int
|
run_id: str,
|
||||||
|
records_downloaded: int = 0,
|
||||||
|
records_uploaded: int = 0,
|
||||||
|
bytes_downloaded: int = 0,
|
||||||
|
bytes_uploaded: int = 0,
|
||||||
|
dataset_sync_hashes: Optional[dict] = None,
|
||||||
) -> Optional[SyncOperation]:
|
) -> Optional[SyncOperation]:
|
||||||
"""Convenience method to mark a sync operation as completed successfully."""
|
"""Convenience method to mark a sync operation as completed successfully."""
|
||||||
return await update_sync_operation(
|
return await update_sync_operation(
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
status=SyncStatus.COMPLETED,
|
status=SyncStatus.COMPLETED,
|
||||||
progress_percentage=100,
|
progress_percentage=100,
|
||||||
processed_records=processed_records,
|
records_downloaded=records_downloaded,
|
||||||
bytes_transferred=bytes_transferred,
|
records_uploaded=records_uploaded,
|
||||||
|
bytes_downloaded=bytes_downloaded,
|
||||||
|
bytes_uploaded=bytes_uploaded,
|
||||||
|
dataset_sync_hashes=dataset_sync_hashes,
|
||||||
completed_at=datetime.now(timezone.utc),
|
completed_at=datetime.now(timezone.utc),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,16 @@
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Optional
|
from typing import Optional, List
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from sqlalchemy import Column, Text, DateTime, UUID as SQLAlchemy_UUID, Integer, Enum as SQLEnum
|
from sqlalchemy import (
|
||||||
|
Column,
|
||||||
|
Text,
|
||||||
|
DateTime,
|
||||||
|
UUID as SQLAlchemy_UUID,
|
||||||
|
Integer,
|
||||||
|
Enum as SQLEnum,
|
||||||
|
JSON,
|
||||||
|
)
|
||||||
|
|
||||||
from cognee.infrastructure.databases.relational import Base
|
from cognee.infrastructure.databases.relational import Base
|
||||||
|
|
||||||
|
|
@ -38,8 +46,8 @@ class SyncOperation(Base):
|
||||||
progress_percentage = Column(Integer, default=0, doc="Progress percentage (0-100)")
|
progress_percentage = Column(Integer, default=0, doc="Progress percentage (0-100)")
|
||||||
|
|
||||||
# Operation metadata
|
# Operation metadata
|
||||||
dataset_id = Column(SQLAlchemy_UUID, index=True, doc="ID of the dataset being synced")
|
dataset_ids = Column(JSON, doc="Array of dataset IDs being synced")
|
||||||
dataset_name = Column(Text, doc="Name of the dataset being synced")
|
dataset_names = Column(JSON, doc="Array of dataset names being synced")
|
||||||
user_id = Column(SQLAlchemy_UUID, index=True, doc="ID of the user who initiated the sync")
|
user_id = Column(SQLAlchemy_UUID, index=True, doc="ID of the user who initiated the sync")
|
||||||
|
|
||||||
# Timing information
|
# Timing information
|
||||||
|
|
@ -54,18 +62,24 @@ class SyncOperation(Base):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Operation details
|
# Operation details
|
||||||
total_records = Column(Integer, doc="Total number of records to sync")
|
total_records_to_sync = Column(Integer, doc="Total number of records to sync")
|
||||||
processed_records = Column(Integer, default=0, doc="Number of records successfully processed")
|
total_records_to_download = Column(Integer, doc="Total number of records to download")
|
||||||
bytes_transferred = Column(Integer, default=0, doc="Total bytes transferred to cloud")
|
total_records_to_upload = Column(Integer, doc="Total number of records to upload")
|
||||||
|
|
||||||
|
records_downloaded = Column(Integer, default=0, doc="Number of records successfully downloaded")
|
||||||
|
records_uploaded = Column(Integer, default=0, doc="Number of records successfully uploaded")
|
||||||
|
bytes_downloaded = Column(Integer, default=0, doc="Total bytes downloaded from cloud")
|
||||||
|
bytes_uploaded = Column(Integer, default=0, doc="Total bytes uploaded to cloud")
|
||||||
|
|
||||||
|
# Data lineage tracking per dataset
|
||||||
|
dataset_sync_hashes = Column(
|
||||||
|
JSON, doc="Mapping of dataset_id -> {uploaded: [hashes], downloaded: [hashes]}"
|
||||||
|
)
|
||||||
|
|
||||||
# Error handling
|
# Error handling
|
||||||
error_message = Column(Text, doc="Error message if sync failed")
|
error_message = Column(Text, doc="Error message if sync failed")
|
||||||
retry_count = Column(Integer, default=0, doc="Number of retry attempts")
|
retry_count = Column(Integer, default=0, doc="Number of retry attempts")
|
||||||
|
|
||||||
# Additional metadata (can be added later when needed)
|
|
||||||
# cloud_endpoint = Column(Text, doc="Cloud endpoint used for sync")
|
|
||||||
# compression_enabled = Column(Text, doc="Whether compression was used")
|
|
||||||
|
|
||||||
def get_duration_seconds(self) -> Optional[float]:
|
def get_duration_seconds(self) -> Optional[float]:
|
||||||
"""Get the duration of the sync operation in seconds."""
|
"""Get the duration of the sync operation in seconds."""
|
||||||
if not self.created_at:
|
if not self.created_at:
|
||||||
|
|
@ -76,11 +90,53 @@ class SyncOperation(Base):
|
||||||
|
|
||||||
def get_progress_info(self) -> dict:
|
def get_progress_info(self) -> dict:
|
||||||
"""Get comprehensive progress information."""
|
"""Get comprehensive progress information."""
|
||||||
|
total_records_processed = (self.records_downloaded or 0) + (self.records_uploaded or 0)
|
||||||
|
total_bytes_transferred = (self.bytes_downloaded or 0) + (self.bytes_uploaded or 0)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": self.status.value,
|
"status": self.status.value,
|
||||||
"progress_percentage": self.progress_percentage,
|
"progress_percentage": self.progress_percentage,
|
||||||
"records_processed": f"{self.processed_records or 0}/{self.total_records or 'unknown'}",
|
"records_processed": f"{total_records_processed}/{self.total_records_to_sync or 'unknown'}",
|
||||||
"bytes_transferred": self.bytes_transferred or 0,
|
"records_downloaded": self.records_downloaded or 0,
|
||||||
|
"records_uploaded": self.records_uploaded or 0,
|
||||||
|
"bytes_transferred": total_bytes_transferred,
|
||||||
|
"bytes_downloaded": self.bytes_downloaded or 0,
|
||||||
|
"bytes_uploaded": self.bytes_uploaded or 0,
|
||||||
"duration_seconds": self.get_duration_seconds(),
|
"duration_seconds": self.get_duration_seconds(),
|
||||||
"error_message": self.error_message,
|
"error_message": self.error_message,
|
||||||
|
"dataset_sync_hashes": self.dataset_sync_hashes or {},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def _get_all_sync_hashes(self) -> List[str]:
|
||||||
|
"""Get all content hashes for data created/modified during this sync operation."""
|
||||||
|
all_hashes = set()
|
||||||
|
dataset_hashes = self.dataset_sync_hashes or {}
|
||||||
|
|
||||||
|
for dataset_id, operations in dataset_hashes.items():
|
||||||
|
if isinstance(operations, dict):
|
||||||
|
all_hashes.update(operations.get("uploaded", []))
|
||||||
|
all_hashes.update(operations.get("downloaded", []))
|
||||||
|
|
||||||
|
return list(all_hashes)
|
||||||
|
|
||||||
|
def _get_dataset_sync_hashes(self, dataset_id: str) -> dict:
|
||||||
|
"""Get uploaded/downloaded hashes for a specific dataset."""
|
||||||
|
dataset_hashes = self.dataset_sync_hashes or {}
|
||||||
|
return dataset_hashes.get(dataset_id, {"uploaded": [], "downloaded": []})
|
||||||
|
|
||||||
|
def was_data_synced(self, content_hash: str, dataset_id: str = None) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a specific piece of data was part of this sync operation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content_hash: The content hash to check for
|
||||||
|
dataset_id: Optional - check only within this dataset
|
||||||
|
"""
|
||||||
|
if dataset_id:
|
||||||
|
dataset_hashes = self.get_dataset_sync_hashes(dataset_id)
|
||||||
|
return content_hash in dataset_hashes.get(
|
||||||
|
"uploaded", []
|
||||||
|
) or content_hash in dataset_hashes.get("downloaded", [])
|
||||||
|
|
||||||
|
all_hashes = self.get_all_sync_hashes()
|
||||||
|
return content_hash in all_hashes
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue