openrag/src/connectors/onedrive/oauth.py
2025-09-17 11:28:25 -07:00

188 lines
6.4 KiB
Python

import os
import json
import aiofiles
from datetime import datetime
import httpx
class OneDriveOAuth:
"""Direct token management for OneDrive, bypassing MSAL cache format"""
SCOPES = [
"offline_access",
"Files.Read.All",
]
AUTH_ENDPOINT = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize"
TOKEN_ENDPOINT = "https://login.microsoftonline.com/common/oauth2/v2.0/token"
def __init__(
self,
client_id: str,
client_secret: str,
token_file: str = "onedrive_token.json",
):
self.client_id = client_id
self.client_secret = client_secret
self.token_file = token_file
self._tokens = None
self._load_tokens()
def _load_tokens(self):
"""Load tokens from file"""
if os.path.exists(self.token_file):
with open(self.token_file, "r") as f:
self._tokens = json.loads(f.read())
print(f"Loaded tokens from {self.token_file}")
else:
print(f"No token file found at {self.token_file}")
async def _save_tokens(self):
"""Save tokens to file"""
if self._tokens:
async with aiofiles.open(self.token_file, "w") as f:
await f.write(json.dumps(self._tokens, indent=2))
def _is_token_expired(self) -> bool:
"""Check if current access token is expired"""
if not self._tokens or 'expiry' not in self._tokens:
return True
expiry_str = self._tokens['expiry']
# Handle different expiry formats
try:
if expiry_str.endswith('Z'):
expiry_dt = datetime.fromisoformat(expiry_str[:-1])
else:
expiry_dt = datetime.fromisoformat(expiry_str)
# Add 5-minute buffer
import datetime as dt
now = datetime.now()
return now >= (expiry_dt - dt.timedelta(minutes=5))
except:
return True
async def _refresh_access_token(self) -> bool:
"""Refresh the access token using refresh token"""
if not self._tokens or 'refresh_token' not in self._tokens:
return False
data = {
'client_id': self.client_id,
'client_secret': self.client_secret,
'refresh_token': self._tokens['refresh_token'],
'grant_type': 'refresh_token',
'scope': ' '.join(self.SCOPES)
}
async with httpx.AsyncClient() as client:
try:
response = await client.post(self.TOKEN_ENDPOINT, data=data)
response.raise_for_status()
token_data = response.json()
# Update tokens
self._tokens['token'] = token_data['access_token']
if 'refresh_token' in token_data:
self._tokens['refresh_token'] = token_data['refresh_token']
# Calculate expiry
expires_in = token_data.get('expires_in', 3600)
import datetime as dt
expiry = datetime.now() + dt.timedelta(seconds=expires_in)
self._tokens['expiry'] = expiry.isoformat()
await self._save_tokens()
print("Access token refreshed successfully")
return True
except Exception as e:
print(f"Failed to refresh token: {e}")
return False
async def is_authenticated(self) -> bool:
"""Check if we have valid credentials"""
if not self._tokens:
return False
# If token is expired, try to refresh
if self._is_token_expired():
print("Token expired, attempting refresh...")
if await self._refresh_access_token():
return True
else:
return False
return True
def get_access_token(self) -> str:
"""Get current access token"""
if not self._tokens or 'token' not in self._tokens:
raise ValueError("No access token available")
if self._is_token_expired():
raise ValueError("Access token expired and refresh failed")
return self._tokens['token']
async def revoke_credentials(self):
"""Clear tokens"""
self._tokens = None
if os.path.exists(self.token_file):
os.remove(self.token_file)
# Keep these methods for compatibility with your existing OAuth flow
def create_authorization_url(self, redirect_uri: str) -> str:
"""Create authorization URL for OAuth flow"""
from urllib.parse import urlencode
params = {
'client_id': self.client_id,
'response_type': 'code',
'redirect_uri': redirect_uri,
'scope': ' '.join(self.SCOPES),
'response_mode': 'query'
}
auth_url = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize"
return f"{auth_url}?{urlencode(params)}"
async def handle_authorization_callback(
self, authorization_code: str, redirect_uri: str
) -> bool:
"""Handle OAuth callback and exchange code for tokens"""
data = {
'client_id': self.client_id,
'client_secret': self.client_secret,
'code': authorization_code,
'grant_type': 'authorization_code',
'redirect_uri': redirect_uri,
'scope': ' '.join(self.SCOPES)
}
async with httpx.AsyncClient() as client:
try:
response = await client.post(self.TOKEN_ENDPOINT, data=data)
response.raise_for_status()
token_data = response.json()
# Store tokens in our format
import datetime as dt
expires_in = token_data.get('expires_in', 3600)
expiry = datetime.now() + dt.timedelta(seconds=expires_in)
self._tokens = {
'token': token_data['access_token'],
'refresh_token': token_data['refresh_token'],
'scopes': self.SCOPES,
'expiry': expiry.isoformat()
}
await self._save_tokens()
print("Authorization successful, tokens saved")
return True
except Exception as e:
print(f"Authorization failed: {e}")
return False