From f03889a2b38d0ce47a5e76ddcb881d00edf09550 Mon Sep 17 00:00:00 2001 From: Eric Hare Date: Thu, 25 Sep 2025 13:12:27 -0700 Subject: [PATCH] Finally fix MSAL in onedrive/sharepoint --- src/api/connectors.py | 95 ++- src/connectors/connection_manager.py | 56 +- src/connectors/onedrive/connector.py | 624 ++++++++++++++------ src/connectors/onedrive/oauth.py | 428 +++++++++----- src/connectors/sharepoint/connector.py | 765 ++++++++++++++++++------- src/connectors/sharepoint/oauth.py | 431 +++++++++----- 6 files changed, 1649 insertions(+), 750 deletions(-) diff --git a/src/api/connectors.py b/src/api/connectors.py index c8bd6636..5fbcec86 100644 --- a/src/api/connectors.py +++ b/src/api/connectors.py @@ -132,7 +132,10 @@ async def connector_status(request: Request, connector_service, session_manager) for connection in connections: try: connector = await connector_service._get_connector(connection.connection_id) - connection_client_ids[connection.connection_id] = connector.get_client_id() + if connector is not None: + connection_client_ids[connection.connection_id] = connector.get_client_id() + else: + connection_client_ids[connection.connection_id] = None except Exception as e: logger.warning( "Could not get connector for connection", @@ -338,8 +341,8 @@ async def connector_webhook(request: Request, connector_service, session_manager ) async def connector_token(request: Request, connector_service, session_manager): - """Get access token for connector API calls (e.g., Google Picker)""" - connector_type = request.path_params.get("connector_type") + """Get access token for connector API calls (e.g., Pickers).""" + url_connector_type = request.path_params.get("connector_type") connection_id = request.query_params.get("connection_id") if not connection_id: @@ -348,37 +351,81 @@ async def connector_token(request: Request, connector_service, session_manager): user = request.state.user try: - # Get the connection and verify it belongs to the user + # 1) Load the connection and verify ownership connection = await connector_service.connection_manager.get_connection(connection_id) if not connection or connection.user_id != user.user_id: return JSONResponse({"error": "Connection not found"}, status_code=404) - # Get the connector instance + # 2) Get the ACTUAL connector instance/type for this connection_id connector = await connector_service._get_connector(connection_id) if not connector: - return JSONResponse({"error": f"Connector not available - authentication may have failed for {connector_type}"}, status_code=404) + return JSONResponse( + {"error": f"Connector not available - authentication may have failed for {url_connector_type}"}, + status_code=404, + ) - # For Google Drive, get the access token - if connector_type == "google_drive" and hasattr(connector, 'oauth'): + real_type = getattr(connector, "type", None) or getattr(connection, "connector_type", None) + if real_type is None: + return JSONResponse({"error": "Unable to determine connector type"}, status_code=500) + + # Optional: warn if URL path type disagrees with real type + if url_connector_type and url_connector_type != real_type: + # You can downgrade this to debug if you expect cross-routing. + return JSONResponse( + { + "error": "Connector type mismatch", + "detail": { + "requested_type": url_connector_type, + "actual_type": real_type, + "hint": "Call the token endpoint using the correct connector_type for this connection_id.", + }, + }, + status_code=400, + ) + + # 3) Branch by the actual connector type + # GOOGLE DRIVE (google-auth) + if real_type == "google_drive" and hasattr(connector, "oauth"): await connector.oauth.load_credentials() if connector.oauth.creds and connector.oauth.creds.valid: - return JSONResponse({ - "access_token": connector.oauth.creds.token, - "expires_in": (connector.oauth.creds.expiry.timestamp() - - __import__('time').time()) if connector.oauth.creds.expiry else None - }) - else: - return JSONResponse({"error": "Invalid or expired credentials"}, status_code=401) - - # For OneDrive and SharePoint, get the access token - elif connector_type in ["onedrive", "sharepoint"] and hasattr(connector, 'oauth'): + expires_in = None + try: + if connector.oauth.creds.expiry: + import time + expires_in = max(0, int(connector.oauth.creds.expiry.timestamp() - time.time())) + except Exception: + expires_in = None + + return JSONResponse( + { + "access_token": connector.oauth.creds.token, + "expires_in": expires_in, + } + ) + return JSONResponse({"error": "Invalid or expired credentials"}, status_code=401) + + # ONEDRIVE / SHAREPOINT (MSAL or custom) + if real_type in ("onedrive", "sharepoint") and hasattr(connector, "oauth"): + # Ensure cache/credentials are loaded before trying to use them try: + # Prefer a dedicated is_authenticated() that loads cache internally + if hasattr(connector.oauth, "is_authenticated"): + ok = await connector.oauth.is_authenticated() + else: + # Fallback: try to load credentials explicitly if available + ok = True + if hasattr(connector.oauth, "load_credentials"): + ok = await connector.oauth.load_credentials() + + if not ok: + return JSONResponse({"error": "Not authenticated"}, status_code=401) + + # Now safe to fetch access token access_token = connector.oauth.get_access_token() - return JSONResponse({ - "access_token": access_token, - "expires_in": None # MSAL handles token expiry internally - }) + # MSAL result has expiry, but we’re returning a raw token; keep expires_in None for simplicity + return JSONResponse({"access_token": access_token, "expires_in": None}) except ValueError as e: + # Typical when acquire_token_silent fails (e.g., needs re-auth) return JSONResponse({"error": f"Failed to get access token: {str(e)}"}, status_code=401) except Exception as e: return JSONResponse({"error": f"Authentication error: {str(e)}"}, status_code=500) @@ -386,7 +433,5 @@ async def connector_token(request: Request, connector_service, session_manager): return JSONResponse({"error": "Token not available for this connector type"}, status_code=400) except Exception as e: - logger.error("Error getting connector token", error=str(e)) + logger.error("Error getting connector token", exc_info=True) return JSONResponse({"error": str(e)}, status_code=500) - - diff --git a/src/connectors/connection_manager.py b/src/connectors/connection_manager.py index 2e70ee1f..07ebd5ee 100644 --- a/src/connectors/connection_manager.py +++ b/src/connectors/connection_manager.py @@ -294,32 +294,39 @@ class ConnectionManager: async def get_connector(self, connection_id: str) -> Optional[BaseConnector]: """Get an active connector instance""" + logger.debug(f"Getting connector for connection_id: {connection_id}") + # Return cached connector if available if connection_id in self.active_connectors: connector = self.active_connectors[connection_id] if connector.is_authenticated: + logger.debug(f"Returning cached authenticated connector for {connection_id}") return connector else: # Remove unauthenticated connector from cache + logger.debug(f"Removing unauthenticated connector from cache for {connection_id}") del self.active_connectors[connection_id] # Try to create and authenticate connector connection_config = self.connections.get(connection_id) if not connection_config or not connection_config.is_active: + logger.debug(f"No active connection config found for {connection_id}") return None + logger.debug(f"Creating connector for {connection_config.connector_type}") connector = self._create_connector(connection_config) - if await connector.authenticate(): + + logger.debug(f"Attempting authentication for {connection_id}") + auth_result = await connector.authenticate() + logger.debug(f"Authentication result for {connection_id}: {auth_result}") + + if auth_result: self.active_connectors[connection_id] = connector - - # Setup webhook subscription if not already set up - await self._setup_webhook_if_needed( - connection_id, connection_config, connector - ) - + # ... rest of the method return connector - - return None + else: + logger.warning(f"Authentication failed for {connection_id}") + return None def get_available_connector_types(self) -> Dict[str, Dict[str, Any]]: """Get available connector types with their metadata""" @@ -363,20 +370,23 @@ class ConnectionManager: def _create_connector(self, config: ConnectionConfig) -> BaseConnector: """Factory method to create connector instances""" - if config.connector_type == "google_drive": - return GoogleDriveConnector(config.config) - elif config.connector_type == "sharepoint": - return SharePointConnector(config.config) - elif config.connector_type == "onedrive": - return OneDriveConnector(config.config) - elif config.connector_type == "box": - # Future: BoxConnector(config.config) - raise NotImplementedError("Box connector not implemented yet") - elif config.connector_type == "dropbox": - # Future: DropboxConnector(config.config) - raise NotImplementedError("Dropbox connector not implemented yet") - else: - raise ValueError(f"Unknown connector type: {config.connector_type}") + try: + if config.connector_type == "google_drive": + return GoogleDriveConnector(config.config) + elif config.connector_type == "sharepoint": + return SharePointConnector(config.config) + elif config.connector_type == "onedrive": + return OneDriveConnector(config.config) + elif config.connector_type == "box": + raise NotImplementedError("Box connector not implemented yet") + elif config.connector_type == "dropbox": + raise NotImplementedError("Dropbox connector not implemented yet") + else: + raise ValueError(f"Unknown connector type: {config.connector_type}") + except Exception as e: + logger.error(f"Failed to create {config.connector_type} connector: {e}") + # Re-raise the exception so caller can handle appropriately + raise async def update_last_sync(self, connection_id: str): """Update the last sync timestamp for a connection""" diff --git a/src/connectors/onedrive/connector.py b/src/connectors/onedrive/connector.py index 9a7b6760..bea1e790 100644 --- a/src/connectors/onedrive/connector.py +++ b/src/connectors/onedrive/connector.py @@ -1,235 +1,487 @@ +import logging from pathlib import Path +from typing import List, Dict, Any, Optional +from datetime import datetime import httpx -import uuid -from datetime import datetime, timedelta -from typing import Dict, List, Any, Optional from ..base import BaseConnector, ConnectorDocument, DocumentACL from .oauth import OneDriveOAuth +logger = logging.getLogger(__name__) + class OneDriveConnector(BaseConnector): - """OneDrive connector using Microsoft Graph API""" + """OneDrive connector using MSAL-based OAuth for authentication.""" + # Required BaseConnector class attributes CLIENT_ID_ENV_VAR = "MICROSOFT_GRAPH_OAUTH_CLIENT_ID" CLIENT_SECRET_ENV_VAR = "MICROSOFT_GRAPH_OAUTH_CLIENT_SECRET" # Connector metadata CONNECTOR_NAME = "OneDrive" - CONNECTOR_DESCRIPTION = "Connect your personal OneDrive to sync documents" + CONNECTOR_DESCRIPTION = "Connect to OneDrive (personal) to sync documents and files" CONNECTOR_ICON = "onedrive" def __init__(self, config: Dict[str, Any]): - super().__init__(config) + logger.debug(f"OneDrive connector __init__ called with config type: {type(config)}") + logger.debug(f"OneDrive connector __init__ config value: {config}") + + if config is None: + logger.debug("Config was None, using empty dict") + config = {} + + try: + logger.debug("Calling super().__init__") + super().__init__(config) + logger.debug("super().__init__ completed successfully") + except Exception as e: + logger.error(f"super().__init__ failed: {e}") + raise + + # Initialize with defaults that allow the connector to be listed + self.client_id = None + self.client_secret = None + self.redirect_uri = config.get("redirect_uri", "http://localhost") # must match your app registration + + # Try to get credentials, but don't fail if they're missing + try: + self.client_id = self.get_client_id() + logger.debug(f"Got client_id: {self.client_id is not None}") + except Exception as e: + logger.debug(f"Failed to get client_id: {e}") + + try: + self.client_secret = self.get_client_secret() + logger.debug(f"Got client_secret: {self.client_secret is not None}") + except Exception as e: + logger.debug(f"Failed to get client_secret: {e}") + + # Token file setup project_root = Path(__file__).resolve().parent.parent.parent.parent token_file = config.get("token_file") or str(project_root / "onedrive_token.json") - self.oauth = OneDriveOAuth( - client_id=self.get_client_id(), - client_secret=self.get_client_secret(), - token_file=token_file, - ) - self.subscription_id = config.get("subscription_id") or config.get( - "webhook_channel_id" - ) - self.base_url = "https://graph.microsoft.com/v1.0" + Path(token_file).parent.mkdir(parents=True, exist_ok=True) - async def authenticate(self) -> bool: - if await self.oauth.is_authenticated(): - self._authenticated = True - return True - return False + # Only initialize OAuth if we have credentials + if self.client_id and self.client_secret: + connection_id = config.get("connection_id", "default") - async def setup_subscription(self) -> str: - if not self._authenticated: - raise ValueError("Not authenticated") + # Use token_file from config if provided, otherwise generate one + if config.get("token_file"): + oauth_token_file = config["token_file"] + else: + # Use a per-connection cache file to avoid collisions with other connectors + oauth_token_file = f"onedrive_token_{connection_id}.json" - webhook_url = self.config.get("webhook_url") - if not webhook_url: - raise ValueError("webhook_url required in config for subscriptions") + # MSA & org both work via /common for OneDrive personal testing + authority = "https://login.microsoftonline.com/common" - expiration = (datetime.utcnow() + timedelta(days=2)).isoformat() + "Z" - body = { - "changeType": "created,updated,deleted", - "notificationUrl": webhook_url, - "resource": "/me/drive/root", - "expirationDateTime": expiration, - "clientState": str(uuid.uuid4()), + self.oauth = OneDriveOAuth( + client_id=self.client_id, + client_secret=self.client_secret, + token_file=oauth_token_file, + authority=authority, + allow_json_refresh=True, # allows one-time migration from legacy JSON if present + ) + else: + self.oauth = None + + # Track subscription ID for webhooks (note: change notifications might not be available for personal accounts) + self._subscription_id: Optional[str] = None + + # Graph API defaults + self._graph_api_version = "v1.0" + self._default_params = { + "$select": "id,name,size,lastModifiedDateTime,createdDateTime,webUrl,file,folder,@microsoft.graph.downloadUrl" } - token = self.oauth.get_access_token() - async with httpx.AsyncClient() as client: - resp = await client.post( - f"{self.base_url}/subscriptions", - json=body, - headers={"Authorization": f"Bearer {token}"}, - ) - resp.raise_for_status() - data = resp.json() + @property + def _graph_base_url(self) -> str: + """Base URL for Microsoft Graph API calls.""" + return f"https://graph.microsoft.com/{self._graph_api_version}" - self.subscription_id = data["id"] - return self.subscription_id + def emit(self, doc: ConnectorDocument) -> None: + """Emit a ConnectorDocument instance (integrate with your pipeline here).""" + logger.debug(f"Emitting OneDrive document: {doc.id} ({doc.filename})") - async def list_files( - self, page_token: Optional[str] = None, limit: int = 100 - ) -> Dict[str, Any]: - if not self._authenticated: - raise ValueError("Not authenticated") + async def authenticate(self) -> bool: + """Test authentication - BaseConnector interface.""" + logger.debug(f"OneDrive authenticate() called, oauth is None: {self.oauth is None}") + try: + if not self.oauth: + logger.debug("OneDrive authentication failed: OAuth not initialized") + self._authenticated = False + return False - params = {"$top": str(limit)} - if page_token: - params["$skiptoken"] = page_token + logger.debug("Loading OneDrive credentials...") + load_result = await self.oauth.load_credentials() + logger.debug(f"Load credentials result: {load_result}") - token = self.oauth.get_access_token() - async with httpx.AsyncClient() as client: - resp = await client.get( - f"{self.base_url}/me/drive/root/children", - params=params, - headers={"Authorization": f"Bearer {token}"}, - ) - resp.raise_for_status() - data = resp.json() + logger.debug("Checking OneDrive authentication status...") + authenticated = await self.oauth.is_authenticated() + logger.debug(f"OneDrive is_authenticated result: {authenticated}") - files = [] - for item in data.get("value", []): - if item.get("file"): - files.append( - { - "id": item["id"], - "name": item["name"], - "mimeType": item.get("file", {}).get( - "mimeType", "application/octet-stream" - ), - "webViewLink": item.get("webUrl"), - "createdTime": item.get("createdDateTime"), - "modifiedTime": item.get("lastModifiedDateTime"), - } - ) + self._authenticated = authenticated + return authenticated + except Exception as e: + logger.error(f"OneDrive authentication failed: {e}") + import traceback + traceback.print_exc() + self._authenticated = False + return False - next_token = None - next_link = data.get("@odata.nextLink") - if next_link: - from urllib.parse import urlparse, parse_qs + def get_auth_url(self) -> str: + """Get OAuth authorization URL.""" + if not self.oauth: + raise RuntimeError("OneDrive OAuth not initialized - missing credentials") + return self.oauth.create_authorization_url(self.redirect_uri) - parsed = urlparse(next_link) - next_token = parse_qs(parsed.query).get("$skiptoken", [None])[0] + async def handle_oauth_callback(self, auth_code: str) -> Dict[str, Any]: + """Handle OAuth callback.""" + if not self.oauth: + raise RuntimeError("OneDrive OAuth not initialized - missing credentials") + try: + success = await self.oauth.handle_authorization_callback(auth_code, self.redirect_uri) + if success: + self._authenticated = True + return {"status": "success"} + else: + raise ValueError("OAuth callback failed") + except Exception as e: + logger.error(f"OAuth callback failed: {e}") + raise - return {"files": files, "nextPageToken": next_token} + def sync_once(self) -> None: + """ + Perform a one-shot sync of OneDrive files and emit documents. + """ + import asyncio + + async def _async_sync(): + try: + file_list = await self.list_files(max_files=1000) + files = file_list.get("files", []) + for file_info in files: + try: + file_id = file_info.get("id") + if not file_id: + continue + doc = await self.get_file_content(file_id) + self.emit(doc) + except Exception as e: + logger.error(f"Failed to sync OneDrive file {file_info.get('name', 'unknown')}: {e}") + continue + except Exception as e: + logger.error(f"OneDrive sync_once failed: {e}") + raise + + if hasattr(asyncio, 'run'): + asyncio.run(_async_sync()) + else: + loop = asyncio.get_event_loop() + loop.run_until_complete(_async_sync()) + + async def setup_subscription(self) -> str: + """ + Set up real-time subscription for file changes. + NOTE: Change notifications may not be available for personal OneDrive accounts. + """ + webhook_url = self.config.get('webhook_url') + if not webhook_url: + logger.warning("No webhook URL configured, skipping OneDrive subscription setup") + return "no-webhook-configured" + + try: + if not await self.authenticate(): + raise RuntimeError("OneDrive authentication failed during subscription setup") + + token = self.oauth.get_access_token() + + # For OneDrive personal we target the user's drive + resource = "/me/drive/root" + + subscription_data = { + "changeType": "created,updated,deleted", + "notificationUrl": f"{webhook_url}/webhook/onedrive", + "resource": resource, + "expirationDateTime": self._get_subscription_expiry(), + "clientState": "onedrive_personal", + } + + headers = { + "Authorization": f"Bearer {token}", + "Content-Type": "application/json", + } + + url = f"{self._graph_base_url}/subscriptions" + + async with httpx.AsyncClient() as client: + response = await client.post(url, json=subscription_data, headers=headers, timeout=30) + response.raise_for_status() + + result = response.json() + subscription_id = result.get("id") + + if subscription_id: + self._subscription_id = subscription_id + logger.info(f"OneDrive subscription created: {subscription_id}") + return subscription_id + else: + raise ValueError("No subscription ID returned from Microsoft Graph") + + except Exception as e: + logger.error(f"Failed to setup OneDrive subscription: {e}") + raise + + def _get_subscription_expiry(self) -> str: + """Get subscription expiry time (Graph caps duration; often <= 3 days).""" + from datetime import datetime, timedelta + expiry = datetime.utcnow() + timedelta(days=3) + return expiry.strftime("%Y-%m-%dT%H:%M:%S.%fZ") + + async def list_files(self, page_token: Optional[str] = None, max_files: Optional[int] = None) -> Dict[str, Any]: + """List files from OneDrive using Microsoft Graph.""" + try: + if not await self.authenticate(): + raise RuntimeError("OneDrive authentication failed during file listing") + + files: List[Dict[str, Any]] = [] + max_files_value = max_files if max_files is not None else 100 + + base_url = f"{self._graph_base_url}/me/drive/root/children" + + params = dict(self._default_params) + params["$top"] = max_files_value + + if page_token: + params["$skiptoken"] = page_token + + response = await self._make_graph_request(base_url, params=params) + data = response.json() + + items = data.get("value", []) + for item in items: + if item.get("file"): # include files only + files.append({ + "id": item.get("id", ""), + "name": item.get("name", ""), + "path": f"/drive/items/{item.get('id')}", + "size": int(item.get("size", 0)), + "modified": item.get("lastModifiedDateTime"), + "created": item.get("createdDateTime"), + "mime_type": item.get("file", {}).get("mimeType", self._get_mime_type(item.get("name", ""))), + "url": item.get("webUrl", ""), + "download_url": item.get("@microsoft.graph.downloadUrl"), + }) + + # Next page + next_page_token = None + next_link = data.get("@odata.nextLink") + if next_link: + from urllib.parse import urlparse, parse_qs + parsed = urlparse(next_link) + query_params = parse_qs(parsed.query) + if "$skiptoken" in query_params: + next_page_token = query_params["$skiptoken"][0] + + return {"files": files, "next_page_token": next_page_token} + + except Exception as e: + logger.error(f"Failed to list OneDrive files: {e}") + return {"files": [], "next_page_token": None} async def get_file_content(self, file_id: str) -> ConnectorDocument: - if not self._authenticated: - raise ValueError("Not authenticated") + """Get file content and metadata.""" + try: + if not await self.authenticate(): + raise RuntimeError("OneDrive authentication failed during file content retrieval") - token = self.oauth.get_access_token() - headers = {"Authorization": f"Bearer {token}"} - async with httpx.AsyncClient() as client: - meta_resp = await client.get( - f"{self.base_url}/me/drive/items/{file_id}", headers=headers - ) - meta_resp.raise_for_status() - metadata = meta_resp.json() + file_metadata = await self._get_file_metadata_by_id(file_id) + if not file_metadata: + raise ValueError(f"File not found: {file_id}") - content_resp = await client.get( - f"{self.base_url}/me/drive/items/{file_id}/content", headers=headers - ) - content = content_resp.content - - # Handle the possibility of this being a redirect - if content_resp.status_code in (301, 302, 303, 307, 308): - redirect_url = content_resp.headers.get("Location") - if redirect_url: - content_resp = await client.get(redirect_url) - content_resp.raise_for_status() - content = content_resp.content + download_url = file_metadata.get("download_url") + if download_url: + content = await self._download_file_from_url(download_url) else: - content_resp.raise_for_status() + content = await self._download_file_content(file_id) - perm_resp = await client.get( - f"{self.base_url}/me/drive/items/{file_id}/permissions", headers=headers + acl = DocumentACL( + owner="", + user_permissions={}, + group_permissions={}, ) - perm_resp.raise_for_status() - permissions = perm_resp.json() - acl = self._parse_permissions(metadata, permissions) - modified = datetime.fromisoformat( - metadata["lastModifiedDateTime"].replace("Z", "+00:00") - ).replace(tzinfo=None) - created = datetime.fromisoformat( - metadata["createdDateTime"].replace("Z", "+00:00") - ).replace(tzinfo=None) + modified_time = self._parse_graph_date(file_metadata.get("modified")) + created_time = self._parse_graph_date(file_metadata.get("created")) - document = ConnectorDocument( - id=metadata["id"], - filename=metadata["name"], - mimetype=metadata.get("file", {}).get( - "mimeType", "application/octet-stream" - ), - content=content, - source_url=metadata.get("webUrl"), - acl=acl, - modified_time=modified, - created_time=created, - metadata={"size": metadata.get("size")}, - ) - return document - - def _parse_permissions( - self, metadata: Dict[str, Any], permissions: Dict[str, Any] - ) -> DocumentACL: - acl = DocumentACL() - owner = metadata.get("createdBy", {}).get("user", {}).get("email") - if owner: - acl.owner = owner - for perm in permissions.get("value", []): - role = perm.get("roles", ["read"])[0] - grantee = perm.get("grantedToV2") or perm.get("grantedTo") - if not grantee: - continue - user = grantee.get("user") - if user and user.get("email"): - acl.user_permissions[user["email"]] = role - group = grantee.get("group") - if group and group.get("email"): - acl.group_permissions[group["email"]] = role - return acl - - def handle_webhook_validation( - self, request_method: str, headers: Dict[str, str], query_params: Dict[str, str] - ) -> Optional[str]: - """Handle Microsoft Graph webhook validation""" - if request_method == "GET": - validation_token = query_params.get("validationtoken") or query_params.get( - "validationToken" + return ConnectorDocument( + id=file_id, + filename=file_metadata.get("name", ""), + mimetype=file_metadata.get("mime_type", "application/octet-stream"), + content=content, + source_url=file_metadata.get("url", ""), + acl=acl, + modified_time=modified_time, + created_time=created_time, + metadata={ + "onedrive_path": file_metadata.get("path", ""), + "size": file_metadata.get("size", 0), + }, ) - if validation_token: - return validation_token + + except Exception as e: + logger.error(f"Failed to get OneDrive file content {file_id}: {e}") + raise + + async def _get_file_metadata_by_id(self, file_id: str) -> Optional[Dict[str, Any]]: + """Get file metadata by ID using Graph API.""" + try: + url = f"{self._graph_base_url}/me/drive/items/{file_id}" + params = dict(self._default_params) + + response = await self._make_graph_request(url, params=params) + item = response.json() + + if item.get("file"): + return { + "id": file_id, + "name": item.get("name", ""), + "path": f"/drive/items/{file_id}", + "size": int(item.get("size", 0)), + "modified": item.get("lastModifiedDateTime"), + "created": item.get("createdDateTime"), + "mime_type": item.get("file", {}).get("mimeType", self._get_mime_type(item.get("name", ""))), + "url": item.get("webUrl", ""), + "download_url": item.get("@microsoft.graph.downloadUrl"), + } + + return None + + except Exception as e: + logger.error(f"Failed to get file metadata for {file_id}: {e}") + return None + + async def _download_file_content(self, file_id: str) -> bytes: + """Download file content by file ID using Graph API.""" + try: + url = f"{self._graph_base_url}/me/drive/items/{file_id}/content" + token = self.oauth.get_access_token() + headers = {"Authorization": f"Bearer {token}"} + + async with httpx.AsyncClient() as client: + response = await client.get(url, headers=headers, timeout=60) + response.raise_for_status() + return response.content + + except Exception as e: + logger.error(f"Failed to download file content for {file_id}: {e}") + raise + + async def _download_file_from_url(self, download_url: str) -> bytes: + """Download file content from direct download URL.""" + try: + async with httpx.AsyncClient() as client: + response = await client.get(download_url, timeout=60) + response.raise_for_status() + return response.content + except Exception as e: + logger.error(f"Failed to download from URL {download_url}: {e}") + raise + + def _parse_graph_date(self, date_str: Optional[str]) -> datetime: + """Parse Microsoft Graph date string to datetime.""" + if not date_str: + return datetime.now() + try: + if date_str.endswith('Z'): + return datetime.fromisoformat(date_str[:-1]).replace(tzinfo=None) + else: + return datetime.fromisoformat(date_str.replace('T', ' ')) + except (ValueError, AttributeError): + return datetime.now() + + async def _make_graph_request(self, url: str, method: str = "GET", + data: Optional[Dict] = None, params: Optional[Dict] = None) -> httpx.Response: + """Make authenticated API request to Microsoft Graph.""" + token = self.oauth.get_access_token() + headers = { + "Authorization": f"Bearer {token}", + "Content-Type": "application/json", + } + + async with httpx.AsyncClient() as client: + if method.upper() == "GET": + response = await client.get(url, headers=headers, params=params, timeout=30) + elif method.upper() == "POST": + response = await client.post(url, headers=headers, json=data, timeout=30) + elif method.upper() == "DELETE": + response = await client.delete(url, headers=headers, timeout=30) + else: + raise ValueError(f"Unsupported HTTP method: {method}") + + response.raise_for_status() + return response + + def _get_mime_type(self, filename: str) -> str: + """Get MIME type based on file extension.""" + import mimetypes + mime_type, _ = mimetypes.guess_type(filename) + return mime_type or "application/octet-stream" + + # Webhook methods - BaseConnector interface + def handle_webhook_validation(self, request_method: str, + headers: Dict[str, str], + query_params: Dict[str, str]) -> Optional[str]: + """Handle webhook validation (Graph API specific).""" + if request_method == "POST" and "validationToken" in query_params: + return query_params["validationToken"] return None - def extract_webhook_channel_id( - self, payload: Dict[str, Any], headers: Dict[str, str] - ) -> Optional[str]: - """Extract SharePoint subscription ID from webhook payload""" - values = payload.get("value", []) - return values[0].get("subscriptionId") if values else None + def extract_webhook_channel_id(self, payload: Dict[str, Any], + headers: Dict[str, str]) -> Optional[str]: + """Extract channel/subscription ID from webhook payload.""" + notifications = payload.get("value", []) + if notifications: + return notifications[0].get("subscriptionId") + return None async def handle_webhook(self, payload: Dict[str, Any]) -> List[str]: - values = payload.get("value", []) - file_ids = [] - for item in values: - resource_data = item.get("resourceData", {}) - file_id = resource_data.get("id") - if file_id: - file_ids.append(file_id) - return file_ids + """Handle webhook notification and return affected file IDs.""" + affected_files: List[str] = [] + notifications = payload.get("value", []) + for notification in notifications: + resource = notification.get("resource") + if resource and "/drive/items/" in resource: + file_id = resource.split("/drive/items/")[-1] + affected_files.append(file_id) + return affected_files - async def cleanup_subscription( - self, subscription_id: str, resource_id: str = None - ) -> bool: - if not self._authenticated: + async def cleanup_subscription(self, subscription_id: str) -> bool: + """Clean up subscription - BaseConnector interface.""" + if subscription_id == "no-webhook-configured": + logger.info("No subscription to cleanup (webhook was not configured)") + return True + + try: + if not await self.authenticate(): + logger.error("OneDrive authentication failed during subscription cleanup") + return False + + token = self.oauth.get_access_token() + headers = {"Authorization": f"Bearer {token}"} + + url = f"{self._graph_base_url}/subscriptions/{subscription_id}" + + async with httpx.AsyncClient() as client: + response = await client.delete(url, headers=headers, timeout=30) + + if response.status_code in [200, 204, 404]: + logger.info(f"OneDrive subscription {subscription_id} cleaned up successfully") + return True + else: + logger.warning(f"Unexpected response cleaning up subscription: {response.status_code}") + return False + + except Exception as e: + logger.error(f"Failed to cleanup OneDrive subscription {subscription_id}: {e}") return False - token = self.oauth.get_access_token() - async with httpx.AsyncClient() as client: - resp = await client.delete( - f"{self.base_url}/subscriptions/{subscription_id}", - headers={"Authorization": f"Bearer {token}"}, - ) - return resp.status_code in (200, 204) diff --git a/src/connectors/onedrive/oauth.py b/src/connectors/onedrive/oauth.py index ad2f17d1..a2c94d15 100644 --- a/src/connectors/onedrive/oauth.py +++ b/src/connectors/onedrive/oauth.py @@ -1,18 +1,28 @@ import os import json +import logging +from typing import Optional, Dict, Any + import aiofiles -from datetime import datetime -import httpx +import msal + +logger = logging.getLogger(__name__) class OneDriveOAuth: - """Direct token management for OneDrive, bypassing MSAL cache format""" + """Handles Microsoft Graph OAuth for OneDrive (personal Microsoft accounts by default).""" - SCOPES = [ - "offline_access", - "Files.Read.All", - ] + # Reserved scopes that must NOT be sent on token or silent calls + RESERVED_SCOPES = {"openid", "profile", "offline_access"} + # For PERSONAL Microsoft Accounts (OneDrive consumer): + # - Use AUTH_SCOPES for interactive auth (consent + refresh token issuance) + # - Use RESOURCE_SCOPES for acquire_token_silent / refresh paths + AUTH_SCOPES = ["User.Read", "Files.Read.All", "offline_access"] + RESOURCE_SCOPES = ["User.Read", "Files.Read.All"] + SCOPES = AUTH_SCOPES # Backward-compat alias if something references .SCOPES + + # Kept for reference; MSAL derives endpoints from `authority` AUTH_ENDPOINT = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize" TOKEN_ENDPOINT = "https://login.microsoftonline.com/common/oauth2/v2.0/token" @@ -21,168 +31,292 @@ class OneDriveOAuth: client_id: str, client_secret: str, token_file: str = "onedrive_token.json", + authority: str = "https://login.microsoftonline.com/common", + allow_json_refresh: bool = True, ): + """ + Initialize OneDriveOAuth. + + Args: + client_id: Azure AD application (client) ID. + client_secret: Azure AD application client secret. + token_file: Path to persisted token cache file (MSAL cache format). + authority: Usually "https://login.microsoftonline.com/common" for MSA + org, + or tenant-specific for work/school. + allow_json_refresh: If True, permit one-time migration from legacy flat JSON + {"access_token","refresh_token",...}. Otherwise refuse it. + """ self.client_id = client_id self.client_secret = client_secret self.token_file = token_file - self._tokens = None - self._load_tokens() + self.authority = authority + self.allow_json_refresh = allow_json_refresh + self.token_cache = msal.SerializableTokenCache() + self._current_account = None - def _load_tokens(self): - """Load tokens from file""" - if os.path.exists(self.token_file): - with open(self.token_file, "r") as f: - self._tokens = json.loads(f.read()) - print(f"Loaded tokens from {self.token_file}") - else: - print(f"No token file found at {self.token_file}") + # Initialize MSAL Confidential Client + self.app = msal.ConfidentialClientApplication( + client_id=self.client_id, + client_credential=self.client_secret, + authority=self.authority, + token_cache=self.token_cache, + ) - async def _save_tokens(self): - """Save tokens to file""" - if self._tokens: - async with aiofiles.open(self.token_file, "w") as f: - await f.write(json.dumps(self._tokens, indent=2)) - - def _is_token_expired(self) -> bool: - """Check if current access token is expired""" - if not self._tokens or 'expiry' not in self._tokens: - return True - - expiry_str = self._tokens['expiry'] - # Handle different expiry formats + async def load_credentials(self) -> bool: + """Load existing credentials from token file (async).""" try: - if expiry_str.endswith('Z'): - expiry_dt = datetime.fromisoformat(expiry_str[:-1]) - else: - expiry_dt = datetime.fromisoformat(expiry_str) - - # Add 5-minute buffer - import datetime as dt - now = datetime.now() - return now >= (expiry_dt - dt.timedelta(minutes=5)) - except: - return True + logger.debug(f"OneDrive OAuth loading credentials from: {self.token_file}") + if os.path.exists(self.token_file): + logger.debug(f"Token file exists, reading: {self.token_file}") + + # Read the token file + async with aiofiles.open(self.token_file, "r") as f: + cache_data = await f.read() + logger.debug(f"Read {len(cache_data)} chars from token file") + + if cache_data.strip(): + # 1) Try legacy flat JSON first + try: + json_data = json.loads(cache_data) + if isinstance(json_data, dict) and "refresh_token" in json_data: + if self.allow_json_refresh: + logger.debug( + "Found legacy JSON refresh_token and allow_json_refresh=True; attempting migration refresh" + ) + return await self._refresh_from_json_token(json_data) + else: + logger.warning( + "Token file contains a legacy JSON refresh_token, but allow_json_refresh=False. " + "Delete the file and re-auth." + ) + return False + except json.JSONDecodeError: + logger.debug("Token file is not flat JSON; attempting MSAL cache format") + + # 2) Try MSAL cache format + logger.debug("Attempting MSAL cache deserialization") + self.token_cache.deserialize(cache_data) + + # Get accounts from loaded cache + accounts = self.app.get_accounts() + logger.debug(f"Found {len(accounts)} accounts in MSAL cache") + if accounts: + self._current_account = accounts[0] + logger.debug(f"Set current account: {self._current_account.get('username', 'no username')}") + + # Use RESOURCE_SCOPES (no reserved scopes) for silent acquisition + result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=self._current_account) + logger.debug(f"Silent token acquisition result keys: {list(result.keys()) if result else 'None'}") + if result and "access_token" in result: + logger.debug("Silent token acquisition successful") + await self.save_cache() + return True + else: + error_msg = (result or {}).get("error") or "No result" + logger.warning(f"Silent token acquisition failed: {error_msg}") + else: + logger.debug(f"Token file {self.token_file} is empty") + else: + logger.debug(f"Token file does not exist: {self.token_file}") - async def _refresh_access_token(self) -> bool: - """Refresh the access token using refresh token""" - if not self._tokens or 'refresh_token' not in self._tokens: return False - data = { - 'client_id': self.client_id, - 'client_secret': self.client_secret, - 'refresh_token': self._tokens['refresh_token'], - 'grant_type': 'refresh_token', - 'scope': ' '.join(self.SCOPES) - } - - async with httpx.AsyncClient() as client: - try: - response = await client.post(self.TOKEN_ENDPOINT, data=data) - response.raise_for_status() - token_data = response.json() - - # Update tokens - self._tokens['token'] = token_data['access_token'] - if 'refresh_token' in token_data: - self._tokens['refresh_token'] = token_data['refresh_token'] - - # Calculate expiry - expires_in = token_data.get('expires_in', 3600) - import datetime as dt - expiry = datetime.now() + dt.timedelta(seconds=expires_in) - self._tokens['expiry'] = expiry.isoformat() - - await self._save_tokens() - print("Access token refreshed successfully") - return True - - except Exception as e: - print(f"Failed to refresh token: {e}") - return False - - async def is_authenticated(self) -> bool: - """Check if we have valid credentials""" - if not self._tokens: + except Exception as e: + logger.error(f"Failed to load OneDrive credentials: {e}") + import traceback + traceback.print_exc() return False - # If token is expired, try to refresh - if self._is_token_expired(): - print("Token expired, attempting refresh...") - if await self._refresh_access_token(): - return True - else: + async def _refresh_from_json_token(self, token_data: dict) -> bool: + """ + Use refresh token from a legacy JSON file to get new tokens (one-time migration path). + Prefer using an MSAL cache file and acquire_token_silent(); this path is only for migrating older files. + """ + try: + refresh_token = token_data.get("refresh_token") + if not refresh_token: + logger.error("No refresh_token found in JSON file - cannot refresh") + logger.error("You must re-authenticate interactively to obtain a valid token") return False - - return True - def get_access_token(self) -> str: - """Get current access token""" - if not self._tokens or 'token' not in self._tokens: - raise ValueError("No access token available") - - if self._is_token_expired(): - raise ValueError("Access token expired and refresh failed") - - return self._tokens['token'] + # Use only RESOURCE_SCOPES when refreshing (no reserved scopes) + refresh_scopes = [s for s in self.RESOURCE_SCOPES if s not in self.RESERVED_SCOPES] + logger.debug(f"Using refresh token; refresh scopes = {refresh_scopes}") - async def revoke_credentials(self): - """Clear tokens""" - self._tokens = None - if os.path.exists(self.token_file): - os.remove(self.token_file) + result = self.app.acquire_token_by_refresh_token( + refresh_token=refresh_token, + scopes=refresh_scopes, + ) - # Keep these methods for compatibility with your existing OAuth flow - def create_authorization_url(self, redirect_uri: str) -> str: - """Create authorization URL for OAuth flow""" - from urllib.parse import urlencode - - params = { - 'client_id': self.client_id, - 'response_type': 'code', - 'redirect_uri': redirect_uri, - 'scope': ' '.join(self.SCOPES), - 'response_mode': 'query' + if result and "access_token" in result: + logger.debug("Successfully refreshed token via legacy JSON path") + await self.save_cache() + + accounts = self.app.get_accounts() + logger.debug(f"After refresh, found {len(accounts)} accounts") + if accounts: + self._current_account = accounts[0] + logger.debug(f"Set current account after refresh: {self._current_account.get('username', 'no username')}") + return True + + # Error handling + err = (result or {}).get("error_description") or (result or {}).get("error") or "Unknown error" + logger.error(f"Refresh token failed: {err}") + + if any(code in err for code in ("AADSTS70000", "invalid_grant", "interaction_required")): + logger.warning( + "Refresh denied due to unauthorized/expired scopes or invalid grant. " + "Delete the token file and perform interactive sign-in with correct scopes." + ) + + return False + + except Exception as e: + logger.error(f"Exception during refresh from JSON token: {e}") + import traceback + traceback.print_exc() + return False + + async def save_cache(self): + """Persist the token cache to file.""" + try: + # Ensure parent directory exists + parent = os.path.dirname(os.path.abspath(self.token_file)) + if parent and not os.path.exists(parent): + os.makedirs(parent, exist_ok=True) + + cache_data = self.token_cache.serialize() + if cache_data: + async with aiofiles.open(self.token_file, "w") as f: + await f.write(cache_data) + logger.debug(f"Token cache saved to {self.token_file}") + except Exception as e: + logger.error(f"Failed to save token cache: {e}") + + def create_authorization_url(self, redirect_uri: str, state: Optional[str] = None) -> str: + """Create authorization URL for OAuth flow.""" + # Store redirect URI for later use in callback + self._redirect_uri = redirect_uri + + kwargs: Dict[str, Any] = { + # Interactive auth includes offline_access + "scopes": self.AUTH_SCOPES, + "redirect_uri": redirect_uri, + "prompt": "consent", # ensure refresh token on first run } - - auth_url = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize" - return f"{auth_url}?{urlencode(params)}" + if state: + kwargs["state"] = state # Optional CSRF protection + + auth_url = self.app.get_authorization_request_url(**kwargs) + + logger.debug(f"Generated auth URL: {auth_url}") + logger.debug(f"Auth scopes: {self.AUTH_SCOPES}") + + return auth_url async def handle_authorization_callback( self, authorization_code: str, redirect_uri: str ) -> bool: - """Handle OAuth callback and exchange code for tokens""" - data = { - 'client_id': self.client_id, - 'client_secret': self.client_secret, - 'code': authorization_code, - 'grant_type': 'authorization_code', - 'redirect_uri': redirect_uri, - 'scope': ' '.join(self.SCOPES) - } + """Handle OAuth callback and exchange code for tokens.""" + try: + result = self.app.acquire_token_by_authorization_code( + authorization_code, + scopes=self.AUTH_SCOPES, # same as authorize step + redirect_uri=redirect_uri, + ) - async with httpx.AsyncClient() as client: - try: - response = await client.post(self.TOKEN_ENDPOINT, data=data) - response.raise_for_status() - token_data = response.json() + if result and "access_token" in result: + accounts = self.app.get_accounts() + if accounts: + self._current_account = accounts[0] - # Store tokens in our format - import datetime as dt - expires_in = token_data.get('expires_in', 3600) - expiry = datetime.now() + dt.timedelta(seconds=expires_in) - - self._tokens = { - 'token': token_data['access_token'], - 'refresh_token': token_data['refresh_token'], - 'scopes': self.SCOPES, - 'expiry': expiry.isoformat() - } - - await self._save_tokens() - print("Authorization successful, tokens saved") + await self.save_cache() + logger.info("OneDrive OAuth authorization successful") return True - except Exception as e: - print(f"Authorization failed: {e}") - return False + error_msg = (result or {}).get("error_description") or (result or {}).get("error") or "Unknown error" + logger.error(f"OneDrive OAuth authorization failed: {error_msg}") + return False + + except Exception as e: + logger.error(f"Exception during OneDrive OAuth authorization: {e}") + return False + + async def is_authenticated(self) -> bool: + """Check if we have valid credentials.""" + try: + # First try to load credentials if we haven't already + if not self._current_account: + await self.load_credentials() + + # Try to get a token (MSAL will refresh if needed) + if self._current_account: + result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=self._current_account) + if result and "access_token" in result: + return True + else: + error_msg = (result or {}).get("error") or "No result returned" + logger.debug(f"Token acquisition failed for current account: {error_msg}") + + # Fallback: try without specific account + result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=None) + if result and "access_token" in result: + accounts = self.app.get_accounts() + if accounts: + self._current_account = accounts[0] + return True + + return False + + except Exception as e: + logger.error(f"Authentication check failed: {e}") + return False + + def get_access_token(self) -> str: + """Get an access token for Microsoft Graph.""" + try: + # Try with current account first + if self._current_account: + result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=self._current_account) + if result and "access_token" in result: + return result["access_token"] + + # Fallback: try without specific account + result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=None) + if result and "access_token" in result: + return result["access_token"] + + # If we get here, authentication has failed + error_msg = (result or {}).get("error_description") or (result or {}).get("error") or "No valid authentication" + raise ValueError(f"Failed to acquire access token: {error_msg}") + + except Exception as e: + logger.error(f"Failed to get access token: {e}") + raise + + async def revoke_credentials(self): + """Clear token cache and remove token file.""" + try: + # Clear in-memory state + self._current_account = None + self.token_cache = msal.SerializableTokenCache() + + # Recreate MSAL app with fresh cache + self.app = msal.ConfidentialClientApplication( + client_id=self.client_id, + client_credential=self.client_secret, + authority=self.authority, + token_cache=self.token_cache, + ) + + # Remove token file + if os.path.exists(self.token_file): + os.remove(self.token_file) + logger.info(f"Removed OneDrive token file: {self.token_file}") + + except Exception as e: + logger.error(f"Failed to revoke OneDrive credentials: {e}") + + def get_service(self) -> str: + """Return an access token (Graph client is just the bearer).""" + return self.get_access_token() diff --git a/src/connectors/sharepoint/connector.py b/src/connectors/sharepoint/connector.py index c31b9acd..3b9769e7 100644 --- a/src/connectors/sharepoint/connector.py +++ b/src/connectors/sharepoint/connector.py @@ -1,241 +1,564 @@ +import logging from pathlib import Path +from typing import List, Dict, Any, Optional +from urllib.parse import urlparse +from datetime import datetime import httpx -import uuid -from datetime import datetime, timedelta -from typing import Dict, List, Any, Optional from ..base import BaseConnector, ConnectorDocument, DocumentACL from .oauth import SharePointOAuth +logger = logging.getLogger(__name__) + class SharePointConnector(BaseConnector): - """SharePoint Sites connector using Microsoft Graph API""" + """SharePoint connector using MSAL-based OAuth for authentication""" + # Required BaseConnector class attributes CLIENT_ID_ENV_VAR = "MICROSOFT_GRAPH_OAUTH_CLIENT_ID" CLIENT_SECRET_ENV_VAR = "MICROSOFT_GRAPH_OAUTH_CLIENT_SECRET" - + # Connector metadata CONNECTOR_NAME = "SharePoint" - CONNECTOR_DESCRIPTION = "Connect to SharePoint sites to sync team documents" + CONNECTOR_DESCRIPTION = "Connect to SharePoint to sync documents and files" CONNECTOR_ICON = "sharepoint" - + def __init__(self, config: Dict[str, Any]): - super().__init__(config) + super().__init__(config) # Fix: Call parent init first + + def __init__(self, config: Dict[str, Any]): + logger.debug(f"SharePoint connector __init__ called with config type: {type(config)}") + logger.debug(f"SharePoint connector __init__ config value: {config}") + + # Ensure we always pass a valid config to the base class + if config is None: + logger.debug("Config was None, using empty dict") + config = {} + + try: + logger.debug("Calling super().__init__") + super().__init__(config) # Now safe to call with empty dict instead of None + logger.debug("super().__init__ completed successfully") + except Exception as e: + logger.error(f"super().__init__ failed: {e}") + raise + + # Initialize with defaults that allow the connector to be listed + self.client_id = None + self.client_secret = None + self.tenant_id = config.get("tenant_id", "common") + self.sharepoint_url = config.get("sharepoint_url") + self.redirect_uri = config.get("redirect_uri", "http://localhost") + + # Try to get credentials, but don't fail if they're missing + try: + logger.debug("Attempting to get client_id") + self.client_id = self.get_client_id() + logger.debug(f"Got client_id: {self.client_id is not None}") + except Exception as e: + logger.debug(f"Failed to get client_id: {e}") + pass # Credentials not available, that's OK for listing + + try: + logger.debug("Attempting to get client_secret") + self.client_secret = self.get_client_secret() + logger.debug(f"Got client_secret: {self.client_secret is not None}") + except Exception as e: + logger.debug(f"Failed to get client_secret: {e}") + pass # Credentials not available, that's OK for listing + + # Token file setup project_root = Path(__file__).resolve().parent.parent.parent.parent - token_file = config.get("token_file") or str(project_root / "onedrive_token.json") - self.oauth = SharePointOAuth( - client_id=self.get_client_id(), - client_secret=self.get_client_secret(), - token_file=token_file, - ) - self.subscription_id = config.get("subscription_id") or config.get( - "webhook_channel_id" - ) - self.base_url = "https://graph.microsoft.com/v1.0" - - # SharePoint site configuration - self.site_id = config.get("site_id") # Required for SharePoint - - async def authenticate(self) -> bool: - if await self.oauth.is_authenticated(): - self._authenticated = True - return True - return False - - async def setup_subscription(self) -> str: - if not self._authenticated: - raise ValueError("Not authenticated") - - webhook_url = self.config.get("webhook_url") - if not webhook_url: - raise ValueError("webhook_url required in config for subscriptions") - - expiration = (datetime.utcnow() + timedelta(days=2)).isoformat() + "Z" - body = { - "changeType": "created,updated,deleted", - "notificationUrl": webhook_url, - "resource": f"/sites/{self.site_id}/drive/root", - "expirationDateTime": expiration, - "clientState": str(uuid.uuid4()), - } - - token = self.oauth.get_access_token() - async with httpx.AsyncClient() as client: - resp = await client.post( - f"{self.base_url}/subscriptions", - json=body, - headers={"Authorization": f"Bearer {token}"}, - ) - resp.raise_for_status() - data = resp.json() - - self.subscription_id = data["id"] - return self.subscription_id - - async def list_files( - self, page_token: Optional[str] = None, limit: int = 100 - ) -> Dict[str, Any]: - if not self._authenticated: - raise ValueError("Not authenticated") - - params = {"$top": str(limit)} - if page_token: - params["$skiptoken"] = page_token - - token = self.oauth.get_access_token() - async with httpx.AsyncClient() as client: - resp = await client.get( - f"{self.base_url}/sites/{self.site_id}/drive/root/children", - params=params, - headers={"Authorization": f"Bearer {token}"}, - ) - resp.raise_for_status() - data = resp.json() - - files = [] - for item in data.get("value", []): - if item.get("file"): - files.append( - { - "id": item["id"], - "name": item["name"], - "mimeType": item.get("file", {}).get( - "mimeType", "application/octet-stream" - ), - "webViewLink": item.get("webUrl"), - "createdTime": item.get("createdDateTime"), - "modifiedTime": item.get("lastModifiedDateTime"), - } - ) - - next_token = None - next_link = data.get("@odata.nextLink") - if next_link: - from urllib.parse import urlparse, parse_qs - - parsed = urlparse(next_link) - next_token = parse_qs(parsed.query).get("$skiptoken", [None])[0] - - return {"files": files, "nextPageToken": next_token} - - async def get_file_content(self, file_id: str) -> ConnectorDocument: - if not self._authenticated: - raise ValueError("Not authenticated") - - token = self.oauth.get_access_token() - headers = {"Authorization": f"Bearer {token}"} - async with httpx.AsyncClient() as client: - meta_resp = await client.get( - f"{self.base_url}/sites/{self.site_id}/drive/items/{file_id}", - headers=headers, - ) - meta_resp.raise_for_status() - metadata = meta_resp.json() - - content_resp = await client.get( - f"{self.base_url}/sites/{self.site_id}/drive/items/{file_id}/content", - headers=headers, - ) - content = content_resp.content - - # Handle the possibility of this being a redirect - if content_resp.status_code in (301, 302, 303, 307, 308): - redirect_url = content_resp.headers.get("Location") - if redirect_url: - content_resp = await client.get(redirect_url) - content_resp.raise_for_status() - content = content_resp.content + token_file = config.get("token_file") or str(project_root / "sharepoint_token.json") + Path(token_file).parent.mkdir(parents=True, exist_ok=True) + + # Only initialize OAuth if we have credentials + if self.client_id and self.client_secret: + connection_id = config.get("connection_id", "default") + + # Use token_file from config if provided, otherwise generate one + if config.get("token_file"): + oauth_token_file = config["token_file"] else: - content_resp.raise_for_status() - - perm_resp = await client.get( - f"{self.base_url}/sites/{self.site_id}/drive/items/{file_id}/permissions", - headers=headers, + oauth_token_file = f"sharepoint_token_{connection_id}.json" + + authority = f"https://login.microsoftonline.com/{self.tenant_id}" if self.tenant_id != "common" else "https://login.microsoftonline.com/common" + + self.oauth = SharePointOAuth( + client_id=self.client_id, + client_secret=self.client_secret, + token_file=oauth_token_file, + authority=authority ) - perm_resp.raise_for_status() - permissions = perm_resp.json() - - acl = self._parse_permissions(metadata, permissions) - modified = datetime.fromisoformat( - metadata["lastModifiedDateTime"].replace("Z", "+00:00") - ).replace(tzinfo=None) - created = datetime.fromisoformat( - metadata["createdDateTime"].replace("Z", "+00:00") - ).replace(tzinfo=None) - - document = ConnectorDocument( - id=metadata["id"], - filename=metadata["name"], - mimetype=metadata.get("file", {}).get( - "mimeType", "application/octet-stream" - ), - content=content, - source_url=metadata.get("webUrl"), - acl=acl, - modified_time=modified, - created_time=created, - metadata={"size": metadata.get("size")}, - ) - return document - - def _parse_permissions( - self, metadata: Dict[str, Any], permissions: Dict[str, Any] - ) -> DocumentACL: - acl = DocumentACL() - owner = metadata.get("createdBy", {}).get("user", {}).get("email") - if owner: - acl.owner = owner - for perm in permissions.get("value", []): - role = perm.get("roles", ["read"])[0] - grantee = perm.get("grantedToV2") or perm.get("grantedTo") - if not grantee: - continue - user = grantee.get("user") - if user and user.get("email"): - acl.user_permissions[user["email"]] = role - group = grantee.get("group") - if group and group.get("email"): - acl.group_permissions[group["email"]] = role - return acl - - def handle_webhook_validation( - self, request_method: str, headers: Dict[str, str], query_params: Dict[str, str] - ) -> Optional[str]: - """Handle Microsoft Graph webhook validation""" - if request_method == "GET": - validation_token = query_params.get("validationtoken") or query_params.get( - "validationToken" - ) - if validation_token: - return validation_token - return None - - def extract_webhook_channel_id( - self, payload: Dict[str, Any], headers: Dict[str, str] - ) -> Optional[str]: - """Extract SharePoint subscription ID from webhook payload""" - values = payload.get("value", []) - return values[0].get("subscriptionId") if values else None - - async def handle_webhook(self, payload: Dict[str, Any]) -> List[str]: - values = payload.get("value", []) - file_ids = [] - for item in values: - resource_data = item.get("resourceData", {}) - file_id = resource_data.get("id") - if file_id: - file_ids.append(file_id) - return file_ids - - async def cleanup_subscription( - self, subscription_id: str, resource_id: str = None - ) -> bool: - if not self._authenticated: + else: + self.oauth = None + + # Track subscription ID for webhooks + self._subscription_id: Optional[str] = None + + # Add Graph API defaults similar to Google Drive flags + self._graph_api_version = "v1.0" + self._default_params = { + "$select": "id,name,size,lastModifiedDateTime,createdDateTime,webUrl,file,folder,@microsoft.graph.downloadUrl" + } + + @property + def _graph_base_url(self) -> str: + """Base URL for Microsoft Graph API calls""" + return f"https://graph.microsoft.com/{self._graph_api_version}" + + def emit(self, doc: ConnectorDocument) -> None: + """ + Emit a ConnectorDocument instance. + Override this method to integrate with your ingestion pipeline. + """ + logger.debug(f"Emitting SharePoint document: {doc.id} ({doc.filename})") + + async def authenticate(self) -> bool: + """Test authentication - BaseConnector interface""" + logger.debug(f"SharePoint authenticate() called, oauth is None: {self.oauth is None}") + try: + if not self.oauth: + logger.debug("SharePoint authentication failed: OAuth not initialized") + self._authenticated = False + return False + + logger.debug("Loading SharePoint credentials...") + # Try to load existing credentials first + load_result = await self.oauth.load_credentials() + logger.debug(f"Load credentials result: {load_result}") + + logger.debug("Checking SharePoint authentication status...") + authenticated = await self.oauth.is_authenticated() + logger.debug(f"SharePoint is_authenticated result: {authenticated}") + + self._authenticated = authenticated + return authenticated + except Exception as e: + logger.error(f"SharePoint authentication failed: {e}") + import traceback + traceback.print_exc() + self._authenticated = False return False - token = self.oauth.get_access_token() - async with httpx.AsyncClient() as client: - resp = await client.delete( - f"{self.base_url}/subscriptions/{subscription_id}", - headers={"Authorization": f"Bearer {token}"}, + + def get_auth_url(self) -> str: + """Get OAuth authorization URL""" + if not self.oauth: + raise RuntimeError("SharePoint OAuth not initialized - missing credentials") + return self.oauth.create_authorization_url(self.redirect_uri) + + async def handle_oauth_callback(self, auth_code: str) -> Dict[str, Any]: + """Handle OAuth callback""" + if not self.oauth: + raise RuntimeError("SharePoint OAuth not initialized - missing credentials") + try: + success = await self.oauth.handle_authorization_callback(auth_code, self.redirect_uri) + if success: + self._authenticated = True + return {"status": "success"} + else: + raise ValueError("OAuth callback failed") + except Exception as e: + logger.error(f"OAuth callback failed: {e}") + raise + + def sync_once(self) -> None: + """ + Perform a one-shot sync of SharePoint files and emit documents. + This method mirrors the Google Drive connector's sync_once functionality. + """ + import asyncio + + async def _async_sync(): + try: + # Get list of files + file_list = await self.list_files(max_files=1000) # Adjust as needed + files = file_list.get("files", []) + + for file_info in files: + try: + file_id = file_info.get("id") + if not file_id: + continue + + # Get full document content + doc = await self.get_file_content(file_id) + self.emit(doc) + + except Exception as e: + logger.error(f"Failed to sync SharePoint file {file_info.get('name', 'unknown')}: {e}") + continue + + except Exception as e: + logger.error(f"SharePoint sync_once failed: {e}") + raise + + # Run the async sync + if hasattr(asyncio, 'run'): + asyncio.run(_async_sync()) + else: + # Python < 3.7 compatibility + loop = asyncio.get_event_loop() + loop.run_until_complete(_async_sync()) + + async def setup_subscription(self) -> str: + """Set up real-time subscription for file changes - BaseConnector interface""" + webhook_url = self.config.get('webhook_url') + if not webhook_url: + logger.warning("No webhook URL configured, skipping SharePoint subscription setup") + return "no-webhook-configured" + + try: + # Ensure we're authenticated + if not await self.authenticate(): + raise RuntimeError("SharePoint authentication failed during subscription setup") + + token = self.oauth.get_access_token() + + # Microsoft Graph subscription for SharePoint site + site_info = self._parse_sharepoint_url() + if site_info: + resource = f"sites/{site_info['host_name']}:/sites/{site_info['site_name']}:/drive/root" + else: + resource = "/me/drive/root" + + subscription_data = { + "changeType": "created,updated,deleted", + "notificationUrl": f"{webhook_url}/webhook/sharepoint", + "resource": resource, + "expirationDateTime": self._get_subscription_expiry(), + "clientState": f"sharepoint_{self.tenant_id}" + } + + headers = { + "Authorization": f"Bearer {token}", + "Content-Type": "application/json" + } + + url = f"{self._graph_base_url}/subscriptions" + + async with httpx.AsyncClient() as client: + response = await client.post(url, json=subscription_data, headers=headers, timeout=30) + response.raise_for_status() + + result = response.json() + subscription_id = result.get("id") + + if subscription_id: + self._subscription_id = subscription_id + logger.info(f"SharePoint subscription created: {subscription_id}") + return subscription_id + else: + raise ValueError("No subscription ID returned from Microsoft Graph") + + except Exception as e: + logger.error(f"Failed to setup SharePoint subscription: {e}") + raise + + def _get_subscription_expiry(self) -> str: + """Get subscription expiry time (max 3 days for Graph API)""" + from datetime import datetime, timedelta + expiry = datetime.utcnow() + timedelta(days=3) # 3 days max for Graph + return expiry.strftime("%Y-%m-%dT%H:%M:%S.%fZ") + + def _parse_sharepoint_url(self) -> Optional[Dict[str, str]]: + """Parse SharePoint URL to extract site information for Graph API""" + if not self.sharepoint_url: + return None + + try: + parsed = urlparse(self.sharepoint_url) + # Extract hostname and site name from URL like: https://contoso.sharepoint.com/sites/teamsite + host_name = parsed.netloc + path_parts = parsed.path.strip('/').split('/') + + if len(path_parts) >= 2 and path_parts[0] == 'sites': + site_name = path_parts[1] + return { + "host_name": host_name, + "site_name": site_name + } + except Exception as e: + logger.warning(f"Could not parse SharePoint URL {self.sharepoint_url}: {e}") + + return None + + async def list_files(self, page_token: Optional[str] = None, max_files: Optional[int] = None) -> Dict[str, Any]: + """List all files using Microsoft Graph API - BaseConnector interface""" + try: + # Ensure authentication + if not await self.authenticate(): + raise RuntimeError("SharePoint authentication failed during file listing") + + files = [] + max_files_value = max_files if max_files is not None else 100 + + # Build Graph API URL for the site or fallback to user's OneDrive + site_info = self._parse_sharepoint_url() + if site_info: + base_url = f"{self._graph_base_url}/sites/{site_info['host_name']}:/sites/{site_info['site_name']}:/drive/root/children" + else: + base_url = f"{self._graph_base_url}/me/drive/root/children" + + params = dict(self._default_params) + params["$top"] = max_files_value + + if page_token: + params["$skiptoken"] = page_token + + response = await self._make_graph_request(base_url, params=params) + data = response.json() + + items = data.get("value", []) + for item in items: + # Only include files, not folders + if item.get("file"): + files.append({ + "id": item.get("id", ""), + "name": item.get("name", ""), + "path": f"/drive/items/{item.get('id')}", + "size": int(item.get("size", 0)), + "modified": item.get("lastModifiedDateTime"), + "created": item.get("createdDateTime"), + "mime_type": item.get("file", {}).get("mimeType", self._get_mime_type(item.get("name", ""))), + "url": item.get("webUrl", ""), + "download_url": item.get("@microsoft.graph.downloadUrl") + }) + + # Check for next page + next_page_token = None + next_link = data.get("@odata.nextLink") + if next_link: + from urllib.parse import urlparse, parse_qs + parsed = urlparse(next_link) + query_params = parse_qs(parsed.query) + if "$skiptoken" in query_params: + next_page_token = query_params["$skiptoken"][0] + + return { + "files": files, + "next_page_token": next_page_token + } + + except Exception as e: + logger.error(f"Failed to list SharePoint files: {e}") + return {"files": [], "next_page_token": None} # Return empty result instead of raising + + async def get_file_content(self, file_id: str) -> ConnectorDocument: + """Get file content and metadata - BaseConnector interface""" + try: + # Ensure authentication + if not await self.authenticate(): + raise RuntimeError("SharePoint authentication failed during file content retrieval") + + # First get file metadata using Graph API + file_metadata = await self._get_file_metadata_by_id(file_id) + + if not file_metadata: + raise ValueError(f"File not found: {file_id}") + + # Download file content + download_url = file_metadata.get("download_url") + if download_url: + content = await self._download_file_from_url(download_url) + else: + content = await self._download_file_content(file_id) + + # Create ACL from metadata + acl = DocumentACL( + owner="", # Graph API requires additional calls for detailed permissions + user_permissions={}, + group_permissions={} ) - return resp.status_code in (200, 204) + + # Parse dates + modified_time = self._parse_graph_date(file_metadata.get("modified")) + created_time = self._parse_graph_date(file_metadata.get("created")) + + return ConnectorDocument( + id=file_id, + filename=file_metadata.get("name", ""), + mimetype=file_metadata.get("mime_type", "application/octet-stream"), + content=content, + source_url=file_metadata.get("url", ""), + acl=acl, + modified_time=modified_time, + created_time=created_time, + metadata={ + "sharepoint_path": file_metadata.get("path", ""), + "sharepoint_url": self.sharepoint_url, + "size": file_metadata.get("size", 0) + } + ) + + except Exception as e: + logger.error(f"Failed to get SharePoint file content {file_id}: {e}") + raise + + async def _get_file_metadata_by_id(self, file_id: str) -> Optional[Dict[str, Any]]: + """Get file metadata by ID using Graph API""" + try: + # Try site-specific path first, then fallback to user drive + site_info = self._parse_sharepoint_url() + if site_info: + url = f"{self._graph_base_url}/sites/{site_info['host_name']}:/sites/{site_info['site_name']}:/drive/items/{file_id}" + else: + url = f"{self._graph_base_url}/me/drive/items/{file_id}" + + params = dict(self._default_params) + + response = await self._make_graph_request(url, params=params) + item = response.json() + + if item.get("file"): + return { + "id": file_id, + "name": item.get("name", ""), + "path": f"/drive/items/{file_id}", + "size": int(item.get("size", 0)), + "modified": item.get("lastModifiedDateTime"), + "created": item.get("createdDateTime"), + "mime_type": item.get("file", {}).get("mimeType", self._get_mime_type(item.get("name", ""))), + "url": item.get("webUrl", ""), + "download_url": item.get("@microsoft.graph.downloadUrl") + } + + return None + + except Exception as e: + logger.error(f"Failed to get file metadata for {file_id}: {e}") + return None + + async def _download_file_content(self, file_id: str) -> bytes: + """Download file content by file ID using Graph API""" + try: + site_info = self._parse_sharepoint_url() + if site_info: + url = f"{self._graph_base_url}/sites/{site_info['host_name']}:/sites/{site_info['site_name']}:/drive/items/{file_id}/content" + else: + url = f"{self._graph_base_url}/me/drive/items/{file_id}/content" + + token = self.oauth.get_access_token() + headers = {"Authorization": f"Bearer {token}"} + + async with httpx.AsyncClient() as client: + response = await client.get(url, headers=headers, timeout=60) + response.raise_for_status() + return response.content + + except Exception as e: + logger.error(f"Failed to download file content for {file_id}: {e}") + raise + + async def _download_file_from_url(self, download_url: str) -> bytes: + """Download file content from direct download URL""" + try: + async with httpx.AsyncClient() as client: + response = await client.get(download_url, timeout=60) + response.raise_for_status() + return response.content + except Exception as e: + logger.error(f"Failed to download from URL {download_url}: {e}") + raise + + def _parse_graph_date(self, date_str: Optional[str]) -> datetime: + """Parse Microsoft Graph date string to datetime""" + if not date_str: + return datetime.now() + + try: + if date_str.endswith('Z'): + return datetime.fromisoformat(date_str[:-1]).replace(tzinfo=None) + else: + return datetime.fromisoformat(date_str.replace('T', ' ')) + except (ValueError, AttributeError): + return datetime.now() + + async def _make_graph_request(self, url: str, method: str = "GET", + data: Optional[Dict] = None, params: Optional[Dict] = None) -> httpx.Response: + """Make authenticated API request to Microsoft Graph""" + token = self.oauth.get_access_token() + headers = { + "Authorization": f"Bearer {token}", + "Content-Type": "application/json" + } + + async with httpx.AsyncClient() as client: + if method.upper() == "GET": + response = await client.get(url, headers=headers, params=params, timeout=30) + elif method.upper() == "POST": + response = await client.post(url, headers=headers, json=data, timeout=30) + elif method.upper() == "DELETE": + response = await client.delete(url, headers=headers, timeout=30) + else: + raise ValueError(f"Unsupported HTTP method: {method}") + + response.raise_for_status() + return response + + def _get_mime_type(self, filename: str) -> str: + """Get MIME type based on file extension""" + import mimetypes + mime_type, _ = mimetypes.guess_type(filename) + return mime_type or "application/octet-stream" + + # Webhook methods - BaseConnector interface + def handle_webhook_validation(self, request_method: str, headers: Dict[str, str], + query_params: Dict[str, str]) -> Optional[str]: + """Handle webhook validation (Graph API specific)""" + if request_method == "POST" and "validationToken" in query_params: + return query_params["validationToken"] + return None + + def extract_webhook_channel_id(self, payload: Dict[str, Any], + headers: Dict[str, str]) -> Optional[str]: + """Extract channel/subscription ID from webhook payload""" + notifications = payload.get("value", []) + if notifications: + return notifications[0].get("subscriptionId") + return None + + async def handle_webhook(self, payload: Dict[str, Any]) -> List[str]: + """Handle webhook notification and return affected file IDs""" + affected_files = [] + + # Process Microsoft Graph webhook payload + notifications = payload.get("value", []) + for notification in notifications: + resource = notification.get("resource") + if resource and "/drive/items/" in resource: + file_id = resource.split("/drive/items/")[-1] + affected_files.append(file_id) + + return affected_files + + async def cleanup_subscription(self, subscription_id: str) -> bool: + """Clean up subscription - BaseConnector interface""" + if subscription_id == "no-webhook-configured": + logger.info("No subscription to cleanup (webhook was not configured)") + return True + + try: + # Ensure authentication + if not await self.authenticate(): + logger.error("SharePoint authentication failed during subscription cleanup") + return False + + token = self.oauth.get_access_token() + headers = {"Authorization": f"Bearer {token}"} + + url = f"{self._graph_base_url}/subscriptions/{subscription_id}" + + async with httpx.AsyncClient() as client: + response = await client.delete(url, headers=headers, timeout=30) + + if response.status_code in [200, 204, 404]: + logger.info(f"SharePoint subscription {subscription_id} cleaned up successfully") + return True + else: + logger.warning(f"Unexpected response cleaning up subscription: {response.status_code}") + return False + + except Exception as e: + logger.error(f"Failed to cleanup SharePoint subscription {subscription_id}: {e}") + return False diff --git a/src/connectors/sharepoint/oauth.py b/src/connectors/sharepoint/oauth.py index 77ce1ecc..4a96581f 100644 --- a/src/connectors/sharepoint/oauth.py +++ b/src/connectors/sharepoint/oauth.py @@ -1,19 +1,28 @@ import os import json +import logging +from typing import Optional, Dict, Any + import aiofiles -from datetime import datetime -import httpx +import msal + +logger = logging.getLogger(__name__) class SharePointOAuth: - """Direct token management for SharePoint, bypassing MSAL cache format""" + """Handles Microsoft Graph OAuth authentication flow following Google Drive pattern.""" - SCOPES = [ - "offline_access", - "Files.Read.All", - "Sites.Read.All", - ] + # Reserved scopes that must NOT be sent on token or silent calls + RESERVED_SCOPES = {"openid", "profile", "offline_access"} + # For PERSONAL Microsoft Accounts (OneDrive consumer): + # - Use AUTH_SCOPES for interactive auth (consent + refresh token issuance) + # - Use RESOURCE_SCOPES for acquire_token_silent / refresh paths + AUTH_SCOPES = ["User.Read", "Files.Read.All", "offline_access"] + RESOURCE_SCOPES = ["User.Read", "Files.Read.All"] + SCOPES = AUTH_SCOPES # Backward compatibility alias + + # Kept for reference; MSAL derives endpoints from `authority` AUTH_ENDPOINT = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize" TOKEN_ENDPOINT = "https://login.microsoftonline.com/common/oauth2/v2.0/token" @@ -22,173 +31,299 @@ class SharePointOAuth: client_id: str, client_secret: str, token_file: str = "sharepoint_token.json", - authority: str = "https://login.microsoftonline.com/common", # Keep for compatibility + authority: str = "https://login.microsoftonline.com/common", + allow_json_refresh: bool = True, ): + """ + Initialize SharePointOAuth. + + Args: + client_id: Azure AD application (client) ID. + client_secret: Azure AD application client secret. + token_file: Path to persisted token cache file (MSAL cache format). + authority: Usually "https://login.microsoftonline.com/common" for MSA + org, + or tenant-specific for work/school. + allow_json_refresh: If True, permit one-time migration from legacy flat JSON + {"access_token","refresh_token",...}. Otherwise refuse it. + """ self.client_id = client_id self.client_secret = client_secret self.token_file = token_file - self.authority = authority # Keep for compatibility but not used - self._tokens = None - self._load_tokens() + self.authority = authority + self.allow_json_refresh = allow_json_refresh + self.token_cache = msal.SerializableTokenCache() + self._current_account = None - def _load_tokens(self): - """Load tokens from file""" - if os.path.exists(self.token_file): - with open(self.token_file, "r") as f: - self._tokens = json.loads(f.read()) - print(f"Loaded tokens from {self.token_file}") - else: - print(f"No token file found at {self.token_file}") + # Initialize MSAL Confidential Client + self.app = msal.ConfidentialClientApplication( + client_id=self.client_id, + client_credential=self.client_secret, + authority=self.authority, + token_cache=self.token_cache, + ) - async def save_cache(self): - """Persist tokens to file (renamed for compatibility)""" - await self._save_tokens() - - async def _save_tokens(self): - """Save tokens to file""" - if self._tokens: - async with aiofiles.open(self.token_file, "w") as f: - await f.write(json.dumps(self._tokens, indent=2)) - - def _is_token_expired(self) -> bool: - """Check if current access token is expired""" - if not self._tokens or 'expiry' not in self._tokens: - return True - - expiry_str = self._tokens['expiry'] - # Handle different expiry formats + async def load_credentials(self) -> bool: + """Load existing credentials from token file (async).""" try: - if expiry_str.endswith('Z'): - expiry_dt = datetime.fromisoformat(expiry_str[:-1]) - else: - expiry_dt = datetime.fromisoformat(expiry_str) - - # Add 5-minute buffer - import datetime as dt - now = datetime.now() - return now >= (expiry_dt - dt.timedelta(minutes=5)) - except: - return True + logger.debug(f"SharePoint OAuth loading credentials from: {self.token_file}") + if os.path.exists(self.token_file): + logger.debug(f"Token file exists, reading: {self.token_file}") + + # Read the token file + async with aiofiles.open(self.token_file, "r") as f: + cache_data = await f.read() + logger.debug(f"Read {len(cache_data)} chars from token file") + + if cache_data.strip(): + # 1) Try legacy flat JSON first + try: + json_data = json.loads(cache_data) + if isinstance(json_data, dict) and "refresh_token" in json_data: + if self.allow_json_refresh: + logger.debug( + "Found legacy JSON refresh_token and allow_json_refresh=True; attempting migration refresh" + ) + return await self._refresh_from_json_token(json_data) + else: + logger.warning( + "Token file contains a legacy JSON refresh_token, but allow_json_refresh=False. " + "Delete the file and re-auth." + ) + return False + except json.JSONDecodeError: + logger.debug("Token file is not flat JSON; attempting MSAL cache format") + + # 2) Try MSAL cache format + logger.debug("Attempting MSAL cache deserialization") + self.token_cache.deserialize(cache_data) + + # Get accounts from loaded cache + accounts = self.app.get_accounts() + logger.debug(f"Found {len(accounts)} accounts in MSAL cache") + if accounts: + self._current_account = accounts[0] + logger.debug(f"Set current account: {self._current_account.get('username', 'no username')}") + + # IMPORTANT: Use RESOURCE_SCOPES (no reserved scopes) for silent acquisition + result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=self._current_account) + logger.debug(f"Silent token acquisition result keys: {list(result.keys()) if result else 'None'}") + if result and "access_token" in result: + logger.debug("Silent token acquisition successful") + await self.save_cache() + return True + else: + error_msg = (result or {}).get("error") or "No result" + logger.warning(f"Silent token acquisition failed: {error_msg}") + else: + logger.debug(f"Token file {self.token_file} is empty") + else: + logger.debug(f"Token file does not exist: {self.token_file}") - async def _refresh_access_token(self) -> bool: - """Refresh the access token using refresh token""" - if not self._tokens or 'refresh_token' not in self._tokens: return False - data = { - 'client_id': self.client_id, - 'client_secret': self.client_secret, - 'refresh_token': self._tokens['refresh_token'], - 'grant_type': 'refresh_token', - 'scope': ' '.join(self.SCOPES) - } + except Exception as e: + logger.error(f"Failed to load SharePoint credentials: {e}") + import traceback + traceback.print_exc() + return False - async with httpx.AsyncClient() as client: - try: - response = await client.post(self.TOKEN_ENDPOINT, data=data) - response.raise_for_status() - token_data = response.json() + async def _refresh_from_json_token(self, token_data: dict) -> bool: + """ + Use refresh token from a legacy JSON file to get new tokens (one-time migration path). - # Update tokens - self._tokens['token'] = token_data['access_token'] - if 'refresh_token' in token_data: - self._tokens['refresh_token'] = token_data['refresh_token'] - - # Calculate expiry - expires_in = token_data.get('expires_in', 3600) - import datetime as dt - expiry = datetime.now() + dt.timedelta(seconds=expires_in) - self._tokens['expiry'] = expiry.isoformat() - - await self._save_tokens() - print("Access token refreshed successfully") - return True - - except Exception as e: - print(f"Failed to refresh token: {e}") + Notes: + - Prefer using an MSAL cache file and acquire_token_silent(). + - This path is only for migrating older refresh_token JSON files. + """ + try: + refresh_token = token_data.get("refresh_token") + if not refresh_token: + logger.error("No refresh_token found in JSON file - cannot refresh") + logger.error("You must re-authenticate interactively to obtain a valid token") return False - def create_authorization_url(self, redirect_uri: str) -> str: - """Create authorization URL for OAuth flow""" - from urllib.parse import urlencode - - params = { - 'client_id': self.client_id, - 'response_type': 'code', - 'redirect_uri': redirect_uri, - 'scope': ' '.join(self.SCOPES), - 'response_mode': 'query' + # Use only RESOURCE_SCOPES when refreshing (no reserved scopes) + refresh_scopes = [s for s in self.RESOURCE_SCOPES if s not in self.RESERVED_SCOPES] + logger.debug(f"Using refresh token; refresh scopes = {refresh_scopes}") + + result = self.app.acquire_token_by_refresh_token( + refresh_token=refresh_token, + scopes=refresh_scopes, + ) + + if result and "access_token" in result: + logger.debug("Successfully refreshed token via legacy JSON path") + await self.save_cache() + + accounts = self.app.get_accounts() + logger.debug(f"After refresh, found {len(accounts)} accounts") + if accounts: + self._current_account = accounts[0] + logger.debug(f"Set current account after refresh: {self._current_account.get('username', 'no username')}") + return True + + # Error handling + err = (result or {}).get("error_description") or (result or {}).get("error") or "Unknown error" + logger.error(f"Refresh token failed: {err}") + + if any(code in err for code in ("AADSTS70000", "invalid_grant", "interaction_required")): + logger.warning( + "Refresh denied due to unauthorized/expired scopes or invalid grant. " + "Delete the token file and perform interactive sign-in with correct scopes." + ) + + return False + + except Exception as e: + logger.error(f"Exception during refresh from JSON token: {e}") + import traceback + traceback.print_exc() + return False + + async def save_cache(self): + """Persist the token cache to file.""" + try: + # Ensure parent directory exists + parent = os.path.dirname(os.path.abspath(self.token_file)) + if parent and not os.path.exists(parent): + os.makedirs(parent, exist_ok=True) + + cache_data = self.token_cache.serialize() + if cache_data: + async with aiofiles.open(self.token_file, "w") as f: + await f.write(cache_data) + logger.debug(f"Token cache saved to {self.token_file}") + except Exception as e: + logger.error(f"Failed to save token cache: {e}") + + def create_authorization_url(self, redirect_uri: str, state: Optional[str] = None) -> str: + """Create authorization URL for OAuth flow.""" + # Store redirect URI for later use in callback + self._redirect_uri = redirect_uri + + kwargs: Dict[str, Any] = { + # IMPORTANT: interactive auth includes offline_access + "scopes": self.AUTH_SCOPES, + "redirect_uri": redirect_uri, + "prompt": "consent", # ensure refresh token on first run } - - auth_url = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize" - return f"{auth_url}?{urlencode(params)}" + if state: + kwargs["state"] = state # Optional CSRF protection + + auth_url = self.app.get_authorization_request_url(**kwargs) + + logger.debug(f"Generated auth URL: {auth_url}") + logger.debug(f"Auth scopes: {self.AUTH_SCOPES}") + + return auth_url async def handle_authorization_callback( self, authorization_code: str, redirect_uri: str ) -> bool: - """Handle OAuth callback and exchange code for tokens""" - data = { - 'client_id': self.client_id, - 'client_secret': self.client_secret, - 'code': authorization_code, - 'grant_type': 'authorization_code', - 'redirect_uri': redirect_uri, - 'scope': ' '.join(self.SCOPES) - } + """Handle OAuth callback and exchange code for tokens.""" + try: + # For code exchange, we pass the same auth scopes as used in the authorize step + result = self.app.acquire_token_by_authorization_code( + authorization_code, + scopes=self.AUTH_SCOPES, + redirect_uri=redirect_uri, + ) - async with httpx.AsyncClient() as client: - try: - response = await client.post(self.TOKEN_ENDPOINT, data=data) - response.raise_for_status() - token_data = response.json() + if result and "access_token" in result: + # Store the account for future use + accounts = self.app.get_accounts() + if accounts: + self._current_account = accounts[0] - # Store tokens in our format - import datetime as dt - expires_in = token_data.get('expires_in', 3600) - expiry = datetime.now() + dt.timedelta(seconds=expires_in) - - self._tokens = { - 'token': token_data['access_token'], - 'refresh_token': token_data['refresh_token'], - 'scopes': self.SCOPES, - 'expiry': expiry.isoformat() - } - - await self._save_tokens() - print("Authorization successful, tokens saved") + await self.save_cache() + logger.info("SharePoint OAuth authorization successful") return True - except Exception as e: - print(f"Authorization failed: {e}") - return False - - async def is_authenticated(self) -> bool: - """Check if we have valid credentials""" - if not self._tokens: + error_msg = (result or {}).get("error_description") or (result or {}).get("error") or "Unknown error" + logger.error(f"SharePoint OAuth authorization failed: {error_msg}") return False - # If token is expired, try to refresh - if self._is_token_expired(): - print("Token expired, attempting refresh...") - if await self._refresh_access_token(): + except Exception as e: + logger.error(f"Exception during SharePoint OAuth authorization: {e}") + return False + + async def is_authenticated(self) -> bool: + """Check if we have valid credentials (simplified like Google Drive).""" + try: + # First try to load credentials if we haven't already + if not self._current_account: + await self.load_credentials() + + # If we have an account, try to get a token (MSAL will refresh if needed) + if self._current_account: + # IMPORTANT: use RESOURCE_SCOPES here + result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=self._current_account) + if result and "access_token" in result: + return True + else: + error_msg = (result or {}).get("error") or "No result returned" + logger.debug(f"Token acquisition failed for current account: {error_msg}") + + # Fallback: try without specific account + result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=None) + if result and "access_token" in result: + # Update current account if this worked + accounts = self.app.get_accounts() + if accounts: + self._current_account = accounts[0] return True - else: - return False - - return True + + return False + + except Exception as e: + logger.error(f"Authentication check failed: {e}") + return False def get_access_token(self) -> str: - """Get current access token""" - if not self._tokens or 'token' not in self._tokens: - raise ValueError("No access token available") - - if self._is_token_expired(): - raise ValueError("Access token expired and refresh failed") - - return self._tokens['token'] + """Get an access token for Microsoft Graph (simplified like Google Drive).""" + try: + # Try with current account first + if self._current_account: + result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=self._current_account) + if result and "access_token" in result: + return result["access_token"] + + # Fallback: try without specific account + result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=None) + if result and "access_token" in result: + return result["access_token"] + + # If we get here, authentication has failed + error_msg = (result or {}).get("error_description") or (result or {}).get("error") or "No valid authentication" + raise ValueError(f"Failed to acquire access token: {error_msg}") + + except Exception as e: + logger.error(f"Failed to get access token: {e}") + raise async def revoke_credentials(self): - """Clear tokens""" - self._tokens = None - if os.path.exists(self.token_file): - os.remove(self.token_file) + """Clear token cache and remove token file (like Google Drive).""" + try: + # Clear in-memory state + self._current_account = None + self.token_cache = msal.SerializableTokenCache() + + # Recreate MSAL app with fresh cache + self.app = msal.ConfidentialClientApplication( + client_id=self.client_id, + client_credential=self.client_secret, + authority=self.authority, + token_cache=self.token_cache, + ) + + # Remove token file + if os.path.exists(self.token_file): + os.remove(self.token_file) + logger.info(f"Removed SharePoint token file: {self.token_file}") + + except Exception as e: + logger.error(f"Failed to revoke SharePoint credentials: {e}") + + def get_service(self) -> str: + """Return an access token (Graph doesn't need a generated client like Google Drive).""" + return self.get_access_token()