Finally fix MSAL in onedrive/sharepoint

This commit is contained in:
Eric Hare 2025-09-25 13:12:27 -07:00
parent acf19bd4dc
commit 349ad80ffd
6 changed files with 1649 additions and 750 deletions

View file

@ -132,7 +132,10 @@ async def connector_status(request: Request, connector_service, session_manager)
for connection in connections: for connection in connections:
try: try:
connector = await connector_service._get_connector(connection.connection_id) connector = await connector_service._get_connector(connection.connection_id)
if connector is not None:
connection_client_ids[connection.connection_id] = connector.get_client_id() connection_client_ids[connection.connection_id] = connector.get_client_id()
else:
connection_client_ids[connection.connection_id] = None
except Exception as e: except Exception as e:
logger.warning( logger.warning(
"Could not get connector for connection", "Could not get connector for connection",
@ -338,8 +341,8 @@ async def connector_webhook(request: Request, connector_service, session_manager
) )
async def connector_token(request: Request, connector_service, session_manager): async def connector_token(request: Request, connector_service, session_manager):
"""Get access token for connector API calls (e.g., Google Picker)""" """Get access token for connector API calls (e.g., Pickers)."""
connector_type = request.path_params.get("connector_type") url_connector_type = request.path_params.get("connector_type")
connection_id = request.query_params.get("connection_id") connection_id = request.query_params.get("connection_id")
if not connection_id: if not connection_id:
@ -348,37 +351,81 @@ async def connector_token(request: Request, connector_service, session_manager):
user = request.state.user user = request.state.user
try: try:
# Get the connection and verify it belongs to the user # 1) Load the connection and verify ownership
connection = await connector_service.connection_manager.get_connection(connection_id) connection = await connector_service.connection_manager.get_connection(connection_id)
if not connection or connection.user_id != user.user_id: if not connection or connection.user_id != user.user_id:
return JSONResponse({"error": "Connection not found"}, status_code=404) return JSONResponse({"error": "Connection not found"}, status_code=404)
# Get the connector instance # 2) Get the ACTUAL connector instance/type for this connection_id
connector = await connector_service._get_connector(connection_id) connector = await connector_service._get_connector(connection_id)
if not connector: if not connector:
return JSONResponse({"error": f"Connector not available - authentication may have failed for {connector_type}"}, status_code=404) return JSONResponse(
{"error": f"Connector not available - authentication may have failed for {url_connector_type}"},
status_code=404,
)
# For Google Drive, get the access token real_type = getattr(connector, "type", None) or getattr(connection, "connector_type", None)
if connector_type == "google_drive" and hasattr(connector, 'oauth'): if real_type is None:
return JSONResponse({"error": "Unable to determine connector type"}, status_code=500)
# Optional: warn if URL path type disagrees with real type
if url_connector_type and url_connector_type != real_type:
# You can downgrade this to debug if you expect cross-routing.
return JSONResponse(
{
"error": "Connector type mismatch",
"detail": {
"requested_type": url_connector_type,
"actual_type": real_type,
"hint": "Call the token endpoint using the correct connector_type for this connection_id.",
},
},
status_code=400,
)
# 3) Branch by the actual connector type
# GOOGLE DRIVE (google-auth)
if real_type == "google_drive" and hasattr(connector, "oauth"):
await connector.oauth.load_credentials() await connector.oauth.load_credentials()
if connector.oauth.creds and connector.oauth.creds.valid: if connector.oauth.creds and connector.oauth.creds.valid:
return JSONResponse({ expires_in = None
try:
if connector.oauth.creds.expiry:
import time
expires_in = max(0, int(connector.oauth.creds.expiry.timestamp() - time.time()))
except Exception:
expires_in = None
return JSONResponse(
{
"access_token": connector.oauth.creds.token, "access_token": connector.oauth.creds.token,
"expires_in": (connector.oauth.creds.expiry.timestamp() - "expires_in": expires_in,
__import__('time').time()) if connector.oauth.creds.expiry else None }
}) )
else:
return JSONResponse({"error": "Invalid or expired credentials"}, status_code=401) return JSONResponse({"error": "Invalid or expired credentials"}, status_code=401)
# For OneDrive and SharePoint, get the access token # ONEDRIVE / SHAREPOINT (MSAL or custom)
elif connector_type in ["onedrive", "sharepoint"] and hasattr(connector, 'oauth'): if real_type in ("onedrive", "sharepoint") and hasattr(connector, "oauth"):
# Ensure cache/credentials are loaded before trying to use them
try: try:
# Prefer a dedicated is_authenticated() that loads cache internally
if hasattr(connector.oauth, "is_authenticated"):
ok = await connector.oauth.is_authenticated()
else:
# Fallback: try to load credentials explicitly if available
ok = True
if hasattr(connector.oauth, "load_credentials"):
ok = await connector.oauth.load_credentials()
if not ok:
return JSONResponse({"error": "Not authenticated"}, status_code=401)
# Now safe to fetch access token
access_token = connector.oauth.get_access_token() access_token = connector.oauth.get_access_token()
return JSONResponse({ # MSAL result has expiry, but were returning a raw token; keep expires_in None for simplicity
"access_token": access_token, return JSONResponse({"access_token": access_token, "expires_in": None})
"expires_in": None # MSAL handles token expiry internally
})
except ValueError as e: except ValueError as e:
# Typical when acquire_token_silent fails (e.g., needs re-auth)
return JSONResponse({"error": f"Failed to get access token: {str(e)}"}, status_code=401) return JSONResponse({"error": f"Failed to get access token: {str(e)}"}, status_code=401)
except Exception as e: except Exception as e:
return JSONResponse({"error": f"Authentication error: {str(e)}"}, status_code=500) return JSONResponse({"error": f"Authentication error: {str(e)}"}, status_code=500)
@ -386,7 +433,5 @@ async def connector_token(request: Request, connector_service, session_manager):
return JSONResponse({"error": "Token not available for this connector type"}, status_code=400) return JSONResponse({"error": "Token not available for this connector type"}, status_code=400)
except Exception as e: except Exception as e:
logger.error("Error getting connector token", error=str(e)) logger.error("Error getting connector token", exc_info=True)
return JSONResponse({"error": str(e)}, status_code=500) return JSONResponse({"error": str(e)}, status_code=500)

View file

@ -294,31 +294,38 @@ class ConnectionManager:
async def get_connector(self, connection_id: str) -> Optional[BaseConnector]: async def get_connector(self, connection_id: str) -> Optional[BaseConnector]:
"""Get an active connector instance""" """Get an active connector instance"""
logger.debug(f"Getting connector for connection_id: {connection_id}")
# Return cached connector if available # Return cached connector if available
if connection_id in self.active_connectors: if connection_id in self.active_connectors:
connector = self.active_connectors[connection_id] connector = self.active_connectors[connection_id]
if connector.is_authenticated: if connector.is_authenticated:
logger.debug(f"Returning cached authenticated connector for {connection_id}")
return connector return connector
else: else:
# Remove unauthenticated connector from cache # Remove unauthenticated connector from cache
logger.debug(f"Removing unauthenticated connector from cache for {connection_id}")
del self.active_connectors[connection_id] del self.active_connectors[connection_id]
# Try to create and authenticate connector # Try to create and authenticate connector
connection_config = self.connections.get(connection_id) connection_config = self.connections.get(connection_id)
if not connection_config or not connection_config.is_active: if not connection_config or not connection_config.is_active:
logger.debug(f"No active connection config found for {connection_id}")
return None return None
logger.debug(f"Creating connector for {connection_config.connector_type}")
connector = self._create_connector(connection_config) connector = self._create_connector(connection_config)
if await connector.authenticate():
logger.debug(f"Attempting authentication for {connection_id}")
auth_result = await connector.authenticate()
logger.debug(f"Authentication result for {connection_id}: {auth_result}")
if auth_result:
self.active_connectors[connection_id] = connector self.active_connectors[connection_id] = connector
# ... rest of the method
# Setup webhook subscription if not already set up
await self._setup_webhook_if_needed(
connection_id, connection_config, connector
)
return connector return connector
else:
logger.warning(f"Authentication failed for {connection_id}")
return None return None
def get_available_connector_types(self) -> Dict[str, Dict[str, Any]]: def get_available_connector_types(self) -> Dict[str, Dict[str, Any]]:
@ -363,6 +370,7 @@ class ConnectionManager:
def _create_connector(self, config: ConnectionConfig) -> BaseConnector: def _create_connector(self, config: ConnectionConfig) -> BaseConnector:
"""Factory method to create connector instances""" """Factory method to create connector instances"""
try:
if config.connector_type == "google_drive": if config.connector_type == "google_drive":
return GoogleDriveConnector(config.config) return GoogleDriveConnector(config.config)
elif config.connector_type == "sharepoint": elif config.connector_type == "sharepoint":
@ -370,13 +378,15 @@ class ConnectionManager:
elif config.connector_type == "onedrive": elif config.connector_type == "onedrive":
return OneDriveConnector(config.config) return OneDriveConnector(config.config)
elif config.connector_type == "box": elif config.connector_type == "box":
# Future: BoxConnector(config.config)
raise NotImplementedError("Box connector not implemented yet") raise NotImplementedError("Box connector not implemented yet")
elif config.connector_type == "dropbox": elif config.connector_type == "dropbox":
# Future: DropboxConnector(config.config)
raise NotImplementedError("Dropbox connector not implemented yet") raise NotImplementedError("Dropbox connector not implemented yet")
else: else:
raise ValueError(f"Unknown connector type: {config.connector_type}") raise ValueError(f"Unknown connector type: {config.connector_type}")
except Exception as e:
logger.error(f"Failed to create {config.connector_type} connector: {e}")
# Re-raise the exception so caller can handle appropriately
raise
async def update_last_sync(self, connection_id: str): async def update_last_sync(self, connection_id: str):
"""Update the last sync timestamp for a connection""" """Update the last sync timestamp for a connection"""

View file

@ -1,235 +1,487 @@
import logging
from pathlib import Path from pathlib import Path
from typing import List, Dict, Any, Optional
from datetime import datetime
import httpx import httpx
import uuid
from datetime import datetime, timedelta
from typing import Dict, List, Any, Optional
from ..base import BaseConnector, ConnectorDocument, DocumentACL from ..base import BaseConnector, ConnectorDocument, DocumentACL
from .oauth import OneDriveOAuth from .oauth import OneDriveOAuth
logger = logging.getLogger(__name__)
class OneDriveConnector(BaseConnector): class OneDriveConnector(BaseConnector):
"""OneDrive connector using Microsoft Graph API""" """OneDrive connector using MSAL-based OAuth for authentication."""
# Required BaseConnector class attributes
CLIENT_ID_ENV_VAR = "MICROSOFT_GRAPH_OAUTH_CLIENT_ID" CLIENT_ID_ENV_VAR = "MICROSOFT_GRAPH_OAUTH_CLIENT_ID"
CLIENT_SECRET_ENV_VAR = "MICROSOFT_GRAPH_OAUTH_CLIENT_SECRET" CLIENT_SECRET_ENV_VAR = "MICROSOFT_GRAPH_OAUTH_CLIENT_SECRET"
# Connector metadata # Connector metadata
CONNECTOR_NAME = "OneDrive" CONNECTOR_NAME = "OneDrive"
CONNECTOR_DESCRIPTION = "Connect your personal OneDrive to sync documents" CONNECTOR_DESCRIPTION = "Connect to OneDrive (personal) to sync documents and files"
CONNECTOR_ICON = "onedrive" CONNECTOR_ICON = "onedrive"
def __init__(self, config: Dict[str, Any]): def __init__(self, config: Dict[str, Any]):
logger.debug(f"OneDrive connector __init__ called with config type: {type(config)}")
logger.debug(f"OneDrive connector __init__ config value: {config}")
if config is None:
logger.debug("Config was None, using empty dict")
config = {}
try:
logger.debug("Calling super().__init__")
super().__init__(config) super().__init__(config)
logger.debug("super().__init__ completed successfully")
except Exception as e:
logger.error(f"super().__init__ failed: {e}")
raise
# Initialize with defaults that allow the connector to be listed
self.client_id = None
self.client_secret = None
self.redirect_uri = config.get("redirect_uri", "http://localhost") # must match your app registration
# Try to get credentials, but don't fail if they're missing
try:
self.client_id = self.get_client_id()
logger.debug(f"Got client_id: {self.client_id is not None}")
except Exception as e:
logger.debug(f"Failed to get client_id: {e}")
try:
self.client_secret = self.get_client_secret()
logger.debug(f"Got client_secret: {self.client_secret is not None}")
except Exception as e:
logger.debug(f"Failed to get client_secret: {e}")
# Token file setup
project_root = Path(__file__).resolve().parent.parent.parent.parent project_root = Path(__file__).resolve().parent.parent.parent.parent
token_file = config.get("token_file") or str(project_root / "onedrive_token.json") token_file = config.get("token_file") or str(project_root / "onedrive_token.json")
Path(token_file).parent.mkdir(parents=True, exist_ok=True)
# Only initialize OAuth if we have credentials
if self.client_id and self.client_secret:
connection_id = config.get("connection_id", "default")
# Use token_file from config if provided, otherwise generate one
if config.get("token_file"):
oauth_token_file = config["token_file"]
else:
# Use a per-connection cache file to avoid collisions with other connectors
oauth_token_file = f"onedrive_token_{connection_id}.json"
# MSA & org both work via /common for OneDrive personal testing
authority = "https://login.microsoftonline.com/common"
self.oauth = OneDriveOAuth( self.oauth = OneDriveOAuth(
client_id=self.get_client_id(), client_id=self.client_id,
client_secret=self.get_client_secret(), client_secret=self.client_secret,
token_file=token_file, token_file=oauth_token_file,
authority=authority,
allow_json_refresh=True, # allows one-time migration from legacy JSON if present
) )
self.subscription_id = config.get("subscription_id") or config.get( else:
"webhook_channel_id" self.oauth = None
)
self.base_url = "https://graph.microsoft.com/v1.0"
async def authenticate(self) -> bool: # Track subscription ID for webhooks (note: change notifications might not be available for personal accounts)
if await self.oauth.is_authenticated(): self._subscription_id: Optional[str] = None
self._authenticated = True
return True
return False
async def setup_subscription(self) -> str: # Graph API defaults
if not self._authenticated: self._graph_api_version = "v1.0"
raise ValueError("Not authenticated") self._default_params = {
"$select": "id,name,size,lastModifiedDateTime,createdDateTime,webUrl,file,folder,@microsoft.graph.downloadUrl"
webhook_url = self.config.get("webhook_url")
if not webhook_url:
raise ValueError("webhook_url required in config for subscriptions")
expiration = (datetime.utcnow() + timedelta(days=2)).isoformat() + "Z"
body = {
"changeType": "created,updated,deleted",
"notificationUrl": webhook_url,
"resource": "/me/drive/root",
"expirationDateTime": expiration,
"clientState": str(uuid.uuid4()),
} }
@property
def _graph_base_url(self) -> str:
"""Base URL for Microsoft Graph API calls."""
return f"https://graph.microsoft.com/{self._graph_api_version}"
def emit(self, doc: ConnectorDocument) -> None:
"""Emit a ConnectorDocument instance (integrate with your pipeline here)."""
logger.debug(f"Emitting OneDrive document: {doc.id} ({doc.filename})")
async def authenticate(self) -> bool:
"""Test authentication - BaseConnector interface."""
logger.debug(f"OneDrive authenticate() called, oauth is None: {self.oauth is None}")
try:
if not self.oauth:
logger.debug("OneDrive authentication failed: OAuth not initialized")
self._authenticated = False
return False
logger.debug("Loading OneDrive credentials...")
load_result = await self.oauth.load_credentials()
logger.debug(f"Load credentials result: {load_result}")
logger.debug("Checking OneDrive authentication status...")
authenticated = await self.oauth.is_authenticated()
logger.debug(f"OneDrive is_authenticated result: {authenticated}")
self._authenticated = authenticated
return authenticated
except Exception as e:
logger.error(f"OneDrive authentication failed: {e}")
import traceback
traceback.print_exc()
self._authenticated = False
return False
def get_auth_url(self) -> str:
"""Get OAuth authorization URL."""
if not self.oauth:
raise RuntimeError("OneDrive OAuth not initialized - missing credentials")
return self.oauth.create_authorization_url(self.redirect_uri)
async def handle_oauth_callback(self, auth_code: str) -> Dict[str, Any]:
"""Handle OAuth callback."""
if not self.oauth:
raise RuntimeError("OneDrive OAuth not initialized - missing credentials")
try:
success = await self.oauth.handle_authorization_callback(auth_code, self.redirect_uri)
if success:
self._authenticated = True
return {"status": "success"}
else:
raise ValueError("OAuth callback failed")
except Exception as e:
logger.error(f"OAuth callback failed: {e}")
raise
def sync_once(self) -> None:
"""
Perform a one-shot sync of OneDrive files and emit documents.
"""
import asyncio
async def _async_sync():
try:
file_list = await self.list_files(max_files=1000)
files = file_list.get("files", [])
for file_info in files:
try:
file_id = file_info.get("id")
if not file_id:
continue
doc = await self.get_file_content(file_id)
self.emit(doc)
except Exception as e:
logger.error(f"Failed to sync OneDrive file {file_info.get('name', 'unknown')}: {e}")
continue
except Exception as e:
logger.error(f"OneDrive sync_once failed: {e}")
raise
if hasattr(asyncio, 'run'):
asyncio.run(_async_sync())
else:
loop = asyncio.get_event_loop()
loop.run_until_complete(_async_sync())
async def setup_subscription(self) -> str:
"""
Set up real-time subscription for file changes.
NOTE: Change notifications may not be available for personal OneDrive accounts.
"""
webhook_url = self.config.get('webhook_url')
if not webhook_url:
logger.warning("No webhook URL configured, skipping OneDrive subscription setup")
return "no-webhook-configured"
try:
if not await self.authenticate():
raise RuntimeError("OneDrive authentication failed during subscription setup")
token = self.oauth.get_access_token() token = self.oauth.get_access_token()
# For OneDrive personal we target the user's drive
resource = "/me/drive/root"
subscription_data = {
"changeType": "created,updated,deleted",
"notificationUrl": f"{webhook_url}/webhook/onedrive",
"resource": resource,
"expirationDateTime": self._get_subscription_expiry(),
"clientState": "onedrive_personal",
}
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
}
url = f"{self._graph_base_url}/subscriptions"
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
resp = await client.post( response = await client.post(url, json=subscription_data, headers=headers, timeout=30)
f"{self.base_url}/subscriptions", response.raise_for_status()
json=body,
headers={"Authorization": f"Bearer {token}"},
)
resp.raise_for_status()
data = resp.json()
self.subscription_id = data["id"] result = response.json()
return self.subscription_id subscription_id = result.get("id")
async def list_files( if subscription_id:
self, page_token: Optional[str] = None, limit: int = 100 self._subscription_id = subscription_id
) -> Dict[str, Any]: logger.info(f"OneDrive subscription created: {subscription_id}")
if not self._authenticated: return subscription_id
raise ValueError("Not authenticated") else:
raise ValueError("No subscription ID returned from Microsoft Graph")
except Exception as e:
logger.error(f"Failed to setup OneDrive subscription: {e}")
raise
def _get_subscription_expiry(self) -> str:
"""Get subscription expiry time (Graph caps duration; often <= 3 days)."""
from datetime import datetime, timedelta
expiry = datetime.utcnow() + timedelta(days=3)
return expiry.strftime("%Y-%m-%dT%H:%M:%S.%fZ")
async def list_files(self, page_token: Optional[str] = None, max_files: Optional[int] = None) -> Dict[str, Any]:
"""List files from OneDrive using Microsoft Graph."""
try:
if not await self.authenticate():
raise RuntimeError("OneDrive authentication failed during file listing")
files: List[Dict[str, Any]] = []
max_files_value = max_files if max_files is not None else 100
base_url = f"{self._graph_base_url}/me/drive/root/children"
params = dict(self._default_params)
params["$top"] = max_files_value
params = {"$top": str(limit)}
if page_token: if page_token:
params["$skiptoken"] = page_token params["$skiptoken"] = page_token
token = self.oauth.get_access_token() response = await self._make_graph_request(base_url, params=params)
async with httpx.AsyncClient() as client: data = response.json()
resp = await client.get(
f"{self.base_url}/me/drive/root/children",
params=params,
headers={"Authorization": f"Bearer {token}"},
)
resp.raise_for_status()
data = resp.json()
files = [] items = data.get("value", [])
for item in data.get("value", []): for item in items:
if item.get("file"): if item.get("file"): # include files only
files.append( files.append({
{ "id": item.get("id", ""),
"id": item["id"], "name": item.get("name", ""),
"name": item["name"], "path": f"/drive/items/{item.get('id')}",
"mimeType": item.get("file", {}).get( "size": int(item.get("size", 0)),
"mimeType", "application/octet-stream" "modified": item.get("lastModifiedDateTime"),
), "created": item.get("createdDateTime"),
"webViewLink": item.get("webUrl"), "mime_type": item.get("file", {}).get("mimeType", self._get_mime_type(item.get("name", ""))),
"createdTime": item.get("createdDateTime"), "url": item.get("webUrl", ""),
"modifiedTime": item.get("lastModifiedDateTime"), "download_url": item.get("@microsoft.graph.downloadUrl"),
} })
)
next_token = None # Next page
next_page_token = None
next_link = data.get("@odata.nextLink") next_link = data.get("@odata.nextLink")
if next_link: if next_link:
from urllib.parse import urlparse, parse_qs from urllib.parse import urlparse, parse_qs
parsed = urlparse(next_link) parsed = urlparse(next_link)
next_token = parse_qs(parsed.query).get("$skiptoken", [None])[0] query_params = parse_qs(parsed.query)
if "$skiptoken" in query_params:
next_page_token = query_params["$skiptoken"][0]
return {"files": files, "nextPageToken": next_token} return {"files": files, "next_page_token": next_page_token}
except Exception as e:
logger.error(f"Failed to list OneDrive files: {e}")
return {"files": [], "next_page_token": None}
async def get_file_content(self, file_id: str) -> ConnectorDocument: async def get_file_content(self, file_id: str) -> ConnectorDocument:
if not self._authenticated: """Get file content and metadata."""
raise ValueError("Not authenticated") try:
if not await self.authenticate():
raise RuntimeError("OneDrive authentication failed during file content retrieval")
file_metadata = await self._get_file_metadata_by_id(file_id)
if not file_metadata:
raise ValueError(f"File not found: {file_id}")
download_url = file_metadata.get("download_url")
if download_url:
content = await self._download_file_from_url(download_url)
else:
content = await self._download_file_content(file_id)
acl = DocumentACL(
owner="",
user_permissions={},
group_permissions={},
)
modified_time = self._parse_graph_date(file_metadata.get("modified"))
created_time = self._parse_graph_date(file_metadata.get("created"))
return ConnectorDocument(
id=file_id,
filename=file_metadata.get("name", ""),
mimetype=file_metadata.get("mime_type", "application/octet-stream"),
content=content,
source_url=file_metadata.get("url", ""),
acl=acl,
modified_time=modified_time,
created_time=created_time,
metadata={
"onedrive_path": file_metadata.get("path", ""),
"size": file_metadata.get("size", 0),
},
)
except Exception as e:
logger.error(f"Failed to get OneDrive file content {file_id}: {e}")
raise
async def _get_file_metadata_by_id(self, file_id: str) -> Optional[Dict[str, Any]]:
"""Get file metadata by ID using Graph API."""
try:
url = f"{self._graph_base_url}/me/drive/items/{file_id}"
params = dict(self._default_params)
response = await self._make_graph_request(url, params=params)
item = response.json()
if item.get("file"):
return {
"id": file_id,
"name": item.get("name", ""),
"path": f"/drive/items/{file_id}",
"size": int(item.get("size", 0)),
"modified": item.get("lastModifiedDateTime"),
"created": item.get("createdDateTime"),
"mime_type": item.get("file", {}).get("mimeType", self._get_mime_type(item.get("name", ""))),
"url": item.get("webUrl", ""),
"download_url": item.get("@microsoft.graph.downloadUrl"),
}
return None
except Exception as e:
logger.error(f"Failed to get file metadata for {file_id}: {e}")
return None
async def _download_file_content(self, file_id: str) -> bytes:
"""Download file content by file ID using Graph API."""
try:
url = f"{self._graph_base_url}/me/drive/items/{file_id}/content"
token = self.oauth.get_access_token()
headers = {"Authorization": f"Bearer {token}"}
async with httpx.AsyncClient() as client:
response = await client.get(url, headers=headers, timeout=60)
response.raise_for_status()
return response.content
except Exception as e:
logger.error(f"Failed to download file content for {file_id}: {e}")
raise
async def _download_file_from_url(self, download_url: str) -> bytes:
"""Download file content from direct download URL."""
try:
async with httpx.AsyncClient() as client:
response = await client.get(download_url, timeout=60)
response.raise_for_status()
return response.content
except Exception as e:
logger.error(f"Failed to download from URL {download_url}: {e}")
raise
def _parse_graph_date(self, date_str: Optional[str]) -> datetime:
"""Parse Microsoft Graph date string to datetime."""
if not date_str:
return datetime.now()
try:
if date_str.endswith('Z'):
return datetime.fromisoformat(date_str[:-1]).replace(tzinfo=None)
else:
return datetime.fromisoformat(date_str.replace('T', ' '))
except (ValueError, AttributeError):
return datetime.now()
async def _make_graph_request(self, url: str, method: str = "GET",
data: Optional[Dict] = None, params: Optional[Dict] = None) -> httpx.Response:
"""Make authenticated API request to Microsoft Graph."""
token = self.oauth.get_access_token()
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
}
async with httpx.AsyncClient() as client:
if method.upper() == "GET":
response = await client.get(url, headers=headers, params=params, timeout=30)
elif method.upper() == "POST":
response = await client.post(url, headers=headers, json=data, timeout=30)
elif method.upper() == "DELETE":
response = await client.delete(url, headers=headers, timeout=30)
else:
raise ValueError(f"Unsupported HTTP method: {method}")
response.raise_for_status()
return response
def _get_mime_type(self, filename: str) -> str:
"""Get MIME type based on file extension."""
import mimetypes
mime_type, _ = mimetypes.guess_type(filename)
return mime_type or "application/octet-stream"
# Webhook methods - BaseConnector interface
def handle_webhook_validation(self, request_method: str,
headers: Dict[str, str],
query_params: Dict[str, str]) -> Optional[str]:
"""Handle webhook validation (Graph API specific)."""
if request_method == "POST" and "validationToken" in query_params:
return query_params["validationToken"]
return None
def extract_webhook_channel_id(self, payload: Dict[str, Any],
headers: Dict[str, str]) -> Optional[str]:
"""Extract channel/subscription ID from webhook payload."""
notifications = payload.get("value", [])
if notifications:
return notifications[0].get("subscriptionId")
return None
async def handle_webhook(self, payload: Dict[str, Any]) -> List[str]:
"""Handle webhook notification and return affected file IDs."""
affected_files: List[str] = []
notifications = payload.get("value", [])
for notification in notifications:
resource = notification.get("resource")
if resource and "/drive/items/" in resource:
file_id = resource.split("/drive/items/")[-1]
affected_files.append(file_id)
return affected_files
async def cleanup_subscription(self, subscription_id: str) -> bool:
"""Clean up subscription - BaseConnector interface."""
if subscription_id == "no-webhook-configured":
logger.info("No subscription to cleanup (webhook was not configured)")
return True
try:
if not await self.authenticate():
logger.error("OneDrive authentication failed during subscription cleanup")
return False
token = self.oauth.get_access_token() token = self.oauth.get_access_token()
headers = {"Authorization": f"Bearer {token}"} headers = {"Authorization": f"Bearer {token}"}
url = f"{self._graph_base_url}/subscriptions/{subscription_id}"
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
meta_resp = await client.get( response = await client.delete(url, headers=headers, timeout=30)
f"{self.base_url}/me/drive/items/{file_id}", headers=headers
)
meta_resp.raise_for_status()
metadata = meta_resp.json()
content_resp = await client.get( if response.status_code in [200, 204, 404]:
f"{self.base_url}/me/drive/items/{file_id}/content", headers=headers logger.info(f"OneDrive subscription {subscription_id} cleaned up successfully")
) return True
content = content_resp.content
# Handle the possibility of this being a redirect
if content_resp.status_code in (301, 302, 303, 307, 308):
redirect_url = content_resp.headers.get("Location")
if redirect_url:
content_resp = await client.get(redirect_url)
content_resp.raise_for_status()
content = content_resp.content
else: else:
content_resp.raise_for_status() logger.warning(f"Unexpected response cleaning up subscription: {response.status_code}")
return False
perm_resp = await client.get(
f"{self.base_url}/me/drive/items/{file_id}/permissions", headers=headers except Exception as e:
) logger.error(f"Failed to cleanup OneDrive subscription {subscription_id}: {e}")
perm_resp.raise_for_status()
permissions = perm_resp.json()
acl = self._parse_permissions(metadata, permissions)
modified = datetime.fromisoformat(
metadata["lastModifiedDateTime"].replace("Z", "+00:00")
).replace(tzinfo=None)
created = datetime.fromisoformat(
metadata["createdDateTime"].replace("Z", "+00:00")
).replace(tzinfo=None)
document = ConnectorDocument(
id=metadata["id"],
filename=metadata["name"],
mimetype=metadata.get("file", {}).get(
"mimeType", "application/octet-stream"
),
content=content,
source_url=metadata.get("webUrl"),
acl=acl,
modified_time=modified,
created_time=created,
metadata={"size": metadata.get("size")},
)
return document
def _parse_permissions(
self, metadata: Dict[str, Any], permissions: Dict[str, Any]
) -> DocumentACL:
acl = DocumentACL()
owner = metadata.get("createdBy", {}).get("user", {}).get("email")
if owner:
acl.owner = owner
for perm in permissions.get("value", []):
role = perm.get("roles", ["read"])[0]
grantee = perm.get("grantedToV2") or perm.get("grantedTo")
if not grantee:
continue
user = grantee.get("user")
if user and user.get("email"):
acl.user_permissions[user["email"]] = role
group = grantee.get("group")
if group and group.get("email"):
acl.group_permissions[group["email"]] = role
return acl
def handle_webhook_validation(
self, request_method: str, headers: Dict[str, str], query_params: Dict[str, str]
) -> Optional[str]:
"""Handle Microsoft Graph webhook validation"""
if request_method == "GET":
validation_token = query_params.get("validationtoken") or query_params.get(
"validationToken"
)
if validation_token:
return validation_token
return None
def extract_webhook_channel_id(
self, payload: Dict[str, Any], headers: Dict[str, str]
) -> Optional[str]:
"""Extract SharePoint subscription ID from webhook payload"""
values = payload.get("value", [])
return values[0].get("subscriptionId") if values else None
async def handle_webhook(self, payload: Dict[str, Any]) -> List[str]:
values = payload.get("value", [])
file_ids = []
for item in values:
resource_data = item.get("resourceData", {})
file_id = resource_data.get("id")
if file_id:
file_ids.append(file_id)
return file_ids
async def cleanup_subscription(
self, subscription_id: str, resource_id: str = None
) -> bool:
if not self._authenticated:
return False return False
token = self.oauth.get_access_token()
async with httpx.AsyncClient() as client:
resp = await client.delete(
f"{self.base_url}/subscriptions/{subscription_id}",
headers={"Authorization": f"Bearer {token}"},
)
return resp.status_code in (200, 204)

View file

@ -1,18 +1,28 @@
import os import os
import json import json
import logging
from typing import Optional, Dict, Any
import aiofiles import aiofiles
from datetime import datetime import msal
import httpx
logger = logging.getLogger(__name__)
class OneDriveOAuth: class OneDriveOAuth:
"""Direct token management for OneDrive, bypassing MSAL cache format""" """Handles Microsoft Graph OAuth for OneDrive (personal Microsoft accounts by default)."""
SCOPES = [ # Reserved scopes that must NOT be sent on token or silent calls
"offline_access", RESERVED_SCOPES = {"openid", "profile", "offline_access"}
"Files.Read.All",
]
# For PERSONAL Microsoft Accounts (OneDrive consumer):
# - Use AUTH_SCOPES for interactive auth (consent + refresh token issuance)
# - Use RESOURCE_SCOPES for acquire_token_silent / refresh paths
AUTH_SCOPES = ["User.Read", "Files.Read.All", "offline_access"]
RESOURCE_SCOPES = ["User.Read", "Files.Read.All"]
SCOPES = AUTH_SCOPES # Backward-compat alias if something references .SCOPES
# Kept for reference; MSAL derives endpoints from `authority`
AUTH_ENDPOINT = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize" AUTH_ENDPOINT = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize"
TOKEN_ENDPOINT = "https://login.microsoftonline.com/common/oauth2/v2.0/token" TOKEN_ENDPOINT = "https://login.microsoftonline.com/common/oauth2/v2.0/token"
@ -21,168 +31,292 @@ class OneDriveOAuth:
client_id: str, client_id: str,
client_secret: str, client_secret: str,
token_file: str = "onedrive_token.json", token_file: str = "onedrive_token.json",
authority: str = "https://login.microsoftonline.com/common",
allow_json_refresh: bool = True,
): ):
"""
Initialize OneDriveOAuth.
Args:
client_id: Azure AD application (client) ID.
client_secret: Azure AD application client secret.
token_file: Path to persisted token cache file (MSAL cache format).
authority: Usually "https://login.microsoftonline.com/common" for MSA + org,
or tenant-specific for work/school.
allow_json_refresh: If True, permit one-time migration from legacy flat JSON
{"access_token","refresh_token",...}. Otherwise refuse it.
"""
self.client_id = client_id self.client_id = client_id
self.client_secret = client_secret self.client_secret = client_secret
self.token_file = token_file self.token_file = token_file
self._tokens = None self.authority = authority
self._load_tokens() self.allow_json_refresh = allow_json_refresh
self.token_cache = msal.SerializableTokenCache()
self._current_account = None
def _load_tokens(self): # Initialize MSAL Confidential Client
"""Load tokens from file""" self.app = msal.ConfidentialClientApplication(
client_id=self.client_id,
client_credential=self.client_secret,
authority=self.authority,
token_cache=self.token_cache,
)
async def load_credentials(self) -> bool:
"""Load existing credentials from token file (async)."""
try:
logger.debug(f"OneDrive OAuth loading credentials from: {self.token_file}")
if os.path.exists(self.token_file): if os.path.exists(self.token_file):
with open(self.token_file, "r") as f: logger.debug(f"Token file exists, reading: {self.token_file}")
self._tokens = json.loads(f.read())
print(f"Loaded tokens from {self.token_file}")
else:
print(f"No token file found at {self.token_file}")
async def _save_tokens(self): # Read the token file
"""Save tokens to file""" async with aiofiles.open(self.token_file, "r") as f:
if self._tokens: cache_data = await f.read()
async with aiofiles.open(self.token_file, "w") as f: logger.debug(f"Read {len(cache_data)} chars from token file")
await f.write(json.dumps(self._tokens, indent=2))
def _is_token_expired(self) -> bool: if cache_data.strip():
"""Check if current access token is expired""" # 1) Try legacy flat JSON first
if not self._tokens or 'expiry' not in self._tokens:
return True
expiry_str = self._tokens['expiry']
# Handle different expiry formats
try: try:
if expiry_str.endswith('Z'): json_data = json.loads(cache_data)
expiry_dt = datetime.fromisoformat(expiry_str[:-1]) if isinstance(json_data, dict) and "refresh_token" in json_data:
if self.allow_json_refresh:
logger.debug(
"Found legacy JSON refresh_token and allow_json_refresh=True; attempting migration refresh"
)
return await self._refresh_from_json_token(json_data)
else: else:
expiry_dt = datetime.fromisoformat(expiry_str) logger.warning(
"Token file contains a legacy JSON refresh_token, but allow_json_refresh=False. "
# Add 5-minute buffer "Delete the file and re-auth."
import datetime as dt )
now = datetime.now()
return now >= (expiry_dt - dt.timedelta(minutes=5))
except:
return True
async def _refresh_access_token(self) -> bool:
"""Refresh the access token using refresh token"""
if not self._tokens or 'refresh_token' not in self._tokens:
return False return False
except json.JSONDecodeError:
logger.debug("Token file is not flat JSON; attempting MSAL cache format")
data = { # 2) Try MSAL cache format
'client_id': self.client_id, logger.debug("Attempting MSAL cache deserialization")
'client_secret': self.client_secret, self.token_cache.deserialize(cache_data)
'refresh_token': self._tokens['refresh_token'],
'grant_type': 'refresh_token',
'scope': ' '.join(self.SCOPES)
}
async with httpx.AsyncClient() as client: # Get accounts from loaded cache
try: accounts = self.app.get_accounts()
response = await client.post(self.TOKEN_ENDPOINT, data=data) logger.debug(f"Found {len(accounts)} accounts in MSAL cache")
response.raise_for_status() if accounts:
token_data = response.json() self._current_account = accounts[0]
logger.debug(f"Set current account: {self._current_account.get('username', 'no username')}")
# Update tokens # Use RESOURCE_SCOPES (no reserved scopes) for silent acquisition
self._tokens['token'] = token_data['access_token'] result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=self._current_account)
if 'refresh_token' in token_data: logger.debug(f"Silent token acquisition result keys: {list(result.keys()) if result else 'None'}")
self._tokens['refresh_token'] = token_data['refresh_token'] if result and "access_token" in result:
logger.debug("Silent token acquisition successful")
# Calculate expiry await self.save_cache()
expires_in = token_data.get('expires_in', 3600)
import datetime as dt
expiry = datetime.now() + dt.timedelta(seconds=expires_in)
self._tokens['expiry'] = expiry.isoformat()
await self._save_tokens()
print("Access token refreshed successfully")
return True return True
else:
error_msg = (result or {}).get("error") or "No result"
logger.warning(f"Silent token acquisition failed: {error_msg}")
else:
logger.debug(f"Token file {self.token_file} is empty")
else:
logger.debug(f"Token file does not exist: {self.token_file}")
return False
except Exception as e: except Exception as e:
print(f"Failed to refresh token: {e}") logger.error(f"Failed to load OneDrive credentials: {e}")
import traceback
traceback.print_exc()
return False return False
async def is_authenticated(self) -> bool: async def _refresh_from_json_token(self, token_data: dict) -> bool:
"""Check if we have valid credentials""" """
if not self._tokens: Use refresh token from a legacy JSON file to get new tokens (one-time migration path).
Prefer using an MSAL cache file and acquire_token_silent(); this path is only for migrating older files.
"""
try:
refresh_token = token_data.get("refresh_token")
if not refresh_token:
logger.error("No refresh_token found in JSON file - cannot refresh")
logger.error("You must re-authenticate interactively to obtain a valid token")
return False return False
# If token is expired, try to refresh # Use only RESOURCE_SCOPES when refreshing (no reserved scopes)
if self._is_token_expired(): refresh_scopes = [s for s in self.RESOURCE_SCOPES if s not in self.RESERVED_SCOPES]
print("Token expired, attempting refresh...") logger.debug(f"Using refresh token; refresh scopes = {refresh_scopes}")
if await self._refresh_access_token():
return True
else:
return False
result = self.app.acquire_token_by_refresh_token(
refresh_token=refresh_token,
scopes=refresh_scopes,
)
if result and "access_token" in result:
logger.debug("Successfully refreshed token via legacy JSON path")
await self.save_cache()
accounts = self.app.get_accounts()
logger.debug(f"After refresh, found {len(accounts)} accounts")
if accounts:
self._current_account = accounts[0]
logger.debug(f"Set current account after refresh: {self._current_account.get('username', 'no username')}")
return True return True
def get_access_token(self) -> str: # Error handling
"""Get current access token""" err = (result or {}).get("error_description") or (result or {}).get("error") or "Unknown error"
if not self._tokens or 'token' not in self._tokens: logger.error(f"Refresh token failed: {err}")
raise ValueError("No access token available")
if self._is_token_expired(): if any(code in err for code in ("AADSTS70000", "invalid_grant", "interaction_required")):
raise ValueError("Access token expired and refresh failed") logger.warning(
"Refresh denied due to unauthorized/expired scopes or invalid grant. "
"Delete the token file and perform interactive sign-in with correct scopes."
)
return self._tokens['token'] return False
async def revoke_credentials(self): except Exception as e:
"""Clear tokens""" logger.error(f"Exception during refresh from JSON token: {e}")
self._tokens = None import traceback
if os.path.exists(self.token_file): traceback.print_exc()
os.remove(self.token_file) return False
# Keep these methods for compatibility with your existing OAuth flow async def save_cache(self):
def create_authorization_url(self, redirect_uri: str) -> str: """Persist the token cache to file."""
"""Create authorization URL for OAuth flow""" try:
from urllib.parse import urlencode # Ensure parent directory exists
parent = os.path.dirname(os.path.abspath(self.token_file))
if parent and not os.path.exists(parent):
os.makedirs(parent, exist_ok=True)
params = { cache_data = self.token_cache.serialize()
'client_id': self.client_id, if cache_data:
'response_type': 'code', async with aiofiles.open(self.token_file, "w") as f:
'redirect_uri': redirect_uri, await f.write(cache_data)
'scope': ' '.join(self.SCOPES), logger.debug(f"Token cache saved to {self.token_file}")
'response_mode': 'query' except Exception as e:
logger.error(f"Failed to save token cache: {e}")
def create_authorization_url(self, redirect_uri: str, state: Optional[str] = None) -> str:
"""Create authorization URL for OAuth flow."""
# Store redirect URI for later use in callback
self._redirect_uri = redirect_uri
kwargs: Dict[str, Any] = {
# Interactive auth includes offline_access
"scopes": self.AUTH_SCOPES,
"redirect_uri": redirect_uri,
"prompt": "consent", # ensure refresh token on first run
} }
if state:
kwargs["state"] = state # Optional CSRF protection
auth_url = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize" auth_url = self.app.get_authorization_request_url(**kwargs)
return f"{auth_url}?{urlencode(params)}"
logger.debug(f"Generated auth URL: {auth_url}")
logger.debug(f"Auth scopes: {self.AUTH_SCOPES}")
return auth_url
async def handle_authorization_callback( async def handle_authorization_callback(
self, authorization_code: str, redirect_uri: str self, authorization_code: str, redirect_uri: str
) -> bool: ) -> bool:
"""Handle OAuth callback and exchange code for tokens""" """Handle OAuth callback and exchange code for tokens."""
data = {
'client_id': self.client_id,
'client_secret': self.client_secret,
'code': authorization_code,
'grant_type': 'authorization_code',
'redirect_uri': redirect_uri,
'scope': ' '.join(self.SCOPES)
}
async with httpx.AsyncClient() as client:
try: try:
response = await client.post(self.TOKEN_ENDPOINT, data=data) result = self.app.acquire_token_by_authorization_code(
response.raise_for_status() authorization_code,
token_data = response.json() scopes=self.AUTH_SCOPES, # same as authorize step
redirect_uri=redirect_uri,
)
# Store tokens in our format if result and "access_token" in result:
import datetime as dt accounts = self.app.get_accounts()
expires_in = token_data.get('expires_in', 3600) if accounts:
expiry = datetime.now() + dt.timedelta(seconds=expires_in) self._current_account = accounts[0]
self._tokens = { await self.save_cache()
'token': token_data['access_token'], logger.info("OneDrive OAuth authorization successful")
'refresh_token': token_data['refresh_token'],
'scopes': self.SCOPES,
'expiry': expiry.isoformat()
}
await self._save_tokens()
print("Authorization successful, tokens saved")
return True return True
except Exception as e: error_msg = (result or {}).get("error_description") or (result or {}).get("error") or "Unknown error"
print(f"Authorization failed: {e}") logger.error(f"OneDrive OAuth authorization failed: {error_msg}")
return False return False
except Exception as e:
logger.error(f"Exception during OneDrive OAuth authorization: {e}")
return False
async def is_authenticated(self) -> bool:
"""Check if we have valid credentials."""
try:
# First try to load credentials if we haven't already
if not self._current_account:
await self.load_credentials()
# Try to get a token (MSAL will refresh if needed)
if self._current_account:
result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=self._current_account)
if result and "access_token" in result:
return True
else:
error_msg = (result or {}).get("error") or "No result returned"
logger.debug(f"Token acquisition failed for current account: {error_msg}")
# Fallback: try without specific account
result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=None)
if result and "access_token" in result:
accounts = self.app.get_accounts()
if accounts:
self._current_account = accounts[0]
return True
return False
except Exception as e:
logger.error(f"Authentication check failed: {e}")
return False
def get_access_token(self) -> str:
"""Get an access token for Microsoft Graph."""
try:
# Try with current account first
if self._current_account:
result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=self._current_account)
if result and "access_token" in result:
return result["access_token"]
# Fallback: try without specific account
result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=None)
if result and "access_token" in result:
return result["access_token"]
# If we get here, authentication has failed
error_msg = (result or {}).get("error_description") or (result or {}).get("error") or "No valid authentication"
raise ValueError(f"Failed to acquire access token: {error_msg}")
except Exception as e:
logger.error(f"Failed to get access token: {e}")
raise
async def revoke_credentials(self):
"""Clear token cache and remove token file."""
try:
# Clear in-memory state
self._current_account = None
self.token_cache = msal.SerializableTokenCache()
# Recreate MSAL app with fresh cache
self.app = msal.ConfidentialClientApplication(
client_id=self.client_id,
client_credential=self.client_secret,
authority=self.authority,
token_cache=self.token_cache,
)
# Remove token file
if os.path.exists(self.token_file):
os.remove(self.token_file)
logger.info(f"Removed OneDrive token file: {self.token_file}")
except Exception as e:
logger.error(f"Failed to revoke OneDrive credentials: {e}")
def get_service(self) -> str:
"""Return an access token (Graph client is just the bearer)."""
return self.get_access_token()

View file

@ -1,241 +1,564 @@
import logging
from pathlib import Path from pathlib import Path
from typing import List, Dict, Any, Optional
from urllib.parse import urlparse
from datetime import datetime
import httpx import httpx
import uuid
from datetime import datetime, timedelta
from typing import Dict, List, Any, Optional
from ..base import BaseConnector, ConnectorDocument, DocumentACL from ..base import BaseConnector, ConnectorDocument, DocumentACL
from .oauth import SharePointOAuth from .oauth import SharePointOAuth
logger = logging.getLogger(__name__)
class SharePointConnector(BaseConnector): class SharePointConnector(BaseConnector):
"""SharePoint Sites connector using Microsoft Graph API""" """SharePoint connector using MSAL-based OAuth for authentication"""
# Required BaseConnector class attributes
CLIENT_ID_ENV_VAR = "MICROSOFT_GRAPH_OAUTH_CLIENT_ID" CLIENT_ID_ENV_VAR = "MICROSOFT_GRAPH_OAUTH_CLIENT_ID"
CLIENT_SECRET_ENV_VAR = "MICROSOFT_GRAPH_OAUTH_CLIENT_SECRET" CLIENT_SECRET_ENV_VAR = "MICROSOFT_GRAPH_OAUTH_CLIENT_SECRET"
# Connector metadata # Connector metadata
CONNECTOR_NAME = "SharePoint" CONNECTOR_NAME = "SharePoint"
CONNECTOR_DESCRIPTION = "Connect to SharePoint sites to sync team documents" CONNECTOR_DESCRIPTION = "Connect to SharePoint to sync documents and files"
CONNECTOR_ICON = "sharepoint" CONNECTOR_ICON = "sharepoint"
def __init__(self, config: Dict[str, Any]): def __init__(self, config: Dict[str, Any]):
super().__init__(config) super().__init__(config) # Fix: Call parent init first
def __init__(self, config: Dict[str, Any]):
logger.debug(f"SharePoint connector __init__ called with config type: {type(config)}")
logger.debug(f"SharePoint connector __init__ config value: {config}")
# Ensure we always pass a valid config to the base class
if config is None:
logger.debug("Config was None, using empty dict")
config = {}
try:
logger.debug("Calling super().__init__")
super().__init__(config) # Now safe to call with empty dict instead of None
logger.debug("super().__init__ completed successfully")
except Exception as e:
logger.error(f"super().__init__ failed: {e}")
raise
# Initialize with defaults that allow the connector to be listed
self.client_id = None
self.client_secret = None
self.tenant_id = config.get("tenant_id", "common")
self.sharepoint_url = config.get("sharepoint_url")
self.redirect_uri = config.get("redirect_uri", "http://localhost")
# Try to get credentials, but don't fail if they're missing
try:
logger.debug("Attempting to get client_id")
self.client_id = self.get_client_id()
logger.debug(f"Got client_id: {self.client_id is not None}")
except Exception as e:
logger.debug(f"Failed to get client_id: {e}")
pass # Credentials not available, that's OK for listing
try:
logger.debug("Attempting to get client_secret")
self.client_secret = self.get_client_secret()
logger.debug(f"Got client_secret: {self.client_secret is not None}")
except Exception as e:
logger.debug(f"Failed to get client_secret: {e}")
pass # Credentials not available, that's OK for listing
# Token file setup
project_root = Path(__file__).resolve().parent.parent.parent.parent project_root = Path(__file__).resolve().parent.parent.parent.parent
token_file = config.get("token_file") or str(project_root / "onedrive_token.json") token_file = config.get("token_file") or str(project_root / "sharepoint_token.json")
Path(token_file).parent.mkdir(parents=True, exist_ok=True)
# Only initialize OAuth if we have credentials
if self.client_id and self.client_secret:
connection_id = config.get("connection_id", "default")
# Use token_file from config if provided, otherwise generate one
if config.get("token_file"):
oauth_token_file = config["token_file"]
else:
oauth_token_file = f"sharepoint_token_{connection_id}.json"
authority = f"https://login.microsoftonline.com/{self.tenant_id}" if self.tenant_id != "common" else "https://login.microsoftonline.com/common"
self.oauth = SharePointOAuth( self.oauth = SharePointOAuth(
client_id=self.get_client_id(), client_id=self.client_id,
client_secret=self.get_client_secret(), client_secret=self.client_secret,
token_file=token_file, token_file=oauth_token_file,
authority=authority
) )
self.subscription_id = config.get("subscription_id") or config.get( else:
"webhook_channel_id" self.oauth = None
)
self.base_url = "https://graph.microsoft.com/v1.0"
# SharePoint site configuration # Track subscription ID for webhooks
self.site_id = config.get("site_id") # Required for SharePoint self._subscription_id: Optional[str] = None
async def authenticate(self) -> bool: # Add Graph API defaults similar to Google Drive flags
if await self.oauth.is_authenticated(): self._graph_api_version = "v1.0"
self._authenticated = True self._default_params = {
return True "$select": "id,name,size,lastModifiedDateTime,createdDateTime,webUrl,file,folder,@microsoft.graph.downloadUrl"
return False
async def setup_subscription(self) -> str:
if not self._authenticated:
raise ValueError("Not authenticated")
webhook_url = self.config.get("webhook_url")
if not webhook_url:
raise ValueError("webhook_url required in config for subscriptions")
expiration = (datetime.utcnow() + timedelta(days=2)).isoformat() + "Z"
body = {
"changeType": "created,updated,deleted",
"notificationUrl": webhook_url,
"resource": f"/sites/{self.site_id}/drive/root",
"expirationDateTime": expiration,
"clientState": str(uuid.uuid4()),
} }
@property
def _graph_base_url(self) -> str:
"""Base URL for Microsoft Graph API calls"""
return f"https://graph.microsoft.com/{self._graph_api_version}"
def emit(self, doc: ConnectorDocument) -> None:
"""
Emit a ConnectorDocument instance.
Override this method to integrate with your ingestion pipeline.
"""
logger.debug(f"Emitting SharePoint document: {doc.id} ({doc.filename})")
async def authenticate(self) -> bool:
"""Test authentication - BaseConnector interface"""
logger.debug(f"SharePoint authenticate() called, oauth is None: {self.oauth is None}")
try:
if not self.oauth:
logger.debug("SharePoint authentication failed: OAuth not initialized")
self._authenticated = False
return False
logger.debug("Loading SharePoint credentials...")
# Try to load existing credentials first
load_result = await self.oauth.load_credentials()
logger.debug(f"Load credentials result: {load_result}")
logger.debug("Checking SharePoint authentication status...")
authenticated = await self.oauth.is_authenticated()
logger.debug(f"SharePoint is_authenticated result: {authenticated}")
self._authenticated = authenticated
return authenticated
except Exception as e:
logger.error(f"SharePoint authentication failed: {e}")
import traceback
traceback.print_exc()
self._authenticated = False
return False
def get_auth_url(self) -> str:
"""Get OAuth authorization URL"""
if not self.oauth:
raise RuntimeError("SharePoint OAuth not initialized - missing credentials")
return self.oauth.create_authorization_url(self.redirect_uri)
async def handle_oauth_callback(self, auth_code: str) -> Dict[str, Any]:
"""Handle OAuth callback"""
if not self.oauth:
raise RuntimeError("SharePoint OAuth not initialized - missing credentials")
try:
success = await self.oauth.handle_authorization_callback(auth_code, self.redirect_uri)
if success:
self._authenticated = True
return {"status": "success"}
else:
raise ValueError("OAuth callback failed")
except Exception as e:
logger.error(f"OAuth callback failed: {e}")
raise
def sync_once(self) -> None:
"""
Perform a one-shot sync of SharePoint files and emit documents.
This method mirrors the Google Drive connector's sync_once functionality.
"""
import asyncio
async def _async_sync():
try:
# Get list of files
file_list = await self.list_files(max_files=1000) # Adjust as needed
files = file_list.get("files", [])
for file_info in files:
try:
file_id = file_info.get("id")
if not file_id:
continue
# Get full document content
doc = await self.get_file_content(file_id)
self.emit(doc)
except Exception as e:
logger.error(f"Failed to sync SharePoint file {file_info.get('name', 'unknown')}: {e}")
continue
except Exception as e:
logger.error(f"SharePoint sync_once failed: {e}")
raise
# Run the async sync
if hasattr(asyncio, 'run'):
asyncio.run(_async_sync())
else:
# Python < 3.7 compatibility
loop = asyncio.get_event_loop()
loop.run_until_complete(_async_sync())
async def setup_subscription(self) -> str:
"""Set up real-time subscription for file changes - BaseConnector interface"""
webhook_url = self.config.get('webhook_url')
if not webhook_url:
logger.warning("No webhook URL configured, skipping SharePoint subscription setup")
return "no-webhook-configured"
try:
# Ensure we're authenticated
if not await self.authenticate():
raise RuntimeError("SharePoint authentication failed during subscription setup")
token = self.oauth.get_access_token() token = self.oauth.get_access_token()
# Microsoft Graph subscription for SharePoint site
site_info = self._parse_sharepoint_url()
if site_info:
resource = f"sites/{site_info['host_name']}:/sites/{site_info['site_name']}:/drive/root"
else:
resource = "/me/drive/root"
subscription_data = {
"changeType": "created,updated,deleted",
"notificationUrl": f"{webhook_url}/webhook/sharepoint",
"resource": resource,
"expirationDateTime": self._get_subscription_expiry(),
"clientState": f"sharepoint_{self.tenant_id}"
}
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json"
}
url = f"{self._graph_base_url}/subscriptions"
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
resp = await client.post( response = await client.post(url, json=subscription_data, headers=headers, timeout=30)
f"{self.base_url}/subscriptions", response.raise_for_status()
json=body,
headers={"Authorization": f"Bearer {token}"},
)
resp.raise_for_status()
data = resp.json()
self.subscription_id = data["id"] result = response.json()
return self.subscription_id subscription_id = result.get("id")
async def list_files( if subscription_id:
self, page_token: Optional[str] = None, limit: int = 100 self._subscription_id = subscription_id
) -> Dict[str, Any]: logger.info(f"SharePoint subscription created: {subscription_id}")
if not self._authenticated: return subscription_id
raise ValueError("Not authenticated") else:
raise ValueError("No subscription ID returned from Microsoft Graph")
except Exception as e:
logger.error(f"Failed to setup SharePoint subscription: {e}")
raise
def _get_subscription_expiry(self) -> str:
"""Get subscription expiry time (max 3 days for Graph API)"""
from datetime import datetime, timedelta
expiry = datetime.utcnow() + timedelta(days=3) # 3 days max for Graph
return expiry.strftime("%Y-%m-%dT%H:%M:%S.%fZ")
def _parse_sharepoint_url(self) -> Optional[Dict[str, str]]:
"""Parse SharePoint URL to extract site information for Graph API"""
if not self.sharepoint_url:
return None
try:
parsed = urlparse(self.sharepoint_url)
# Extract hostname and site name from URL like: https://contoso.sharepoint.com/sites/teamsite
host_name = parsed.netloc
path_parts = parsed.path.strip('/').split('/')
if len(path_parts) >= 2 and path_parts[0] == 'sites':
site_name = path_parts[1]
return {
"host_name": host_name,
"site_name": site_name
}
except Exception as e:
logger.warning(f"Could not parse SharePoint URL {self.sharepoint_url}: {e}")
return None
async def list_files(self, page_token: Optional[str] = None, max_files: Optional[int] = None) -> Dict[str, Any]:
"""List all files using Microsoft Graph API - BaseConnector interface"""
try:
# Ensure authentication
if not await self.authenticate():
raise RuntimeError("SharePoint authentication failed during file listing")
files = []
max_files_value = max_files if max_files is not None else 100
# Build Graph API URL for the site or fallback to user's OneDrive
site_info = self._parse_sharepoint_url()
if site_info:
base_url = f"{self._graph_base_url}/sites/{site_info['host_name']}:/sites/{site_info['site_name']}:/drive/root/children"
else:
base_url = f"{self._graph_base_url}/me/drive/root/children"
params = dict(self._default_params)
params["$top"] = max_files_value
params = {"$top": str(limit)}
if page_token: if page_token:
params["$skiptoken"] = page_token params["$skiptoken"] = page_token
token = self.oauth.get_access_token() response = await self._make_graph_request(base_url, params=params)
async with httpx.AsyncClient() as client: data = response.json()
resp = await client.get(
f"{self.base_url}/sites/{self.site_id}/drive/root/children",
params=params,
headers={"Authorization": f"Bearer {token}"},
)
resp.raise_for_status()
data = resp.json()
files = [] items = data.get("value", [])
for item in data.get("value", []): for item in items:
# Only include files, not folders
if item.get("file"): if item.get("file"):
files.append( files.append({
{ "id": item.get("id", ""),
"id": item["id"], "name": item.get("name", ""),
"name": item["name"], "path": f"/drive/items/{item.get('id')}",
"mimeType": item.get("file", {}).get( "size": int(item.get("size", 0)),
"mimeType", "application/octet-stream" "modified": item.get("lastModifiedDateTime"),
), "created": item.get("createdDateTime"),
"webViewLink": item.get("webUrl"), "mime_type": item.get("file", {}).get("mimeType", self._get_mime_type(item.get("name", ""))),
"createdTime": item.get("createdDateTime"), "url": item.get("webUrl", ""),
"modifiedTime": item.get("lastModifiedDateTime"), "download_url": item.get("@microsoft.graph.downloadUrl")
} })
)
next_token = None # Check for next page
next_page_token = None
next_link = data.get("@odata.nextLink") next_link = data.get("@odata.nextLink")
if next_link: if next_link:
from urllib.parse import urlparse, parse_qs from urllib.parse import urlparse, parse_qs
parsed = urlparse(next_link) parsed = urlparse(next_link)
next_token = parse_qs(parsed.query).get("$skiptoken", [None])[0] query_params = parse_qs(parsed.query)
if "$skiptoken" in query_params:
next_page_token = query_params["$skiptoken"][0]
return {"files": files, "nextPageToken": next_token} return {
"files": files,
"next_page_token": next_page_token
}
except Exception as e:
logger.error(f"Failed to list SharePoint files: {e}")
return {"files": [], "next_page_token": None} # Return empty result instead of raising
async def get_file_content(self, file_id: str) -> ConnectorDocument: async def get_file_content(self, file_id: str) -> ConnectorDocument:
if not self._authenticated: """Get file content and metadata - BaseConnector interface"""
raise ValueError("Not authenticated") try:
# Ensure authentication
if not await self.authenticate():
raise RuntimeError("SharePoint authentication failed during file content retrieval")
# First get file metadata using Graph API
file_metadata = await self._get_file_metadata_by_id(file_id)
if not file_metadata:
raise ValueError(f"File not found: {file_id}")
# Download file content
download_url = file_metadata.get("download_url")
if download_url:
content = await self._download_file_from_url(download_url)
else:
content = await self._download_file_content(file_id)
# Create ACL from metadata
acl = DocumentACL(
owner="", # Graph API requires additional calls for detailed permissions
user_permissions={},
group_permissions={}
)
# Parse dates
modified_time = self._parse_graph_date(file_metadata.get("modified"))
created_time = self._parse_graph_date(file_metadata.get("created"))
return ConnectorDocument(
id=file_id,
filename=file_metadata.get("name", ""),
mimetype=file_metadata.get("mime_type", "application/octet-stream"),
content=content,
source_url=file_metadata.get("url", ""),
acl=acl,
modified_time=modified_time,
created_time=created_time,
metadata={
"sharepoint_path": file_metadata.get("path", ""),
"sharepoint_url": self.sharepoint_url,
"size": file_metadata.get("size", 0)
}
)
except Exception as e:
logger.error(f"Failed to get SharePoint file content {file_id}: {e}")
raise
async def _get_file_metadata_by_id(self, file_id: str) -> Optional[Dict[str, Any]]:
"""Get file metadata by ID using Graph API"""
try:
# Try site-specific path first, then fallback to user drive
site_info = self._parse_sharepoint_url()
if site_info:
url = f"{self._graph_base_url}/sites/{site_info['host_name']}:/sites/{site_info['site_name']}:/drive/items/{file_id}"
else:
url = f"{self._graph_base_url}/me/drive/items/{file_id}"
params = dict(self._default_params)
response = await self._make_graph_request(url, params=params)
item = response.json()
if item.get("file"):
return {
"id": file_id,
"name": item.get("name", ""),
"path": f"/drive/items/{file_id}",
"size": int(item.get("size", 0)),
"modified": item.get("lastModifiedDateTime"),
"created": item.get("createdDateTime"),
"mime_type": item.get("file", {}).get("mimeType", self._get_mime_type(item.get("name", ""))),
"url": item.get("webUrl", ""),
"download_url": item.get("@microsoft.graph.downloadUrl")
}
return None
except Exception as e:
logger.error(f"Failed to get file metadata for {file_id}: {e}")
return None
async def _download_file_content(self, file_id: str) -> bytes:
"""Download file content by file ID using Graph API"""
try:
site_info = self._parse_sharepoint_url()
if site_info:
url = f"{self._graph_base_url}/sites/{site_info['host_name']}:/sites/{site_info['site_name']}:/drive/items/{file_id}/content"
else:
url = f"{self._graph_base_url}/me/drive/items/{file_id}/content"
token = self.oauth.get_access_token() token = self.oauth.get_access_token()
headers = {"Authorization": f"Bearer {token}"} headers = {"Authorization": f"Bearer {token}"}
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
meta_resp = await client.get( response = await client.get(url, headers=headers, timeout=60)
f"{self.base_url}/sites/{self.site_id}/drive/items/{file_id}", response.raise_for_status()
headers=headers, return response.content
)
meta_resp.raise_for_status()
metadata = meta_resp.json()
content_resp = await client.get( except Exception as e:
f"{self.base_url}/sites/{self.site_id}/drive/items/{file_id}/content", logger.error(f"Failed to download file content for {file_id}: {e}")
headers=headers, raise
)
content = content_resp.content
# Handle the possibility of this being a redirect async def _download_file_from_url(self, download_url: str) -> bytes:
if content_resp.status_code in (301, 302, 303, 307, 308): """Download file content from direct download URL"""
redirect_url = content_resp.headers.get("Location") try:
if redirect_url: async with httpx.AsyncClient() as client:
content_resp = await client.get(redirect_url) response = await client.get(download_url, timeout=60)
content_resp.raise_for_status() response.raise_for_status()
content = content_resp.content return response.content
except Exception as e:
logger.error(f"Failed to download from URL {download_url}: {e}")
raise
def _parse_graph_date(self, date_str: Optional[str]) -> datetime:
"""Parse Microsoft Graph date string to datetime"""
if not date_str:
return datetime.now()
try:
if date_str.endswith('Z'):
return datetime.fromisoformat(date_str[:-1]).replace(tzinfo=None)
else: else:
content_resp.raise_for_status() return datetime.fromisoformat(date_str.replace('T', ' '))
except (ValueError, AttributeError):
return datetime.now()
perm_resp = await client.get( async def _make_graph_request(self, url: str, method: str = "GET",
f"{self.base_url}/sites/{self.site_id}/drive/items/{file_id}/permissions", data: Optional[Dict] = None, params: Optional[Dict] = None) -> httpx.Response:
headers=headers, """Make authenticated API request to Microsoft Graph"""
) token = self.oauth.get_access_token()
perm_resp.raise_for_status() headers = {
permissions = perm_resp.json() "Authorization": f"Bearer {token}",
"Content-Type": "application/json"
}
acl = self._parse_permissions(metadata, permissions) async with httpx.AsyncClient() as client:
modified = datetime.fromisoformat( if method.upper() == "GET":
metadata["lastModifiedDateTime"].replace("Z", "+00:00") response = await client.get(url, headers=headers, params=params, timeout=30)
).replace(tzinfo=None) elif method.upper() == "POST":
created = datetime.fromisoformat( response = await client.post(url, headers=headers, json=data, timeout=30)
metadata["createdDateTime"].replace("Z", "+00:00") elif method.upper() == "DELETE":
).replace(tzinfo=None) response = await client.delete(url, headers=headers, timeout=30)
else:
raise ValueError(f"Unsupported HTTP method: {method}")
document = ConnectorDocument( response.raise_for_status()
id=metadata["id"], return response
filename=metadata["name"],
mimetype=metadata.get("file", {}).get(
"mimeType", "application/octet-stream"
),
content=content,
source_url=metadata.get("webUrl"),
acl=acl,
modified_time=modified,
created_time=created,
metadata={"size": metadata.get("size")},
)
return document
def _parse_permissions( def _get_mime_type(self, filename: str) -> str:
self, metadata: Dict[str, Any], permissions: Dict[str, Any] """Get MIME type based on file extension"""
) -> DocumentACL: import mimetypes
acl = DocumentACL() mime_type, _ = mimetypes.guess_type(filename)
owner = metadata.get("createdBy", {}).get("user", {}).get("email") return mime_type or "application/octet-stream"
if owner:
acl.owner = owner
for perm in permissions.get("value", []):
role = perm.get("roles", ["read"])[0]
grantee = perm.get("grantedToV2") or perm.get("grantedTo")
if not grantee:
continue
user = grantee.get("user")
if user and user.get("email"):
acl.user_permissions[user["email"]] = role
group = grantee.get("group")
if group and group.get("email"):
acl.group_permissions[group["email"]] = role
return acl
def handle_webhook_validation( # Webhook methods - BaseConnector interface
self, request_method: str, headers: Dict[str, str], query_params: Dict[str, str] def handle_webhook_validation(self, request_method: str, headers: Dict[str, str],
) -> Optional[str]: query_params: Dict[str, str]) -> Optional[str]:
"""Handle Microsoft Graph webhook validation""" """Handle webhook validation (Graph API specific)"""
if request_method == "GET": if request_method == "POST" and "validationToken" in query_params:
validation_token = query_params.get("validationtoken") or query_params.get( return query_params["validationToken"]
"validationToken"
)
if validation_token:
return validation_token
return None return None
def extract_webhook_channel_id( def extract_webhook_channel_id(self, payload: Dict[str, Any],
self, payload: Dict[str, Any], headers: Dict[str, str] headers: Dict[str, str]) -> Optional[str]:
) -> Optional[str]: """Extract channel/subscription ID from webhook payload"""
"""Extract SharePoint subscription ID from webhook payload""" notifications = payload.get("value", [])
values = payload.get("value", []) if notifications:
return values[0].get("subscriptionId") if values else None return notifications[0].get("subscriptionId")
return None
async def handle_webhook(self, payload: Dict[str, Any]) -> List[str]: async def handle_webhook(self, payload: Dict[str, Any]) -> List[str]:
values = payload.get("value", []) """Handle webhook notification and return affected file IDs"""
file_ids = [] affected_files = []
for item in values:
resource_data = item.get("resourceData", {})
file_id = resource_data.get("id")
if file_id:
file_ids.append(file_id)
return file_ids
async def cleanup_subscription( # Process Microsoft Graph webhook payload
self, subscription_id: str, resource_id: str = None notifications = payload.get("value", [])
) -> bool: for notification in notifications:
if not self._authenticated: resource = notification.get("resource")
if resource and "/drive/items/" in resource:
file_id = resource.split("/drive/items/")[-1]
affected_files.append(file_id)
return affected_files
async def cleanup_subscription(self, subscription_id: str) -> bool:
"""Clean up subscription - BaseConnector interface"""
if subscription_id == "no-webhook-configured":
logger.info("No subscription to cleanup (webhook was not configured)")
return True
try:
# Ensure authentication
if not await self.authenticate():
logger.error("SharePoint authentication failed during subscription cleanup")
return False return False
token = self.oauth.get_access_token() token = self.oauth.get_access_token()
headers = {"Authorization": f"Bearer {token}"}
url = f"{self._graph_base_url}/subscriptions/{subscription_id}"
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
resp = await client.delete( response = await client.delete(url, headers=headers, timeout=30)
f"{self.base_url}/subscriptions/{subscription_id}",
headers={"Authorization": f"Bearer {token}"}, if response.status_code in [200, 204, 404]:
) logger.info(f"SharePoint subscription {subscription_id} cleaned up successfully")
return resp.status_code in (200, 204) return True
else:
logger.warning(f"Unexpected response cleaning up subscription: {response.status_code}")
return False
except Exception as e:
logger.error(f"Failed to cleanup SharePoint subscription {subscription_id}: {e}")
return False

View file

@ -1,19 +1,28 @@
import os import os
import json import json
import logging
from typing import Optional, Dict, Any
import aiofiles import aiofiles
from datetime import datetime import msal
import httpx
logger = logging.getLogger(__name__)
class SharePointOAuth: class SharePointOAuth:
"""Direct token management for SharePoint, bypassing MSAL cache format""" """Handles Microsoft Graph OAuth authentication flow following Google Drive pattern."""
SCOPES = [ # Reserved scopes that must NOT be sent on token or silent calls
"offline_access", RESERVED_SCOPES = {"openid", "profile", "offline_access"}
"Files.Read.All",
"Sites.Read.All",
]
# For PERSONAL Microsoft Accounts (OneDrive consumer):
# - Use AUTH_SCOPES for interactive auth (consent + refresh token issuance)
# - Use RESOURCE_SCOPES for acquire_token_silent / refresh paths
AUTH_SCOPES = ["User.Read", "Files.Read.All", "offline_access"]
RESOURCE_SCOPES = ["User.Read", "Files.Read.All"]
SCOPES = AUTH_SCOPES # Backward compatibility alias
# Kept for reference; MSAL derives endpoints from `authority`
AUTH_ENDPOINT = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize" AUTH_ENDPOINT = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize"
TOKEN_ENDPOINT = "https://login.microsoftonline.com/common/oauth2/v2.0/token" TOKEN_ENDPOINT = "https://login.microsoftonline.com/common/oauth2/v2.0/token"
@ -22,173 +31,299 @@ class SharePointOAuth:
client_id: str, client_id: str,
client_secret: str, client_secret: str,
token_file: str = "sharepoint_token.json", token_file: str = "sharepoint_token.json",
authority: str = "https://login.microsoftonline.com/common", # Keep for compatibility authority: str = "https://login.microsoftonline.com/common",
allow_json_refresh: bool = True,
): ):
"""
Initialize SharePointOAuth.
Args:
client_id: Azure AD application (client) ID.
client_secret: Azure AD application client secret.
token_file: Path to persisted token cache file (MSAL cache format).
authority: Usually "https://login.microsoftonline.com/common" for MSA + org,
or tenant-specific for work/school.
allow_json_refresh: If True, permit one-time migration from legacy flat JSON
{"access_token","refresh_token",...}. Otherwise refuse it.
"""
self.client_id = client_id self.client_id = client_id
self.client_secret = client_secret self.client_secret = client_secret
self.token_file = token_file self.token_file = token_file
self.authority = authority # Keep for compatibility but not used self.authority = authority
self._tokens = None self.allow_json_refresh = allow_json_refresh
self._load_tokens() self.token_cache = msal.SerializableTokenCache()
self._current_account = None
def _load_tokens(self): # Initialize MSAL Confidential Client
"""Load tokens from file""" self.app = msal.ConfidentialClientApplication(
client_id=self.client_id,
client_credential=self.client_secret,
authority=self.authority,
token_cache=self.token_cache,
)
async def load_credentials(self) -> bool:
"""Load existing credentials from token file (async)."""
try:
logger.debug(f"SharePoint OAuth loading credentials from: {self.token_file}")
if os.path.exists(self.token_file): if os.path.exists(self.token_file):
with open(self.token_file, "r") as f: logger.debug(f"Token file exists, reading: {self.token_file}")
self._tokens = json.loads(f.read())
print(f"Loaded tokens from {self.token_file}")
else:
print(f"No token file found at {self.token_file}")
async def save_cache(self): # Read the token file
"""Persist tokens to file (renamed for compatibility)""" async with aiofiles.open(self.token_file, "r") as f:
await self._save_tokens() cache_data = await f.read()
logger.debug(f"Read {len(cache_data)} chars from token file")
async def _save_tokens(self): if cache_data.strip():
"""Save tokens to file""" # 1) Try legacy flat JSON first
if self._tokens:
async with aiofiles.open(self.token_file, "w") as f:
await f.write(json.dumps(self._tokens, indent=2))
def _is_token_expired(self) -> bool:
"""Check if current access token is expired"""
if not self._tokens or 'expiry' not in self._tokens:
return True
expiry_str = self._tokens['expiry']
# Handle different expiry formats
try: try:
if expiry_str.endswith('Z'): json_data = json.loads(cache_data)
expiry_dt = datetime.fromisoformat(expiry_str[:-1]) if isinstance(json_data, dict) and "refresh_token" in json_data:
if self.allow_json_refresh:
logger.debug(
"Found legacy JSON refresh_token and allow_json_refresh=True; attempting migration refresh"
)
return await self._refresh_from_json_token(json_data)
else: else:
expiry_dt = datetime.fromisoformat(expiry_str) logger.warning(
"Token file contains a legacy JSON refresh_token, but allow_json_refresh=False. "
# Add 5-minute buffer "Delete the file and re-auth."
import datetime as dt )
now = datetime.now()
return now >= (expiry_dt - dt.timedelta(minutes=5))
except:
return True
async def _refresh_access_token(self) -> bool:
"""Refresh the access token using refresh token"""
if not self._tokens or 'refresh_token' not in self._tokens:
return False return False
except json.JSONDecodeError:
logger.debug("Token file is not flat JSON; attempting MSAL cache format")
data = { # 2) Try MSAL cache format
'client_id': self.client_id, logger.debug("Attempting MSAL cache deserialization")
'client_secret': self.client_secret, self.token_cache.deserialize(cache_data)
'refresh_token': self._tokens['refresh_token'],
'grant_type': 'refresh_token',
'scope': ' '.join(self.SCOPES)
}
async with httpx.AsyncClient() as client: # Get accounts from loaded cache
try: accounts = self.app.get_accounts()
response = await client.post(self.TOKEN_ENDPOINT, data=data) logger.debug(f"Found {len(accounts)} accounts in MSAL cache")
response.raise_for_status() if accounts:
token_data = response.json() self._current_account = accounts[0]
logger.debug(f"Set current account: {self._current_account.get('username', 'no username')}")
# Update tokens # IMPORTANT: Use RESOURCE_SCOPES (no reserved scopes) for silent acquisition
self._tokens['token'] = token_data['access_token'] result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=self._current_account)
if 'refresh_token' in token_data: logger.debug(f"Silent token acquisition result keys: {list(result.keys()) if result else 'None'}")
self._tokens['refresh_token'] = token_data['refresh_token'] if result and "access_token" in result:
logger.debug("Silent token acquisition successful")
# Calculate expiry await self.save_cache()
expires_in = token_data.get('expires_in', 3600)
import datetime as dt
expiry = datetime.now() + dt.timedelta(seconds=expires_in)
self._tokens['expiry'] = expiry.isoformat()
await self._save_tokens()
print("Access token refreshed successfully")
return True return True
else:
error_msg = (result or {}).get("error") or "No result"
logger.warning(f"Silent token acquisition failed: {error_msg}")
else:
logger.debug(f"Token file {self.token_file} is empty")
else:
logger.debug(f"Token file does not exist: {self.token_file}")
return False
except Exception as e: except Exception as e:
print(f"Failed to refresh token: {e}") logger.error(f"Failed to load SharePoint credentials: {e}")
import traceback
traceback.print_exc()
return False return False
def create_authorization_url(self, redirect_uri: str) -> str: async def _refresh_from_json_token(self, token_data: dict) -> bool:
"""Create authorization URL for OAuth flow""" """
from urllib.parse import urlencode Use refresh token from a legacy JSON file to get new tokens (one-time migration path).
params = { Notes:
'client_id': self.client_id, - Prefer using an MSAL cache file and acquire_token_silent().
'response_type': 'code', - This path is only for migrating older refresh_token JSON files.
'redirect_uri': redirect_uri, """
'scope': ' '.join(self.SCOPES), try:
'response_mode': 'query' refresh_token = token_data.get("refresh_token")
if not refresh_token:
logger.error("No refresh_token found in JSON file - cannot refresh")
logger.error("You must re-authenticate interactively to obtain a valid token")
return False
# Use only RESOURCE_SCOPES when refreshing (no reserved scopes)
refresh_scopes = [s for s in self.RESOURCE_SCOPES if s not in self.RESERVED_SCOPES]
logger.debug(f"Using refresh token; refresh scopes = {refresh_scopes}")
result = self.app.acquire_token_by_refresh_token(
refresh_token=refresh_token,
scopes=refresh_scopes,
)
if result and "access_token" in result:
logger.debug("Successfully refreshed token via legacy JSON path")
await self.save_cache()
accounts = self.app.get_accounts()
logger.debug(f"After refresh, found {len(accounts)} accounts")
if accounts:
self._current_account = accounts[0]
logger.debug(f"Set current account after refresh: {self._current_account.get('username', 'no username')}")
return True
# Error handling
err = (result or {}).get("error_description") or (result or {}).get("error") or "Unknown error"
logger.error(f"Refresh token failed: {err}")
if any(code in err for code in ("AADSTS70000", "invalid_grant", "interaction_required")):
logger.warning(
"Refresh denied due to unauthorized/expired scopes or invalid grant. "
"Delete the token file and perform interactive sign-in with correct scopes."
)
return False
except Exception as e:
logger.error(f"Exception during refresh from JSON token: {e}")
import traceback
traceback.print_exc()
return False
async def save_cache(self):
"""Persist the token cache to file."""
try:
# Ensure parent directory exists
parent = os.path.dirname(os.path.abspath(self.token_file))
if parent and not os.path.exists(parent):
os.makedirs(parent, exist_ok=True)
cache_data = self.token_cache.serialize()
if cache_data:
async with aiofiles.open(self.token_file, "w") as f:
await f.write(cache_data)
logger.debug(f"Token cache saved to {self.token_file}")
except Exception as e:
logger.error(f"Failed to save token cache: {e}")
def create_authorization_url(self, redirect_uri: str, state: Optional[str] = None) -> str:
"""Create authorization URL for OAuth flow."""
# Store redirect URI for later use in callback
self._redirect_uri = redirect_uri
kwargs: Dict[str, Any] = {
# IMPORTANT: interactive auth includes offline_access
"scopes": self.AUTH_SCOPES,
"redirect_uri": redirect_uri,
"prompt": "consent", # ensure refresh token on first run
} }
if state:
kwargs["state"] = state # Optional CSRF protection
auth_url = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize" auth_url = self.app.get_authorization_request_url(**kwargs)
return f"{auth_url}?{urlencode(params)}"
logger.debug(f"Generated auth URL: {auth_url}")
logger.debug(f"Auth scopes: {self.AUTH_SCOPES}")
return auth_url
async def handle_authorization_callback( async def handle_authorization_callback(
self, authorization_code: str, redirect_uri: str self, authorization_code: str, redirect_uri: str
) -> bool: ) -> bool:
"""Handle OAuth callback and exchange code for tokens""" """Handle OAuth callback and exchange code for tokens."""
data = {
'client_id': self.client_id,
'client_secret': self.client_secret,
'code': authorization_code,
'grant_type': 'authorization_code',
'redirect_uri': redirect_uri,
'scope': ' '.join(self.SCOPES)
}
async with httpx.AsyncClient() as client:
try: try:
response = await client.post(self.TOKEN_ENDPOINT, data=data) # For code exchange, we pass the same auth scopes as used in the authorize step
response.raise_for_status() result = self.app.acquire_token_by_authorization_code(
token_data = response.json() authorization_code,
scopes=self.AUTH_SCOPES,
redirect_uri=redirect_uri,
)
# Store tokens in our format if result and "access_token" in result:
import datetime as dt # Store the account for future use
expires_in = token_data.get('expires_in', 3600) accounts = self.app.get_accounts()
expiry = datetime.now() + dt.timedelta(seconds=expires_in) if accounts:
self._current_account = accounts[0]
self._tokens = { await self.save_cache()
'token': token_data['access_token'], logger.info("SharePoint OAuth authorization successful")
'refresh_token': token_data['refresh_token'],
'scopes': self.SCOPES,
'expiry': expiry.isoformat()
}
await self._save_tokens()
print("Authorization successful, tokens saved")
return True return True
error_msg = (result or {}).get("error_description") or (result or {}).get("error") or "Unknown error"
logger.error(f"SharePoint OAuth authorization failed: {error_msg}")
return False
except Exception as e: except Exception as e:
print(f"Authorization failed: {e}") logger.error(f"Exception during SharePoint OAuth authorization: {e}")
return False return False
async def is_authenticated(self) -> bool: async def is_authenticated(self) -> bool:
"""Check if we have valid credentials""" """Check if we have valid credentials (simplified like Google Drive)."""
if not self._tokens: try:
return False # First try to load credentials if we haven't already
if not self._current_account:
await self.load_credentials()
# If token is expired, try to refresh # If we have an account, try to get a token (MSAL will refresh if needed)
if self._is_token_expired(): if self._current_account:
print("Token expired, attempting refresh...") # IMPORTANT: use RESOURCE_SCOPES here
if await self._refresh_access_token(): result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=self._current_account)
if result and "access_token" in result:
return True return True
else: else:
return False error_msg = (result or {}).get("error") or "No result returned"
logger.debug(f"Token acquisition failed for current account: {error_msg}")
# Fallback: try without specific account
result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=None)
if result and "access_token" in result:
# Update current account if this worked
accounts = self.app.get_accounts()
if accounts:
self._current_account = accounts[0]
return True return True
return False
except Exception as e:
logger.error(f"Authentication check failed: {e}")
return False
def get_access_token(self) -> str: def get_access_token(self) -> str:
"""Get current access token""" """Get an access token for Microsoft Graph (simplified like Google Drive)."""
if not self._tokens or 'token' not in self._tokens: try:
raise ValueError("No access token available") # Try with current account first
if self._current_account:
result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=self._current_account)
if result and "access_token" in result:
return result["access_token"]
if self._is_token_expired(): # Fallback: try without specific account
raise ValueError("Access token expired and refresh failed") result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=None)
if result and "access_token" in result:
return result["access_token"]
return self._tokens['token'] # If we get here, authentication has failed
error_msg = (result or {}).get("error_description") or (result or {}).get("error") or "No valid authentication"
raise ValueError(f"Failed to acquire access token: {error_msg}")
except Exception as e:
logger.error(f"Failed to get access token: {e}")
raise
async def revoke_credentials(self): async def revoke_credentials(self):
"""Clear tokens""" """Clear token cache and remove token file (like Google Drive)."""
self._tokens = None try:
# Clear in-memory state
self._current_account = None
self.token_cache = msal.SerializableTokenCache()
# Recreate MSAL app with fresh cache
self.app = msal.ConfidentialClientApplication(
client_id=self.client_id,
client_credential=self.client_secret,
authority=self.authority,
token_cache=self.token_cache,
)
# Remove token file
if os.path.exists(self.token_file): if os.path.exists(self.token_file):
os.remove(self.token_file) os.remove(self.token_file)
logger.info(f"Removed SharePoint token file: {self.token_file}")
except Exception as e:
logger.error(f"Failed to revoke SharePoint credentials: {e}")
def get_service(self) -> str:
"""Return an access token (Graph doesn't need a generated client like Google Drive)."""
return self.get_access_token()