feat: Enhanced GDrive Connector

This commit is contained in:
Eric Hare 2025-09-03 07:29:58 -07:00
parent e810aef588
commit f878c38c40
No known key found for this signature in database
GPG key ID: A73DF73724270AB7
2 changed files with 363 additions and 195 deletions

View file

@ -3,8 +3,9 @@ import io
import os
import uuid
from datetime import datetime
from typing import Dict, List, Any, Optional
from googleapiclient.discovery import build
from typing import Dict, List, Any, Optional, Set
from googleapiclient.discovery import build # noqa: F401 (kept for symmetry with oauth)
from googleapiclient.errors import HttpError
from googleapiclient.http import MediaIoBaseDownload
@ -15,18 +16,18 @@ from .oauth import GoogleDriveOAuth
# Global worker service cache for process pools
_worker_drive_service = None
def get_worker_drive_service(client_id: str, client_secret: str, token_file: str):
"""Get or create a Google Drive service instance for this worker process"""
global _worker_drive_service
if _worker_drive_service is None:
print(f"🔧 Initializing Google Drive service in worker process (PID: {os.getpid()})")
# Create OAuth instance and load credentials in worker
from .oauth import GoogleDriveOAuth
oauth = GoogleDriveOAuth(client_id=client_id, client_secret=client_secret, token_file=token_file)
from .oauth import GoogleDriveOAuth as _GoogleDriveOAuth
oauth = _GoogleDriveOAuth(client_id=client_id, client_secret=client_secret, token_file=token_file)
# Load credentials synchronously in worker
import asyncio
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
@ -35,7 +36,7 @@ def get_worker_drive_service(client_id: str, client_secret: str, token_file: str
print(f"✅ Google Drive service ready in worker process (PID: {os.getpid()})")
finally:
loop.close()
return _worker_drive_service
@ -56,7 +57,7 @@ def _sync_get_metadata_worker(client_id, client_secret, token_file, file_id):
service = get_worker_drive_service(client_id, client_secret, token_file)
return service.files().get(
fileId=file_id,
fields="id, name, mimeType, modifiedTime, createdTime, webViewLink, permissions, owners, size"
fields="id, name, mimeType, modifiedTime, createdTime, webViewLink, permissions, owners, size, trashed, parents"
).execute()
@ -64,35 +65,35 @@ def _sync_download_worker(client_id, client_secret, token_file, file_id, mime_ty
"""Worker function for downloading files in process pool"""
import signal
import time
# File size limits (in bytes)
MAX_REGULAR_FILE_SIZE = 100 * 1024 * 1024 # 100MB for regular files
MAX_GOOGLE_WORKSPACE_SIZE = 50 * 1024 * 1024 # 50MB for Google Workspace docs (they can't be streamed)
MAX_REGULAR_FILE_SIZE = 1000 * 1024 * 1024 # 1000MB for regular files
MAX_GOOGLE_WORKSPACE_SIZE = 500 * 1024 * 1024 # 500MB for Google Workspace docs (they can't be streamed)
# Check file size limits
if file_size:
if mime_type.startswith('application/vnd.google-apps.') and file_size > MAX_GOOGLE_WORKSPACE_SIZE:
raise ValueError(f"Google Workspace file too large: {file_size} bytes (max {MAX_GOOGLE_WORKSPACE_SIZE})")
elif not mime_type.startswith('application/vnd.google-apps.') and file_size > MAX_REGULAR_FILE_SIZE:
raise ValueError(f"File too large: {file_size} bytes (max {MAX_REGULAR_FILE_SIZE})")
# Dynamic timeout based on file size (minimum 60s, 10s per MB, max 300s)
if file_size:
file_size_mb = file_size / (1024 * 1024)
timeout_seconds = min(300, max(60, int(file_size_mb * 10)))
else:
timeout_seconds = 60 # Default timeout if size unknown
# Set a timeout for the entire download operation
def timeout_handler(signum, frame):
raise TimeoutError(f"File download timed out after {timeout_seconds} seconds")
signal.signal(signal.SIGALRM, timeout_handler)
signal.alarm(timeout_seconds)
try:
service = get_worker_drive_service(client_id, client_secret, token_file)
# For Google native formats, export as PDF
if mime_type.startswith('application/vnd.google-apps.'):
export_format = 'application/pdf'
@ -100,15 +101,15 @@ def _sync_download_worker(client_id, client_secret, token_file, file_id, mime_ty
else:
# For regular files, download directly
request = service.files().get_media(fileId=file_id)
# Download file with chunked approach
file_io = io.BytesIO()
downloader = MediaIoBaseDownload(file_io, request, chunksize=1024*1024) # 1MB chunks
downloader = MediaIoBaseDownload(file_io, request, chunksize=1024 * 1024) # 1MB chunks
done = False
retry_count = 0
max_retries = 2
while not done and retry_count < max_retries:
try:
status, done = downloader.next_chunk()
@ -118,9 +119,9 @@ def _sync_download_worker(client_id, client_secret, token_file, file_id, mime_ty
if retry_count >= max_retries:
raise e
time.sleep(1) # Brief pause before retry
return file_io.getvalue()
finally:
# Cancel the alarm
signal.alarm(0)
@ -128,16 +129,16 @@ def _sync_download_worker(client_id, client_secret, token_file, file_id, mime_ty
class GoogleDriveConnector(BaseConnector):
"""Google Drive connector with OAuth and webhook support"""
# OAuth environment variables
CLIENT_ID_ENV_VAR = "GOOGLE_OAUTH_CLIENT_ID"
CLIENT_SECRET_ENV_VAR = "GOOGLE_OAUTH_CLIENT_SECRET"
# Connector metadata
CONNECTOR_NAME = "Google Drive"
CONNECTOR_DESCRIPTION = "Connect your Google Drive to automatically sync documents"
CONNECTOR_ICON = "google-drive"
# Supported file types that can be processed by docling
SUPPORTED_MIMETYPES = {
'application/pdf',
@ -149,11 +150,11 @@ class GoogleDriveConnector(BaseConnector):
'text/html',
'application/rtf',
# Google Docs native formats - we'll export these
'application/vnd.google-apps.document', # Google Docs -> PDF
'application/vnd.google-apps.presentation', # Google Slides -> PDF
'application/vnd.google-apps.spreadsheet', # Google Sheets -> PDF
'application/vnd.google-apps.document', # Google Docs -> PDF
'application/vnd.google-apps.presentation', # Google Slides -> PDF
'application/vnd.google-apps.spreadsheet', # Google Sheets -> PDF
}
def __init__(self, config: Dict[str, Any]):
super().__init__(config)
self.oauth = GoogleDriveOAuth(
@ -166,34 +167,48 @@ class GoogleDriveConnector(BaseConnector):
self.webhook_channel_id = config.get('webhook_channel_id') or config.get('subscription_id')
# Load existing webhook resource ID (Google Drive requires this to stop a channel)
self.webhook_resource_id = config.get('resource_id')
# ---- NEW: selection & filtering config ----
# Explicit include lists
self.selected_file_ids: Set[str] = set(config.get("file_ids") or config.get("selected_file_ids") or [])
self.selected_folder_ids: Set[str] = set(config.get("folder_ids") or config.get("selected_folder_ids") or [])
# Recursive traversal for folders (default True)
self.recursive: bool = bool(config.get("recursive", True))
# Optional MIME overrides
self.include_mime_types: Set[str] = set(config.get("include_mime_types", []))
self.exclude_mime_types: Set[str] = set(config.get("exclude_mime_types", []))
# Cache of expanded folder IDs (when recursive)
self._expanded_folder_ids: Optional[Set[str]] = None
async def authenticate(self) -> bool:
"""Authenticate with Google Drive"""
try:
if await self.oauth.is_authenticated():
self.service = self.oauth.get_service()
self._authenticated = True
# Pre-expand folder tree if needed
if self.selected_folder_ids and self.recursive:
self._expanded_folder_ids = await self._expand_folders_recursive(self.selected_folder_ids)
return True
return False
except Exception as e:
print(f"Authentication failed: {e}")
return False
async def setup_subscription(self) -> str:
"""Set up Google Drive push notifications"""
if not self._authenticated:
raise ValueError("Not authenticated")
# Generate unique channel ID
channel_id = str(uuid.uuid4())
# Set up push notification
# Note: This requires a publicly accessible webhook endpoint
webhook_url = self.config.get('webhook_url')
if not webhook_url:
raise ValueError("webhook_url required in config for subscriptions")
try:
body = {
'id': channel_id,
@ -202,12 +217,12 @@ class GoogleDriveConnector(BaseConnector):
'payload': True,
'expiration': str(int((datetime.now().timestamp() + 86400) * 1000)) # 24 hours
}
result = self.service.changes().watch(
pageToken=self._get_start_page_token(),
body=body
).execute()
self.webhook_channel_id = channel_id
# Persist the resourceId returned by Google to allow proper cleanup
try:
@ -215,17 +230,17 @@ class GoogleDriveConnector(BaseConnector):
except Exception:
self.webhook_resource_id = None
return channel_id
except HttpError as e:
print(f"Failed to set up subscription: {e}")
raise
def _get_start_page_token(self) -> str:
"""Get the current page token for change notifications"""
return self.service.changes().getStartPageToken().execute()['startPageToken']
async def list_files(self, page_token: Optional[str] = None, limit: Optional[int] = None) -> Dict[str, Any]:
"""List all supported files in Google Drive.
"""List supported files in Google Drive, scoped by selected folders/files if configured.
Uses a thread pool (not the shared process pool) to avoid issues with
Google API clients in forked processes and adds light retries for
@ -234,13 +249,65 @@ class GoogleDriveConnector(BaseConnector):
if not self._authenticated:
raise ValueError("Not authenticated")
# Build query for supported file types
mimetype_query = " or ".join([f"mimeType='{mt}'" for mt in self.SUPPORTED_MIMETYPES])
query = f"({mimetype_query}) and trashed=false"
# ---- compute MIME type filter (with optional overrides) ----
effective_mimes = set(self.SUPPORTED_MIMETYPES)
if self.include_mime_types:
effective_mimes |= set(self.include_mime_types)
if self.exclude_mime_types:
effective_mimes -= set(self.exclude_mime_types)
mimetype_query = " or ".join([f"mimeType='{mt}'" for mt in sorted(effective_mimes)])
# ---- build parents scope if folders were selected ----
parent_ids: Set[str] = set()
if self.selected_folder_ids:
if self.recursive:
parent_ids = set(self._expanded_folder_ids or [])
else:
parent_ids = set(self.selected_folder_ids)
parents_query = ""
if parent_ids:
parents_query = " and (" + " or ".join([f"'{fid}' in parents" for fid in sorted(parent_ids)]) + ")"
# Final query
query = f"({mimetype_query}) and trashed=false{parents_query}"
# Use provided limit or default to 100, max 1000 (Google Drive API limit)
page_size = min(limit or 100, 1000)
# ---- fast-path explicit file IDs (ignores page_token/size intentionally) ----
if self.selected_file_ids:
files: List[Dict[str, Any]] = []
loop = asyncio.get_event_loop()
from utils.process_pool import process_pool
async def fetch(fid: str):
try:
meta = await loop.run_in_executor(
process_pool,
_sync_get_metadata_worker,
self.oauth.client_id,
self.oauth.client_secret,
self.oauth.token_file,
fid,
)
if meta.get("mimeType") in effective_mimes and not meta.get("trashed", False):
files.append({
'id': meta['id'],
'name': meta['name'],
'mimeType': meta['mimeType'],
'modifiedTime': meta['modifiedTime'],
'createdTime': meta['createdTime'],
'webViewLink': meta.get('webViewLink'),
'permissions': meta.get('permissions', []),
'owners': meta.get('owners', []),
})
except Exception:
# Skip missing/forbidden IDs silently; caller can log if desired
pass
await asyncio.gather(*[fetch(fid) for fid in sorted(self.selected_file_ids)])
return {'files': files, 'nextPageToken': None}
def _sync_list_files_inner():
import time
attempts = 0
@ -252,7 +319,7 @@ class GoogleDriveConnector(BaseConnector):
q=query,
pageSize=page_size,
pageToken=page_token,
fields="nextPageToken, files(id, name, mimeType, modifiedTime, createdTime, webViewLink, permissions, owners)"
fields="nextPageToken, files(id, name, mimeType, modifiedTime, createdTime, webViewLink, permissions, owners, parents, trashed)"
).execute()
except Exception as e:
attempts += 1
@ -267,7 +334,6 @@ class GoogleDriveConnector(BaseConnector):
try:
# Offload blocking HTTP call to default ThreadPoolExecutor
import asyncio
loop = asyncio.get_event_loop()
results = await loop.run_in_executor(None, _sync_list_files_inner)
@ -292,37 +358,36 @@ class GoogleDriveConnector(BaseConnector):
except HttpError as e:
print(f"Failed to list files: {e}")
raise
async def get_file_content(self, file_id: str) -> ConnectorDocument:
"""Get file content and metadata"""
if not self._authenticated:
raise ValueError("Not authenticated")
try:
# Get file metadata (run in thread pool to avoid blocking)
import asyncio
loop = asyncio.get_event_loop()
# Use the same process pool as docling processing
from utils.process_pool import process_pool
file_metadata = await loop.run_in_executor(
process_pool,
_sync_get_metadata_worker,
process_pool,
_sync_get_metadata_worker,
self.oauth.client_id,
self.oauth.client_secret,
self.oauth.token_file,
file_id
)
# Download file content (pass file size for timeout calculation)
file_size = file_metadata.get('size')
if file_size:
file_size = int(file_size) # Ensure it's an integer
content = await self._download_file_content(file_id, file_metadata['mimeType'], file_size)
# Extract ACL information
acl = self._extract_acl(file_metadata)
return ConnectorDocument(
id=file_id,
filename=file_metadata['name'],
@ -337,46 +402,89 @@ class GoogleDriveConnector(BaseConnector):
'owners': file_metadata.get('owners', [])
}
)
except HttpError as e:
print(f"Failed to get file content: {e}")
raise
async def _download_file_content(self, file_id: str, mime_type: str, file_size: int = None) -> bytes:
"""Download file content, converting Google Docs formats if needed"""
# Download file (run in process pool to avoid blocking)
import asyncio
loop = asyncio.get_event_loop()
# Use the same process pool as docling processing
from utils.process_pool import process_pool
return await loop.run_in_executor(
process_pool,
_sync_download_worker,
process_pool,
_sync_download_worker,
self.oauth.client_id,
self.oauth.client_secret,
self.oauth.token_file,
file_id,
file_id,
mime_type,
file_size
)
# ---- helper to expand folder IDs recursively ----
async def _expand_folders_recursive(self, root_folder_ids: Set[str]) -> Set[str]:
"""Return set of folder IDs including all descendants of root_folder_ids."""
loop = asyncio.get_event_loop()
FOLDER_MIME = "application/vnd.google-apps.folder"
def _list_children(parents: Set[str]) -> List[Dict[str, Any]]:
# Query: all folders whose parents include any in 'parents'
# Chunk 'parents' to avoid overly long queries
all_found: List[Dict[str, Any]] = []
parent_list = list(parents)
CHUNK = 20
for i in range(0, len(parent_list), CHUNK):
chunk = parent_list[i:i + CHUNK]
parents_q = " or ".join([f"'{pid}' in parents" for pid in chunk])
q = f"mimeType='{FOLDER_MIME}' and trashed=false and ({parents_q})"
page_token = None
while True:
resp = self.service.files().list(
q=q,
pageSize=1000,
pageToken=page_token,
fields="nextPageToken, files(id, parents)"
).execute()
all_found.extend(resp.get("files", []))
page_token = resp.get("nextPageToken")
if not page_token:
break
return all_found
# BFS
visited: Set[str] = set(root_folder_ids)
frontier: Set[str] = set(root_folder_ids)
while frontier:
children = await loop.run_in_executor(None, _list_children, frontier)
next_frontier: Set[str] = set()
for f in children:
fid = f.get("id")
if fid and fid not in visited:
visited.add(fid)
next_frontier.add(fid)
frontier = next_frontier
return visited
def _extract_acl(self, file_metadata: Dict[str, Any]) -> DocumentACL:
"""Extract ACL information from file metadata"""
user_permissions = {}
group_permissions = {}
owner = None
if file_metadata.get('owners'):
owner = file_metadata['owners'][0].get('emailAddress')
# Process permissions
for perm in file_metadata.get('permissions', []):
email = perm.get('emailAddress')
role = perm.get('role', 'reader')
perm_type = perm.get('type')
if perm_type == 'user' and email:
user_permissions[email] = role
elif perm_type == 'group' and email:
@ -385,13 +493,13 @@ class GoogleDriveConnector(BaseConnector):
# Domain-wide permissions - could be treated as a group
domain = perm.get('domain', 'unknown-domain')
group_permissions[f"domain:{domain}"] = role
return DocumentACL(
owner=owner,
user_permissions=user_permissions,
group_permissions=group_permissions
)
def extract_webhook_channel_id(self, payload: Dict[str, Any], headers: Dict[str, str]) -> Optional[str]:
"""Extract Google Drive channel ID from webhook headers"""
return headers.get('x-goog-channel-id')
@ -399,87 +507,106 @@ class GoogleDriveConnector(BaseConnector):
def extract_webhook_resource_id(self, headers: Dict[str, str]) -> Optional[str]:
"""Extract Google Drive resource ID from webhook headers"""
return headers.get('x-goog-resource-id')
async def handle_webhook(self, payload: Dict[str, Any]) -> List[str]:
"""Handle Google Drive webhook notification"""
if not self._authenticated:
raise ValueError("Not authenticated")
# Google Drive sends headers with the important info
headers = payload.get('_headers', {})
# Extract Google Drive specific headers
channel_id = headers.get('x-goog-channel-id')
resource_state = headers.get('x-goog-resource-state')
if not channel_id:
print("[WEBHOOK] No channel ID found in Google Drive webhook")
return []
# Check if this webhook belongs to this connection
if self.webhook_channel_id != channel_id:
print(f"[WEBHOOK] Channel ID mismatch: expected {self.webhook_channel_id}, got {channel_id}")
return []
# Only process certain states (ignore 'sync' which is just a ping)
if resource_state not in ['exists', 'not_exists', 'change']:
print(f"[WEBHOOK] Ignoring resource state: {resource_state}")
return []
try:
# Extract page token from the resource URI if available
page_token = None
headers = payload.get('_headers', {})
resource_uri = headers.get('x-goog-resource-uri')
if resource_uri and 'pageToken=' in resource_uri:
# Extract page token from URI like:
# Extract page token from URI like:
# https://www.googleapis.com/drive/v3/changes?alt=json&pageToken=4337807
import urllib.parse
parsed = urllib.parse.urlparse(resource_uri)
query_params = urllib.parse.parse_qs(parsed.query)
page_token = query_params.get('pageToken', [None])[0]
if not page_token:
print("[WEBHOOK] No page token found, cannot identify specific changes")
return []
print(f"[WEBHOOK] Getting changes since page token: {page_token}")
# Get list of changes since the page token
changes = self.service.changes().list(
pageToken=page_token,
fields="changes(fileId, file(id, name, mimeType, trashed, parents))"
).execute()
# Recompute effective mimes (same as list_files)
effective_mimes = set(self.SUPPORTED_MIMETYPES)
if self.include_mime_types:
effective_mimes |= set(self.include_mime_types)
if self.exclude_mime_types:
effective_mimes -= set(self.exclude_mime_types)
affected_files = []
for change in changes.get('changes', []):
file_info = change.get('file', {})
file_info = change.get('file', {}) or {}
file_id = change.get('fileId')
if not file_id:
continue
# Only include supported file types that aren't trashed
mime_type = file_info.get('mimeType', '')
is_trashed = file_info.get('trashed', False)
if not is_trashed and mime_type in self.SUPPORTED_MIMETYPES:
# ---- scope filtering for webhook ----
in_scope = True
# file-based inclusion
if self.selected_file_ids and file_id not in self.selected_file_ids:
in_scope = False
# folder-based inclusion
if in_scope and self.selected_folder_ids:
allowed_parents = self._expanded_folder_ids if self.recursive else self.selected_folder_ids
file_parents = set((file_info.get('parents') or []))
if not (file_parents & set(allowed_parents or [])):
in_scope = False
if not is_trashed and mime_type in effective_mimes and in_scope:
print(f"[WEBHOOK] File changed: {file_info.get('name', 'Unknown')} ({file_id})")
affected_files.append(file_id)
elif is_trashed:
print(f"[WEBHOOK] File deleted/trashed: {file_info.get('name', 'Unknown')} ({file_id})")
# TODO: Handle file deletion (remove from index)
else:
print(f"[WEBHOOK] Ignoring unsupported file type: {mime_type}")
print(f"[WEBHOOK] Ignoring unsupported or out-of-scope file: mime={mime_type} id={file_id}")
print(f"[WEBHOOK] Found {len(affected_files)} affected supported files")
return affected_files
except HttpError as e:
print(f"Failed to handle webhook: {e}")
return []
async def cleanup_subscription(self, subscription_id: str) -> bool:
"""Clean up Google Drive subscription for this connection.
@ -487,13 +614,13 @@ class GoogleDriveConnector(BaseConnector):
"""
if not self._authenticated:
return False
try:
# Google Channels API requires both 'id' (channel) and 'resourceId'
if not self.webhook_resource_id:
raise ValueError("Missing resource_id for cleanup; ensure subscription state is persisted")
body = {'id': subscription_id, 'resourceId': self.webhook_resource_id}
self.service.channels().stop(body=body).execute()
return True
except HttpError as e:

View file

@ -1,7 +1,8 @@
import os
import json
import asyncio
from typing import Dict, Any, Optional
from typing import Optional, Iterable, Sequence
from datetime import datetime
from google.auth.transport.requests import Request
from google.oauth2.credentials import Credentials
from google_auth_oauthlib.flow import Flow
@ -10,138 +11,178 @@ import aiofiles
class GoogleDriveOAuth:
"""Handles Google Drive OAuth authentication flow"""
SCOPES = [
'openid',
'email',
'profile',
'https://www.googleapis.com/auth/drive.readonly',
'https://www.googleapis.com/auth/drive.metadata.readonly'
]
"""Handles Google Drive OAuth authentication flow with scope-upgrade detection."""
# Core scopes needed by your connector:
# - drive.readonly: content/export
# - drive.metadata.readonly: owners/permissions/parents
REQUIRED_SCOPES = {
"https://www.googleapis.com/auth/drive.readonly",
"https://www.googleapis.com/auth/drive.metadata.readonly",
}
# Optional OIDC/userinfo scopes (nice-to-have; keep them if you use identity info elsewhere)
OPTIONAL_SCOPES = {"openid", "email", "profile"}
# Final scopes we request during auth
SCOPES = sorted(list(REQUIRED_SCOPES | OPTIONAL_SCOPES))
AUTH_ENDPOINT = "https://accounts.google.com/o/oauth2/v2/auth"
TOKEN_ENDPOINT = "https://oauth2.googleapis.com/token"
def __init__(self, client_id: str = None, client_secret: str = None, token_file: str = "token.json"):
def __init__(self, client_id: str, client_secret: str, token_file: str = "token.json"):
self.client_id = client_id
self.client_secret = client_secret
self.token_file = token_file
self.creds: Optional[Credentials] = None
self._flow_state: Optional[str] = None
self._flow: Optional[Flow] = None
# ---------------------------
# Internal helpers
# ---------------------------
@staticmethod
def _scopes_satisfied(creds: Credentials, required: Iterable[str]) -> bool:
"""
Check that the credential's scopes include all required scopes.
Google sometimes returns None/[] for creds.scopes after refresh; guard for that.
"""
try:
got = set(creds.scopes or [])
except Exception:
got = set()
return set(required).issubset(got)
# ---------------------------
# Public methods
# ---------------------------
async def load_credentials(self) -> Optional[Credentials]:
"""Load existing credentials from token file"""
"""Load existing credentials from token file; refresh if expired.
If token exists but lacks required scopes, delete it to force re-auth with upgraded scopes.
"""
if os.path.exists(self.token_file):
async with aiofiles.open(self.token_file, 'r') as f:
async with aiofiles.open(self.token_file, "r") as f:
token_data = json.loads(await f.read())
# Create credentials from token data
# Build creds from stored token data (be tolerant of missing fields)
scopes_from_file: Sequence[str] = token_data.get("scopes") or self.SCOPES
self.creds = Credentials(
token=token_data.get('token'),
refresh_token=token_data.get('refresh_token'),
id_token=token_data.get('id_token'),
token_uri="https://oauth2.googleapis.com/token",
token=token_data.get("token"),
refresh_token=token_data.get("refresh_token"),
id_token=token_data.get("id_token"),
token_uri=self.TOKEN_ENDPOINT,
client_id=self.client_id,
client_secret=self.client_secret, # Need for refresh
scopes=token_data.get('scopes', self.SCOPES)
client_secret=self.client_secret,
scopes=scopes_from_file,
)
# Set expiry if available (ensure timezone-naive for Google auth compatibility)
if token_data.get('expiry'):
from datetime import datetime
expiry_dt = datetime.fromisoformat(token_data['expiry'])
# Remove timezone info to make it naive (Google auth expects naive datetimes)
self.creds.expiry = expiry_dt.replace(tzinfo=None)
# If credentials are expired, refresh them
if self.creds and self.creds.expired and self.creds.refresh_token:
self.creds.refresh(Request())
await self.save_credentials()
# Restore expiry (as naive datetime for google-auth)
if token_data.get("expiry"):
try:
expiry_dt = datetime.fromisoformat(token_data["expiry"])
self.creds.expiry = expiry_dt.replace(tzinfo=None)
except Exception:
# If malformed, let refresh handle it
pass
# If expired and we have a refresh token, try refreshing
if self.creds and self.creds.expired and self.creds.refresh_token:
try:
self.creds.refresh(Request())
finally:
await self.save_credentials()
# *** Scope-upgrade detection ***
if self.creds and not self._scopes_satisfied(self.creds, self.REQUIRED_SCOPES):
# Old/narrow token — remove it to force a clean re-consent with broader scopes
try:
os.remove(self.token_file)
except FileNotFoundError:
pass
self.creds = None # signal caller that we need to re-auth
return self.creds
async def save_credentials(self):
"""Save credentials to token file (without client_secret)"""
if self.creds:
# Create minimal token data without client_secret
token_data = {
"token": self.creds.token,
"refresh_token": self.creds.refresh_token,
"id_token": self.creds.id_token,
"scopes": self.creds.scopes,
}
# Add expiry if available
if self.creds.expiry:
token_data["expiry"] = self.creds.expiry.isoformat()
async with aiofiles.open(self.token_file, 'w') as f:
await f.write(json.dumps(token_data, indent=2))
def create_authorization_url(self, redirect_uri: str) -> str:
"""Create authorization URL for OAuth flow"""
# Create flow from client credentials directly
"""Save credentials to token file (no client_secret)."""
if not self.creds:
return
token_data = {
"token": self.creds.token,
"refresh_token": self.creds.refresh_token,
"id_token": self.creds.id_token,
# Persist the scopes we actually have now; sort for determinism
"scopes": sorted(list(self.creds.scopes or self.SCOPES)),
}
if self.creds.expiry:
token_data["expiry"] = self.creds.expiry.isoformat()
async with aiofiles.open(self.token_file, "w") as f:
await f.write(json.dumps(token_data, indent=2))
def create_authorization_url(self, redirect_uri: str, *, force_consent: bool = True) -> str:
"""Create authorization URL for OAuth flow.
Set force_consent=True to guarantee Google prompts for the broader scopes (scope upgrade).
"""
client_config = {
"web": {
"client_id": self.client_id,
"client_secret": self.client_secret,
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
"token_uri": "https://oauth2.googleapis.com/token"
# Use v2 endpoints consistently
"auth_uri": self.AUTH_ENDPOINT,
"token_uri": self.TOKEN_ENDPOINT,
}
}
flow = Flow.from_client_config(
client_config,
scopes=self.SCOPES,
redirect_uri=redirect_uri
)
flow = Flow.from_client_config(client_config, scopes=self.SCOPES, redirect_uri=redirect_uri)
auth_url, _ = flow.authorization_url(
access_type='offline',
include_granted_scopes='true',
prompt='consent' # Force consent to get refresh token
access_type="offline",
# include_granted_scopes=True can cause Google to reuse an old narrow grant.
# For upgrades, it's safer to disable it when force_consent is True.
include_granted_scopes="false" if force_consent else "true",
prompt="consent" if force_consent else None, # ensure we actually see the consent screen
)
# Store flow state for later use
self._flow_state = flow.state
self._flow = flow
return auth_url
async def handle_authorization_callback(self, authorization_code: str, state: str) -> bool:
"""Handle OAuth callback and exchange code for tokens"""
if not hasattr(self, '_flow') or self._flow_state != state:
raise ValueError("Invalid OAuth state")
# Exchange authorization code for credentials
"""Handle OAuth callback and exchange code for tokens."""
if not self._flow or self._flow_state != state:
raise ValueError("Invalid or missing OAuth state")
self._flow.fetch_token(code=authorization_code)
self.creds = self._flow.credentials
# Save credentials
await self.save_credentials()
return True
async def is_authenticated(self) -> bool:
"""Check if we have valid credentials"""
"""Return True if we have a usable credential with all required scopes."""
if not self.creds:
await self.load_credentials()
return self.creds and self.creds.valid
return bool(self.creds and self.creds.valid and self._scopes_satisfied(self.creds, self.REQUIRED_SCOPES))
def get_service(self):
"""Get authenticated Google Drive service"""
if not self.creds or not self.creds.valid:
raise ValueError("Not authenticated")
return build('drive', 'v3', credentials=self.creds)
"""Get authenticated Google Drive service."""
if not self.creds or not self.creds.valid or not self._scopes_satisfied(self.creds, self.REQUIRED_SCOPES):
raise ValueError("Not authenticated with required scopes")
# cache_discovery=False avoids a deprecation warning chatter in some environments
return build("drive", "v3", credentials=self.creds, cache_discovery=False)
async def revoke_credentials(self):
"""Revoke credentials and delete token file"""
"""Revoke credentials and delete token file."""
if self.creds:
self.creds.revoke(Request())
try:
self.creds.revoke(Request())
except Exception:
# Revocation is best-effort; continue to clear local token
pass
if os.path.exists(self.token_file):
os.remove(self.token_file)
self.creds = None
try:
os.remove(self.token_file)
except FileNotFoundError:
pass
self.creds = None