openrag/src/connectors/onedrive/oauth.py
2025-09-25 13:12:27 -07:00

322 lines
14 KiB
Python

import os
import json
import logging
from typing import Optional, Dict, Any
import aiofiles
import msal
logger = logging.getLogger(__name__)
class OneDriveOAuth:
"""Handles Microsoft Graph OAuth for OneDrive (personal Microsoft accounts by default)."""
# Reserved scopes that must NOT be sent on token or silent calls
RESERVED_SCOPES = {"openid", "profile", "offline_access"}
# For PERSONAL Microsoft Accounts (OneDrive consumer):
# - Use AUTH_SCOPES for interactive auth (consent + refresh token issuance)
# - Use RESOURCE_SCOPES for acquire_token_silent / refresh paths
AUTH_SCOPES = ["User.Read", "Files.Read.All", "offline_access"]
RESOURCE_SCOPES = ["User.Read", "Files.Read.All"]
SCOPES = AUTH_SCOPES # Backward-compat alias if something references .SCOPES
# Kept for reference; MSAL derives endpoints from `authority`
AUTH_ENDPOINT = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize"
TOKEN_ENDPOINT = "https://login.microsoftonline.com/common/oauth2/v2.0/token"
def __init__(
self,
client_id: str,
client_secret: str,
token_file: str = "onedrive_token.json",
authority: str = "https://login.microsoftonline.com/common",
allow_json_refresh: bool = True,
):
"""
Initialize OneDriveOAuth.
Args:
client_id: Azure AD application (client) ID.
client_secret: Azure AD application client secret.
token_file: Path to persisted token cache file (MSAL cache format).
authority: Usually "https://login.microsoftonline.com/common" for MSA + org,
or tenant-specific for work/school.
allow_json_refresh: If True, permit one-time migration from legacy flat JSON
{"access_token","refresh_token",...}. Otherwise refuse it.
"""
self.client_id = client_id
self.client_secret = client_secret
self.token_file = token_file
self.authority = authority
self.allow_json_refresh = allow_json_refresh
self.token_cache = msal.SerializableTokenCache()
self._current_account = None
# Initialize MSAL Confidential Client
self.app = msal.ConfidentialClientApplication(
client_id=self.client_id,
client_credential=self.client_secret,
authority=self.authority,
token_cache=self.token_cache,
)
async def load_credentials(self) -> bool:
"""Load existing credentials from token file (async)."""
try:
logger.debug(f"OneDrive OAuth loading credentials from: {self.token_file}")
if os.path.exists(self.token_file):
logger.debug(f"Token file exists, reading: {self.token_file}")
# Read the token file
async with aiofiles.open(self.token_file, "r") as f:
cache_data = await f.read()
logger.debug(f"Read {len(cache_data)} chars from token file")
if cache_data.strip():
# 1) Try legacy flat JSON first
try:
json_data = json.loads(cache_data)
if isinstance(json_data, dict) and "refresh_token" in json_data:
if self.allow_json_refresh:
logger.debug(
"Found legacy JSON refresh_token and allow_json_refresh=True; attempting migration refresh"
)
return await self._refresh_from_json_token(json_data)
else:
logger.warning(
"Token file contains a legacy JSON refresh_token, but allow_json_refresh=False. "
"Delete the file and re-auth."
)
return False
except json.JSONDecodeError:
logger.debug("Token file is not flat JSON; attempting MSAL cache format")
# 2) Try MSAL cache format
logger.debug("Attempting MSAL cache deserialization")
self.token_cache.deserialize(cache_data)
# Get accounts from loaded cache
accounts = self.app.get_accounts()
logger.debug(f"Found {len(accounts)} accounts in MSAL cache")
if accounts:
self._current_account = accounts[0]
logger.debug(f"Set current account: {self._current_account.get('username', 'no username')}")
# Use RESOURCE_SCOPES (no reserved scopes) for silent acquisition
result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=self._current_account)
logger.debug(f"Silent token acquisition result keys: {list(result.keys()) if result else 'None'}")
if result and "access_token" in result:
logger.debug("Silent token acquisition successful")
await self.save_cache()
return True
else:
error_msg = (result or {}).get("error") or "No result"
logger.warning(f"Silent token acquisition failed: {error_msg}")
else:
logger.debug(f"Token file {self.token_file} is empty")
else:
logger.debug(f"Token file does not exist: {self.token_file}")
return False
except Exception as e:
logger.error(f"Failed to load OneDrive credentials: {e}")
import traceback
traceback.print_exc()
return False
async def _refresh_from_json_token(self, token_data: dict) -> bool:
"""
Use refresh token from a legacy JSON file to get new tokens (one-time migration path).
Prefer using an MSAL cache file and acquire_token_silent(); this path is only for migrating older files.
"""
try:
refresh_token = token_data.get("refresh_token")
if not refresh_token:
logger.error("No refresh_token found in JSON file - cannot refresh")
logger.error("You must re-authenticate interactively to obtain a valid token")
return False
# Use only RESOURCE_SCOPES when refreshing (no reserved scopes)
refresh_scopes = [s for s in self.RESOURCE_SCOPES if s not in self.RESERVED_SCOPES]
logger.debug(f"Using refresh token; refresh scopes = {refresh_scopes}")
result = self.app.acquire_token_by_refresh_token(
refresh_token=refresh_token,
scopes=refresh_scopes,
)
if result and "access_token" in result:
logger.debug("Successfully refreshed token via legacy JSON path")
await self.save_cache()
accounts = self.app.get_accounts()
logger.debug(f"After refresh, found {len(accounts)} accounts")
if accounts:
self._current_account = accounts[0]
logger.debug(f"Set current account after refresh: {self._current_account.get('username', 'no username')}")
return True
# Error handling
err = (result or {}).get("error_description") or (result or {}).get("error") or "Unknown error"
logger.error(f"Refresh token failed: {err}")
if any(code in err for code in ("AADSTS70000", "invalid_grant", "interaction_required")):
logger.warning(
"Refresh denied due to unauthorized/expired scopes or invalid grant. "
"Delete the token file and perform interactive sign-in with correct scopes."
)
return False
except Exception as e:
logger.error(f"Exception during refresh from JSON token: {e}")
import traceback
traceback.print_exc()
return False
async def save_cache(self):
"""Persist the token cache to file."""
try:
# Ensure parent directory exists
parent = os.path.dirname(os.path.abspath(self.token_file))
if parent and not os.path.exists(parent):
os.makedirs(parent, exist_ok=True)
cache_data = self.token_cache.serialize()
if cache_data:
async with aiofiles.open(self.token_file, "w") as f:
await f.write(cache_data)
logger.debug(f"Token cache saved to {self.token_file}")
except Exception as e:
logger.error(f"Failed to save token cache: {e}")
def create_authorization_url(self, redirect_uri: str, state: Optional[str] = None) -> str:
"""Create authorization URL for OAuth flow."""
# Store redirect URI for later use in callback
self._redirect_uri = redirect_uri
kwargs: Dict[str, Any] = {
# Interactive auth includes offline_access
"scopes": self.AUTH_SCOPES,
"redirect_uri": redirect_uri,
"prompt": "consent", # ensure refresh token on first run
}
if state:
kwargs["state"] = state # Optional CSRF protection
auth_url = self.app.get_authorization_request_url(**kwargs)
logger.debug(f"Generated auth URL: {auth_url}")
logger.debug(f"Auth scopes: {self.AUTH_SCOPES}")
return auth_url
async def handle_authorization_callback(
self, authorization_code: str, redirect_uri: str
) -> bool:
"""Handle OAuth callback and exchange code for tokens."""
try:
result = self.app.acquire_token_by_authorization_code(
authorization_code,
scopes=self.AUTH_SCOPES, # same as authorize step
redirect_uri=redirect_uri,
)
if result and "access_token" in result:
accounts = self.app.get_accounts()
if accounts:
self._current_account = accounts[0]
await self.save_cache()
logger.info("OneDrive OAuth authorization successful")
return True
error_msg = (result or {}).get("error_description") or (result or {}).get("error") or "Unknown error"
logger.error(f"OneDrive OAuth authorization failed: {error_msg}")
return False
except Exception as e:
logger.error(f"Exception during OneDrive OAuth authorization: {e}")
return False
async def is_authenticated(self) -> bool:
"""Check if we have valid credentials."""
try:
# First try to load credentials if we haven't already
if not self._current_account:
await self.load_credentials()
# Try to get a token (MSAL will refresh if needed)
if self._current_account:
result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=self._current_account)
if result and "access_token" in result:
return True
else:
error_msg = (result or {}).get("error") or "No result returned"
logger.debug(f"Token acquisition failed for current account: {error_msg}")
# Fallback: try without specific account
result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=None)
if result and "access_token" in result:
accounts = self.app.get_accounts()
if accounts:
self._current_account = accounts[0]
return True
return False
except Exception as e:
logger.error(f"Authentication check failed: {e}")
return False
def get_access_token(self) -> str:
"""Get an access token for Microsoft Graph."""
try:
# Try with current account first
if self._current_account:
result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=self._current_account)
if result and "access_token" in result:
return result["access_token"]
# Fallback: try without specific account
result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=None)
if result and "access_token" in result:
return result["access_token"]
# If we get here, authentication has failed
error_msg = (result or {}).get("error_description") or (result or {}).get("error") or "No valid authentication"
raise ValueError(f"Failed to acquire access token: {error_msg}")
except Exception as e:
logger.error(f"Failed to get access token: {e}")
raise
async def revoke_credentials(self):
"""Clear token cache and remove token file."""
try:
# Clear in-memory state
self._current_account = None
self.token_cache = msal.SerializableTokenCache()
# Recreate MSAL app with fresh cache
self.app = msal.ConfidentialClientApplication(
client_id=self.client_id,
client_credential=self.client_secret,
authority=self.authority,
token_cache=self.token_cache,
)
# Remove token file
if os.path.exists(self.token_file):
os.remove(self.token_file)
logger.info(f"Removed OneDrive token file: {self.token_file}")
except Exception as e:
logger.error(f"Failed to revoke OneDrive credentials: {e}")
def get_service(self) -> str:
"""Return an access token (Graph client is just the bearer)."""
return self.get_access_token()