feat: Sharepoint and OneDrive auth

This commit is contained in:
Eric Hare 2025-09-17 11:28:25 -07:00
parent 8dc737c124
commit 8dd6497eb4
4 changed files with 304 additions and 107 deletions

View file

@ -1,3 +1,4 @@
from pathlib import Path
import httpx
import uuid
from datetime import datetime, timedelta
@ -20,10 +21,12 @@ class OneDriveConnector(BaseConnector):
def __init__(self, config: Dict[str, Any]):
super().__init__(config)
project_root = Path(__file__).resolve().parent.parent.parent.parent
token_file = config.get("token_file") or str(project_root / "onedrive_token.json")
self.oauth = OneDriveOAuth(
client_id=self.get_client_id(),
client_secret=self.get_client_secret(),
token_file=config.get("token_file", "onedrive_token.json"),
token_file=token_file,
)
self.subscription_id = config.get("subscription_id") or config.get(
"webhook_channel_id"

View file

@ -1,11 +1,12 @@
import os
import json
import aiofiles
from typing import Optional
import msal
from datetime import datetime
import httpx
class OneDriveOAuth:
"""Handles Microsoft Graph OAuth authentication flow"""
"""Direct token management for OneDrive, bypassing MSAL cache format"""
SCOPES = [
"offline_access",
@ -20,76 +21,168 @@ class OneDriveOAuth:
client_id: str,
client_secret: str,
token_file: str = "onedrive_token.json",
authority: str = "https://login.microsoftonline.com/common",
):
self.client_id = client_id
self.client_secret = client_secret
self.token_file = token_file
self.authority = authority
self.token_cache = msal.SerializableTokenCache()
self._tokens = None
self._load_tokens()
# Load existing cache if available
def _load_tokens(self):
"""Load tokens from file"""
if os.path.exists(self.token_file):
with open(self.token_file, "r") as f:
self.token_cache.deserialize(f.read())
self._tokens = json.loads(f.read())
print(f"Loaded tokens from {self.token_file}")
else:
print(f"No token file found at {self.token_file}")
self.app = msal.ConfidentialClientApplication(
client_id=self.client_id,
client_credential=self.client_secret,
authority=self.authority,
token_cache=self.token_cache,
)
async def _save_tokens(self):
"""Save tokens to file"""
if self._tokens:
async with aiofiles.open(self.token_file, "w") as f:
await f.write(json.dumps(self._tokens, indent=2))
async def save_cache(self):
"""Persist the token cache to file"""
async with aiofiles.open(self.token_file, "w") as f:
await f.write(self.token_cache.serialize())
def _is_token_expired(self) -> bool:
"""Check if current access token is expired"""
if not self._tokens or 'expiry' not in self._tokens:
return True
expiry_str = self._tokens['expiry']
# Handle different expiry formats
try:
if expiry_str.endswith('Z'):
expiry_dt = datetime.fromisoformat(expiry_str[:-1])
else:
expiry_dt = datetime.fromisoformat(expiry_str)
# Add 5-minute buffer
import datetime as dt
now = datetime.now()
return now >= (expiry_dt - dt.timedelta(minutes=5))
except:
return True
async def _refresh_access_token(self) -> bool:
"""Refresh the access token using refresh token"""
if not self._tokens or 'refresh_token' not in self._tokens:
return False
data = {
'client_id': self.client_id,
'client_secret': self.client_secret,
'refresh_token': self._tokens['refresh_token'],
'grant_type': 'refresh_token',
'scope': ' '.join(self.SCOPES)
}
async with httpx.AsyncClient() as client:
try:
response = await client.post(self.TOKEN_ENDPOINT, data=data)
response.raise_for_status()
token_data = response.json()
# Update tokens
self._tokens['token'] = token_data['access_token']
if 'refresh_token' in token_data:
self._tokens['refresh_token'] = token_data['refresh_token']
# Calculate expiry
expires_in = token_data.get('expires_in', 3600)
import datetime as dt
expiry = datetime.now() + dt.timedelta(seconds=expires_in)
self._tokens['expiry'] = expiry.isoformat()
await self._save_tokens()
print("Access token refreshed successfully")
return True
except Exception as e:
print(f"Failed to refresh token: {e}")
return False
async def is_authenticated(self) -> bool:
"""Check if we have valid credentials"""
if not self._tokens:
return False
# If token is expired, try to refresh
if self._is_token_expired():
print("Token expired, attempting refresh...")
if await self._refresh_access_token():
return True
else:
return False
return True
def get_access_token(self) -> str:
"""Get current access token"""
if not self._tokens or 'token' not in self._tokens:
raise ValueError("No access token available")
if self._is_token_expired():
raise ValueError("Access token expired and refresh failed")
return self._tokens['token']
async def revoke_credentials(self):
"""Clear tokens"""
self._tokens = None
if os.path.exists(self.token_file):
os.remove(self.token_file)
# Keep these methods for compatibility with your existing OAuth flow
def create_authorization_url(self, redirect_uri: str) -> str:
"""Create authorization URL for OAuth flow"""
return self.app.get_authorization_request_url(
self.SCOPES, redirect_uri=redirect_uri
)
from urllib.parse import urlencode
params = {
'client_id': self.client_id,
'response_type': 'code',
'redirect_uri': redirect_uri,
'scope': ' '.join(self.SCOPES),
'response_mode': 'query'
}
auth_url = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize"
return f"{auth_url}?{urlencode(params)}"
async def handle_authorization_callback(
self, authorization_code: str, redirect_uri: str
) -> bool:
"""Handle OAuth callback and exchange code for tokens"""
result = self.app.acquire_token_by_authorization_code(
authorization_code,
scopes=self.SCOPES,
redirect_uri=redirect_uri,
)
if "access_token" in result:
await self.save_cache()
return True
raise ValueError(result.get("error_description") or "Authorization failed")
data = {
'client_id': self.client_id,
'client_secret': self.client_secret,
'code': authorization_code,
'grant_type': 'authorization_code',
'redirect_uri': redirect_uri,
'scope': ' '.join(self.SCOPES)
}
async def is_authenticated(self) -> bool:
"""Check if we have valid credentials"""
accounts = self.app.get_accounts()
if not accounts:
return False
result = self.app.acquire_token_silent(self.SCOPES, account=accounts[0])
if "access_token" in result:
await self.save_cache()
return True
return False
async with httpx.AsyncClient() as client:
try:
response = await client.post(self.TOKEN_ENDPOINT, data=data)
response.raise_for_status()
token_data = response.json()
def get_access_token(self) -> str:
"""Get an access token for Microsoft Graph"""
accounts = self.app.get_accounts()
if not accounts:
raise ValueError("Not authenticated")
result = self.app.acquire_token_silent(self.SCOPES, account=accounts[0])
if "access_token" not in result:
raise ValueError(
result.get("error_description") or "Failed to acquire access token"
)
return result["access_token"]
# Store tokens in our format
import datetime as dt
expires_in = token_data.get('expires_in', 3600)
expiry = datetime.now() + dt.timedelta(seconds=expires_in)
self._tokens = {
'token': token_data['access_token'],
'refresh_token': token_data['refresh_token'],
'scopes': self.SCOPES,
'expiry': expiry.isoformat()
}
async def revoke_credentials(self):
"""Clear token cache and remove token file"""
self.token_cache.clear()
if os.path.exists(self.token_file):
os.remove(self.token_file)
await self._save_tokens()
print("Authorization successful, tokens saved")
return True
except Exception as e:
print(f"Authorization failed: {e}")
return False

View file

@ -1,3 +1,4 @@
from pathlib import Path
import httpx
import uuid
from datetime import datetime, timedelta
@ -20,10 +21,12 @@ class SharePointConnector(BaseConnector):
def __init__(self, config: Dict[str, Any]):
super().__init__(config)
project_root = Path(__file__).resolve().parent.parent.parent.parent
token_file = config.get("token_file") or str(project_root / "onedrive_token.json")
self.oauth = SharePointOAuth(
client_id=self.get_client_id(),
client_secret=self.get_client_secret(),
token_file=config.get("token_file", "sharepoint_token.json"),
token_file=token_file,
)
self.subscription_id = config.get("subscription_id") or config.get(
"webhook_channel_id"

View file

@ -1,11 +1,12 @@
import os
import json
import aiofiles
from typing import Optional
import msal
from datetime import datetime
import httpx
class SharePointOAuth:
"""Handles Microsoft Graph OAuth authentication flow"""
"""Direct token management for SharePoint, bypassing MSAL cache format"""
SCOPES = [
"offline_access",
@ -21,76 +22,173 @@ class SharePointOAuth:
client_id: str,
client_secret: str,
token_file: str = "sharepoint_token.json",
authority: str = "https://login.microsoftonline.com/common",
authority: str = "https://login.microsoftonline.com/common", # Keep for compatibility
):
self.client_id = client_id
self.client_secret = client_secret
self.token_file = token_file
self.authority = authority
self.token_cache = msal.SerializableTokenCache()
self.authority = authority # Keep for compatibility but not used
self._tokens = None
self._load_tokens()
# Load existing cache if available
def _load_tokens(self):
"""Load tokens from file"""
if os.path.exists(self.token_file):
with open(self.token_file, "r") as f:
self.token_cache.deserialize(f.read())
self.app = msal.ConfidentialClientApplication(
client_id=self.client_id,
client_credential=self.client_secret,
authority=self.authority,
token_cache=self.token_cache,
)
self._tokens = json.loads(f.read())
print(f"Loaded tokens from {self.token_file}")
else:
print(f"No token file found at {self.token_file}")
async def save_cache(self):
"""Persist the token cache to file"""
async with aiofiles.open(self.token_file, "w") as f:
await f.write(self.token_cache.serialize())
"""Persist tokens to file (renamed for compatibility)"""
await self._save_tokens()
async def _save_tokens(self):
"""Save tokens to file"""
if self._tokens:
async with aiofiles.open(self.token_file, "w") as f:
await f.write(json.dumps(self._tokens, indent=2))
def _is_token_expired(self) -> bool:
"""Check if current access token is expired"""
if not self._tokens or 'expiry' not in self._tokens:
return True
expiry_str = self._tokens['expiry']
# Handle different expiry formats
try:
if expiry_str.endswith('Z'):
expiry_dt = datetime.fromisoformat(expiry_str[:-1])
else:
expiry_dt = datetime.fromisoformat(expiry_str)
# Add 5-minute buffer
import datetime as dt
now = datetime.now()
return now >= (expiry_dt - dt.timedelta(minutes=5))
except:
return True
async def _refresh_access_token(self) -> bool:
"""Refresh the access token using refresh token"""
if not self._tokens or 'refresh_token' not in self._tokens:
return False
data = {
'client_id': self.client_id,
'client_secret': self.client_secret,
'refresh_token': self._tokens['refresh_token'],
'grant_type': 'refresh_token',
'scope': ' '.join(self.SCOPES)
}
async with httpx.AsyncClient() as client:
try:
response = await client.post(self.TOKEN_ENDPOINT, data=data)
response.raise_for_status()
token_data = response.json()
# Update tokens
self._tokens['token'] = token_data['access_token']
if 'refresh_token' in token_data:
self._tokens['refresh_token'] = token_data['refresh_token']
# Calculate expiry
expires_in = token_data.get('expires_in', 3600)
import datetime as dt
expiry = datetime.now() + dt.timedelta(seconds=expires_in)
self._tokens['expiry'] = expiry.isoformat()
await self._save_tokens()
print("Access token refreshed successfully")
return True
except Exception as e:
print(f"Failed to refresh token: {e}")
return False
def create_authorization_url(self, redirect_uri: str) -> str:
"""Create authorization URL for OAuth flow"""
return self.app.get_authorization_request_url(
self.SCOPES, redirect_uri=redirect_uri
)
from urllib.parse import urlencode
params = {
'client_id': self.client_id,
'response_type': 'code',
'redirect_uri': redirect_uri,
'scope': ' '.join(self.SCOPES),
'response_mode': 'query'
}
auth_url = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize"
return f"{auth_url}?{urlencode(params)}"
async def handle_authorization_callback(
self, authorization_code: str, redirect_uri: str
) -> bool:
"""Handle OAuth callback and exchange code for tokens"""
result = self.app.acquire_token_by_authorization_code(
authorization_code,
scopes=self.SCOPES,
redirect_uri=redirect_uri,
)
if "access_token" in result:
await self.save_cache()
return True
raise ValueError(result.get("error_description") or "Authorization failed")
data = {
'client_id': self.client_id,
'client_secret': self.client_secret,
'code': authorization_code,
'grant_type': 'authorization_code',
'redirect_uri': redirect_uri,
'scope': ' '.join(self.SCOPES)
}
async with httpx.AsyncClient() as client:
try:
response = await client.post(self.TOKEN_ENDPOINT, data=data)
response.raise_for_status()
token_data = response.json()
# Store tokens in our format
import datetime as dt
expires_in = token_data.get('expires_in', 3600)
expiry = datetime.now() + dt.timedelta(seconds=expires_in)
self._tokens = {
'token': token_data['access_token'],
'refresh_token': token_data['refresh_token'],
'scopes': self.SCOPES,
'expiry': expiry.isoformat()
}
await self._save_tokens()
print("Authorization successful, tokens saved")
return True
except Exception as e:
print(f"Authorization failed: {e}")
return False
async def is_authenticated(self) -> bool:
"""Check if we have valid credentials"""
accounts = self.app.get_accounts()
if not accounts:
if not self._tokens:
return False
result = self.app.acquire_token_silent(self.SCOPES, account=accounts[0])
if "access_token" in result:
await self.save_cache()
return True
return False
# If token is expired, try to refresh
if self._is_token_expired():
print("Token expired, attempting refresh...")
if await self._refresh_access_token():
return True
else:
return False
return True
def get_access_token(self) -> str:
"""Get an access token for Microsoft Graph"""
accounts = self.app.get_accounts()
if not accounts:
raise ValueError("Not authenticated")
result = self.app.acquire_token_silent(self.SCOPES, account=accounts[0])
if "access_token" not in result:
raise ValueError(
result.get("error_description") or "Failed to acquire access token"
)
return result["access_token"]
"""Get current access token"""
if not self._tokens or 'token' not in self._tokens:
raise ValueError("No access token available")
if self._is_token_expired():
raise ValueError("Access token expired and refresh failed")
return self._tokens['token']
async def revoke_credentials(self):
"""Clear token cache and remove token file"""
self.token_cache.clear()
"""Clear tokens"""
self._tokens = None
if os.path.exists(self.token_file):
os.remove(self.token_file)