feat: Sharepoint and OneDrive auth

This commit is contained in:
Eric Hare 2025-09-17 11:28:25 -07:00
parent 8dc737c124
commit 8dd6497eb4
4 changed files with 304 additions and 107 deletions

View file

@ -1,3 +1,4 @@
from pathlib import Path
import httpx import httpx
import uuid import uuid
from datetime import datetime, timedelta from datetime import datetime, timedelta
@ -20,10 +21,12 @@ class OneDriveConnector(BaseConnector):
def __init__(self, config: Dict[str, Any]): def __init__(self, config: Dict[str, Any]):
super().__init__(config) super().__init__(config)
project_root = Path(__file__).resolve().parent.parent.parent.parent
token_file = config.get("token_file") or str(project_root / "onedrive_token.json")
self.oauth = OneDriveOAuth( self.oauth = OneDriveOAuth(
client_id=self.get_client_id(), client_id=self.get_client_id(),
client_secret=self.get_client_secret(), client_secret=self.get_client_secret(),
token_file=config.get("token_file", "onedrive_token.json"), token_file=token_file,
) )
self.subscription_id = config.get("subscription_id") or config.get( self.subscription_id = config.get("subscription_id") or config.get(
"webhook_channel_id" "webhook_channel_id"

View file

@ -1,11 +1,12 @@
import os import os
import json
import aiofiles import aiofiles
from typing import Optional from datetime import datetime
import msal import httpx
class OneDriveOAuth: class OneDriveOAuth:
"""Handles Microsoft Graph OAuth authentication flow""" """Direct token management for OneDrive, bypassing MSAL cache format"""
SCOPES = [ SCOPES = [
"offline_access", "offline_access",
@ -20,76 +21,168 @@ class OneDriveOAuth:
client_id: str, client_id: str,
client_secret: str, client_secret: str,
token_file: str = "onedrive_token.json", token_file: str = "onedrive_token.json",
authority: str = "https://login.microsoftonline.com/common",
): ):
self.client_id = client_id self.client_id = client_id
self.client_secret = client_secret self.client_secret = client_secret
self.token_file = token_file self.token_file = token_file
self.authority = authority self._tokens = None
self.token_cache = msal.SerializableTokenCache() self._load_tokens()
# Load existing cache if available def _load_tokens(self):
"""Load tokens from file"""
if os.path.exists(self.token_file): if os.path.exists(self.token_file):
with open(self.token_file, "r") as f: with open(self.token_file, "r") as f:
self.token_cache.deserialize(f.read()) self._tokens = json.loads(f.read())
print(f"Loaded tokens from {self.token_file}")
else:
print(f"No token file found at {self.token_file}")
self.app = msal.ConfidentialClientApplication( async def _save_tokens(self):
client_id=self.client_id, """Save tokens to file"""
client_credential=self.client_secret, if self._tokens:
authority=self.authority, async with aiofiles.open(self.token_file, "w") as f:
token_cache=self.token_cache, await f.write(json.dumps(self._tokens, indent=2))
)
async def save_cache(self): def _is_token_expired(self) -> bool:
"""Persist the token cache to file""" """Check if current access token is expired"""
async with aiofiles.open(self.token_file, "w") as f: if not self._tokens or 'expiry' not in self._tokens:
await f.write(self.token_cache.serialize()) return True
expiry_str = self._tokens['expiry']
# Handle different expiry formats
try:
if expiry_str.endswith('Z'):
expiry_dt = datetime.fromisoformat(expiry_str[:-1])
else:
expiry_dt = datetime.fromisoformat(expiry_str)
# Add 5-minute buffer
import datetime as dt
now = datetime.now()
return now >= (expiry_dt - dt.timedelta(minutes=5))
except:
return True
async def _refresh_access_token(self) -> bool:
"""Refresh the access token using refresh token"""
if not self._tokens or 'refresh_token' not in self._tokens:
return False
data = {
'client_id': self.client_id,
'client_secret': self.client_secret,
'refresh_token': self._tokens['refresh_token'],
'grant_type': 'refresh_token',
'scope': ' '.join(self.SCOPES)
}
async with httpx.AsyncClient() as client:
try:
response = await client.post(self.TOKEN_ENDPOINT, data=data)
response.raise_for_status()
token_data = response.json()
# Update tokens
self._tokens['token'] = token_data['access_token']
if 'refresh_token' in token_data:
self._tokens['refresh_token'] = token_data['refresh_token']
# Calculate expiry
expires_in = token_data.get('expires_in', 3600)
import datetime as dt
expiry = datetime.now() + dt.timedelta(seconds=expires_in)
self._tokens['expiry'] = expiry.isoformat()
await self._save_tokens()
print("Access token refreshed successfully")
return True
except Exception as e:
print(f"Failed to refresh token: {e}")
return False
async def is_authenticated(self) -> bool:
"""Check if we have valid credentials"""
if not self._tokens:
return False
# If token is expired, try to refresh
if self._is_token_expired():
print("Token expired, attempting refresh...")
if await self._refresh_access_token():
return True
else:
return False
return True
def get_access_token(self) -> str:
"""Get current access token"""
if not self._tokens or 'token' not in self._tokens:
raise ValueError("No access token available")
if self._is_token_expired():
raise ValueError("Access token expired and refresh failed")
return self._tokens['token']
async def revoke_credentials(self):
"""Clear tokens"""
self._tokens = None
if os.path.exists(self.token_file):
os.remove(self.token_file)
# Keep these methods for compatibility with your existing OAuth flow
def create_authorization_url(self, redirect_uri: str) -> str: def create_authorization_url(self, redirect_uri: str) -> str:
"""Create authorization URL for OAuth flow""" """Create authorization URL for OAuth flow"""
return self.app.get_authorization_request_url( from urllib.parse import urlencode
self.SCOPES, redirect_uri=redirect_uri
) params = {
'client_id': self.client_id,
'response_type': 'code',
'redirect_uri': redirect_uri,
'scope': ' '.join(self.SCOPES),
'response_mode': 'query'
}
auth_url = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize"
return f"{auth_url}?{urlencode(params)}"
async def handle_authorization_callback( async def handle_authorization_callback(
self, authorization_code: str, redirect_uri: str self, authorization_code: str, redirect_uri: str
) -> bool: ) -> bool:
"""Handle OAuth callback and exchange code for tokens""" """Handle OAuth callback and exchange code for tokens"""
result = self.app.acquire_token_by_authorization_code( data = {
authorization_code, 'client_id': self.client_id,
scopes=self.SCOPES, 'client_secret': self.client_secret,
redirect_uri=redirect_uri, 'code': authorization_code,
) 'grant_type': 'authorization_code',
if "access_token" in result: 'redirect_uri': redirect_uri,
await self.save_cache() 'scope': ' '.join(self.SCOPES)
return True }
raise ValueError(result.get("error_description") or "Authorization failed")
async def is_authenticated(self) -> bool: async with httpx.AsyncClient() as client:
"""Check if we have valid credentials""" try:
accounts = self.app.get_accounts() response = await client.post(self.TOKEN_ENDPOINT, data=data)
if not accounts: response.raise_for_status()
return False token_data = response.json()
result = self.app.acquire_token_silent(self.SCOPES, account=accounts[0])
if "access_token" in result:
await self.save_cache()
return True
return False
def get_access_token(self) -> str: # Store tokens in our format
"""Get an access token for Microsoft Graph""" import datetime as dt
accounts = self.app.get_accounts() expires_in = token_data.get('expires_in', 3600)
if not accounts: expiry = datetime.now() + dt.timedelta(seconds=expires_in)
raise ValueError("Not authenticated")
result = self.app.acquire_token_silent(self.SCOPES, account=accounts[0])
if "access_token" not in result:
raise ValueError(
result.get("error_description") or "Failed to acquire access token"
)
return result["access_token"]
async def revoke_credentials(self): self._tokens = {
"""Clear token cache and remove token file""" 'token': token_data['access_token'],
self.token_cache.clear() 'refresh_token': token_data['refresh_token'],
if os.path.exists(self.token_file): 'scopes': self.SCOPES,
os.remove(self.token_file) 'expiry': expiry.isoformat()
}
await self._save_tokens()
print("Authorization successful, tokens saved")
return True
except Exception as e:
print(f"Authorization failed: {e}")
return False

View file

@ -1,3 +1,4 @@
from pathlib import Path
import httpx import httpx
import uuid import uuid
from datetime import datetime, timedelta from datetime import datetime, timedelta
@ -20,10 +21,12 @@ class SharePointConnector(BaseConnector):
def __init__(self, config: Dict[str, Any]): def __init__(self, config: Dict[str, Any]):
super().__init__(config) super().__init__(config)
project_root = Path(__file__).resolve().parent.parent.parent.parent
token_file = config.get("token_file") or str(project_root / "onedrive_token.json")
self.oauth = SharePointOAuth( self.oauth = SharePointOAuth(
client_id=self.get_client_id(), client_id=self.get_client_id(),
client_secret=self.get_client_secret(), client_secret=self.get_client_secret(),
token_file=config.get("token_file", "sharepoint_token.json"), token_file=token_file,
) )
self.subscription_id = config.get("subscription_id") or config.get( self.subscription_id = config.get("subscription_id") or config.get(
"webhook_channel_id" "webhook_channel_id"

View file

@ -1,11 +1,12 @@
import os import os
import json
import aiofiles import aiofiles
from typing import Optional from datetime import datetime
import msal import httpx
class SharePointOAuth: class SharePointOAuth:
"""Handles Microsoft Graph OAuth authentication flow""" """Direct token management for SharePoint, bypassing MSAL cache format"""
SCOPES = [ SCOPES = [
"offline_access", "offline_access",
@ -21,76 +22,173 @@ class SharePointOAuth:
client_id: str, client_id: str,
client_secret: str, client_secret: str,
token_file: str = "sharepoint_token.json", token_file: str = "sharepoint_token.json",
authority: str = "https://login.microsoftonline.com/common", authority: str = "https://login.microsoftonline.com/common", # Keep for compatibility
): ):
self.client_id = client_id self.client_id = client_id
self.client_secret = client_secret self.client_secret = client_secret
self.token_file = token_file self.token_file = token_file
self.authority = authority self.authority = authority # Keep for compatibility but not used
self.token_cache = msal.SerializableTokenCache() self._tokens = None
self._load_tokens()
# Load existing cache if available def _load_tokens(self):
"""Load tokens from file"""
if os.path.exists(self.token_file): if os.path.exists(self.token_file):
with open(self.token_file, "r") as f: with open(self.token_file, "r") as f:
self.token_cache.deserialize(f.read()) self._tokens = json.loads(f.read())
print(f"Loaded tokens from {self.token_file}")
self.app = msal.ConfidentialClientApplication( else:
client_id=self.client_id, print(f"No token file found at {self.token_file}")
client_credential=self.client_secret,
authority=self.authority,
token_cache=self.token_cache,
)
async def save_cache(self): async def save_cache(self):
"""Persist the token cache to file""" """Persist tokens to file (renamed for compatibility)"""
async with aiofiles.open(self.token_file, "w") as f: await self._save_tokens()
await f.write(self.token_cache.serialize())
async def _save_tokens(self):
"""Save tokens to file"""
if self._tokens:
async with aiofiles.open(self.token_file, "w") as f:
await f.write(json.dumps(self._tokens, indent=2))
def _is_token_expired(self) -> bool:
"""Check if current access token is expired"""
if not self._tokens or 'expiry' not in self._tokens:
return True
expiry_str = self._tokens['expiry']
# Handle different expiry formats
try:
if expiry_str.endswith('Z'):
expiry_dt = datetime.fromisoformat(expiry_str[:-1])
else:
expiry_dt = datetime.fromisoformat(expiry_str)
# Add 5-minute buffer
import datetime as dt
now = datetime.now()
return now >= (expiry_dt - dt.timedelta(minutes=5))
except:
return True
async def _refresh_access_token(self) -> bool:
"""Refresh the access token using refresh token"""
if not self._tokens or 'refresh_token' not in self._tokens:
return False
data = {
'client_id': self.client_id,
'client_secret': self.client_secret,
'refresh_token': self._tokens['refresh_token'],
'grant_type': 'refresh_token',
'scope': ' '.join(self.SCOPES)
}
async with httpx.AsyncClient() as client:
try:
response = await client.post(self.TOKEN_ENDPOINT, data=data)
response.raise_for_status()
token_data = response.json()
# Update tokens
self._tokens['token'] = token_data['access_token']
if 'refresh_token' in token_data:
self._tokens['refresh_token'] = token_data['refresh_token']
# Calculate expiry
expires_in = token_data.get('expires_in', 3600)
import datetime as dt
expiry = datetime.now() + dt.timedelta(seconds=expires_in)
self._tokens['expiry'] = expiry.isoformat()
await self._save_tokens()
print("Access token refreshed successfully")
return True
except Exception as e:
print(f"Failed to refresh token: {e}")
return False
def create_authorization_url(self, redirect_uri: str) -> str: def create_authorization_url(self, redirect_uri: str) -> str:
"""Create authorization URL for OAuth flow""" """Create authorization URL for OAuth flow"""
return self.app.get_authorization_request_url( from urllib.parse import urlencode
self.SCOPES, redirect_uri=redirect_uri
) params = {
'client_id': self.client_id,
'response_type': 'code',
'redirect_uri': redirect_uri,
'scope': ' '.join(self.SCOPES),
'response_mode': 'query'
}
auth_url = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize"
return f"{auth_url}?{urlencode(params)}"
async def handle_authorization_callback( async def handle_authorization_callback(
self, authorization_code: str, redirect_uri: str self, authorization_code: str, redirect_uri: str
) -> bool: ) -> bool:
"""Handle OAuth callback and exchange code for tokens""" """Handle OAuth callback and exchange code for tokens"""
result = self.app.acquire_token_by_authorization_code( data = {
authorization_code, 'client_id': self.client_id,
scopes=self.SCOPES, 'client_secret': self.client_secret,
redirect_uri=redirect_uri, 'code': authorization_code,
) 'grant_type': 'authorization_code',
if "access_token" in result: 'redirect_uri': redirect_uri,
await self.save_cache() 'scope': ' '.join(self.SCOPES)
return True }
raise ValueError(result.get("error_description") or "Authorization failed")
async with httpx.AsyncClient() as client:
try:
response = await client.post(self.TOKEN_ENDPOINT, data=data)
response.raise_for_status()
token_data = response.json()
# Store tokens in our format
import datetime as dt
expires_in = token_data.get('expires_in', 3600)
expiry = datetime.now() + dt.timedelta(seconds=expires_in)
self._tokens = {
'token': token_data['access_token'],
'refresh_token': token_data['refresh_token'],
'scopes': self.SCOPES,
'expiry': expiry.isoformat()
}
await self._save_tokens()
print("Authorization successful, tokens saved")
return True
except Exception as e:
print(f"Authorization failed: {e}")
return False
async def is_authenticated(self) -> bool: async def is_authenticated(self) -> bool:
"""Check if we have valid credentials""" """Check if we have valid credentials"""
accounts = self.app.get_accounts() if not self._tokens:
if not accounts:
return False return False
result = self.app.acquire_token_silent(self.SCOPES, account=accounts[0])
if "access_token" in result: # If token is expired, try to refresh
await self.save_cache() if self._is_token_expired():
return True print("Token expired, attempting refresh...")
return False if await self._refresh_access_token():
return True
else:
return False
return True
def get_access_token(self) -> str: def get_access_token(self) -> str:
"""Get an access token for Microsoft Graph""" """Get current access token"""
accounts = self.app.get_accounts() if not self._tokens or 'token' not in self._tokens:
if not accounts: raise ValueError("No access token available")
raise ValueError("Not authenticated")
result = self.app.acquire_token_silent(self.SCOPES, account=accounts[0]) if self._is_token_expired():
if "access_token" not in result: raise ValueError("Access token expired and refresh failed")
raise ValueError(
result.get("error_description") or "Failed to acquire access token" return self._tokens['token']
)
return result["access_token"]
async def revoke_credentials(self): async def revoke_credentials(self):
"""Clear token cache and remove token file""" """Clear tokens"""
self.token_cache.clear() self._tokens = None
if os.path.exists(self.token_file): if os.path.exists(self.token_file):
os.remove(self.token_file) os.remove(self.token_file)