diff --git a/src/connectors/onedrive/connector.py b/src/connectors/onedrive/connector.py index 8b800b3d..0664e874 100644 --- a/src/connectors/onedrive/connector.py +++ b/src/connectors/onedrive/connector.py @@ -1,3 +1,4 @@ +from pathlib import Path import httpx import uuid from datetime import datetime, timedelta @@ -20,10 +21,12 @@ class OneDriveConnector(BaseConnector): def __init__(self, config: Dict[str, Any]): super().__init__(config) + project_root = Path(__file__).resolve().parent.parent.parent.parent + token_file = config.get("token_file") or str(project_root / "onedrive_token.json") self.oauth = OneDriveOAuth( client_id=self.get_client_id(), client_secret=self.get_client_secret(), - token_file=config.get("token_file", "onedrive_token.json"), + token_file=token_file, ) self.subscription_id = config.get("subscription_id") or config.get( "webhook_channel_id" diff --git a/src/connectors/onedrive/oauth.py b/src/connectors/onedrive/oauth.py index a81124e6..ad2f17d1 100644 --- a/src/connectors/onedrive/oauth.py +++ b/src/connectors/onedrive/oauth.py @@ -1,11 +1,12 @@ import os +import json import aiofiles -from typing import Optional -import msal +from datetime import datetime +import httpx class OneDriveOAuth: - """Handles Microsoft Graph OAuth authentication flow""" + """Direct token management for OneDrive, bypassing MSAL cache format""" SCOPES = [ "offline_access", @@ -20,76 +21,168 @@ class OneDriveOAuth: client_id: str, client_secret: str, token_file: str = "onedrive_token.json", - authority: str = "https://login.microsoftonline.com/common", ): self.client_id = client_id self.client_secret = client_secret self.token_file = token_file - self.authority = authority - self.token_cache = msal.SerializableTokenCache() + self._tokens = None + self._load_tokens() - # Load existing cache if available + def _load_tokens(self): + """Load tokens from file""" if os.path.exists(self.token_file): with open(self.token_file, "r") as f: - self.token_cache.deserialize(f.read()) + self._tokens = json.loads(f.read()) + print(f"Loaded tokens from {self.token_file}") + else: + print(f"No token file found at {self.token_file}") - self.app = msal.ConfidentialClientApplication( - client_id=self.client_id, - client_credential=self.client_secret, - authority=self.authority, - token_cache=self.token_cache, - ) + async def _save_tokens(self): + """Save tokens to file""" + if self._tokens: + async with aiofiles.open(self.token_file, "w") as f: + await f.write(json.dumps(self._tokens, indent=2)) - async def save_cache(self): - """Persist the token cache to file""" - async with aiofiles.open(self.token_file, "w") as f: - await f.write(self.token_cache.serialize()) + def _is_token_expired(self) -> bool: + """Check if current access token is expired""" + if not self._tokens or 'expiry' not in self._tokens: + return True + + expiry_str = self._tokens['expiry'] + # Handle different expiry formats + try: + if expiry_str.endswith('Z'): + expiry_dt = datetime.fromisoformat(expiry_str[:-1]) + else: + expiry_dt = datetime.fromisoformat(expiry_str) + + # Add 5-minute buffer + import datetime as dt + now = datetime.now() + return now >= (expiry_dt - dt.timedelta(minutes=5)) + except: + return True + async def _refresh_access_token(self) -> bool: + """Refresh the access token using refresh token""" + if not self._tokens or 'refresh_token' not in self._tokens: + return False + + data = { + 'client_id': self.client_id, + 'client_secret': self.client_secret, + 'refresh_token': self._tokens['refresh_token'], + 'grant_type': 'refresh_token', + 'scope': ' '.join(self.SCOPES) + } + + async with httpx.AsyncClient() as client: + try: + response = await client.post(self.TOKEN_ENDPOINT, data=data) + response.raise_for_status() + token_data = response.json() + + # Update tokens + self._tokens['token'] = token_data['access_token'] + if 'refresh_token' in token_data: + self._tokens['refresh_token'] = token_data['refresh_token'] + + # Calculate expiry + expires_in = token_data.get('expires_in', 3600) + import datetime as dt + expiry = datetime.now() + dt.timedelta(seconds=expires_in) + self._tokens['expiry'] = expiry.isoformat() + + await self._save_tokens() + print("Access token refreshed successfully") + return True + + except Exception as e: + print(f"Failed to refresh token: {e}") + return False + + async def is_authenticated(self) -> bool: + """Check if we have valid credentials""" + if not self._tokens: + return False + + # If token is expired, try to refresh + if self._is_token_expired(): + print("Token expired, attempting refresh...") + if await self._refresh_access_token(): + return True + else: + return False + + return True + + def get_access_token(self) -> str: + """Get current access token""" + if not self._tokens or 'token' not in self._tokens: + raise ValueError("No access token available") + + if self._is_token_expired(): + raise ValueError("Access token expired and refresh failed") + + return self._tokens['token'] + + async def revoke_credentials(self): + """Clear tokens""" + self._tokens = None + if os.path.exists(self.token_file): + os.remove(self.token_file) + + # Keep these methods for compatibility with your existing OAuth flow def create_authorization_url(self, redirect_uri: str) -> str: """Create authorization URL for OAuth flow""" - return self.app.get_authorization_request_url( - self.SCOPES, redirect_uri=redirect_uri - ) + from urllib.parse import urlencode + + params = { + 'client_id': self.client_id, + 'response_type': 'code', + 'redirect_uri': redirect_uri, + 'scope': ' '.join(self.SCOPES), + 'response_mode': 'query' + } + + auth_url = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize" + return f"{auth_url}?{urlencode(params)}" async def handle_authorization_callback( self, authorization_code: str, redirect_uri: str ) -> bool: """Handle OAuth callback and exchange code for tokens""" - result = self.app.acquire_token_by_authorization_code( - authorization_code, - scopes=self.SCOPES, - redirect_uri=redirect_uri, - ) - if "access_token" in result: - await self.save_cache() - return True - raise ValueError(result.get("error_description") or "Authorization failed") + data = { + 'client_id': self.client_id, + 'client_secret': self.client_secret, + 'code': authorization_code, + 'grant_type': 'authorization_code', + 'redirect_uri': redirect_uri, + 'scope': ' '.join(self.SCOPES) + } - async def is_authenticated(self) -> bool: - """Check if we have valid credentials""" - accounts = self.app.get_accounts() - if not accounts: - return False - result = self.app.acquire_token_silent(self.SCOPES, account=accounts[0]) - if "access_token" in result: - await self.save_cache() - return True - return False + async with httpx.AsyncClient() as client: + try: + response = await client.post(self.TOKEN_ENDPOINT, data=data) + response.raise_for_status() + token_data = response.json() - def get_access_token(self) -> str: - """Get an access token for Microsoft Graph""" - accounts = self.app.get_accounts() - if not accounts: - raise ValueError("Not authenticated") - result = self.app.acquire_token_silent(self.SCOPES, account=accounts[0]) - if "access_token" not in result: - raise ValueError( - result.get("error_description") or "Failed to acquire access token" - ) - return result["access_token"] + # Store tokens in our format + import datetime as dt + expires_in = token_data.get('expires_in', 3600) + expiry = datetime.now() + dt.timedelta(seconds=expires_in) + + self._tokens = { + 'token': token_data['access_token'], + 'refresh_token': token_data['refresh_token'], + 'scopes': self.SCOPES, + 'expiry': expiry.isoformat() + } - async def revoke_credentials(self): - """Clear token cache and remove token file""" - self.token_cache.clear() - if os.path.exists(self.token_file): - os.remove(self.token_file) + await self._save_tokens() + print("Authorization successful, tokens saved") + return True + + except Exception as e: + print(f"Authorization failed: {e}") + return False diff --git a/src/connectors/sharepoint/connector.py b/src/connectors/sharepoint/connector.py index 7135cc8e..8282f891 100644 --- a/src/connectors/sharepoint/connector.py +++ b/src/connectors/sharepoint/connector.py @@ -1,3 +1,4 @@ +from pathlib import Path import httpx import uuid from datetime import datetime, timedelta @@ -20,10 +21,12 @@ class SharePointConnector(BaseConnector): def __init__(self, config: Dict[str, Any]): super().__init__(config) + project_root = Path(__file__).resolve().parent.parent.parent.parent + token_file = config.get("token_file") or str(project_root / "onedrive_token.json") self.oauth = SharePointOAuth( client_id=self.get_client_id(), client_secret=self.get_client_secret(), - token_file=config.get("token_file", "sharepoint_token.json"), + token_file=token_file, ) self.subscription_id = config.get("subscription_id") or config.get( "webhook_channel_id" diff --git a/src/connectors/sharepoint/oauth.py b/src/connectors/sharepoint/oauth.py index fa7424e9..77ce1ecc 100644 --- a/src/connectors/sharepoint/oauth.py +++ b/src/connectors/sharepoint/oauth.py @@ -1,11 +1,12 @@ import os +import json import aiofiles -from typing import Optional -import msal +from datetime import datetime +import httpx class SharePointOAuth: - """Handles Microsoft Graph OAuth authentication flow""" + """Direct token management for SharePoint, bypassing MSAL cache format""" SCOPES = [ "offline_access", @@ -21,76 +22,173 @@ class SharePointOAuth: client_id: str, client_secret: str, token_file: str = "sharepoint_token.json", - authority: str = "https://login.microsoftonline.com/common", + authority: str = "https://login.microsoftonline.com/common", # Keep for compatibility ): self.client_id = client_id self.client_secret = client_secret self.token_file = token_file - self.authority = authority - self.token_cache = msal.SerializableTokenCache() + self.authority = authority # Keep for compatibility but not used + self._tokens = None + self._load_tokens() - # Load existing cache if available + def _load_tokens(self): + """Load tokens from file""" if os.path.exists(self.token_file): with open(self.token_file, "r") as f: - self.token_cache.deserialize(f.read()) - - self.app = msal.ConfidentialClientApplication( - client_id=self.client_id, - client_credential=self.client_secret, - authority=self.authority, - token_cache=self.token_cache, - ) + self._tokens = json.loads(f.read()) + print(f"Loaded tokens from {self.token_file}") + else: + print(f"No token file found at {self.token_file}") async def save_cache(self): - """Persist the token cache to file""" - async with aiofiles.open(self.token_file, "w") as f: - await f.write(self.token_cache.serialize()) + """Persist tokens to file (renamed for compatibility)""" + await self._save_tokens() + + async def _save_tokens(self): + """Save tokens to file""" + if self._tokens: + async with aiofiles.open(self.token_file, "w") as f: + await f.write(json.dumps(self._tokens, indent=2)) + + def _is_token_expired(self) -> bool: + """Check if current access token is expired""" + if not self._tokens or 'expiry' not in self._tokens: + return True + + expiry_str = self._tokens['expiry'] + # Handle different expiry formats + try: + if expiry_str.endswith('Z'): + expiry_dt = datetime.fromisoformat(expiry_str[:-1]) + else: + expiry_dt = datetime.fromisoformat(expiry_str) + + # Add 5-minute buffer + import datetime as dt + now = datetime.now() + return now >= (expiry_dt - dt.timedelta(minutes=5)) + except: + return True + + async def _refresh_access_token(self) -> bool: + """Refresh the access token using refresh token""" + if not self._tokens or 'refresh_token' not in self._tokens: + return False + + data = { + 'client_id': self.client_id, + 'client_secret': self.client_secret, + 'refresh_token': self._tokens['refresh_token'], + 'grant_type': 'refresh_token', + 'scope': ' '.join(self.SCOPES) + } + + async with httpx.AsyncClient() as client: + try: + response = await client.post(self.TOKEN_ENDPOINT, data=data) + response.raise_for_status() + token_data = response.json() + + # Update tokens + self._tokens['token'] = token_data['access_token'] + if 'refresh_token' in token_data: + self._tokens['refresh_token'] = token_data['refresh_token'] + + # Calculate expiry + expires_in = token_data.get('expires_in', 3600) + import datetime as dt + expiry = datetime.now() + dt.timedelta(seconds=expires_in) + self._tokens['expiry'] = expiry.isoformat() + + await self._save_tokens() + print("Access token refreshed successfully") + return True + + except Exception as e: + print(f"Failed to refresh token: {e}") + return False def create_authorization_url(self, redirect_uri: str) -> str: """Create authorization URL for OAuth flow""" - return self.app.get_authorization_request_url( - self.SCOPES, redirect_uri=redirect_uri - ) + from urllib.parse import urlencode + + params = { + 'client_id': self.client_id, + 'response_type': 'code', + 'redirect_uri': redirect_uri, + 'scope': ' '.join(self.SCOPES), + 'response_mode': 'query' + } + + auth_url = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize" + return f"{auth_url}?{urlencode(params)}" async def handle_authorization_callback( self, authorization_code: str, redirect_uri: str ) -> bool: """Handle OAuth callback and exchange code for tokens""" - result = self.app.acquire_token_by_authorization_code( - authorization_code, - scopes=self.SCOPES, - redirect_uri=redirect_uri, - ) - if "access_token" in result: - await self.save_cache() - return True - raise ValueError(result.get("error_description") or "Authorization failed") + data = { + 'client_id': self.client_id, + 'client_secret': self.client_secret, + 'code': authorization_code, + 'grant_type': 'authorization_code', + 'redirect_uri': redirect_uri, + 'scope': ' '.join(self.SCOPES) + } + + async with httpx.AsyncClient() as client: + try: + response = await client.post(self.TOKEN_ENDPOINT, data=data) + response.raise_for_status() + token_data = response.json() + + # Store tokens in our format + import datetime as dt + expires_in = token_data.get('expires_in', 3600) + expiry = datetime.now() + dt.timedelta(seconds=expires_in) + + self._tokens = { + 'token': token_data['access_token'], + 'refresh_token': token_data['refresh_token'], + 'scopes': self.SCOPES, + 'expiry': expiry.isoformat() + } + + await self._save_tokens() + print("Authorization successful, tokens saved") + return True + + except Exception as e: + print(f"Authorization failed: {e}") + return False async def is_authenticated(self) -> bool: """Check if we have valid credentials""" - accounts = self.app.get_accounts() - if not accounts: + if not self._tokens: return False - result = self.app.acquire_token_silent(self.SCOPES, account=accounts[0]) - if "access_token" in result: - await self.save_cache() - return True - return False + + # If token is expired, try to refresh + if self._is_token_expired(): + print("Token expired, attempting refresh...") + if await self._refresh_access_token(): + return True + else: + return False + + return True def get_access_token(self) -> str: - """Get an access token for Microsoft Graph""" - accounts = self.app.get_accounts() - if not accounts: - raise ValueError("Not authenticated") - result = self.app.acquire_token_silent(self.SCOPES, account=accounts[0]) - if "access_token" not in result: - raise ValueError( - result.get("error_description") or "Failed to acquire access token" - ) - return result["access_token"] + """Get current access token""" + if not self._tokens or 'token' not in self._tokens: + raise ValueError("No access token available") + + if self._is_token_expired(): + raise ValueError("Access token expired and refresh failed") + + return self._tokens['token'] async def revoke_credentials(self): - """Clear token cache and remove token file""" - self.token_cache.clear() + """Clear tokens""" + self._tokens = None if os.path.exists(self.token_file): os.remove(self.token_file)