571 lines
21 KiB
Python
571 lines
21 KiB
Python
import asyncio
|
|
import io
|
|
import os
|
|
import uuid
|
|
from datetime import datetime
|
|
from typing import Dict, List, Any, Optional
|
|
from googleapiclient.discovery import build
|
|
from googleapiclient.errors import HttpError
|
|
from googleapiclient.http import MediaIoBaseDownload
|
|
from utils.logging_config import get_logger
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
from ..base import BaseConnector, ConnectorDocument, DocumentACL
|
|
from .oauth import GoogleDriveOAuth
|
|
|
|
|
|
# Global worker service cache for process pools
|
|
_worker_drive_service = None
|
|
|
|
|
|
def get_worker_drive_service(client_id: str, client_secret: str, token_file: str):
|
|
"""Get or create a Google Drive service instance for this worker process"""
|
|
global _worker_drive_service
|
|
if _worker_drive_service is None:
|
|
logger.info("Initializing Google Drive service in worker process", pid=os.getpid())
|
|
|
|
# Create OAuth instance and load credentials in worker
|
|
from .oauth import GoogleDriveOAuth
|
|
|
|
oauth = GoogleDriveOAuth(
|
|
client_id=client_id, client_secret=client_secret, token_file=token_file
|
|
)
|
|
|
|
# Load credentials synchronously in worker
|
|
import asyncio
|
|
|
|
loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(loop)
|
|
try:
|
|
loop.run_until_complete(oauth.load_credentials())
|
|
_worker_drive_service = oauth.get_service()
|
|
logger.info("Google Drive service ready in worker process", pid=os.getpid())
|
|
finally:
|
|
loop.close()
|
|
|
|
return _worker_drive_service
|
|
|
|
|
|
# Module-level functions for process pool execution (must be pickleable)
|
|
def _sync_list_files_worker(
|
|
client_id, client_secret, token_file, query, page_token, page_size
|
|
):
|
|
"""Worker function for listing files in process pool"""
|
|
service = get_worker_drive_service(client_id, client_secret, token_file)
|
|
return (
|
|
service.files()
|
|
.list(
|
|
q=query,
|
|
pageSize=page_size,
|
|
pageToken=page_token,
|
|
fields="nextPageToken, files(id, name, mimeType, modifiedTime, createdTime, webViewLink, permissions, owners)",
|
|
)
|
|
.execute()
|
|
)
|
|
|
|
|
|
def _sync_get_metadata_worker(client_id, client_secret, token_file, file_id):
|
|
"""Worker function for getting file metadata in process pool"""
|
|
service = get_worker_drive_service(client_id, client_secret, token_file)
|
|
return (
|
|
service.files()
|
|
.get(
|
|
fileId=file_id,
|
|
fields="id, name, mimeType, modifiedTime, createdTime, webViewLink, permissions, owners, size",
|
|
)
|
|
.execute()
|
|
)
|
|
|
|
|
|
def _sync_download_worker(
|
|
client_id, client_secret, token_file, file_id, mime_type, file_size=None
|
|
):
|
|
"""Worker function for downloading files in process pool"""
|
|
import signal
|
|
import time
|
|
|
|
# File size limits (in bytes)
|
|
MAX_REGULAR_FILE_SIZE = 100 * 1024 * 1024 # 100MB for regular files
|
|
MAX_GOOGLE_WORKSPACE_SIZE = (
|
|
50 * 1024 * 1024
|
|
) # 50MB for Google Workspace docs (they can't be streamed)
|
|
|
|
# Check file size limits
|
|
if file_size:
|
|
if (
|
|
mime_type.startswith("application/vnd.google-apps.")
|
|
and file_size > MAX_GOOGLE_WORKSPACE_SIZE
|
|
):
|
|
raise ValueError(
|
|
f"Google Workspace file too large: {file_size} bytes (max {MAX_GOOGLE_WORKSPACE_SIZE})"
|
|
)
|
|
elif (
|
|
not mime_type.startswith("application/vnd.google-apps.")
|
|
and file_size > MAX_REGULAR_FILE_SIZE
|
|
):
|
|
raise ValueError(
|
|
f"File too large: {file_size} bytes (max {MAX_REGULAR_FILE_SIZE})"
|
|
)
|
|
|
|
# Dynamic timeout based on file size (minimum 60s, 10s per MB, max 300s)
|
|
if file_size:
|
|
file_size_mb = file_size / (1024 * 1024)
|
|
timeout_seconds = min(300, max(60, int(file_size_mb * 10)))
|
|
else:
|
|
timeout_seconds = 60 # Default timeout if size unknown
|
|
|
|
# Set a timeout for the entire download operation
|
|
def timeout_handler(signum, frame):
|
|
raise TimeoutError(f"File download timed out after {timeout_seconds} seconds")
|
|
|
|
signal.signal(signal.SIGALRM, timeout_handler)
|
|
signal.alarm(timeout_seconds)
|
|
|
|
try:
|
|
service = get_worker_drive_service(client_id, client_secret, token_file)
|
|
|
|
# For Google native formats, export as PDF
|
|
if mime_type.startswith("application/vnd.google-apps."):
|
|
export_format = "application/pdf"
|
|
request = service.files().export_media(
|
|
fileId=file_id, mimeType=export_format
|
|
)
|
|
else:
|
|
# For regular files, download directly
|
|
request = service.files().get_media(fileId=file_id)
|
|
|
|
# Download file with chunked approach
|
|
file_io = io.BytesIO()
|
|
downloader = MediaIoBaseDownload(
|
|
file_io, request, chunksize=1024 * 1024
|
|
) # 1MB chunks
|
|
|
|
done = False
|
|
retry_count = 0
|
|
max_retries = 2
|
|
|
|
while not done and retry_count < max_retries:
|
|
try:
|
|
status, done = downloader.next_chunk()
|
|
retry_count = 0 # Reset retry count on successful chunk
|
|
except Exception as e:
|
|
retry_count += 1
|
|
if retry_count >= max_retries:
|
|
raise e
|
|
time.sleep(1) # Brief pause before retry
|
|
|
|
return file_io.getvalue()
|
|
|
|
finally:
|
|
# Cancel the alarm
|
|
signal.alarm(0)
|
|
|
|
|
|
class GoogleDriveConnector(BaseConnector):
|
|
"""Google Drive connector with OAuth and webhook support"""
|
|
|
|
# OAuth environment variables
|
|
CLIENT_ID_ENV_VAR = "GOOGLE_OAUTH_CLIENT_ID"
|
|
CLIENT_SECRET_ENV_VAR = "GOOGLE_OAUTH_CLIENT_SECRET"
|
|
|
|
# Connector metadata
|
|
CONNECTOR_NAME = "Google Drive"
|
|
CONNECTOR_DESCRIPTION = "Connect your Google Drive to automatically sync documents"
|
|
CONNECTOR_ICON = "google-drive"
|
|
|
|
# Supported file types that can be processed by docling
|
|
SUPPORTED_MIMETYPES = {
|
|
"application/pdf",
|
|
"application/vnd.openxmlformats-officedocument.wordprocessingml.document", # .docx
|
|
"application/msword", # .doc
|
|
"application/vnd.openxmlformats-officedocument.presentationml.presentation", # .pptx
|
|
"application/vnd.ms-powerpoint", # .ppt
|
|
"text/plain",
|
|
"text/html",
|
|
"application/rtf",
|
|
# Google Docs native formats - we'll export these
|
|
"application/vnd.google-apps.document", # Google Docs -> PDF
|
|
"application/vnd.google-apps.presentation", # Google Slides -> PDF
|
|
"application/vnd.google-apps.spreadsheet", # Google Sheets -> PDF
|
|
}
|
|
|
|
def __init__(self, config: Dict[str, Any]):
|
|
super().__init__(config)
|
|
self.oauth = GoogleDriveOAuth(
|
|
client_id=self.get_client_id(),
|
|
client_secret=self.get_client_secret(),
|
|
token_file=config.get("token_file", "gdrive_token.json"),
|
|
)
|
|
self.service = None
|
|
# Load existing webhook channel ID from config if available
|
|
self.webhook_channel_id = config.get("webhook_channel_id") or config.get(
|
|
"subscription_id"
|
|
)
|
|
# Load existing webhook resource ID (Google Drive requires this to stop a channel)
|
|
self.webhook_resource_id = config.get("resource_id")
|
|
|
|
async def authenticate(self) -> bool:
|
|
"""Authenticate with Google Drive"""
|
|
try:
|
|
if await self.oauth.is_authenticated():
|
|
self.service = self.oauth.get_service()
|
|
self._authenticated = True
|
|
return True
|
|
return False
|
|
except Exception as e:
|
|
logger.error("Authentication failed", error=str(e))
|
|
return False
|
|
|
|
async def setup_subscription(self) -> str:
|
|
"""Set up Google Drive push notifications"""
|
|
if not self._authenticated:
|
|
raise ValueError("Not authenticated")
|
|
|
|
# Generate unique channel ID
|
|
channel_id = str(uuid.uuid4())
|
|
|
|
# Set up push notification
|
|
# Note: This requires a publicly accessible webhook endpoint
|
|
webhook_url = self.config.get("webhook_url")
|
|
if not webhook_url:
|
|
raise ValueError("webhook_url required in config for subscriptions")
|
|
|
|
try:
|
|
body = {
|
|
"id": channel_id,
|
|
"type": "web_hook",
|
|
"address": webhook_url,
|
|
"payload": True,
|
|
"expiration": str(
|
|
int((datetime.now().timestamp() + 86400) * 1000)
|
|
), # 24 hours
|
|
}
|
|
|
|
result = (
|
|
self.service.changes()
|
|
.watch(pageToken=self._get_start_page_token(), body=body)
|
|
.execute()
|
|
)
|
|
|
|
self.webhook_channel_id = channel_id
|
|
# Persist the resourceId returned by Google to allow proper cleanup
|
|
try:
|
|
self.webhook_resource_id = result.get("resourceId")
|
|
except Exception:
|
|
self.webhook_resource_id = None
|
|
return channel_id
|
|
|
|
except HttpError as e:
|
|
logger.error("Failed to set up subscription", error=str(e))
|
|
raise
|
|
|
|
def _get_start_page_token(self) -> str:
|
|
"""Get the current page token for change notifications"""
|
|
return self.service.changes().getStartPageToken().execute()["startPageToken"]
|
|
|
|
async def list_files(
|
|
self, page_token: Optional[str] = None, limit: Optional[int] = None
|
|
) -> Dict[str, Any]:
|
|
"""List all supported files in Google Drive.
|
|
|
|
Uses a thread pool (not the shared process pool) to avoid issues with
|
|
Google API clients in forked processes and adds light retries for
|
|
transient BrokenPipe/connection errors.
|
|
"""
|
|
if not self._authenticated:
|
|
raise ValueError("Not authenticated")
|
|
|
|
# Build query for supported file types
|
|
mimetype_query = " or ".join(
|
|
[f"mimeType='{mt}'" for mt in self.SUPPORTED_MIMETYPES]
|
|
)
|
|
query = f"({mimetype_query}) and trashed=false"
|
|
|
|
# Use provided limit or default to 100, max 1000 (Google Drive API limit)
|
|
page_size = min(limit or 100, 1000)
|
|
|
|
def _sync_list_files_inner():
|
|
import time
|
|
|
|
attempts = 0
|
|
max_attempts = 3
|
|
backoff = 1.0
|
|
while True:
|
|
try:
|
|
return (
|
|
self.service.files()
|
|
.list(
|
|
q=query,
|
|
pageSize=page_size,
|
|
pageToken=page_token,
|
|
fields="nextPageToken, files(id, name, mimeType, modifiedTime, createdTime, webViewLink, permissions, owners)",
|
|
)
|
|
.execute()
|
|
)
|
|
except Exception as e:
|
|
attempts += 1
|
|
is_broken_pipe = isinstance(e, BrokenPipeError) or (
|
|
isinstance(e, OSError) and getattr(e, "errno", None) == 32
|
|
)
|
|
if attempts < max_attempts and is_broken_pipe:
|
|
time.sleep(backoff)
|
|
backoff = min(4.0, backoff * 2)
|
|
continue
|
|
raise
|
|
|
|
try:
|
|
# Offload blocking HTTP call to default ThreadPoolExecutor
|
|
import asyncio
|
|
|
|
loop = asyncio.get_event_loop()
|
|
results = await loop.run_in_executor(None, _sync_list_files_inner)
|
|
|
|
files = []
|
|
for file in results.get("files", []):
|
|
files.append(
|
|
{
|
|
"id": file["id"],
|
|
"name": file["name"],
|
|
"mimeType": file["mimeType"],
|
|
"modifiedTime": file["modifiedTime"],
|
|
"createdTime": file["createdTime"],
|
|
"webViewLink": file["webViewLink"],
|
|
"permissions": file.get("permissions", []),
|
|
"owners": file.get("owners", []),
|
|
}
|
|
)
|
|
|
|
return {"files": files, "nextPageToken": results.get("nextPageToken")}
|
|
|
|
except HttpError as e:
|
|
logger.error("Failed to list files", error=str(e))
|
|
raise
|
|
|
|
async def get_file_content(self, file_id: str) -> ConnectorDocument:
|
|
"""Get file content and metadata"""
|
|
if not self._authenticated:
|
|
raise ValueError("Not authenticated")
|
|
|
|
try:
|
|
# Get file metadata (run in thread pool to avoid blocking)
|
|
import asyncio
|
|
|
|
loop = asyncio.get_event_loop()
|
|
|
|
# Use the same process pool as docling processing
|
|
from utils.process_pool import process_pool
|
|
|
|
file_metadata = await loop.run_in_executor(
|
|
process_pool,
|
|
_sync_get_metadata_worker,
|
|
self.oauth.client_id,
|
|
self.oauth.client_secret,
|
|
self.oauth.token_file,
|
|
file_id,
|
|
)
|
|
|
|
# Download file content (pass file size for timeout calculation)
|
|
file_size = file_metadata.get("size")
|
|
if file_size:
|
|
file_size = int(file_size) # Ensure it's an integer
|
|
content = await self._download_file_content(
|
|
file_id, file_metadata["mimeType"], file_size
|
|
)
|
|
|
|
# Extract ACL information
|
|
acl = self._extract_acl(file_metadata)
|
|
|
|
return ConnectorDocument(
|
|
id=file_id,
|
|
filename=file_metadata["name"],
|
|
mimetype=file_metadata["mimeType"],
|
|
content=content,
|
|
source_url=file_metadata["webViewLink"],
|
|
acl=acl,
|
|
modified_time=datetime.fromisoformat(
|
|
file_metadata["modifiedTime"].replace("Z", "+00:00")
|
|
).replace(tzinfo=None),
|
|
created_time=datetime.fromisoformat(
|
|
file_metadata["createdTime"].replace("Z", "+00:00")
|
|
).replace(tzinfo=None),
|
|
metadata={
|
|
"size": file_metadata.get("size"),
|
|
"owners": file_metadata.get("owners", []),
|
|
},
|
|
)
|
|
|
|
except HttpError as e:
|
|
logger.error("Failed to get file content", error=str(e))
|
|
raise
|
|
|
|
async def _download_file_content(
|
|
self, file_id: str, mime_type: str, file_size: int = None
|
|
) -> bytes:
|
|
"""Download file content, converting Google Docs formats if needed"""
|
|
|
|
# Download file (run in process pool to avoid blocking)
|
|
import asyncio
|
|
|
|
loop = asyncio.get_event_loop()
|
|
|
|
# Use the same process pool as docling processing
|
|
from utils.process_pool import process_pool
|
|
|
|
return await loop.run_in_executor(
|
|
process_pool,
|
|
_sync_download_worker,
|
|
self.oauth.client_id,
|
|
self.oauth.client_secret,
|
|
self.oauth.token_file,
|
|
file_id,
|
|
mime_type,
|
|
file_size,
|
|
)
|
|
|
|
def _extract_acl(self, file_metadata: Dict[str, Any]) -> DocumentACL:
|
|
"""Extract ACL information from file metadata"""
|
|
user_permissions = {}
|
|
group_permissions = {}
|
|
|
|
owner = None
|
|
if file_metadata.get("owners"):
|
|
owner = file_metadata["owners"][0].get("emailAddress")
|
|
|
|
# Process permissions
|
|
for perm in file_metadata.get("permissions", []):
|
|
email = perm.get("emailAddress")
|
|
role = perm.get("role", "reader")
|
|
perm_type = perm.get("type")
|
|
|
|
if perm_type == "user" and email:
|
|
user_permissions[email] = role
|
|
elif perm_type == "group" and email:
|
|
group_permissions[email] = role
|
|
elif perm_type == "domain":
|
|
# Domain-wide permissions - could be treated as a group
|
|
domain = perm.get("domain", "unknown-domain")
|
|
group_permissions[f"domain:{domain}"] = role
|
|
|
|
return DocumentACL(
|
|
owner=owner,
|
|
user_permissions=user_permissions,
|
|
group_permissions=group_permissions,
|
|
)
|
|
|
|
def extract_webhook_channel_id(
|
|
self, payload: Dict[str, Any], headers: Dict[str, str]
|
|
) -> Optional[str]:
|
|
"""Extract Google Drive channel ID from webhook headers"""
|
|
return headers.get("x-goog-channel-id")
|
|
|
|
def extract_webhook_resource_id(self, headers: Dict[str, str]) -> Optional[str]:
|
|
"""Extract Google Drive resource ID from webhook headers"""
|
|
return headers.get("x-goog-resource-id")
|
|
|
|
async def handle_webhook(self, payload: Dict[str, Any]) -> List[str]:
|
|
"""Handle Google Drive webhook notification"""
|
|
if not self._authenticated:
|
|
raise ValueError("Not authenticated")
|
|
|
|
# Google Drive sends headers with the important info
|
|
headers = payload.get("_headers", {})
|
|
|
|
# Extract Google Drive specific headers
|
|
channel_id = headers.get("x-goog-channel-id")
|
|
resource_state = headers.get("x-goog-resource-state")
|
|
|
|
if not channel_id:
|
|
logger.warning("No channel ID found in Google Drive webhook")
|
|
return []
|
|
|
|
# Check if this webhook belongs to this connection
|
|
if self.webhook_channel_id != channel_id:
|
|
logger.warning("Channel ID mismatch", expected=self.webhook_channel_id, received=channel_id)
|
|
return []
|
|
|
|
# Only process certain states (ignore 'sync' which is just a ping)
|
|
if resource_state not in ["exists", "not_exists", "change"]:
|
|
logger.debug("Ignoring resource state", state=resource_state)
|
|
return []
|
|
|
|
try:
|
|
# Extract page token from the resource URI if available
|
|
page_token = None
|
|
headers = payload.get("_headers", {})
|
|
resource_uri = headers.get("x-goog-resource-uri")
|
|
|
|
if resource_uri and "pageToken=" in resource_uri:
|
|
# Extract page token from URI like:
|
|
# https://www.googleapis.com/drive/v3/changes?alt=json&pageToken=4337807
|
|
import urllib.parse
|
|
|
|
parsed = urllib.parse.urlparse(resource_uri)
|
|
query_params = urllib.parse.parse_qs(parsed.query)
|
|
page_token = query_params.get("pageToken", [None])[0]
|
|
|
|
if not page_token:
|
|
logger.warning("No page token found, cannot identify specific changes")
|
|
return []
|
|
|
|
logger.info("Getting changes since page token", page_token=page_token)
|
|
|
|
# Get list of changes since the page token
|
|
changes = (
|
|
self.service.changes()
|
|
.list(
|
|
pageToken=page_token,
|
|
fields="changes(fileId, file(id, name, mimeType, trashed, parents))",
|
|
)
|
|
.execute()
|
|
)
|
|
|
|
affected_files = []
|
|
for change in changes.get("changes", []):
|
|
file_info = change.get("file", {})
|
|
file_id = change.get("fileId")
|
|
|
|
if not file_id:
|
|
continue
|
|
|
|
# Only include supported file types that aren't trashed
|
|
mime_type = file_info.get("mimeType", "")
|
|
is_trashed = file_info.get("trashed", False)
|
|
|
|
if not is_trashed and mime_type in self.SUPPORTED_MIMETYPES:
|
|
logger.info("File changed", filename=file_info.get('name', 'Unknown'), file_id=file_id)
|
|
affected_files.append(file_id)
|
|
elif is_trashed:
|
|
logger.info("File deleted/trashed", filename=file_info.get('name', 'Unknown'), file_id=file_id)
|
|
# TODO: Handle file deletion (remove from index)
|
|
else:
|
|
logger.debug("Ignoring unsupported file type", mime_type=mime_type)
|
|
|
|
logger.info("Found affected supported files", count=len(affected_files))
|
|
return affected_files
|
|
|
|
except HttpError as e:
|
|
logger.error("Failed to handle webhook", error=str(e))
|
|
return []
|
|
|
|
async def cleanup_subscription(self, subscription_id: str) -> bool:
|
|
"""Clean up Google Drive subscription for this connection.
|
|
|
|
Uses the stored resource_id captured during subscription setup.
|
|
"""
|
|
if not self._authenticated:
|
|
return False
|
|
|
|
try:
|
|
# Google Channels API requires both 'id' (channel) and 'resourceId'
|
|
if not self.webhook_resource_id:
|
|
raise ValueError(
|
|
"Missing resource_id for cleanup; ensure subscription state is persisted"
|
|
)
|
|
body = {"id": subscription_id, "resourceId": self.webhook_resource_id}
|
|
|
|
self.service.channels().stop(body=body).execute()
|
|
return True
|
|
except HttpError as e:
|
|
logger.error("Failed to cleanup subscription", error=str(e))
|
|
return False
|