feat: Enhanced GDrive Connector
This commit is contained in:
parent
e810aef588
commit
f878c38c40
2 changed files with 363 additions and 195 deletions
|
|
@ -3,8 +3,9 @@ import io
|
|||
import os
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Any, Optional
|
||||
from googleapiclient.discovery import build
|
||||
from typing import Dict, List, Any, Optional, Set
|
||||
|
||||
from googleapiclient.discovery import build # noqa: F401 (kept for symmetry with oauth)
|
||||
from googleapiclient.errors import HttpError
|
||||
from googleapiclient.http import MediaIoBaseDownload
|
||||
|
||||
|
|
@ -15,18 +16,18 @@ from .oauth import GoogleDriveOAuth
|
|||
# Global worker service cache for process pools
|
||||
_worker_drive_service = None
|
||||
|
||||
|
||||
def get_worker_drive_service(client_id: str, client_secret: str, token_file: str):
|
||||
"""Get or create a Google Drive service instance for this worker process"""
|
||||
global _worker_drive_service
|
||||
if _worker_drive_service is None:
|
||||
print(f"🔧 Initializing Google Drive service in worker process (PID: {os.getpid()})")
|
||||
|
||||
|
||||
# Create OAuth instance and load credentials in worker
|
||||
from .oauth import GoogleDriveOAuth
|
||||
oauth = GoogleDriveOAuth(client_id=client_id, client_secret=client_secret, token_file=token_file)
|
||||
|
||||
from .oauth import GoogleDriveOAuth as _GoogleDriveOAuth
|
||||
oauth = _GoogleDriveOAuth(client_id=client_id, client_secret=client_secret, token_file=token_file)
|
||||
|
||||
# Load credentials synchronously in worker
|
||||
import asyncio
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
|
|
@ -35,7 +36,7 @@ def get_worker_drive_service(client_id: str, client_secret: str, token_file: str
|
|||
print(f"✅ Google Drive service ready in worker process (PID: {os.getpid()})")
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
return _worker_drive_service
|
||||
|
||||
|
||||
|
|
@ -56,7 +57,7 @@ def _sync_get_metadata_worker(client_id, client_secret, token_file, file_id):
|
|||
service = get_worker_drive_service(client_id, client_secret, token_file)
|
||||
return service.files().get(
|
||||
fileId=file_id,
|
||||
fields="id, name, mimeType, modifiedTime, createdTime, webViewLink, permissions, owners, size"
|
||||
fields="id, name, mimeType, modifiedTime, createdTime, webViewLink, permissions, owners, size, trashed, parents"
|
||||
).execute()
|
||||
|
||||
|
||||
|
|
@ -64,35 +65,35 @@ def _sync_download_worker(client_id, client_secret, token_file, file_id, mime_ty
|
|||
"""Worker function for downloading files in process pool"""
|
||||
import signal
|
||||
import time
|
||||
|
||||
|
||||
# File size limits (in bytes)
|
||||
MAX_REGULAR_FILE_SIZE = 100 * 1024 * 1024 # 100MB for regular files
|
||||
MAX_GOOGLE_WORKSPACE_SIZE = 50 * 1024 * 1024 # 50MB for Google Workspace docs (they can't be streamed)
|
||||
|
||||
MAX_REGULAR_FILE_SIZE = 1000 * 1024 * 1024 # 1000MB for regular files
|
||||
MAX_GOOGLE_WORKSPACE_SIZE = 500 * 1024 * 1024 # 500MB for Google Workspace docs (they can't be streamed)
|
||||
|
||||
# Check file size limits
|
||||
if file_size:
|
||||
if mime_type.startswith('application/vnd.google-apps.') and file_size > MAX_GOOGLE_WORKSPACE_SIZE:
|
||||
raise ValueError(f"Google Workspace file too large: {file_size} bytes (max {MAX_GOOGLE_WORKSPACE_SIZE})")
|
||||
elif not mime_type.startswith('application/vnd.google-apps.') and file_size > MAX_REGULAR_FILE_SIZE:
|
||||
raise ValueError(f"File too large: {file_size} bytes (max {MAX_REGULAR_FILE_SIZE})")
|
||||
|
||||
|
||||
# Dynamic timeout based on file size (minimum 60s, 10s per MB, max 300s)
|
||||
if file_size:
|
||||
file_size_mb = file_size / (1024 * 1024)
|
||||
timeout_seconds = min(300, max(60, int(file_size_mb * 10)))
|
||||
else:
|
||||
timeout_seconds = 60 # Default timeout if size unknown
|
||||
|
||||
|
||||
# Set a timeout for the entire download operation
|
||||
def timeout_handler(signum, frame):
|
||||
raise TimeoutError(f"File download timed out after {timeout_seconds} seconds")
|
||||
|
||||
|
||||
signal.signal(signal.SIGALRM, timeout_handler)
|
||||
signal.alarm(timeout_seconds)
|
||||
|
||||
|
||||
try:
|
||||
service = get_worker_drive_service(client_id, client_secret, token_file)
|
||||
|
||||
|
||||
# For Google native formats, export as PDF
|
||||
if mime_type.startswith('application/vnd.google-apps.'):
|
||||
export_format = 'application/pdf'
|
||||
|
|
@ -100,15 +101,15 @@ def _sync_download_worker(client_id, client_secret, token_file, file_id, mime_ty
|
|||
else:
|
||||
# For regular files, download directly
|
||||
request = service.files().get_media(fileId=file_id)
|
||||
|
||||
|
||||
# Download file with chunked approach
|
||||
file_io = io.BytesIO()
|
||||
downloader = MediaIoBaseDownload(file_io, request, chunksize=1024*1024) # 1MB chunks
|
||||
|
||||
downloader = MediaIoBaseDownload(file_io, request, chunksize=1024 * 1024) # 1MB chunks
|
||||
|
||||
done = False
|
||||
retry_count = 0
|
||||
max_retries = 2
|
||||
|
||||
|
||||
while not done and retry_count < max_retries:
|
||||
try:
|
||||
status, done = downloader.next_chunk()
|
||||
|
|
@ -118,9 +119,9 @@ def _sync_download_worker(client_id, client_secret, token_file, file_id, mime_ty
|
|||
if retry_count >= max_retries:
|
||||
raise e
|
||||
time.sleep(1) # Brief pause before retry
|
||||
|
||||
|
||||
return file_io.getvalue()
|
||||
|
||||
|
||||
finally:
|
||||
# Cancel the alarm
|
||||
signal.alarm(0)
|
||||
|
|
@ -128,16 +129,16 @@ def _sync_download_worker(client_id, client_secret, token_file, file_id, mime_ty
|
|||
|
||||
class GoogleDriveConnector(BaseConnector):
|
||||
"""Google Drive connector with OAuth and webhook support"""
|
||||
|
||||
|
||||
# OAuth environment variables
|
||||
CLIENT_ID_ENV_VAR = "GOOGLE_OAUTH_CLIENT_ID"
|
||||
CLIENT_SECRET_ENV_VAR = "GOOGLE_OAUTH_CLIENT_SECRET"
|
||||
|
||||
|
||||
# Connector metadata
|
||||
CONNECTOR_NAME = "Google Drive"
|
||||
CONNECTOR_DESCRIPTION = "Connect your Google Drive to automatically sync documents"
|
||||
CONNECTOR_ICON = "google-drive"
|
||||
|
||||
|
||||
# Supported file types that can be processed by docling
|
||||
SUPPORTED_MIMETYPES = {
|
||||
'application/pdf',
|
||||
|
|
@ -149,11 +150,11 @@ class GoogleDriveConnector(BaseConnector):
|
|||
'text/html',
|
||||
'application/rtf',
|
||||
# Google Docs native formats - we'll export these
|
||||
'application/vnd.google-apps.document', # Google Docs -> PDF
|
||||
'application/vnd.google-apps.presentation', # Google Slides -> PDF
|
||||
'application/vnd.google-apps.spreadsheet', # Google Sheets -> PDF
|
||||
'application/vnd.google-apps.document', # Google Docs -> PDF
|
||||
'application/vnd.google-apps.presentation', # Google Slides -> PDF
|
||||
'application/vnd.google-apps.spreadsheet', # Google Sheets -> PDF
|
||||
}
|
||||
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
super().__init__(config)
|
||||
self.oauth = GoogleDriveOAuth(
|
||||
|
|
@ -166,34 +167,48 @@ class GoogleDriveConnector(BaseConnector):
|
|||
self.webhook_channel_id = config.get('webhook_channel_id') or config.get('subscription_id')
|
||||
# Load existing webhook resource ID (Google Drive requires this to stop a channel)
|
||||
self.webhook_resource_id = config.get('resource_id')
|
||||
|
||||
|
||||
# ---- NEW: selection & filtering config ----
|
||||
# Explicit include lists
|
||||
self.selected_file_ids: Set[str] = set(config.get("file_ids") or config.get("selected_file_ids") or [])
|
||||
self.selected_folder_ids: Set[str] = set(config.get("folder_ids") or config.get("selected_folder_ids") or [])
|
||||
# Recursive traversal for folders (default True)
|
||||
self.recursive: bool = bool(config.get("recursive", True))
|
||||
# Optional MIME overrides
|
||||
self.include_mime_types: Set[str] = set(config.get("include_mime_types", []))
|
||||
self.exclude_mime_types: Set[str] = set(config.get("exclude_mime_types", []))
|
||||
# Cache of expanded folder IDs (when recursive)
|
||||
self._expanded_folder_ids: Optional[Set[str]] = None
|
||||
|
||||
async def authenticate(self) -> bool:
|
||||
"""Authenticate with Google Drive"""
|
||||
try:
|
||||
if await self.oauth.is_authenticated():
|
||||
self.service = self.oauth.get_service()
|
||||
self._authenticated = True
|
||||
# Pre-expand folder tree if needed
|
||||
if self.selected_folder_ids and self.recursive:
|
||||
self._expanded_folder_ids = await self._expand_folders_recursive(self.selected_folder_ids)
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"Authentication failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
|
||||
async def setup_subscription(self) -> str:
|
||||
"""Set up Google Drive push notifications"""
|
||||
if not self._authenticated:
|
||||
raise ValueError("Not authenticated")
|
||||
|
||||
|
||||
# Generate unique channel ID
|
||||
channel_id = str(uuid.uuid4())
|
||||
|
||||
|
||||
# Set up push notification
|
||||
# Note: This requires a publicly accessible webhook endpoint
|
||||
webhook_url = self.config.get('webhook_url')
|
||||
if not webhook_url:
|
||||
raise ValueError("webhook_url required in config for subscriptions")
|
||||
|
||||
|
||||
try:
|
||||
body = {
|
||||
'id': channel_id,
|
||||
|
|
@ -202,12 +217,12 @@ class GoogleDriveConnector(BaseConnector):
|
|||
'payload': True,
|
||||
'expiration': str(int((datetime.now().timestamp() + 86400) * 1000)) # 24 hours
|
||||
}
|
||||
|
||||
|
||||
result = self.service.changes().watch(
|
||||
pageToken=self._get_start_page_token(),
|
||||
body=body
|
||||
).execute()
|
||||
|
||||
|
||||
self.webhook_channel_id = channel_id
|
||||
# Persist the resourceId returned by Google to allow proper cleanup
|
||||
try:
|
||||
|
|
@ -215,17 +230,17 @@ class GoogleDriveConnector(BaseConnector):
|
|||
except Exception:
|
||||
self.webhook_resource_id = None
|
||||
return channel_id
|
||||
|
||||
|
||||
except HttpError as e:
|
||||
print(f"Failed to set up subscription: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def _get_start_page_token(self) -> str:
|
||||
"""Get the current page token for change notifications"""
|
||||
return self.service.changes().getStartPageToken().execute()['startPageToken']
|
||||
|
||||
|
||||
async def list_files(self, page_token: Optional[str] = None, limit: Optional[int] = None) -> Dict[str, Any]:
|
||||
"""List all supported files in Google Drive.
|
||||
"""List supported files in Google Drive, scoped by selected folders/files if configured.
|
||||
|
||||
Uses a thread pool (not the shared process pool) to avoid issues with
|
||||
Google API clients in forked processes and adds light retries for
|
||||
|
|
@ -234,13 +249,65 @@ class GoogleDriveConnector(BaseConnector):
|
|||
if not self._authenticated:
|
||||
raise ValueError("Not authenticated")
|
||||
|
||||
# Build query for supported file types
|
||||
mimetype_query = " or ".join([f"mimeType='{mt}'" for mt in self.SUPPORTED_MIMETYPES])
|
||||
query = f"({mimetype_query}) and trashed=false"
|
||||
# ---- compute MIME type filter (with optional overrides) ----
|
||||
effective_mimes = set(self.SUPPORTED_MIMETYPES)
|
||||
if self.include_mime_types:
|
||||
effective_mimes |= set(self.include_mime_types)
|
||||
if self.exclude_mime_types:
|
||||
effective_mimes -= set(self.exclude_mime_types)
|
||||
mimetype_query = " or ".join([f"mimeType='{mt}'" for mt in sorted(effective_mimes)])
|
||||
|
||||
# ---- build parents scope if folders were selected ----
|
||||
parent_ids: Set[str] = set()
|
||||
if self.selected_folder_ids:
|
||||
if self.recursive:
|
||||
parent_ids = set(self._expanded_folder_ids or [])
|
||||
else:
|
||||
parent_ids = set(self.selected_folder_ids)
|
||||
parents_query = ""
|
||||
if parent_ids:
|
||||
parents_query = " and (" + " or ".join([f"'{fid}' in parents" for fid in sorted(parent_ids)]) + ")"
|
||||
|
||||
# Final query
|
||||
query = f"({mimetype_query}) and trashed=false{parents_query}"
|
||||
|
||||
# Use provided limit or default to 100, max 1000 (Google Drive API limit)
|
||||
page_size = min(limit or 100, 1000)
|
||||
|
||||
# ---- fast-path explicit file IDs (ignores page_token/size intentionally) ----
|
||||
if self.selected_file_ids:
|
||||
files: List[Dict[str, Any]] = []
|
||||
loop = asyncio.get_event_loop()
|
||||
from utils.process_pool import process_pool
|
||||
|
||||
async def fetch(fid: str):
|
||||
try:
|
||||
meta = await loop.run_in_executor(
|
||||
process_pool,
|
||||
_sync_get_metadata_worker,
|
||||
self.oauth.client_id,
|
||||
self.oauth.client_secret,
|
||||
self.oauth.token_file,
|
||||
fid,
|
||||
)
|
||||
if meta.get("mimeType") in effective_mimes and not meta.get("trashed", False):
|
||||
files.append({
|
||||
'id': meta['id'],
|
||||
'name': meta['name'],
|
||||
'mimeType': meta['mimeType'],
|
||||
'modifiedTime': meta['modifiedTime'],
|
||||
'createdTime': meta['createdTime'],
|
||||
'webViewLink': meta.get('webViewLink'),
|
||||
'permissions': meta.get('permissions', []),
|
||||
'owners': meta.get('owners', []),
|
||||
})
|
||||
except Exception:
|
||||
# Skip missing/forbidden IDs silently; caller can log if desired
|
||||
pass
|
||||
|
||||
await asyncio.gather(*[fetch(fid) for fid in sorted(self.selected_file_ids)])
|
||||
return {'files': files, 'nextPageToken': None}
|
||||
|
||||
def _sync_list_files_inner():
|
||||
import time
|
||||
attempts = 0
|
||||
|
|
@ -252,7 +319,7 @@ class GoogleDriveConnector(BaseConnector):
|
|||
q=query,
|
||||
pageSize=page_size,
|
||||
pageToken=page_token,
|
||||
fields="nextPageToken, files(id, name, mimeType, modifiedTime, createdTime, webViewLink, permissions, owners)"
|
||||
fields="nextPageToken, files(id, name, mimeType, modifiedTime, createdTime, webViewLink, permissions, owners, parents, trashed)"
|
||||
).execute()
|
||||
except Exception as e:
|
||||
attempts += 1
|
||||
|
|
@ -267,7 +334,6 @@ class GoogleDriveConnector(BaseConnector):
|
|||
|
||||
try:
|
||||
# Offload blocking HTTP call to default ThreadPoolExecutor
|
||||
import asyncio
|
||||
loop = asyncio.get_event_loop()
|
||||
results = await loop.run_in_executor(None, _sync_list_files_inner)
|
||||
|
||||
|
|
@ -292,37 +358,36 @@ class GoogleDriveConnector(BaseConnector):
|
|||
except HttpError as e:
|
||||
print(f"Failed to list files: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def get_file_content(self, file_id: str) -> ConnectorDocument:
|
||||
"""Get file content and metadata"""
|
||||
if not self._authenticated:
|
||||
raise ValueError("Not authenticated")
|
||||
|
||||
|
||||
try:
|
||||
# Get file metadata (run in thread pool to avoid blocking)
|
||||
import asyncio
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
|
||||
# Use the same process pool as docling processing
|
||||
from utils.process_pool import process_pool
|
||||
file_metadata = await loop.run_in_executor(
|
||||
process_pool,
|
||||
_sync_get_metadata_worker,
|
||||
process_pool,
|
||||
_sync_get_metadata_worker,
|
||||
self.oauth.client_id,
|
||||
self.oauth.client_secret,
|
||||
self.oauth.token_file,
|
||||
file_id
|
||||
)
|
||||
|
||||
|
||||
# Download file content (pass file size for timeout calculation)
|
||||
file_size = file_metadata.get('size')
|
||||
if file_size:
|
||||
file_size = int(file_size) # Ensure it's an integer
|
||||
content = await self._download_file_content(file_id, file_metadata['mimeType'], file_size)
|
||||
|
||||
|
||||
# Extract ACL information
|
||||
acl = self._extract_acl(file_metadata)
|
||||
|
||||
|
||||
return ConnectorDocument(
|
||||
id=file_id,
|
||||
filename=file_metadata['name'],
|
||||
|
|
@ -337,46 +402,89 @@ class GoogleDriveConnector(BaseConnector):
|
|||
'owners': file_metadata.get('owners', [])
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
except HttpError as e:
|
||||
print(f"Failed to get file content: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def _download_file_content(self, file_id: str, mime_type: str, file_size: int = None) -> bytes:
|
||||
"""Download file content, converting Google Docs formats if needed"""
|
||||
|
||||
|
||||
# Download file (run in process pool to avoid blocking)
|
||||
import asyncio
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
|
||||
# Use the same process pool as docling processing
|
||||
from utils.process_pool import process_pool
|
||||
return await loop.run_in_executor(
|
||||
process_pool,
|
||||
_sync_download_worker,
|
||||
process_pool,
|
||||
_sync_download_worker,
|
||||
self.oauth.client_id,
|
||||
self.oauth.client_secret,
|
||||
self.oauth.token_file,
|
||||
file_id,
|
||||
file_id,
|
||||
mime_type,
|
||||
file_size
|
||||
)
|
||||
|
||||
|
||||
# ---- helper to expand folder IDs recursively ----
|
||||
async def _expand_folders_recursive(self, root_folder_ids: Set[str]) -> Set[str]:
|
||||
"""Return set of folder IDs including all descendants of root_folder_ids."""
|
||||
loop = asyncio.get_event_loop()
|
||||
FOLDER_MIME = "application/vnd.google-apps.folder"
|
||||
|
||||
def _list_children(parents: Set[str]) -> List[Dict[str, Any]]:
|
||||
# Query: all folders whose parents include any in 'parents'
|
||||
# Chunk 'parents' to avoid overly long queries
|
||||
all_found: List[Dict[str, Any]] = []
|
||||
parent_list = list(parents)
|
||||
CHUNK = 20
|
||||
for i in range(0, len(parent_list), CHUNK):
|
||||
chunk = parent_list[i:i + CHUNK]
|
||||
parents_q = " or ".join([f"'{pid}' in parents" for pid in chunk])
|
||||
q = f"mimeType='{FOLDER_MIME}' and trashed=false and ({parents_q})"
|
||||
page_token = None
|
||||
while True:
|
||||
resp = self.service.files().list(
|
||||
q=q,
|
||||
pageSize=1000,
|
||||
pageToken=page_token,
|
||||
fields="nextPageToken, files(id, parents)"
|
||||
).execute()
|
||||
all_found.extend(resp.get("files", []))
|
||||
page_token = resp.get("nextPageToken")
|
||||
if not page_token:
|
||||
break
|
||||
return all_found
|
||||
|
||||
# BFS
|
||||
visited: Set[str] = set(root_folder_ids)
|
||||
frontier: Set[str] = set(root_folder_ids)
|
||||
while frontier:
|
||||
children = await loop.run_in_executor(None, _list_children, frontier)
|
||||
next_frontier: Set[str] = set()
|
||||
for f in children:
|
||||
fid = f.get("id")
|
||||
if fid and fid not in visited:
|
||||
visited.add(fid)
|
||||
next_frontier.add(fid)
|
||||
frontier = next_frontier
|
||||
return visited
|
||||
|
||||
def _extract_acl(self, file_metadata: Dict[str, Any]) -> DocumentACL:
|
||||
"""Extract ACL information from file metadata"""
|
||||
user_permissions = {}
|
||||
group_permissions = {}
|
||||
|
||||
|
||||
owner = None
|
||||
if file_metadata.get('owners'):
|
||||
owner = file_metadata['owners'][0].get('emailAddress')
|
||||
|
||||
|
||||
# Process permissions
|
||||
for perm in file_metadata.get('permissions', []):
|
||||
email = perm.get('emailAddress')
|
||||
role = perm.get('role', 'reader')
|
||||
perm_type = perm.get('type')
|
||||
|
||||
|
||||
if perm_type == 'user' and email:
|
||||
user_permissions[email] = role
|
||||
elif perm_type == 'group' and email:
|
||||
|
|
@ -385,13 +493,13 @@ class GoogleDriveConnector(BaseConnector):
|
|||
# Domain-wide permissions - could be treated as a group
|
||||
domain = perm.get('domain', 'unknown-domain')
|
||||
group_permissions[f"domain:{domain}"] = role
|
||||
|
||||
|
||||
return DocumentACL(
|
||||
owner=owner,
|
||||
user_permissions=user_permissions,
|
||||
group_permissions=group_permissions
|
||||
)
|
||||
|
||||
|
||||
def extract_webhook_channel_id(self, payload: Dict[str, Any], headers: Dict[str, str]) -> Optional[str]:
|
||||
"""Extract Google Drive channel ID from webhook headers"""
|
||||
return headers.get('x-goog-channel-id')
|
||||
|
|
@ -399,87 +507,106 @@ class GoogleDriveConnector(BaseConnector):
|
|||
def extract_webhook_resource_id(self, headers: Dict[str, str]) -> Optional[str]:
|
||||
"""Extract Google Drive resource ID from webhook headers"""
|
||||
return headers.get('x-goog-resource-id')
|
||||
|
||||
|
||||
async def handle_webhook(self, payload: Dict[str, Any]) -> List[str]:
|
||||
"""Handle Google Drive webhook notification"""
|
||||
if not self._authenticated:
|
||||
raise ValueError("Not authenticated")
|
||||
|
||||
|
||||
# Google Drive sends headers with the important info
|
||||
headers = payload.get('_headers', {})
|
||||
|
||||
|
||||
# Extract Google Drive specific headers
|
||||
channel_id = headers.get('x-goog-channel-id')
|
||||
resource_state = headers.get('x-goog-resource-state')
|
||||
|
||||
|
||||
if not channel_id:
|
||||
print("[WEBHOOK] No channel ID found in Google Drive webhook")
|
||||
return []
|
||||
|
||||
|
||||
# Check if this webhook belongs to this connection
|
||||
if self.webhook_channel_id != channel_id:
|
||||
print(f"[WEBHOOK] Channel ID mismatch: expected {self.webhook_channel_id}, got {channel_id}")
|
||||
return []
|
||||
|
||||
|
||||
# Only process certain states (ignore 'sync' which is just a ping)
|
||||
if resource_state not in ['exists', 'not_exists', 'change']:
|
||||
print(f"[WEBHOOK] Ignoring resource state: {resource_state}")
|
||||
return []
|
||||
|
||||
|
||||
try:
|
||||
# Extract page token from the resource URI if available
|
||||
page_token = None
|
||||
headers = payload.get('_headers', {})
|
||||
resource_uri = headers.get('x-goog-resource-uri')
|
||||
|
||||
|
||||
if resource_uri and 'pageToken=' in resource_uri:
|
||||
# Extract page token from URI like:
|
||||
# Extract page token from URI like:
|
||||
# https://www.googleapis.com/drive/v3/changes?alt=json&pageToken=4337807
|
||||
import urllib.parse
|
||||
parsed = urllib.parse.urlparse(resource_uri)
|
||||
query_params = urllib.parse.parse_qs(parsed.query)
|
||||
page_token = query_params.get('pageToken', [None])[0]
|
||||
|
||||
|
||||
if not page_token:
|
||||
print("[WEBHOOK] No page token found, cannot identify specific changes")
|
||||
return []
|
||||
|
||||
|
||||
print(f"[WEBHOOK] Getting changes since page token: {page_token}")
|
||||
|
||||
|
||||
# Get list of changes since the page token
|
||||
changes = self.service.changes().list(
|
||||
pageToken=page_token,
|
||||
fields="changes(fileId, file(id, name, mimeType, trashed, parents))"
|
||||
).execute()
|
||||
|
||||
|
||||
# Recompute effective mimes (same as list_files)
|
||||
effective_mimes = set(self.SUPPORTED_MIMETYPES)
|
||||
if self.include_mime_types:
|
||||
effective_mimes |= set(self.include_mime_types)
|
||||
if self.exclude_mime_types:
|
||||
effective_mimes -= set(self.exclude_mime_types)
|
||||
|
||||
affected_files = []
|
||||
for change in changes.get('changes', []):
|
||||
file_info = change.get('file', {})
|
||||
file_info = change.get('file', {}) or {}
|
||||
file_id = change.get('fileId')
|
||||
|
||||
|
||||
if not file_id:
|
||||
continue
|
||||
|
||||
|
||||
# Only include supported file types that aren't trashed
|
||||
mime_type = file_info.get('mimeType', '')
|
||||
is_trashed = file_info.get('trashed', False)
|
||||
|
||||
if not is_trashed and mime_type in self.SUPPORTED_MIMETYPES:
|
||||
|
||||
# ---- scope filtering for webhook ----
|
||||
in_scope = True
|
||||
# file-based inclusion
|
||||
if self.selected_file_ids and file_id not in self.selected_file_ids:
|
||||
in_scope = False
|
||||
# folder-based inclusion
|
||||
if in_scope and self.selected_folder_ids:
|
||||
allowed_parents = self._expanded_folder_ids if self.recursive else self.selected_folder_ids
|
||||
file_parents = set((file_info.get('parents') or []))
|
||||
if not (file_parents & set(allowed_parents or [])):
|
||||
in_scope = False
|
||||
|
||||
if not is_trashed and mime_type in effective_mimes and in_scope:
|
||||
print(f"[WEBHOOK] File changed: {file_info.get('name', 'Unknown')} ({file_id})")
|
||||
affected_files.append(file_id)
|
||||
elif is_trashed:
|
||||
print(f"[WEBHOOK] File deleted/trashed: {file_info.get('name', 'Unknown')} ({file_id})")
|
||||
# TODO: Handle file deletion (remove from index)
|
||||
else:
|
||||
print(f"[WEBHOOK] Ignoring unsupported file type: {mime_type}")
|
||||
|
||||
print(f"[WEBHOOK] Ignoring unsupported or out-of-scope file: mime={mime_type} id={file_id}")
|
||||
|
||||
print(f"[WEBHOOK] Found {len(affected_files)} affected supported files")
|
||||
return affected_files
|
||||
|
||||
|
||||
except HttpError as e:
|
||||
print(f"Failed to handle webhook: {e}")
|
||||
return []
|
||||
|
||||
|
||||
async def cleanup_subscription(self, subscription_id: str) -> bool:
|
||||
"""Clean up Google Drive subscription for this connection.
|
||||
|
||||
|
|
@ -487,13 +614,13 @@ class GoogleDriveConnector(BaseConnector):
|
|||
"""
|
||||
if not self._authenticated:
|
||||
return False
|
||||
|
||||
|
||||
try:
|
||||
# Google Channels API requires both 'id' (channel) and 'resourceId'
|
||||
if not self.webhook_resource_id:
|
||||
raise ValueError("Missing resource_id for cleanup; ensure subscription state is persisted")
|
||||
body = {'id': subscription_id, 'resourceId': self.webhook_resource_id}
|
||||
|
||||
|
||||
self.service.channels().stop(body=body).execute()
|
||||
return True
|
||||
except HttpError as e:
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
import os
|
||||
import json
|
||||
import asyncio
|
||||
from typing import Dict, Any, Optional
|
||||
from typing import Optional, Iterable, Sequence
|
||||
from datetime import datetime
|
||||
|
||||
from google.auth.transport.requests import Request
|
||||
from google.oauth2.credentials import Credentials
|
||||
from google_auth_oauthlib.flow import Flow
|
||||
|
|
@ -10,138 +11,178 @@ import aiofiles
|
|||
|
||||
|
||||
class GoogleDriveOAuth:
|
||||
"""Handles Google Drive OAuth authentication flow"""
|
||||
|
||||
SCOPES = [
|
||||
'openid',
|
||||
'email',
|
||||
'profile',
|
||||
'https://www.googleapis.com/auth/drive.readonly',
|
||||
'https://www.googleapis.com/auth/drive.metadata.readonly'
|
||||
]
|
||||
|
||||
"""Handles Google Drive OAuth authentication flow with scope-upgrade detection."""
|
||||
|
||||
# Core scopes needed by your connector:
|
||||
# - drive.readonly: content/export
|
||||
# - drive.metadata.readonly: owners/permissions/parents
|
||||
REQUIRED_SCOPES = {
|
||||
"https://www.googleapis.com/auth/drive.readonly",
|
||||
"https://www.googleapis.com/auth/drive.metadata.readonly",
|
||||
}
|
||||
|
||||
# Optional OIDC/userinfo scopes (nice-to-have; keep them if you use identity info elsewhere)
|
||||
OPTIONAL_SCOPES = {"openid", "email", "profile"}
|
||||
|
||||
# Final scopes we request during auth
|
||||
SCOPES = sorted(list(REQUIRED_SCOPES | OPTIONAL_SCOPES))
|
||||
|
||||
AUTH_ENDPOINT = "https://accounts.google.com/o/oauth2/v2/auth"
|
||||
TOKEN_ENDPOINT = "https://oauth2.googleapis.com/token"
|
||||
|
||||
def __init__(self, client_id: str = None, client_secret: str = None, token_file: str = "token.json"):
|
||||
|
||||
def __init__(self, client_id: str, client_secret: str, token_file: str = "token.json"):
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
self.token_file = token_file
|
||||
self.creds: Optional[Credentials] = None
|
||||
|
||||
self._flow_state: Optional[str] = None
|
||||
self._flow: Optional[Flow] = None
|
||||
|
||||
# ---------------------------
|
||||
# Internal helpers
|
||||
# ---------------------------
|
||||
@staticmethod
|
||||
def _scopes_satisfied(creds: Credentials, required: Iterable[str]) -> bool:
|
||||
"""
|
||||
Check that the credential's scopes include all required scopes.
|
||||
Google sometimes returns None/[] for creds.scopes after refresh; guard for that.
|
||||
"""
|
||||
try:
|
||||
got = set(creds.scopes or [])
|
||||
except Exception:
|
||||
got = set()
|
||||
return set(required).issubset(got)
|
||||
|
||||
# ---------------------------
|
||||
# Public methods
|
||||
# ---------------------------
|
||||
async def load_credentials(self) -> Optional[Credentials]:
|
||||
"""Load existing credentials from token file"""
|
||||
"""Load existing credentials from token file; refresh if expired.
|
||||
If token exists but lacks required scopes, delete it to force re-auth with upgraded scopes.
|
||||
"""
|
||||
if os.path.exists(self.token_file):
|
||||
async with aiofiles.open(self.token_file, 'r') as f:
|
||||
async with aiofiles.open(self.token_file, "r") as f:
|
||||
token_data = json.loads(await f.read())
|
||||
|
||||
# Create credentials from token data
|
||||
|
||||
# Build creds from stored token data (be tolerant of missing fields)
|
||||
scopes_from_file: Sequence[str] = token_data.get("scopes") or self.SCOPES
|
||||
self.creds = Credentials(
|
||||
token=token_data.get('token'),
|
||||
refresh_token=token_data.get('refresh_token'),
|
||||
id_token=token_data.get('id_token'),
|
||||
token_uri="https://oauth2.googleapis.com/token",
|
||||
token=token_data.get("token"),
|
||||
refresh_token=token_data.get("refresh_token"),
|
||||
id_token=token_data.get("id_token"),
|
||||
token_uri=self.TOKEN_ENDPOINT,
|
||||
client_id=self.client_id,
|
||||
client_secret=self.client_secret, # Need for refresh
|
||||
scopes=token_data.get('scopes', self.SCOPES)
|
||||
client_secret=self.client_secret,
|
||||
scopes=scopes_from_file,
|
||||
)
|
||||
|
||||
# Set expiry if available (ensure timezone-naive for Google auth compatibility)
|
||||
if token_data.get('expiry'):
|
||||
from datetime import datetime
|
||||
expiry_dt = datetime.fromisoformat(token_data['expiry'])
|
||||
# Remove timezone info to make it naive (Google auth expects naive datetimes)
|
||||
self.creds.expiry = expiry_dt.replace(tzinfo=None)
|
||||
|
||||
# If credentials are expired, refresh them
|
||||
if self.creds and self.creds.expired and self.creds.refresh_token:
|
||||
self.creds.refresh(Request())
|
||||
await self.save_credentials()
|
||||
|
||||
|
||||
# Restore expiry (as naive datetime for google-auth)
|
||||
if token_data.get("expiry"):
|
||||
try:
|
||||
expiry_dt = datetime.fromisoformat(token_data["expiry"])
|
||||
self.creds.expiry = expiry_dt.replace(tzinfo=None)
|
||||
except Exception:
|
||||
# If malformed, let refresh handle it
|
||||
pass
|
||||
|
||||
# If expired and we have a refresh token, try refreshing
|
||||
if self.creds and self.creds.expired and self.creds.refresh_token:
|
||||
try:
|
||||
self.creds.refresh(Request())
|
||||
finally:
|
||||
await self.save_credentials()
|
||||
|
||||
# *** Scope-upgrade detection ***
|
||||
if self.creds and not self._scopes_satisfied(self.creds, self.REQUIRED_SCOPES):
|
||||
# Old/narrow token — remove it to force a clean re-consent with broader scopes
|
||||
try:
|
||||
os.remove(self.token_file)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
self.creds = None # signal caller that we need to re-auth
|
||||
|
||||
return self.creds
|
||||
|
||||
|
||||
async def save_credentials(self):
|
||||
"""Save credentials to token file (without client_secret)"""
|
||||
if self.creds:
|
||||
# Create minimal token data without client_secret
|
||||
token_data = {
|
||||
"token": self.creds.token,
|
||||
"refresh_token": self.creds.refresh_token,
|
||||
"id_token": self.creds.id_token,
|
||||
"scopes": self.creds.scopes,
|
||||
}
|
||||
|
||||
# Add expiry if available
|
||||
if self.creds.expiry:
|
||||
token_data["expiry"] = self.creds.expiry.isoformat()
|
||||
|
||||
async with aiofiles.open(self.token_file, 'w') as f:
|
||||
await f.write(json.dumps(token_data, indent=2))
|
||||
|
||||
def create_authorization_url(self, redirect_uri: str) -> str:
|
||||
"""Create authorization URL for OAuth flow"""
|
||||
# Create flow from client credentials directly
|
||||
"""Save credentials to token file (no client_secret)."""
|
||||
if not self.creds:
|
||||
return
|
||||
|
||||
token_data = {
|
||||
"token": self.creds.token,
|
||||
"refresh_token": self.creds.refresh_token,
|
||||
"id_token": self.creds.id_token,
|
||||
# Persist the scopes we actually have now; sort for determinism
|
||||
"scopes": sorted(list(self.creds.scopes or self.SCOPES)),
|
||||
}
|
||||
if self.creds.expiry:
|
||||
token_data["expiry"] = self.creds.expiry.isoformat()
|
||||
|
||||
async with aiofiles.open(self.token_file, "w") as f:
|
||||
await f.write(json.dumps(token_data, indent=2))
|
||||
|
||||
def create_authorization_url(self, redirect_uri: str, *, force_consent: bool = True) -> str:
|
||||
"""Create authorization URL for OAuth flow.
|
||||
Set force_consent=True to guarantee Google prompts for the broader scopes (scope upgrade).
|
||||
"""
|
||||
client_config = {
|
||||
"web": {
|
||||
"client_id": self.client_id,
|
||||
"client_secret": self.client_secret,
|
||||
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
|
||||
"token_uri": "https://oauth2.googleapis.com/token"
|
||||
# Use v2 endpoints consistently
|
||||
"auth_uri": self.AUTH_ENDPOINT,
|
||||
"token_uri": self.TOKEN_ENDPOINT,
|
||||
}
|
||||
}
|
||||
|
||||
flow = Flow.from_client_config(
|
||||
client_config,
|
||||
scopes=self.SCOPES,
|
||||
redirect_uri=redirect_uri
|
||||
)
|
||||
|
||||
|
||||
flow = Flow.from_client_config(client_config, scopes=self.SCOPES, redirect_uri=redirect_uri)
|
||||
|
||||
auth_url, _ = flow.authorization_url(
|
||||
access_type='offline',
|
||||
include_granted_scopes='true',
|
||||
prompt='consent' # Force consent to get refresh token
|
||||
access_type="offline",
|
||||
# include_granted_scopes=True can cause Google to reuse an old narrow grant.
|
||||
# For upgrades, it's safer to disable it when force_consent is True.
|
||||
include_granted_scopes="false" if force_consent else "true",
|
||||
prompt="consent" if force_consent else None, # ensure we actually see the consent screen
|
||||
)
|
||||
|
||||
# Store flow state for later use
|
||||
|
||||
self._flow_state = flow.state
|
||||
self._flow = flow
|
||||
|
||||
return auth_url
|
||||
|
||||
|
||||
async def handle_authorization_callback(self, authorization_code: str, state: str) -> bool:
|
||||
"""Handle OAuth callback and exchange code for tokens"""
|
||||
if not hasattr(self, '_flow') or self._flow_state != state:
|
||||
raise ValueError("Invalid OAuth state")
|
||||
|
||||
# Exchange authorization code for credentials
|
||||
"""Handle OAuth callback and exchange code for tokens."""
|
||||
if not self._flow or self._flow_state != state:
|
||||
raise ValueError("Invalid or missing OAuth state")
|
||||
|
||||
self._flow.fetch_token(code=authorization_code)
|
||||
self.creds = self._flow.credentials
|
||||
|
||||
# Save credentials
|
||||
await self.save_credentials()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def is_authenticated(self) -> bool:
|
||||
"""Check if we have valid credentials"""
|
||||
"""Return True if we have a usable credential with all required scopes."""
|
||||
if not self.creds:
|
||||
await self.load_credentials()
|
||||
|
||||
return self.creds and self.creds.valid
|
||||
|
||||
return bool(self.creds and self.creds.valid and self._scopes_satisfied(self.creds, self.REQUIRED_SCOPES))
|
||||
|
||||
def get_service(self):
|
||||
"""Get authenticated Google Drive service"""
|
||||
if not self.creds or not self.creds.valid:
|
||||
raise ValueError("Not authenticated")
|
||||
|
||||
return build('drive', 'v3', credentials=self.creds)
|
||||
|
||||
"""Get authenticated Google Drive service."""
|
||||
if not self.creds or not self.creds.valid or not self._scopes_satisfied(self.creds, self.REQUIRED_SCOPES):
|
||||
raise ValueError("Not authenticated with required scopes")
|
||||
# cache_discovery=False avoids a deprecation warning chatter in some environments
|
||||
return build("drive", "v3", credentials=self.creds, cache_discovery=False)
|
||||
|
||||
async def revoke_credentials(self):
|
||||
"""Revoke credentials and delete token file"""
|
||||
"""Revoke credentials and delete token file."""
|
||||
if self.creds:
|
||||
self.creds.revoke(Request())
|
||||
|
||||
try:
|
||||
self.creds.revoke(Request())
|
||||
except Exception:
|
||||
# Revocation is best-effort; continue to clear local token
|
||||
pass
|
||||
if os.path.exists(self.token_file):
|
||||
os.remove(self.token_file)
|
||||
|
||||
self.creds = None
|
||||
try:
|
||||
os.remove(self.token_file)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
self.creds = None
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue