feat: Sharepoint and OneDrive auth
This commit is contained in:
parent
8dc737c124
commit
8dd6497eb4
4 changed files with 304 additions and 107 deletions
|
|
@ -1,3 +1,4 @@
|
|||
from pathlib import Path
|
||||
import httpx
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
|
|
@ -20,10 +21,12 @@ class OneDriveConnector(BaseConnector):
|
|||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
super().__init__(config)
|
||||
project_root = Path(__file__).resolve().parent.parent.parent.parent
|
||||
token_file = config.get("token_file") or str(project_root / "onedrive_token.json")
|
||||
self.oauth = OneDriveOAuth(
|
||||
client_id=self.get_client_id(),
|
||||
client_secret=self.get_client_secret(),
|
||||
token_file=config.get("token_file", "onedrive_token.json"),
|
||||
token_file=token_file,
|
||||
)
|
||||
self.subscription_id = config.get("subscription_id") or config.get(
|
||||
"webhook_channel_id"
|
||||
|
|
|
|||
|
|
@ -1,11 +1,12 @@
|
|||
import os
|
||||
import json
|
||||
import aiofiles
|
||||
from typing import Optional
|
||||
import msal
|
||||
from datetime import datetime
|
||||
import httpx
|
||||
|
||||
|
||||
class OneDriveOAuth:
|
||||
"""Handles Microsoft Graph OAuth authentication flow"""
|
||||
"""Direct token management for OneDrive, bypassing MSAL cache format"""
|
||||
|
||||
SCOPES = [
|
||||
"offline_access",
|
||||
|
|
@ -20,76 +21,168 @@ class OneDriveOAuth:
|
|||
client_id: str,
|
||||
client_secret: str,
|
||||
token_file: str = "onedrive_token.json",
|
||||
authority: str = "https://login.microsoftonline.com/common",
|
||||
):
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
self.token_file = token_file
|
||||
self.authority = authority
|
||||
self.token_cache = msal.SerializableTokenCache()
|
||||
self._tokens = None
|
||||
self._load_tokens()
|
||||
|
||||
# Load existing cache if available
|
||||
def _load_tokens(self):
|
||||
"""Load tokens from file"""
|
||||
if os.path.exists(self.token_file):
|
||||
with open(self.token_file, "r") as f:
|
||||
self.token_cache.deserialize(f.read())
|
||||
self._tokens = json.loads(f.read())
|
||||
print(f"Loaded tokens from {self.token_file}")
|
||||
else:
|
||||
print(f"No token file found at {self.token_file}")
|
||||
|
||||
self.app = msal.ConfidentialClientApplication(
|
||||
client_id=self.client_id,
|
||||
client_credential=self.client_secret,
|
||||
authority=self.authority,
|
||||
token_cache=self.token_cache,
|
||||
)
|
||||
async def _save_tokens(self):
|
||||
"""Save tokens to file"""
|
||||
if self._tokens:
|
||||
async with aiofiles.open(self.token_file, "w") as f:
|
||||
await f.write(json.dumps(self._tokens, indent=2))
|
||||
|
||||
async def save_cache(self):
|
||||
"""Persist the token cache to file"""
|
||||
async with aiofiles.open(self.token_file, "w") as f:
|
||||
await f.write(self.token_cache.serialize())
|
||||
def _is_token_expired(self) -> bool:
|
||||
"""Check if current access token is expired"""
|
||||
if not self._tokens or 'expiry' not in self._tokens:
|
||||
return True
|
||||
|
||||
expiry_str = self._tokens['expiry']
|
||||
# Handle different expiry formats
|
||||
try:
|
||||
if expiry_str.endswith('Z'):
|
||||
expiry_dt = datetime.fromisoformat(expiry_str[:-1])
|
||||
else:
|
||||
expiry_dt = datetime.fromisoformat(expiry_str)
|
||||
|
||||
# Add 5-minute buffer
|
||||
import datetime as dt
|
||||
now = datetime.now()
|
||||
return now >= (expiry_dt - dt.timedelta(minutes=5))
|
||||
except:
|
||||
return True
|
||||
|
||||
async def _refresh_access_token(self) -> bool:
|
||||
"""Refresh the access token using refresh token"""
|
||||
if not self._tokens or 'refresh_token' not in self._tokens:
|
||||
return False
|
||||
|
||||
data = {
|
||||
'client_id': self.client_id,
|
||||
'client_secret': self.client_secret,
|
||||
'refresh_token': self._tokens['refresh_token'],
|
||||
'grant_type': 'refresh_token',
|
||||
'scope': ' '.join(self.SCOPES)
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
try:
|
||||
response = await client.post(self.TOKEN_ENDPOINT, data=data)
|
||||
response.raise_for_status()
|
||||
token_data = response.json()
|
||||
|
||||
# Update tokens
|
||||
self._tokens['token'] = token_data['access_token']
|
||||
if 'refresh_token' in token_data:
|
||||
self._tokens['refresh_token'] = token_data['refresh_token']
|
||||
|
||||
# Calculate expiry
|
||||
expires_in = token_data.get('expires_in', 3600)
|
||||
import datetime as dt
|
||||
expiry = datetime.now() + dt.timedelta(seconds=expires_in)
|
||||
self._tokens['expiry'] = expiry.isoformat()
|
||||
|
||||
await self._save_tokens()
|
||||
print("Access token refreshed successfully")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to refresh token: {e}")
|
||||
return False
|
||||
|
||||
async def is_authenticated(self) -> bool:
|
||||
"""Check if we have valid credentials"""
|
||||
if not self._tokens:
|
||||
return False
|
||||
|
||||
# If token is expired, try to refresh
|
||||
if self._is_token_expired():
|
||||
print("Token expired, attempting refresh...")
|
||||
if await self._refresh_access_token():
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def get_access_token(self) -> str:
|
||||
"""Get current access token"""
|
||||
if not self._tokens or 'token' not in self._tokens:
|
||||
raise ValueError("No access token available")
|
||||
|
||||
if self._is_token_expired():
|
||||
raise ValueError("Access token expired and refresh failed")
|
||||
|
||||
return self._tokens['token']
|
||||
|
||||
async def revoke_credentials(self):
|
||||
"""Clear tokens"""
|
||||
self._tokens = None
|
||||
if os.path.exists(self.token_file):
|
||||
os.remove(self.token_file)
|
||||
|
||||
# Keep these methods for compatibility with your existing OAuth flow
|
||||
def create_authorization_url(self, redirect_uri: str) -> str:
|
||||
"""Create authorization URL for OAuth flow"""
|
||||
return self.app.get_authorization_request_url(
|
||||
self.SCOPES, redirect_uri=redirect_uri
|
||||
)
|
||||
from urllib.parse import urlencode
|
||||
|
||||
params = {
|
||||
'client_id': self.client_id,
|
||||
'response_type': 'code',
|
||||
'redirect_uri': redirect_uri,
|
||||
'scope': ' '.join(self.SCOPES),
|
||||
'response_mode': 'query'
|
||||
}
|
||||
|
||||
auth_url = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize"
|
||||
return f"{auth_url}?{urlencode(params)}"
|
||||
|
||||
async def handle_authorization_callback(
|
||||
self, authorization_code: str, redirect_uri: str
|
||||
) -> bool:
|
||||
"""Handle OAuth callback and exchange code for tokens"""
|
||||
result = self.app.acquire_token_by_authorization_code(
|
||||
authorization_code,
|
||||
scopes=self.SCOPES,
|
||||
redirect_uri=redirect_uri,
|
||||
)
|
||||
if "access_token" in result:
|
||||
await self.save_cache()
|
||||
return True
|
||||
raise ValueError(result.get("error_description") or "Authorization failed")
|
||||
data = {
|
||||
'client_id': self.client_id,
|
||||
'client_secret': self.client_secret,
|
||||
'code': authorization_code,
|
||||
'grant_type': 'authorization_code',
|
||||
'redirect_uri': redirect_uri,
|
||||
'scope': ' '.join(self.SCOPES)
|
||||
}
|
||||
|
||||
async def is_authenticated(self) -> bool:
|
||||
"""Check if we have valid credentials"""
|
||||
accounts = self.app.get_accounts()
|
||||
if not accounts:
|
||||
return False
|
||||
result = self.app.acquire_token_silent(self.SCOPES, account=accounts[0])
|
||||
if "access_token" in result:
|
||||
await self.save_cache()
|
||||
return True
|
||||
return False
|
||||
async with httpx.AsyncClient() as client:
|
||||
try:
|
||||
response = await client.post(self.TOKEN_ENDPOINT, data=data)
|
||||
response.raise_for_status()
|
||||
token_data = response.json()
|
||||
|
||||
def get_access_token(self) -> str:
|
||||
"""Get an access token for Microsoft Graph"""
|
||||
accounts = self.app.get_accounts()
|
||||
if not accounts:
|
||||
raise ValueError("Not authenticated")
|
||||
result = self.app.acquire_token_silent(self.SCOPES, account=accounts[0])
|
||||
if "access_token" not in result:
|
||||
raise ValueError(
|
||||
result.get("error_description") or "Failed to acquire access token"
|
||||
)
|
||||
return result["access_token"]
|
||||
# Store tokens in our format
|
||||
import datetime as dt
|
||||
expires_in = token_data.get('expires_in', 3600)
|
||||
expiry = datetime.now() + dt.timedelta(seconds=expires_in)
|
||||
|
||||
self._tokens = {
|
||||
'token': token_data['access_token'],
|
||||
'refresh_token': token_data['refresh_token'],
|
||||
'scopes': self.SCOPES,
|
||||
'expiry': expiry.isoformat()
|
||||
}
|
||||
|
||||
async def revoke_credentials(self):
|
||||
"""Clear token cache and remove token file"""
|
||||
self.token_cache.clear()
|
||||
if os.path.exists(self.token_file):
|
||||
os.remove(self.token_file)
|
||||
await self._save_tokens()
|
||||
print("Authorization successful, tokens saved")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"Authorization failed: {e}")
|
||||
return False
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
from pathlib import Path
|
||||
import httpx
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
|
|
@ -20,10 +21,12 @@ class SharePointConnector(BaseConnector):
|
|||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
super().__init__(config)
|
||||
project_root = Path(__file__).resolve().parent.parent.parent.parent
|
||||
token_file = config.get("token_file") or str(project_root / "onedrive_token.json")
|
||||
self.oauth = SharePointOAuth(
|
||||
client_id=self.get_client_id(),
|
||||
client_secret=self.get_client_secret(),
|
||||
token_file=config.get("token_file", "sharepoint_token.json"),
|
||||
token_file=token_file,
|
||||
)
|
||||
self.subscription_id = config.get("subscription_id") or config.get(
|
||||
"webhook_channel_id"
|
||||
|
|
|
|||
|
|
@ -1,11 +1,12 @@
|
|||
import os
|
||||
import json
|
||||
import aiofiles
|
||||
from typing import Optional
|
||||
import msal
|
||||
from datetime import datetime
|
||||
import httpx
|
||||
|
||||
|
||||
class SharePointOAuth:
|
||||
"""Handles Microsoft Graph OAuth authentication flow"""
|
||||
"""Direct token management for SharePoint, bypassing MSAL cache format"""
|
||||
|
||||
SCOPES = [
|
||||
"offline_access",
|
||||
|
|
@ -21,76 +22,173 @@ class SharePointOAuth:
|
|||
client_id: str,
|
||||
client_secret: str,
|
||||
token_file: str = "sharepoint_token.json",
|
||||
authority: str = "https://login.microsoftonline.com/common",
|
||||
authority: str = "https://login.microsoftonline.com/common", # Keep for compatibility
|
||||
):
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
self.token_file = token_file
|
||||
self.authority = authority
|
||||
self.token_cache = msal.SerializableTokenCache()
|
||||
self.authority = authority # Keep for compatibility but not used
|
||||
self._tokens = None
|
||||
self._load_tokens()
|
||||
|
||||
# Load existing cache if available
|
||||
def _load_tokens(self):
|
||||
"""Load tokens from file"""
|
||||
if os.path.exists(self.token_file):
|
||||
with open(self.token_file, "r") as f:
|
||||
self.token_cache.deserialize(f.read())
|
||||
|
||||
self.app = msal.ConfidentialClientApplication(
|
||||
client_id=self.client_id,
|
||||
client_credential=self.client_secret,
|
||||
authority=self.authority,
|
||||
token_cache=self.token_cache,
|
||||
)
|
||||
self._tokens = json.loads(f.read())
|
||||
print(f"Loaded tokens from {self.token_file}")
|
||||
else:
|
||||
print(f"No token file found at {self.token_file}")
|
||||
|
||||
async def save_cache(self):
|
||||
"""Persist the token cache to file"""
|
||||
async with aiofiles.open(self.token_file, "w") as f:
|
||||
await f.write(self.token_cache.serialize())
|
||||
"""Persist tokens to file (renamed for compatibility)"""
|
||||
await self._save_tokens()
|
||||
|
||||
async def _save_tokens(self):
|
||||
"""Save tokens to file"""
|
||||
if self._tokens:
|
||||
async with aiofiles.open(self.token_file, "w") as f:
|
||||
await f.write(json.dumps(self._tokens, indent=2))
|
||||
|
||||
def _is_token_expired(self) -> bool:
|
||||
"""Check if current access token is expired"""
|
||||
if not self._tokens or 'expiry' not in self._tokens:
|
||||
return True
|
||||
|
||||
expiry_str = self._tokens['expiry']
|
||||
# Handle different expiry formats
|
||||
try:
|
||||
if expiry_str.endswith('Z'):
|
||||
expiry_dt = datetime.fromisoformat(expiry_str[:-1])
|
||||
else:
|
||||
expiry_dt = datetime.fromisoformat(expiry_str)
|
||||
|
||||
# Add 5-minute buffer
|
||||
import datetime as dt
|
||||
now = datetime.now()
|
||||
return now >= (expiry_dt - dt.timedelta(minutes=5))
|
||||
except:
|
||||
return True
|
||||
|
||||
async def _refresh_access_token(self) -> bool:
|
||||
"""Refresh the access token using refresh token"""
|
||||
if not self._tokens or 'refresh_token' not in self._tokens:
|
||||
return False
|
||||
|
||||
data = {
|
||||
'client_id': self.client_id,
|
||||
'client_secret': self.client_secret,
|
||||
'refresh_token': self._tokens['refresh_token'],
|
||||
'grant_type': 'refresh_token',
|
||||
'scope': ' '.join(self.SCOPES)
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
try:
|
||||
response = await client.post(self.TOKEN_ENDPOINT, data=data)
|
||||
response.raise_for_status()
|
||||
token_data = response.json()
|
||||
|
||||
# Update tokens
|
||||
self._tokens['token'] = token_data['access_token']
|
||||
if 'refresh_token' in token_data:
|
||||
self._tokens['refresh_token'] = token_data['refresh_token']
|
||||
|
||||
# Calculate expiry
|
||||
expires_in = token_data.get('expires_in', 3600)
|
||||
import datetime as dt
|
||||
expiry = datetime.now() + dt.timedelta(seconds=expires_in)
|
||||
self._tokens['expiry'] = expiry.isoformat()
|
||||
|
||||
await self._save_tokens()
|
||||
print("Access token refreshed successfully")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to refresh token: {e}")
|
||||
return False
|
||||
|
||||
def create_authorization_url(self, redirect_uri: str) -> str:
|
||||
"""Create authorization URL for OAuth flow"""
|
||||
return self.app.get_authorization_request_url(
|
||||
self.SCOPES, redirect_uri=redirect_uri
|
||||
)
|
||||
from urllib.parse import urlencode
|
||||
|
||||
params = {
|
||||
'client_id': self.client_id,
|
||||
'response_type': 'code',
|
||||
'redirect_uri': redirect_uri,
|
||||
'scope': ' '.join(self.SCOPES),
|
||||
'response_mode': 'query'
|
||||
}
|
||||
|
||||
auth_url = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize"
|
||||
return f"{auth_url}?{urlencode(params)}"
|
||||
|
||||
async def handle_authorization_callback(
|
||||
self, authorization_code: str, redirect_uri: str
|
||||
) -> bool:
|
||||
"""Handle OAuth callback and exchange code for tokens"""
|
||||
result = self.app.acquire_token_by_authorization_code(
|
||||
authorization_code,
|
||||
scopes=self.SCOPES,
|
||||
redirect_uri=redirect_uri,
|
||||
)
|
||||
if "access_token" in result:
|
||||
await self.save_cache()
|
||||
return True
|
||||
raise ValueError(result.get("error_description") or "Authorization failed")
|
||||
data = {
|
||||
'client_id': self.client_id,
|
||||
'client_secret': self.client_secret,
|
||||
'code': authorization_code,
|
||||
'grant_type': 'authorization_code',
|
||||
'redirect_uri': redirect_uri,
|
||||
'scope': ' '.join(self.SCOPES)
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
try:
|
||||
response = await client.post(self.TOKEN_ENDPOINT, data=data)
|
||||
response.raise_for_status()
|
||||
token_data = response.json()
|
||||
|
||||
# Store tokens in our format
|
||||
import datetime as dt
|
||||
expires_in = token_data.get('expires_in', 3600)
|
||||
expiry = datetime.now() + dt.timedelta(seconds=expires_in)
|
||||
|
||||
self._tokens = {
|
||||
'token': token_data['access_token'],
|
||||
'refresh_token': token_data['refresh_token'],
|
||||
'scopes': self.SCOPES,
|
||||
'expiry': expiry.isoformat()
|
||||
}
|
||||
|
||||
await self._save_tokens()
|
||||
print("Authorization successful, tokens saved")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"Authorization failed: {e}")
|
||||
return False
|
||||
|
||||
async def is_authenticated(self) -> bool:
|
||||
"""Check if we have valid credentials"""
|
||||
accounts = self.app.get_accounts()
|
||||
if not accounts:
|
||||
if not self._tokens:
|
||||
return False
|
||||
result = self.app.acquire_token_silent(self.SCOPES, account=accounts[0])
|
||||
if "access_token" in result:
|
||||
await self.save_cache()
|
||||
return True
|
||||
return False
|
||||
|
||||
# If token is expired, try to refresh
|
||||
if self._is_token_expired():
|
||||
print("Token expired, attempting refresh...")
|
||||
if await self._refresh_access_token():
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def get_access_token(self) -> str:
|
||||
"""Get an access token for Microsoft Graph"""
|
||||
accounts = self.app.get_accounts()
|
||||
if not accounts:
|
||||
raise ValueError("Not authenticated")
|
||||
result = self.app.acquire_token_silent(self.SCOPES, account=accounts[0])
|
||||
if "access_token" not in result:
|
||||
raise ValueError(
|
||||
result.get("error_description") or "Failed to acquire access token"
|
||||
)
|
||||
return result["access_token"]
|
||||
"""Get current access token"""
|
||||
if not self._tokens or 'token' not in self._tokens:
|
||||
raise ValueError("No access token available")
|
||||
|
||||
if self._is_token_expired():
|
||||
raise ValueError("Access token expired and refresh failed")
|
||||
|
||||
return self._tokens['token']
|
||||
|
||||
async def revoke_credentials(self):
|
||||
"""Clear token cache and remove token file"""
|
||||
self.token_cache.clear()
|
||||
"""Clear tokens"""
|
||||
self._tokens = None
|
||||
if os.path.exists(self.token_file):
|
||||
os.remove(self.token_file)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue