Finally fix MSAL in onedrive/sharepoint

This commit is contained in:
Eric Hare 2025-09-25 13:12:27 -07:00
parent 5e14c7f100
commit f03889a2b3
No known key found for this signature in database
GPG key ID: A73DF73724270AB7
6 changed files with 1649 additions and 750 deletions

View file

@ -132,7 +132,10 @@ async def connector_status(request: Request, connector_service, session_manager)
for connection in connections:
try:
connector = await connector_service._get_connector(connection.connection_id)
connection_client_ids[connection.connection_id] = connector.get_client_id()
if connector is not None:
connection_client_ids[connection.connection_id] = connector.get_client_id()
else:
connection_client_ids[connection.connection_id] = None
except Exception as e:
logger.warning(
"Could not get connector for connection",
@ -338,8 +341,8 @@ async def connector_webhook(request: Request, connector_service, session_manager
)
async def connector_token(request: Request, connector_service, session_manager):
"""Get access token for connector API calls (e.g., Google Picker)"""
connector_type = request.path_params.get("connector_type")
"""Get access token for connector API calls (e.g., Pickers)."""
url_connector_type = request.path_params.get("connector_type")
connection_id = request.query_params.get("connection_id")
if not connection_id:
@ -348,37 +351,81 @@ async def connector_token(request: Request, connector_service, session_manager):
user = request.state.user
try:
# Get the connection and verify it belongs to the user
# 1) Load the connection and verify ownership
connection = await connector_service.connection_manager.get_connection(connection_id)
if not connection or connection.user_id != user.user_id:
return JSONResponse({"error": "Connection not found"}, status_code=404)
# Get the connector instance
# 2) Get the ACTUAL connector instance/type for this connection_id
connector = await connector_service._get_connector(connection_id)
if not connector:
return JSONResponse({"error": f"Connector not available - authentication may have failed for {connector_type}"}, status_code=404)
return JSONResponse(
{"error": f"Connector not available - authentication may have failed for {url_connector_type}"},
status_code=404,
)
# For Google Drive, get the access token
if connector_type == "google_drive" and hasattr(connector, 'oauth'):
real_type = getattr(connector, "type", None) or getattr(connection, "connector_type", None)
if real_type is None:
return JSONResponse({"error": "Unable to determine connector type"}, status_code=500)
# Optional: warn if URL path type disagrees with real type
if url_connector_type and url_connector_type != real_type:
# You can downgrade this to debug if you expect cross-routing.
return JSONResponse(
{
"error": "Connector type mismatch",
"detail": {
"requested_type": url_connector_type,
"actual_type": real_type,
"hint": "Call the token endpoint using the correct connector_type for this connection_id.",
},
},
status_code=400,
)
# 3) Branch by the actual connector type
# GOOGLE DRIVE (google-auth)
if real_type == "google_drive" and hasattr(connector, "oauth"):
await connector.oauth.load_credentials()
if connector.oauth.creds and connector.oauth.creds.valid:
return JSONResponse({
"access_token": connector.oauth.creds.token,
"expires_in": (connector.oauth.creds.expiry.timestamp() -
__import__('time').time()) if connector.oauth.creds.expiry else None
})
else:
return JSONResponse({"error": "Invalid or expired credentials"}, status_code=401)
# For OneDrive and SharePoint, get the access token
elif connector_type in ["onedrive", "sharepoint"] and hasattr(connector, 'oauth'):
expires_in = None
try:
if connector.oauth.creds.expiry:
import time
expires_in = max(0, int(connector.oauth.creds.expiry.timestamp() - time.time()))
except Exception:
expires_in = None
return JSONResponse(
{
"access_token": connector.oauth.creds.token,
"expires_in": expires_in,
}
)
return JSONResponse({"error": "Invalid or expired credentials"}, status_code=401)
# ONEDRIVE / SHAREPOINT (MSAL or custom)
if real_type in ("onedrive", "sharepoint") and hasattr(connector, "oauth"):
# Ensure cache/credentials are loaded before trying to use them
try:
# Prefer a dedicated is_authenticated() that loads cache internally
if hasattr(connector.oauth, "is_authenticated"):
ok = await connector.oauth.is_authenticated()
else:
# Fallback: try to load credentials explicitly if available
ok = True
if hasattr(connector.oauth, "load_credentials"):
ok = await connector.oauth.load_credentials()
if not ok:
return JSONResponse({"error": "Not authenticated"}, status_code=401)
# Now safe to fetch access token
access_token = connector.oauth.get_access_token()
return JSONResponse({
"access_token": access_token,
"expires_in": None # MSAL handles token expiry internally
})
# MSAL result has expiry, but were returning a raw token; keep expires_in None for simplicity
return JSONResponse({"access_token": access_token, "expires_in": None})
except ValueError as e:
# Typical when acquire_token_silent fails (e.g., needs re-auth)
return JSONResponse({"error": f"Failed to get access token: {str(e)}"}, status_code=401)
except Exception as e:
return JSONResponse({"error": f"Authentication error: {str(e)}"}, status_code=500)
@ -386,7 +433,5 @@ async def connector_token(request: Request, connector_service, session_manager):
return JSONResponse({"error": "Token not available for this connector type"}, status_code=400)
except Exception as e:
logger.error("Error getting connector token", error=str(e))
logger.error("Error getting connector token", exc_info=True)
return JSONResponse({"error": str(e)}, status_code=500)

View file

@ -294,32 +294,39 @@ class ConnectionManager:
async def get_connector(self, connection_id: str) -> Optional[BaseConnector]:
"""Get an active connector instance"""
logger.debug(f"Getting connector for connection_id: {connection_id}")
# Return cached connector if available
if connection_id in self.active_connectors:
connector = self.active_connectors[connection_id]
if connector.is_authenticated:
logger.debug(f"Returning cached authenticated connector for {connection_id}")
return connector
else:
# Remove unauthenticated connector from cache
logger.debug(f"Removing unauthenticated connector from cache for {connection_id}")
del self.active_connectors[connection_id]
# Try to create and authenticate connector
connection_config = self.connections.get(connection_id)
if not connection_config or not connection_config.is_active:
logger.debug(f"No active connection config found for {connection_id}")
return None
logger.debug(f"Creating connector for {connection_config.connector_type}")
connector = self._create_connector(connection_config)
if await connector.authenticate():
logger.debug(f"Attempting authentication for {connection_id}")
auth_result = await connector.authenticate()
logger.debug(f"Authentication result for {connection_id}: {auth_result}")
if auth_result:
self.active_connectors[connection_id] = connector
# Setup webhook subscription if not already set up
await self._setup_webhook_if_needed(
connection_id, connection_config, connector
)
# ... rest of the method
return connector
return None
else:
logger.warning(f"Authentication failed for {connection_id}")
return None
def get_available_connector_types(self) -> Dict[str, Dict[str, Any]]:
"""Get available connector types with their metadata"""
@ -363,20 +370,23 @@ class ConnectionManager:
def _create_connector(self, config: ConnectionConfig) -> BaseConnector:
"""Factory method to create connector instances"""
if config.connector_type == "google_drive":
return GoogleDriveConnector(config.config)
elif config.connector_type == "sharepoint":
return SharePointConnector(config.config)
elif config.connector_type == "onedrive":
return OneDriveConnector(config.config)
elif config.connector_type == "box":
# Future: BoxConnector(config.config)
raise NotImplementedError("Box connector not implemented yet")
elif config.connector_type == "dropbox":
# Future: DropboxConnector(config.config)
raise NotImplementedError("Dropbox connector not implemented yet")
else:
raise ValueError(f"Unknown connector type: {config.connector_type}")
try:
if config.connector_type == "google_drive":
return GoogleDriveConnector(config.config)
elif config.connector_type == "sharepoint":
return SharePointConnector(config.config)
elif config.connector_type == "onedrive":
return OneDriveConnector(config.config)
elif config.connector_type == "box":
raise NotImplementedError("Box connector not implemented yet")
elif config.connector_type == "dropbox":
raise NotImplementedError("Dropbox connector not implemented yet")
else:
raise ValueError(f"Unknown connector type: {config.connector_type}")
except Exception as e:
logger.error(f"Failed to create {config.connector_type} connector: {e}")
# Re-raise the exception so caller can handle appropriately
raise
async def update_last_sync(self, connection_id: str):
"""Update the last sync timestamp for a connection"""

View file

@ -1,235 +1,487 @@
import logging
from pathlib import Path
from typing import List, Dict, Any, Optional
from datetime import datetime
import httpx
import uuid
from datetime import datetime, timedelta
from typing import Dict, List, Any, Optional
from ..base import BaseConnector, ConnectorDocument, DocumentACL
from .oauth import OneDriveOAuth
logger = logging.getLogger(__name__)
class OneDriveConnector(BaseConnector):
"""OneDrive connector using Microsoft Graph API"""
"""OneDrive connector using MSAL-based OAuth for authentication."""
# Required BaseConnector class attributes
CLIENT_ID_ENV_VAR = "MICROSOFT_GRAPH_OAUTH_CLIENT_ID"
CLIENT_SECRET_ENV_VAR = "MICROSOFT_GRAPH_OAUTH_CLIENT_SECRET"
# Connector metadata
CONNECTOR_NAME = "OneDrive"
CONNECTOR_DESCRIPTION = "Connect your personal OneDrive to sync documents"
CONNECTOR_DESCRIPTION = "Connect to OneDrive (personal) to sync documents and files"
CONNECTOR_ICON = "onedrive"
def __init__(self, config: Dict[str, Any]):
super().__init__(config)
logger.debug(f"OneDrive connector __init__ called with config type: {type(config)}")
logger.debug(f"OneDrive connector __init__ config value: {config}")
if config is None:
logger.debug("Config was None, using empty dict")
config = {}
try:
logger.debug("Calling super().__init__")
super().__init__(config)
logger.debug("super().__init__ completed successfully")
except Exception as e:
logger.error(f"super().__init__ failed: {e}")
raise
# Initialize with defaults that allow the connector to be listed
self.client_id = None
self.client_secret = None
self.redirect_uri = config.get("redirect_uri", "http://localhost") # must match your app registration
# Try to get credentials, but don't fail if they're missing
try:
self.client_id = self.get_client_id()
logger.debug(f"Got client_id: {self.client_id is not None}")
except Exception as e:
logger.debug(f"Failed to get client_id: {e}")
try:
self.client_secret = self.get_client_secret()
logger.debug(f"Got client_secret: {self.client_secret is not None}")
except Exception as e:
logger.debug(f"Failed to get client_secret: {e}")
# Token file setup
project_root = Path(__file__).resolve().parent.parent.parent.parent
token_file = config.get("token_file") or str(project_root / "onedrive_token.json")
self.oauth = OneDriveOAuth(
client_id=self.get_client_id(),
client_secret=self.get_client_secret(),
token_file=token_file,
)
self.subscription_id = config.get("subscription_id") or config.get(
"webhook_channel_id"
)
self.base_url = "https://graph.microsoft.com/v1.0"
Path(token_file).parent.mkdir(parents=True, exist_ok=True)
async def authenticate(self) -> bool:
if await self.oauth.is_authenticated():
self._authenticated = True
return True
return False
# Only initialize OAuth if we have credentials
if self.client_id and self.client_secret:
connection_id = config.get("connection_id", "default")
async def setup_subscription(self) -> str:
if not self._authenticated:
raise ValueError("Not authenticated")
# Use token_file from config if provided, otherwise generate one
if config.get("token_file"):
oauth_token_file = config["token_file"]
else:
# Use a per-connection cache file to avoid collisions with other connectors
oauth_token_file = f"onedrive_token_{connection_id}.json"
webhook_url = self.config.get("webhook_url")
if not webhook_url:
raise ValueError("webhook_url required in config for subscriptions")
# MSA & org both work via /common for OneDrive personal testing
authority = "https://login.microsoftonline.com/common"
expiration = (datetime.utcnow() + timedelta(days=2)).isoformat() + "Z"
body = {
"changeType": "created,updated,deleted",
"notificationUrl": webhook_url,
"resource": "/me/drive/root",
"expirationDateTime": expiration,
"clientState": str(uuid.uuid4()),
self.oauth = OneDriveOAuth(
client_id=self.client_id,
client_secret=self.client_secret,
token_file=oauth_token_file,
authority=authority,
allow_json_refresh=True, # allows one-time migration from legacy JSON if present
)
else:
self.oauth = None
# Track subscription ID for webhooks (note: change notifications might not be available for personal accounts)
self._subscription_id: Optional[str] = None
# Graph API defaults
self._graph_api_version = "v1.0"
self._default_params = {
"$select": "id,name,size,lastModifiedDateTime,createdDateTime,webUrl,file,folder,@microsoft.graph.downloadUrl"
}
token = self.oauth.get_access_token()
async with httpx.AsyncClient() as client:
resp = await client.post(
f"{self.base_url}/subscriptions",
json=body,
headers={"Authorization": f"Bearer {token}"},
)
resp.raise_for_status()
data = resp.json()
@property
def _graph_base_url(self) -> str:
"""Base URL for Microsoft Graph API calls."""
return f"https://graph.microsoft.com/{self._graph_api_version}"
self.subscription_id = data["id"]
return self.subscription_id
def emit(self, doc: ConnectorDocument) -> None:
"""Emit a ConnectorDocument instance (integrate with your pipeline here)."""
logger.debug(f"Emitting OneDrive document: {doc.id} ({doc.filename})")
async def list_files(
self, page_token: Optional[str] = None, limit: int = 100
) -> Dict[str, Any]:
if not self._authenticated:
raise ValueError("Not authenticated")
async def authenticate(self) -> bool:
"""Test authentication - BaseConnector interface."""
logger.debug(f"OneDrive authenticate() called, oauth is None: {self.oauth is None}")
try:
if not self.oauth:
logger.debug("OneDrive authentication failed: OAuth not initialized")
self._authenticated = False
return False
params = {"$top": str(limit)}
if page_token:
params["$skiptoken"] = page_token
logger.debug("Loading OneDrive credentials...")
load_result = await self.oauth.load_credentials()
logger.debug(f"Load credentials result: {load_result}")
token = self.oauth.get_access_token()
async with httpx.AsyncClient() as client:
resp = await client.get(
f"{self.base_url}/me/drive/root/children",
params=params,
headers={"Authorization": f"Bearer {token}"},
)
resp.raise_for_status()
data = resp.json()
logger.debug("Checking OneDrive authentication status...")
authenticated = await self.oauth.is_authenticated()
logger.debug(f"OneDrive is_authenticated result: {authenticated}")
files = []
for item in data.get("value", []):
if item.get("file"):
files.append(
{
"id": item["id"],
"name": item["name"],
"mimeType": item.get("file", {}).get(
"mimeType", "application/octet-stream"
),
"webViewLink": item.get("webUrl"),
"createdTime": item.get("createdDateTime"),
"modifiedTime": item.get("lastModifiedDateTime"),
}
)
self._authenticated = authenticated
return authenticated
except Exception as e:
logger.error(f"OneDrive authentication failed: {e}")
import traceback
traceback.print_exc()
self._authenticated = False
return False
next_token = None
next_link = data.get("@odata.nextLink")
if next_link:
from urllib.parse import urlparse, parse_qs
def get_auth_url(self) -> str:
"""Get OAuth authorization URL."""
if not self.oauth:
raise RuntimeError("OneDrive OAuth not initialized - missing credentials")
return self.oauth.create_authorization_url(self.redirect_uri)
parsed = urlparse(next_link)
next_token = parse_qs(parsed.query).get("$skiptoken", [None])[0]
async def handle_oauth_callback(self, auth_code: str) -> Dict[str, Any]:
"""Handle OAuth callback."""
if not self.oauth:
raise RuntimeError("OneDrive OAuth not initialized - missing credentials")
try:
success = await self.oauth.handle_authorization_callback(auth_code, self.redirect_uri)
if success:
self._authenticated = True
return {"status": "success"}
else:
raise ValueError("OAuth callback failed")
except Exception as e:
logger.error(f"OAuth callback failed: {e}")
raise
return {"files": files, "nextPageToken": next_token}
def sync_once(self) -> None:
"""
Perform a one-shot sync of OneDrive files and emit documents.
"""
import asyncio
async def _async_sync():
try:
file_list = await self.list_files(max_files=1000)
files = file_list.get("files", [])
for file_info in files:
try:
file_id = file_info.get("id")
if not file_id:
continue
doc = await self.get_file_content(file_id)
self.emit(doc)
except Exception as e:
logger.error(f"Failed to sync OneDrive file {file_info.get('name', 'unknown')}: {e}")
continue
except Exception as e:
logger.error(f"OneDrive sync_once failed: {e}")
raise
if hasattr(asyncio, 'run'):
asyncio.run(_async_sync())
else:
loop = asyncio.get_event_loop()
loop.run_until_complete(_async_sync())
async def setup_subscription(self) -> str:
"""
Set up real-time subscription for file changes.
NOTE: Change notifications may not be available for personal OneDrive accounts.
"""
webhook_url = self.config.get('webhook_url')
if not webhook_url:
logger.warning("No webhook URL configured, skipping OneDrive subscription setup")
return "no-webhook-configured"
try:
if not await self.authenticate():
raise RuntimeError("OneDrive authentication failed during subscription setup")
token = self.oauth.get_access_token()
# For OneDrive personal we target the user's drive
resource = "/me/drive/root"
subscription_data = {
"changeType": "created,updated,deleted",
"notificationUrl": f"{webhook_url}/webhook/onedrive",
"resource": resource,
"expirationDateTime": self._get_subscription_expiry(),
"clientState": "onedrive_personal",
}
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
}
url = f"{self._graph_base_url}/subscriptions"
async with httpx.AsyncClient() as client:
response = await client.post(url, json=subscription_data, headers=headers, timeout=30)
response.raise_for_status()
result = response.json()
subscription_id = result.get("id")
if subscription_id:
self._subscription_id = subscription_id
logger.info(f"OneDrive subscription created: {subscription_id}")
return subscription_id
else:
raise ValueError("No subscription ID returned from Microsoft Graph")
except Exception as e:
logger.error(f"Failed to setup OneDrive subscription: {e}")
raise
def _get_subscription_expiry(self) -> str:
"""Get subscription expiry time (Graph caps duration; often <= 3 days)."""
from datetime import datetime, timedelta
expiry = datetime.utcnow() + timedelta(days=3)
return expiry.strftime("%Y-%m-%dT%H:%M:%S.%fZ")
async def list_files(self, page_token: Optional[str] = None, max_files: Optional[int] = None) -> Dict[str, Any]:
"""List files from OneDrive using Microsoft Graph."""
try:
if not await self.authenticate():
raise RuntimeError("OneDrive authentication failed during file listing")
files: List[Dict[str, Any]] = []
max_files_value = max_files if max_files is not None else 100
base_url = f"{self._graph_base_url}/me/drive/root/children"
params = dict(self._default_params)
params["$top"] = max_files_value
if page_token:
params["$skiptoken"] = page_token
response = await self._make_graph_request(base_url, params=params)
data = response.json()
items = data.get("value", [])
for item in items:
if item.get("file"): # include files only
files.append({
"id": item.get("id", ""),
"name": item.get("name", ""),
"path": f"/drive/items/{item.get('id')}",
"size": int(item.get("size", 0)),
"modified": item.get("lastModifiedDateTime"),
"created": item.get("createdDateTime"),
"mime_type": item.get("file", {}).get("mimeType", self._get_mime_type(item.get("name", ""))),
"url": item.get("webUrl", ""),
"download_url": item.get("@microsoft.graph.downloadUrl"),
})
# Next page
next_page_token = None
next_link = data.get("@odata.nextLink")
if next_link:
from urllib.parse import urlparse, parse_qs
parsed = urlparse(next_link)
query_params = parse_qs(parsed.query)
if "$skiptoken" in query_params:
next_page_token = query_params["$skiptoken"][0]
return {"files": files, "next_page_token": next_page_token}
except Exception as e:
logger.error(f"Failed to list OneDrive files: {e}")
return {"files": [], "next_page_token": None}
async def get_file_content(self, file_id: str) -> ConnectorDocument:
if not self._authenticated:
raise ValueError("Not authenticated")
"""Get file content and metadata."""
try:
if not await self.authenticate():
raise RuntimeError("OneDrive authentication failed during file content retrieval")
token = self.oauth.get_access_token()
headers = {"Authorization": f"Bearer {token}"}
async with httpx.AsyncClient() as client:
meta_resp = await client.get(
f"{self.base_url}/me/drive/items/{file_id}", headers=headers
)
meta_resp.raise_for_status()
metadata = meta_resp.json()
file_metadata = await self._get_file_metadata_by_id(file_id)
if not file_metadata:
raise ValueError(f"File not found: {file_id}")
content_resp = await client.get(
f"{self.base_url}/me/drive/items/{file_id}/content", headers=headers
)
content = content_resp.content
# Handle the possibility of this being a redirect
if content_resp.status_code in (301, 302, 303, 307, 308):
redirect_url = content_resp.headers.get("Location")
if redirect_url:
content_resp = await client.get(redirect_url)
content_resp.raise_for_status()
content = content_resp.content
download_url = file_metadata.get("download_url")
if download_url:
content = await self._download_file_from_url(download_url)
else:
content_resp.raise_for_status()
content = await self._download_file_content(file_id)
perm_resp = await client.get(
f"{self.base_url}/me/drive/items/{file_id}/permissions", headers=headers
acl = DocumentACL(
owner="",
user_permissions={},
group_permissions={},
)
perm_resp.raise_for_status()
permissions = perm_resp.json()
acl = self._parse_permissions(metadata, permissions)
modified = datetime.fromisoformat(
metadata["lastModifiedDateTime"].replace("Z", "+00:00")
).replace(tzinfo=None)
created = datetime.fromisoformat(
metadata["createdDateTime"].replace("Z", "+00:00")
).replace(tzinfo=None)
modified_time = self._parse_graph_date(file_metadata.get("modified"))
created_time = self._parse_graph_date(file_metadata.get("created"))
document = ConnectorDocument(
id=metadata["id"],
filename=metadata["name"],
mimetype=metadata.get("file", {}).get(
"mimeType", "application/octet-stream"
),
content=content,
source_url=metadata.get("webUrl"),
acl=acl,
modified_time=modified,
created_time=created,
metadata={"size": metadata.get("size")},
)
return document
def _parse_permissions(
self, metadata: Dict[str, Any], permissions: Dict[str, Any]
) -> DocumentACL:
acl = DocumentACL()
owner = metadata.get("createdBy", {}).get("user", {}).get("email")
if owner:
acl.owner = owner
for perm in permissions.get("value", []):
role = perm.get("roles", ["read"])[0]
grantee = perm.get("grantedToV2") or perm.get("grantedTo")
if not grantee:
continue
user = grantee.get("user")
if user and user.get("email"):
acl.user_permissions[user["email"]] = role
group = grantee.get("group")
if group and group.get("email"):
acl.group_permissions[group["email"]] = role
return acl
def handle_webhook_validation(
self, request_method: str, headers: Dict[str, str], query_params: Dict[str, str]
) -> Optional[str]:
"""Handle Microsoft Graph webhook validation"""
if request_method == "GET":
validation_token = query_params.get("validationtoken") or query_params.get(
"validationToken"
return ConnectorDocument(
id=file_id,
filename=file_metadata.get("name", ""),
mimetype=file_metadata.get("mime_type", "application/octet-stream"),
content=content,
source_url=file_metadata.get("url", ""),
acl=acl,
modified_time=modified_time,
created_time=created_time,
metadata={
"onedrive_path": file_metadata.get("path", ""),
"size": file_metadata.get("size", 0),
},
)
if validation_token:
return validation_token
except Exception as e:
logger.error(f"Failed to get OneDrive file content {file_id}: {e}")
raise
async def _get_file_metadata_by_id(self, file_id: str) -> Optional[Dict[str, Any]]:
"""Get file metadata by ID using Graph API."""
try:
url = f"{self._graph_base_url}/me/drive/items/{file_id}"
params = dict(self._default_params)
response = await self._make_graph_request(url, params=params)
item = response.json()
if item.get("file"):
return {
"id": file_id,
"name": item.get("name", ""),
"path": f"/drive/items/{file_id}",
"size": int(item.get("size", 0)),
"modified": item.get("lastModifiedDateTime"),
"created": item.get("createdDateTime"),
"mime_type": item.get("file", {}).get("mimeType", self._get_mime_type(item.get("name", ""))),
"url": item.get("webUrl", ""),
"download_url": item.get("@microsoft.graph.downloadUrl"),
}
return None
except Exception as e:
logger.error(f"Failed to get file metadata for {file_id}: {e}")
return None
async def _download_file_content(self, file_id: str) -> bytes:
"""Download file content by file ID using Graph API."""
try:
url = f"{self._graph_base_url}/me/drive/items/{file_id}/content"
token = self.oauth.get_access_token()
headers = {"Authorization": f"Bearer {token}"}
async with httpx.AsyncClient() as client:
response = await client.get(url, headers=headers, timeout=60)
response.raise_for_status()
return response.content
except Exception as e:
logger.error(f"Failed to download file content for {file_id}: {e}")
raise
async def _download_file_from_url(self, download_url: str) -> bytes:
"""Download file content from direct download URL."""
try:
async with httpx.AsyncClient() as client:
response = await client.get(download_url, timeout=60)
response.raise_for_status()
return response.content
except Exception as e:
logger.error(f"Failed to download from URL {download_url}: {e}")
raise
def _parse_graph_date(self, date_str: Optional[str]) -> datetime:
"""Parse Microsoft Graph date string to datetime."""
if not date_str:
return datetime.now()
try:
if date_str.endswith('Z'):
return datetime.fromisoformat(date_str[:-1]).replace(tzinfo=None)
else:
return datetime.fromisoformat(date_str.replace('T', ' '))
except (ValueError, AttributeError):
return datetime.now()
async def _make_graph_request(self, url: str, method: str = "GET",
data: Optional[Dict] = None, params: Optional[Dict] = None) -> httpx.Response:
"""Make authenticated API request to Microsoft Graph."""
token = self.oauth.get_access_token()
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
}
async with httpx.AsyncClient() as client:
if method.upper() == "GET":
response = await client.get(url, headers=headers, params=params, timeout=30)
elif method.upper() == "POST":
response = await client.post(url, headers=headers, json=data, timeout=30)
elif method.upper() == "DELETE":
response = await client.delete(url, headers=headers, timeout=30)
else:
raise ValueError(f"Unsupported HTTP method: {method}")
response.raise_for_status()
return response
def _get_mime_type(self, filename: str) -> str:
"""Get MIME type based on file extension."""
import mimetypes
mime_type, _ = mimetypes.guess_type(filename)
return mime_type or "application/octet-stream"
# Webhook methods - BaseConnector interface
def handle_webhook_validation(self, request_method: str,
headers: Dict[str, str],
query_params: Dict[str, str]) -> Optional[str]:
"""Handle webhook validation (Graph API specific)."""
if request_method == "POST" and "validationToken" in query_params:
return query_params["validationToken"]
return None
def extract_webhook_channel_id(
self, payload: Dict[str, Any], headers: Dict[str, str]
) -> Optional[str]:
"""Extract SharePoint subscription ID from webhook payload"""
values = payload.get("value", [])
return values[0].get("subscriptionId") if values else None
def extract_webhook_channel_id(self, payload: Dict[str, Any],
headers: Dict[str, str]) -> Optional[str]:
"""Extract channel/subscription ID from webhook payload."""
notifications = payload.get("value", [])
if notifications:
return notifications[0].get("subscriptionId")
return None
async def handle_webhook(self, payload: Dict[str, Any]) -> List[str]:
values = payload.get("value", [])
file_ids = []
for item in values:
resource_data = item.get("resourceData", {})
file_id = resource_data.get("id")
if file_id:
file_ids.append(file_id)
return file_ids
"""Handle webhook notification and return affected file IDs."""
affected_files: List[str] = []
notifications = payload.get("value", [])
for notification in notifications:
resource = notification.get("resource")
if resource and "/drive/items/" in resource:
file_id = resource.split("/drive/items/")[-1]
affected_files.append(file_id)
return affected_files
async def cleanup_subscription(
self, subscription_id: str, resource_id: str = None
) -> bool:
if not self._authenticated:
async def cleanup_subscription(self, subscription_id: str) -> bool:
"""Clean up subscription - BaseConnector interface."""
if subscription_id == "no-webhook-configured":
logger.info("No subscription to cleanup (webhook was not configured)")
return True
try:
if not await self.authenticate():
logger.error("OneDrive authentication failed during subscription cleanup")
return False
token = self.oauth.get_access_token()
headers = {"Authorization": f"Bearer {token}"}
url = f"{self._graph_base_url}/subscriptions/{subscription_id}"
async with httpx.AsyncClient() as client:
response = await client.delete(url, headers=headers, timeout=30)
if response.status_code in [200, 204, 404]:
logger.info(f"OneDrive subscription {subscription_id} cleaned up successfully")
return True
else:
logger.warning(f"Unexpected response cleaning up subscription: {response.status_code}")
return False
except Exception as e:
logger.error(f"Failed to cleanup OneDrive subscription {subscription_id}: {e}")
return False
token = self.oauth.get_access_token()
async with httpx.AsyncClient() as client:
resp = await client.delete(
f"{self.base_url}/subscriptions/{subscription_id}",
headers={"Authorization": f"Bearer {token}"},
)
return resp.status_code in (200, 204)

View file

@ -1,18 +1,28 @@
import os
import json
import logging
from typing import Optional, Dict, Any
import aiofiles
from datetime import datetime
import httpx
import msal
logger = logging.getLogger(__name__)
class OneDriveOAuth:
"""Direct token management for OneDrive, bypassing MSAL cache format"""
"""Handles Microsoft Graph OAuth for OneDrive (personal Microsoft accounts by default)."""
SCOPES = [
"offline_access",
"Files.Read.All",
]
# Reserved scopes that must NOT be sent on token or silent calls
RESERVED_SCOPES = {"openid", "profile", "offline_access"}
# For PERSONAL Microsoft Accounts (OneDrive consumer):
# - Use AUTH_SCOPES for interactive auth (consent + refresh token issuance)
# - Use RESOURCE_SCOPES for acquire_token_silent / refresh paths
AUTH_SCOPES = ["User.Read", "Files.Read.All", "offline_access"]
RESOURCE_SCOPES = ["User.Read", "Files.Read.All"]
SCOPES = AUTH_SCOPES # Backward-compat alias if something references .SCOPES
# Kept for reference; MSAL derives endpoints from `authority`
AUTH_ENDPOINT = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize"
TOKEN_ENDPOINT = "https://login.microsoftonline.com/common/oauth2/v2.0/token"
@ -21,168 +31,292 @@ class OneDriveOAuth:
client_id: str,
client_secret: str,
token_file: str = "onedrive_token.json",
authority: str = "https://login.microsoftonline.com/common",
allow_json_refresh: bool = True,
):
"""
Initialize OneDriveOAuth.
Args:
client_id: Azure AD application (client) ID.
client_secret: Azure AD application client secret.
token_file: Path to persisted token cache file (MSAL cache format).
authority: Usually "https://login.microsoftonline.com/common" for MSA + org,
or tenant-specific for work/school.
allow_json_refresh: If True, permit one-time migration from legacy flat JSON
{"access_token","refresh_token",...}. Otherwise refuse it.
"""
self.client_id = client_id
self.client_secret = client_secret
self.token_file = token_file
self._tokens = None
self._load_tokens()
self.authority = authority
self.allow_json_refresh = allow_json_refresh
self.token_cache = msal.SerializableTokenCache()
self._current_account = None
def _load_tokens(self):
"""Load tokens from file"""
if os.path.exists(self.token_file):
with open(self.token_file, "r") as f:
self._tokens = json.loads(f.read())
print(f"Loaded tokens from {self.token_file}")
else:
print(f"No token file found at {self.token_file}")
# Initialize MSAL Confidential Client
self.app = msal.ConfidentialClientApplication(
client_id=self.client_id,
client_credential=self.client_secret,
authority=self.authority,
token_cache=self.token_cache,
)
async def _save_tokens(self):
"""Save tokens to file"""
if self._tokens:
async with aiofiles.open(self.token_file, "w") as f:
await f.write(json.dumps(self._tokens, indent=2))
def _is_token_expired(self) -> bool:
"""Check if current access token is expired"""
if not self._tokens or 'expiry' not in self._tokens:
return True
expiry_str = self._tokens['expiry']
# Handle different expiry formats
async def load_credentials(self) -> bool:
"""Load existing credentials from token file (async)."""
try:
if expiry_str.endswith('Z'):
expiry_dt = datetime.fromisoformat(expiry_str[:-1])
else:
expiry_dt = datetime.fromisoformat(expiry_str)
# Add 5-minute buffer
import datetime as dt
now = datetime.now()
return now >= (expiry_dt - dt.timedelta(minutes=5))
except:
return True
logger.debug(f"OneDrive OAuth loading credentials from: {self.token_file}")
if os.path.exists(self.token_file):
logger.debug(f"Token file exists, reading: {self.token_file}")
# Read the token file
async with aiofiles.open(self.token_file, "r") as f:
cache_data = await f.read()
logger.debug(f"Read {len(cache_data)} chars from token file")
if cache_data.strip():
# 1) Try legacy flat JSON first
try:
json_data = json.loads(cache_data)
if isinstance(json_data, dict) and "refresh_token" in json_data:
if self.allow_json_refresh:
logger.debug(
"Found legacy JSON refresh_token and allow_json_refresh=True; attempting migration refresh"
)
return await self._refresh_from_json_token(json_data)
else:
logger.warning(
"Token file contains a legacy JSON refresh_token, but allow_json_refresh=False. "
"Delete the file and re-auth."
)
return False
except json.JSONDecodeError:
logger.debug("Token file is not flat JSON; attempting MSAL cache format")
# 2) Try MSAL cache format
logger.debug("Attempting MSAL cache deserialization")
self.token_cache.deserialize(cache_data)
# Get accounts from loaded cache
accounts = self.app.get_accounts()
logger.debug(f"Found {len(accounts)} accounts in MSAL cache")
if accounts:
self._current_account = accounts[0]
logger.debug(f"Set current account: {self._current_account.get('username', 'no username')}")
# Use RESOURCE_SCOPES (no reserved scopes) for silent acquisition
result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=self._current_account)
logger.debug(f"Silent token acquisition result keys: {list(result.keys()) if result else 'None'}")
if result and "access_token" in result:
logger.debug("Silent token acquisition successful")
await self.save_cache()
return True
else:
error_msg = (result or {}).get("error") or "No result"
logger.warning(f"Silent token acquisition failed: {error_msg}")
else:
logger.debug(f"Token file {self.token_file} is empty")
else:
logger.debug(f"Token file does not exist: {self.token_file}")
async def _refresh_access_token(self) -> bool:
"""Refresh the access token using refresh token"""
if not self._tokens or 'refresh_token' not in self._tokens:
return False
data = {
'client_id': self.client_id,
'client_secret': self.client_secret,
'refresh_token': self._tokens['refresh_token'],
'grant_type': 'refresh_token',
'scope': ' '.join(self.SCOPES)
}
async with httpx.AsyncClient() as client:
try:
response = await client.post(self.TOKEN_ENDPOINT, data=data)
response.raise_for_status()
token_data = response.json()
# Update tokens
self._tokens['token'] = token_data['access_token']
if 'refresh_token' in token_data:
self._tokens['refresh_token'] = token_data['refresh_token']
# Calculate expiry
expires_in = token_data.get('expires_in', 3600)
import datetime as dt
expiry = datetime.now() + dt.timedelta(seconds=expires_in)
self._tokens['expiry'] = expiry.isoformat()
await self._save_tokens()
print("Access token refreshed successfully")
return True
except Exception as e:
print(f"Failed to refresh token: {e}")
return False
async def is_authenticated(self) -> bool:
"""Check if we have valid credentials"""
if not self._tokens:
except Exception as e:
logger.error(f"Failed to load OneDrive credentials: {e}")
import traceback
traceback.print_exc()
return False
# If token is expired, try to refresh
if self._is_token_expired():
print("Token expired, attempting refresh...")
if await self._refresh_access_token():
return True
else:
async def _refresh_from_json_token(self, token_data: dict) -> bool:
"""
Use refresh token from a legacy JSON file to get new tokens (one-time migration path).
Prefer using an MSAL cache file and acquire_token_silent(); this path is only for migrating older files.
"""
try:
refresh_token = token_data.get("refresh_token")
if not refresh_token:
logger.error("No refresh_token found in JSON file - cannot refresh")
logger.error("You must re-authenticate interactively to obtain a valid token")
return False
return True
def get_access_token(self) -> str:
"""Get current access token"""
if not self._tokens or 'token' not in self._tokens:
raise ValueError("No access token available")
if self._is_token_expired():
raise ValueError("Access token expired and refresh failed")
return self._tokens['token']
# Use only RESOURCE_SCOPES when refreshing (no reserved scopes)
refresh_scopes = [s for s in self.RESOURCE_SCOPES if s not in self.RESERVED_SCOPES]
logger.debug(f"Using refresh token; refresh scopes = {refresh_scopes}")
async def revoke_credentials(self):
"""Clear tokens"""
self._tokens = None
if os.path.exists(self.token_file):
os.remove(self.token_file)
result = self.app.acquire_token_by_refresh_token(
refresh_token=refresh_token,
scopes=refresh_scopes,
)
# Keep these methods for compatibility with your existing OAuth flow
def create_authorization_url(self, redirect_uri: str) -> str:
"""Create authorization URL for OAuth flow"""
from urllib.parse import urlencode
params = {
'client_id': self.client_id,
'response_type': 'code',
'redirect_uri': redirect_uri,
'scope': ' '.join(self.SCOPES),
'response_mode': 'query'
if result and "access_token" in result:
logger.debug("Successfully refreshed token via legacy JSON path")
await self.save_cache()
accounts = self.app.get_accounts()
logger.debug(f"After refresh, found {len(accounts)} accounts")
if accounts:
self._current_account = accounts[0]
logger.debug(f"Set current account after refresh: {self._current_account.get('username', 'no username')}")
return True
# Error handling
err = (result or {}).get("error_description") or (result or {}).get("error") or "Unknown error"
logger.error(f"Refresh token failed: {err}")
if any(code in err for code in ("AADSTS70000", "invalid_grant", "interaction_required")):
logger.warning(
"Refresh denied due to unauthorized/expired scopes or invalid grant. "
"Delete the token file and perform interactive sign-in with correct scopes."
)
return False
except Exception as e:
logger.error(f"Exception during refresh from JSON token: {e}")
import traceback
traceback.print_exc()
return False
async def save_cache(self):
"""Persist the token cache to file."""
try:
# Ensure parent directory exists
parent = os.path.dirname(os.path.abspath(self.token_file))
if parent and not os.path.exists(parent):
os.makedirs(parent, exist_ok=True)
cache_data = self.token_cache.serialize()
if cache_data:
async with aiofiles.open(self.token_file, "w") as f:
await f.write(cache_data)
logger.debug(f"Token cache saved to {self.token_file}")
except Exception as e:
logger.error(f"Failed to save token cache: {e}")
def create_authorization_url(self, redirect_uri: str, state: Optional[str] = None) -> str:
"""Create authorization URL for OAuth flow."""
# Store redirect URI for later use in callback
self._redirect_uri = redirect_uri
kwargs: Dict[str, Any] = {
# Interactive auth includes offline_access
"scopes": self.AUTH_SCOPES,
"redirect_uri": redirect_uri,
"prompt": "consent", # ensure refresh token on first run
}
auth_url = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize"
return f"{auth_url}?{urlencode(params)}"
if state:
kwargs["state"] = state # Optional CSRF protection
auth_url = self.app.get_authorization_request_url(**kwargs)
logger.debug(f"Generated auth URL: {auth_url}")
logger.debug(f"Auth scopes: {self.AUTH_SCOPES}")
return auth_url
async def handle_authorization_callback(
self, authorization_code: str, redirect_uri: str
) -> bool:
"""Handle OAuth callback and exchange code for tokens"""
data = {
'client_id': self.client_id,
'client_secret': self.client_secret,
'code': authorization_code,
'grant_type': 'authorization_code',
'redirect_uri': redirect_uri,
'scope': ' '.join(self.SCOPES)
}
"""Handle OAuth callback and exchange code for tokens."""
try:
result = self.app.acquire_token_by_authorization_code(
authorization_code,
scopes=self.AUTH_SCOPES, # same as authorize step
redirect_uri=redirect_uri,
)
async with httpx.AsyncClient() as client:
try:
response = await client.post(self.TOKEN_ENDPOINT, data=data)
response.raise_for_status()
token_data = response.json()
if result and "access_token" in result:
accounts = self.app.get_accounts()
if accounts:
self._current_account = accounts[0]
# Store tokens in our format
import datetime as dt
expires_in = token_data.get('expires_in', 3600)
expiry = datetime.now() + dt.timedelta(seconds=expires_in)
self._tokens = {
'token': token_data['access_token'],
'refresh_token': token_data['refresh_token'],
'scopes': self.SCOPES,
'expiry': expiry.isoformat()
}
await self._save_tokens()
print("Authorization successful, tokens saved")
await self.save_cache()
logger.info("OneDrive OAuth authorization successful")
return True
except Exception as e:
print(f"Authorization failed: {e}")
return False
error_msg = (result or {}).get("error_description") or (result or {}).get("error") or "Unknown error"
logger.error(f"OneDrive OAuth authorization failed: {error_msg}")
return False
except Exception as e:
logger.error(f"Exception during OneDrive OAuth authorization: {e}")
return False
async def is_authenticated(self) -> bool:
"""Check if we have valid credentials."""
try:
# First try to load credentials if we haven't already
if not self._current_account:
await self.load_credentials()
# Try to get a token (MSAL will refresh if needed)
if self._current_account:
result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=self._current_account)
if result and "access_token" in result:
return True
else:
error_msg = (result or {}).get("error") or "No result returned"
logger.debug(f"Token acquisition failed for current account: {error_msg}")
# Fallback: try without specific account
result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=None)
if result and "access_token" in result:
accounts = self.app.get_accounts()
if accounts:
self._current_account = accounts[0]
return True
return False
except Exception as e:
logger.error(f"Authentication check failed: {e}")
return False
def get_access_token(self) -> str:
"""Get an access token for Microsoft Graph."""
try:
# Try with current account first
if self._current_account:
result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=self._current_account)
if result and "access_token" in result:
return result["access_token"]
# Fallback: try without specific account
result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=None)
if result and "access_token" in result:
return result["access_token"]
# If we get here, authentication has failed
error_msg = (result or {}).get("error_description") or (result or {}).get("error") or "No valid authentication"
raise ValueError(f"Failed to acquire access token: {error_msg}")
except Exception as e:
logger.error(f"Failed to get access token: {e}")
raise
async def revoke_credentials(self):
"""Clear token cache and remove token file."""
try:
# Clear in-memory state
self._current_account = None
self.token_cache = msal.SerializableTokenCache()
# Recreate MSAL app with fresh cache
self.app = msal.ConfidentialClientApplication(
client_id=self.client_id,
client_credential=self.client_secret,
authority=self.authority,
token_cache=self.token_cache,
)
# Remove token file
if os.path.exists(self.token_file):
os.remove(self.token_file)
logger.info(f"Removed OneDrive token file: {self.token_file}")
except Exception as e:
logger.error(f"Failed to revoke OneDrive credentials: {e}")
def get_service(self) -> str:
"""Return an access token (Graph client is just the bearer)."""
return self.get_access_token()

View file

@ -1,241 +1,564 @@
import logging
from pathlib import Path
from typing import List, Dict, Any, Optional
from urllib.parse import urlparse
from datetime import datetime
import httpx
import uuid
from datetime import datetime, timedelta
from typing import Dict, List, Any, Optional
from ..base import BaseConnector, ConnectorDocument, DocumentACL
from .oauth import SharePointOAuth
logger = logging.getLogger(__name__)
class SharePointConnector(BaseConnector):
"""SharePoint Sites connector using Microsoft Graph API"""
"""SharePoint connector using MSAL-based OAuth for authentication"""
# Required BaseConnector class attributes
CLIENT_ID_ENV_VAR = "MICROSOFT_GRAPH_OAUTH_CLIENT_ID"
CLIENT_SECRET_ENV_VAR = "MICROSOFT_GRAPH_OAUTH_CLIENT_SECRET"
# Connector metadata
CONNECTOR_NAME = "SharePoint"
CONNECTOR_DESCRIPTION = "Connect to SharePoint sites to sync team documents"
CONNECTOR_DESCRIPTION = "Connect to SharePoint to sync documents and files"
CONNECTOR_ICON = "sharepoint"
def __init__(self, config: Dict[str, Any]):
super().__init__(config)
super().__init__(config) # Fix: Call parent init first
def __init__(self, config: Dict[str, Any]):
logger.debug(f"SharePoint connector __init__ called with config type: {type(config)}")
logger.debug(f"SharePoint connector __init__ config value: {config}")
# Ensure we always pass a valid config to the base class
if config is None:
logger.debug("Config was None, using empty dict")
config = {}
try:
logger.debug("Calling super().__init__")
super().__init__(config) # Now safe to call with empty dict instead of None
logger.debug("super().__init__ completed successfully")
except Exception as e:
logger.error(f"super().__init__ failed: {e}")
raise
# Initialize with defaults that allow the connector to be listed
self.client_id = None
self.client_secret = None
self.tenant_id = config.get("tenant_id", "common")
self.sharepoint_url = config.get("sharepoint_url")
self.redirect_uri = config.get("redirect_uri", "http://localhost")
# Try to get credentials, but don't fail if they're missing
try:
logger.debug("Attempting to get client_id")
self.client_id = self.get_client_id()
logger.debug(f"Got client_id: {self.client_id is not None}")
except Exception as e:
logger.debug(f"Failed to get client_id: {e}")
pass # Credentials not available, that's OK for listing
try:
logger.debug("Attempting to get client_secret")
self.client_secret = self.get_client_secret()
logger.debug(f"Got client_secret: {self.client_secret is not None}")
except Exception as e:
logger.debug(f"Failed to get client_secret: {e}")
pass # Credentials not available, that's OK for listing
# Token file setup
project_root = Path(__file__).resolve().parent.parent.parent.parent
token_file = config.get("token_file") or str(project_root / "onedrive_token.json")
self.oauth = SharePointOAuth(
client_id=self.get_client_id(),
client_secret=self.get_client_secret(),
token_file=token_file,
)
self.subscription_id = config.get("subscription_id") or config.get(
"webhook_channel_id"
)
self.base_url = "https://graph.microsoft.com/v1.0"
# SharePoint site configuration
self.site_id = config.get("site_id") # Required for SharePoint
async def authenticate(self) -> bool:
if await self.oauth.is_authenticated():
self._authenticated = True
return True
return False
async def setup_subscription(self) -> str:
if not self._authenticated:
raise ValueError("Not authenticated")
webhook_url = self.config.get("webhook_url")
if not webhook_url:
raise ValueError("webhook_url required in config for subscriptions")
expiration = (datetime.utcnow() + timedelta(days=2)).isoformat() + "Z"
body = {
"changeType": "created,updated,deleted",
"notificationUrl": webhook_url,
"resource": f"/sites/{self.site_id}/drive/root",
"expirationDateTime": expiration,
"clientState": str(uuid.uuid4()),
}
token = self.oauth.get_access_token()
async with httpx.AsyncClient() as client:
resp = await client.post(
f"{self.base_url}/subscriptions",
json=body,
headers={"Authorization": f"Bearer {token}"},
)
resp.raise_for_status()
data = resp.json()
self.subscription_id = data["id"]
return self.subscription_id
async def list_files(
self, page_token: Optional[str] = None, limit: int = 100
) -> Dict[str, Any]:
if not self._authenticated:
raise ValueError("Not authenticated")
params = {"$top": str(limit)}
if page_token:
params["$skiptoken"] = page_token
token = self.oauth.get_access_token()
async with httpx.AsyncClient() as client:
resp = await client.get(
f"{self.base_url}/sites/{self.site_id}/drive/root/children",
params=params,
headers={"Authorization": f"Bearer {token}"},
)
resp.raise_for_status()
data = resp.json()
files = []
for item in data.get("value", []):
if item.get("file"):
files.append(
{
"id": item["id"],
"name": item["name"],
"mimeType": item.get("file", {}).get(
"mimeType", "application/octet-stream"
),
"webViewLink": item.get("webUrl"),
"createdTime": item.get("createdDateTime"),
"modifiedTime": item.get("lastModifiedDateTime"),
}
)
next_token = None
next_link = data.get("@odata.nextLink")
if next_link:
from urllib.parse import urlparse, parse_qs
parsed = urlparse(next_link)
next_token = parse_qs(parsed.query).get("$skiptoken", [None])[0]
return {"files": files, "nextPageToken": next_token}
async def get_file_content(self, file_id: str) -> ConnectorDocument:
if not self._authenticated:
raise ValueError("Not authenticated")
token = self.oauth.get_access_token()
headers = {"Authorization": f"Bearer {token}"}
async with httpx.AsyncClient() as client:
meta_resp = await client.get(
f"{self.base_url}/sites/{self.site_id}/drive/items/{file_id}",
headers=headers,
)
meta_resp.raise_for_status()
metadata = meta_resp.json()
content_resp = await client.get(
f"{self.base_url}/sites/{self.site_id}/drive/items/{file_id}/content",
headers=headers,
)
content = content_resp.content
# Handle the possibility of this being a redirect
if content_resp.status_code in (301, 302, 303, 307, 308):
redirect_url = content_resp.headers.get("Location")
if redirect_url:
content_resp = await client.get(redirect_url)
content_resp.raise_for_status()
content = content_resp.content
token_file = config.get("token_file") or str(project_root / "sharepoint_token.json")
Path(token_file).parent.mkdir(parents=True, exist_ok=True)
# Only initialize OAuth if we have credentials
if self.client_id and self.client_secret:
connection_id = config.get("connection_id", "default")
# Use token_file from config if provided, otherwise generate one
if config.get("token_file"):
oauth_token_file = config["token_file"]
else:
content_resp.raise_for_status()
perm_resp = await client.get(
f"{self.base_url}/sites/{self.site_id}/drive/items/{file_id}/permissions",
headers=headers,
oauth_token_file = f"sharepoint_token_{connection_id}.json"
authority = f"https://login.microsoftonline.com/{self.tenant_id}" if self.tenant_id != "common" else "https://login.microsoftonline.com/common"
self.oauth = SharePointOAuth(
client_id=self.client_id,
client_secret=self.client_secret,
token_file=oauth_token_file,
authority=authority
)
perm_resp.raise_for_status()
permissions = perm_resp.json()
acl = self._parse_permissions(metadata, permissions)
modified = datetime.fromisoformat(
metadata["lastModifiedDateTime"].replace("Z", "+00:00")
).replace(tzinfo=None)
created = datetime.fromisoformat(
metadata["createdDateTime"].replace("Z", "+00:00")
).replace(tzinfo=None)
document = ConnectorDocument(
id=metadata["id"],
filename=metadata["name"],
mimetype=metadata.get("file", {}).get(
"mimeType", "application/octet-stream"
),
content=content,
source_url=metadata.get("webUrl"),
acl=acl,
modified_time=modified,
created_time=created,
metadata={"size": metadata.get("size")},
)
return document
def _parse_permissions(
self, metadata: Dict[str, Any], permissions: Dict[str, Any]
) -> DocumentACL:
acl = DocumentACL()
owner = metadata.get("createdBy", {}).get("user", {}).get("email")
if owner:
acl.owner = owner
for perm in permissions.get("value", []):
role = perm.get("roles", ["read"])[0]
grantee = perm.get("grantedToV2") or perm.get("grantedTo")
if not grantee:
continue
user = grantee.get("user")
if user and user.get("email"):
acl.user_permissions[user["email"]] = role
group = grantee.get("group")
if group and group.get("email"):
acl.group_permissions[group["email"]] = role
return acl
def handle_webhook_validation(
self, request_method: str, headers: Dict[str, str], query_params: Dict[str, str]
) -> Optional[str]:
"""Handle Microsoft Graph webhook validation"""
if request_method == "GET":
validation_token = query_params.get("validationtoken") or query_params.get(
"validationToken"
)
if validation_token:
return validation_token
return None
def extract_webhook_channel_id(
self, payload: Dict[str, Any], headers: Dict[str, str]
) -> Optional[str]:
"""Extract SharePoint subscription ID from webhook payload"""
values = payload.get("value", [])
return values[0].get("subscriptionId") if values else None
async def handle_webhook(self, payload: Dict[str, Any]) -> List[str]:
values = payload.get("value", [])
file_ids = []
for item in values:
resource_data = item.get("resourceData", {})
file_id = resource_data.get("id")
if file_id:
file_ids.append(file_id)
return file_ids
async def cleanup_subscription(
self, subscription_id: str, resource_id: str = None
) -> bool:
if not self._authenticated:
else:
self.oauth = None
# Track subscription ID for webhooks
self._subscription_id: Optional[str] = None
# Add Graph API defaults similar to Google Drive flags
self._graph_api_version = "v1.0"
self._default_params = {
"$select": "id,name,size,lastModifiedDateTime,createdDateTime,webUrl,file,folder,@microsoft.graph.downloadUrl"
}
@property
def _graph_base_url(self) -> str:
"""Base URL for Microsoft Graph API calls"""
return f"https://graph.microsoft.com/{self._graph_api_version}"
def emit(self, doc: ConnectorDocument) -> None:
"""
Emit a ConnectorDocument instance.
Override this method to integrate with your ingestion pipeline.
"""
logger.debug(f"Emitting SharePoint document: {doc.id} ({doc.filename})")
async def authenticate(self) -> bool:
"""Test authentication - BaseConnector interface"""
logger.debug(f"SharePoint authenticate() called, oauth is None: {self.oauth is None}")
try:
if not self.oauth:
logger.debug("SharePoint authentication failed: OAuth not initialized")
self._authenticated = False
return False
logger.debug("Loading SharePoint credentials...")
# Try to load existing credentials first
load_result = await self.oauth.load_credentials()
logger.debug(f"Load credentials result: {load_result}")
logger.debug("Checking SharePoint authentication status...")
authenticated = await self.oauth.is_authenticated()
logger.debug(f"SharePoint is_authenticated result: {authenticated}")
self._authenticated = authenticated
return authenticated
except Exception as e:
logger.error(f"SharePoint authentication failed: {e}")
import traceback
traceback.print_exc()
self._authenticated = False
return False
token = self.oauth.get_access_token()
async with httpx.AsyncClient() as client:
resp = await client.delete(
f"{self.base_url}/subscriptions/{subscription_id}",
headers={"Authorization": f"Bearer {token}"},
def get_auth_url(self) -> str:
"""Get OAuth authorization URL"""
if not self.oauth:
raise RuntimeError("SharePoint OAuth not initialized - missing credentials")
return self.oauth.create_authorization_url(self.redirect_uri)
async def handle_oauth_callback(self, auth_code: str) -> Dict[str, Any]:
"""Handle OAuth callback"""
if not self.oauth:
raise RuntimeError("SharePoint OAuth not initialized - missing credentials")
try:
success = await self.oauth.handle_authorization_callback(auth_code, self.redirect_uri)
if success:
self._authenticated = True
return {"status": "success"}
else:
raise ValueError("OAuth callback failed")
except Exception as e:
logger.error(f"OAuth callback failed: {e}")
raise
def sync_once(self) -> None:
"""
Perform a one-shot sync of SharePoint files and emit documents.
This method mirrors the Google Drive connector's sync_once functionality.
"""
import asyncio
async def _async_sync():
try:
# Get list of files
file_list = await self.list_files(max_files=1000) # Adjust as needed
files = file_list.get("files", [])
for file_info in files:
try:
file_id = file_info.get("id")
if not file_id:
continue
# Get full document content
doc = await self.get_file_content(file_id)
self.emit(doc)
except Exception as e:
logger.error(f"Failed to sync SharePoint file {file_info.get('name', 'unknown')}: {e}")
continue
except Exception as e:
logger.error(f"SharePoint sync_once failed: {e}")
raise
# Run the async sync
if hasattr(asyncio, 'run'):
asyncio.run(_async_sync())
else:
# Python < 3.7 compatibility
loop = asyncio.get_event_loop()
loop.run_until_complete(_async_sync())
async def setup_subscription(self) -> str:
"""Set up real-time subscription for file changes - BaseConnector interface"""
webhook_url = self.config.get('webhook_url')
if not webhook_url:
logger.warning("No webhook URL configured, skipping SharePoint subscription setup")
return "no-webhook-configured"
try:
# Ensure we're authenticated
if not await self.authenticate():
raise RuntimeError("SharePoint authentication failed during subscription setup")
token = self.oauth.get_access_token()
# Microsoft Graph subscription for SharePoint site
site_info = self._parse_sharepoint_url()
if site_info:
resource = f"sites/{site_info['host_name']}:/sites/{site_info['site_name']}:/drive/root"
else:
resource = "/me/drive/root"
subscription_data = {
"changeType": "created,updated,deleted",
"notificationUrl": f"{webhook_url}/webhook/sharepoint",
"resource": resource,
"expirationDateTime": self._get_subscription_expiry(),
"clientState": f"sharepoint_{self.tenant_id}"
}
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json"
}
url = f"{self._graph_base_url}/subscriptions"
async with httpx.AsyncClient() as client:
response = await client.post(url, json=subscription_data, headers=headers, timeout=30)
response.raise_for_status()
result = response.json()
subscription_id = result.get("id")
if subscription_id:
self._subscription_id = subscription_id
logger.info(f"SharePoint subscription created: {subscription_id}")
return subscription_id
else:
raise ValueError("No subscription ID returned from Microsoft Graph")
except Exception as e:
logger.error(f"Failed to setup SharePoint subscription: {e}")
raise
def _get_subscription_expiry(self) -> str:
"""Get subscription expiry time (max 3 days for Graph API)"""
from datetime import datetime, timedelta
expiry = datetime.utcnow() + timedelta(days=3) # 3 days max for Graph
return expiry.strftime("%Y-%m-%dT%H:%M:%S.%fZ")
def _parse_sharepoint_url(self) -> Optional[Dict[str, str]]:
"""Parse SharePoint URL to extract site information for Graph API"""
if not self.sharepoint_url:
return None
try:
parsed = urlparse(self.sharepoint_url)
# Extract hostname and site name from URL like: https://contoso.sharepoint.com/sites/teamsite
host_name = parsed.netloc
path_parts = parsed.path.strip('/').split('/')
if len(path_parts) >= 2 and path_parts[0] == 'sites':
site_name = path_parts[1]
return {
"host_name": host_name,
"site_name": site_name
}
except Exception as e:
logger.warning(f"Could not parse SharePoint URL {self.sharepoint_url}: {e}")
return None
async def list_files(self, page_token: Optional[str] = None, max_files: Optional[int] = None) -> Dict[str, Any]:
"""List all files using Microsoft Graph API - BaseConnector interface"""
try:
# Ensure authentication
if not await self.authenticate():
raise RuntimeError("SharePoint authentication failed during file listing")
files = []
max_files_value = max_files if max_files is not None else 100
# Build Graph API URL for the site or fallback to user's OneDrive
site_info = self._parse_sharepoint_url()
if site_info:
base_url = f"{self._graph_base_url}/sites/{site_info['host_name']}:/sites/{site_info['site_name']}:/drive/root/children"
else:
base_url = f"{self._graph_base_url}/me/drive/root/children"
params = dict(self._default_params)
params["$top"] = max_files_value
if page_token:
params["$skiptoken"] = page_token
response = await self._make_graph_request(base_url, params=params)
data = response.json()
items = data.get("value", [])
for item in items:
# Only include files, not folders
if item.get("file"):
files.append({
"id": item.get("id", ""),
"name": item.get("name", ""),
"path": f"/drive/items/{item.get('id')}",
"size": int(item.get("size", 0)),
"modified": item.get("lastModifiedDateTime"),
"created": item.get("createdDateTime"),
"mime_type": item.get("file", {}).get("mimeType", self._get_mime_type(item.get("name", ""))),
"url": item.get("webUrl", ""),
"download_url": item.get("@microsoft.graph.downloadUrl")
})
# Check for next page
next_page_token = None
next_link = data.get("@odata.nextLink")
if next_link:
from urllib.parse import urlparse, parse_qs
parsed = urlparse(next_link)
query_params = parse_qs(parsed.query)
if "$skiptoken" in query_params:
next_page_token = query_params["$skiptoken"][0]
return {
"files": files,
"next_page_token": next_page_token
}
except Exception as e:
logger.error(f"Failed to list SharePoint files: {e}")
return {"files": [], "next_page_token": None} # Return empty result instead of raising
async def get_file_content(self, file_id: str) -> ConnectorDocument:
"""Get file content and metadata - BaseConnector interface"""
try:
# Ensure authentication
if not await self.authenticate():
raise RuntimeError("SharePoint authentication failed during file content retrieval")
# First get file metadata using Graph API
file_metadata = await self._get_file_metadata_by_id(file_id)
if not file_metadata:
raise ValueError(f"File not found: {file_id}")
# Download file content
download_url = file_metadata.get("download_url")
if download_url:
content = await self._download_file_from_url(download_url)
else:
content = await self._download_file_content(file_id)
# Create ACL from metadata
acl = DocumentACL(
owner="", # Graph API requires additional calls for detailed permissions
user_permissions={},
group_permissions={}
)
return resp.status_code in (200, 204)
# Parse dates
modified_time = self._parse_graph_date(file_metadata.get("modified"))
created_time = self._parse_graph_date(file_metadata.get("created"))
return ConnectorDocument(
id=file_id,
filename=file_metadata.get("name", ""),
mimetype=file_metadata.get("mime_type", "application/octet-stream"),
content=content,
source_url=file_metadata.get("url", ""),
acl=acl,
modified_time=modified_time,
created_time=created_time,
metadata={
"sharepoint_path": file_metadata.get("path", ""),
"sharepoint_url": self.sharepoint_url,
"size": file_metadata.get("size", 0)
}
)
except Exception as e:
logger.error(f"Failed to get SharePoint file content {file_id}: {e}")
raise
async def _get_file_metadata_by_id(self, file_id: str) -> Optional[Dict[str, Any]]:
"""Get file metadata by ID using Graph API"""
try:
# Try site-specific path first, then fallback to user drive
site_info = self._parse_sharepoint_url()
if site_info:
url = f"{self._graph_base_url}/sites/{site_info['host_name']}:/sites/{site_info['site_name']}:/drive/items/{file_id}"
else:
url = f"{self._graph_base_url}/me/drive/items/{file_id}"
params = dict(self._default_params)
response = await self._make_graph_request(url, params=params)
item = response.json()
if item.get("file"):
return {
"id": file_id,
"name": item.get("name", ""),
"path": f"/drive/items/{file_id}",
"size": int(item.get("size", 0)),
"modified": item.get("lastModifiedDateTime"),
"created": item.get("createdDateTime"),
"mime_type": item.get("file", {}).get("mimeType", self._get_mime_type(item.get("name", ""))),
"url": item.get("webUrl", ""),
"download_url": item.get("@microsoft.graph.downloadUrl")
}
return None
except Exception as e:
logger.error(f"Failed to get file metadata for {file_id}: {e}")
return None
async def _download_file_content(self, file_id: str) -> bytes:
"""Download file content by file ID using Graph API"""
try:
site_info = self._parse_sharepoint_url()
if site_info:
url = f"{self._graph_base_url}/sites/{site_info['host_name']}:/sites/{site_info['site_name']}:/drive/items/{file_id}/content"
else:
url = f"{self._graph_base_url}/me/drive/items/{file_id}/content"
token = self.oauth.get_access_token()
headers = {"Authorization": f"Bearer {token}"}
async with httpx.AsyncClient() as client:
response = await client.get(url, headers=headers, timeout=60)
response.raise_for_status()
return response.content
except Exception as e:
logger.error(f"Failed to download file content for {file_id}: {e}")
raise
async def _download_file_from_url(self, download_url: str) -> bytes:
"""Download file content from direct download URL"""
try:
async with httpx.AsyncClient() as client:
response = await client.get(download_url, timeout=60)
response.raise_for_status()
return response.content
except Exception as e:
logger.error(f"Failed to download from URL {download_url}: {e}")
raise
def _parse_graph_date(self, date_str: Optional[str]) -> datetime:
"""Parse Microsoft Graph date string to datetime"""
if not date_str:
return datetime.now()
try:
if date_str.endswith('Z'):
return datetime.fromisoformat(date_str[:-1]).replace(tzinfo=None)
else:
return datetime.fromisoformat(date_str.replace('T', ' '))
except (ValueError, AttributeError):
return datetime.now()
async def _make_graph_request(self, url: str, method: str = "GET",
data: Optional[Dict] = None, params: Optional[Dict] = None) -> httpx.Response:
"""Make authenticated API request to Microsoft Graph"""
token = self.oauth.get_access_token()
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json"
}
async with httpx.AsyncClient() as client:
if method.upper() == "GET":
response = await client.get(url, headers=headers, params=params, timeout=30)
elif method.upper() == "POST":
response = await client.post(url, headers=headers, json=data, timeout=30)
elif method.upper() == "DELETE":
response = await client.delete(url, headers=headers, timeout=30)
else:
raise ValueError(f"Unsupported HTTP method: {method}")
response.raise_for_status()
return response
def _get_mime_type(self, filename: str) -> str:
"""Get MIME type based on file extension"""
import mimetypes
mime_type, _ = mimetypes.guess_type(filename)
return mime_type or "application/octet-stream"
# Webhook methods - BaseConnector interface
def handle_webhook_validation(self, request_method: str, headers: Dict[str, str],
query_params: Dict[str, str]) -> Optional[str]:
"""Handle webhook validation (Graph API specific)"""
if request_method == "POST" and "validationToken" in query_params:
return query_params["validationToken"]
return None
def extract_webhook_channel_id(self, payload: Dict[str, Any],
headers: Dict[str, str]) -> Optional[str]:
"""Extract channel/subscription ID from webhook payload"""
notifications = payload.get("value", [])
if notifications:
return notifications[0].get("subscriptionId")
return None
async def handle_webhook(self, payload: Dict[str, Any]) -> List[str]:
"""Handle webhook notification and return affected file IDs"""
affected_files = []
# Process Microsoft Graph webhook payload
notifications = payload.get("value", [])
for notification in notifications:
resource = notification.get("resource")
if resource and "/drive/items/" in resource:
file_id = resource.split("/drive/items/")[-1]
affected_files.append(file_id)
return affected_files
async def cleanup_subscription(self, subscription_id: str) -> bool:
"""Clean up subscription - BaseConnector interface"""
if subscription_id == "no-webhook-configured":
logger.info("No subscription to cleanup (webhook was not configured)")
return True
try:
# Ensure authentication
if not await self.authenticate():
logger.error("SharePoint authentication failed during subscription cleanup")
return False
token = self.oauth.get_access_token()
headers = {"Authorization": f"Bearer {token}"}
url = f"{self._graph_base_url}/subscriptions/{subscription_id}"
async with httpx.AsyncClient() as client:
response = await client.delete(url, headers=headers, timeout=30)
if response.status_code in [200, 204, 404]:
logger.info(f"SharePoint subscription {subscription_id} cleaned up successfully")
return True
else:
logger.warning(f"Unexpected response cleaning up subscription: {response.status_code}")
return False
except Exception as e:
logger.error(f"Failed to cleanup SharePoint subscription {subscription_id}: {e}")
return False

View file

@ -1,19 +1,28 @@
import os
import json
import logging
from typing import Optional, Dict, Any
import aiofiles
from datetime import datetime
import httpx
import msal
logger = logging.getLogger(__name__)
class SharePointOAuth:
"""Direct token management for SharePoint, bypassing MSAL cache format"""
"""Handles Microsoft Graph OAuth authentication flow following Google Drive pattern."""
SCOPES = [
"offline_access",
"Files.Read.All",
"Sites.Read.All",
]
# Reserved scopes that must NOT be sent on token or silent calls
RESERVED_SCOPES = {"openid", "profile", "offline_access"}
# For PERSONAL Microsoft Accounts (OneDrive consumer):
# - Use AUTH_SCOPES for interactive auth (consent + refresh token issuance)
# - Use RESOURCE_SCOPES for acquire_token_silent / refresh paths
AUTH_SCOPES = ["User.Read", "Files.Read.All", "offline_access"]
RESOURCE_SCOPES = ["User.Read", "Files.Read.All"]
SCOPES = AUTH_SCOPES # Backward compatibility alias
# Kept for reference; MSAL derives endpoints from `authority`
AUTH_ENDPOINT = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize"
TOKEN_ENDPOINT = "https://login.microsoftonline.com/common/oauth2/v2.0/token"
@ -22,173 +31,299 @@ class SharePointOAuth:
client_id: str,
client_secret: str,
token_file: str = "sharepoint_token.json",
authority: str = "https://login.microsoftonline.com/common", # Keep for compatibility
authority: str = "https://login.microsoftonline.com/common",
allow_json_refresh: bool = True,
):
"""
Initialize SharePointOAuth.
Args:
client_id: Azure AD application (client) ID.
client_secret: Azure AD application client secret.
token_file: Path to persisted token cache file (MSAL cache format).
authority: Usually "https://login.microsoftonline.com/common" for MSA + org,
or tenant-specific for work/school.
allow_json_refresh: If True, permit one-time migration from legacy flat JSON
{"access_token","refresh_token",...}. Otherwise refuse it.
"""
self.client_id = client_id
self.client_secret = client_secret
self.token_file = token_file
self.authority = authority # Keep for compatibility but not used
self._tokens = None
self._load_tokens()
self.authority = authority
self.allow_json_refresh = allow_json_refresh
self.token_cache = msal.SerializableTokenCache()
self._current_account = None
def _load_tokens(self):
"""Load tokens from file"""
if os.path.exists(self.token_file):
with open(self.token_file, "r") as f:
self._tokens = json.loads(f.read())
print(f"Loaded tokens from {self.token_file}")
else:
print(f"No token file found at {self.token_file}")
# Initialize MSAL Confidential Client
self.app = msal.ConfidentialClientApplication(
client_id=self.client_id,
client_credential=self.client_secret,
authority=self.authority,
token_cache=self.token_cache,
)
async def save_cache(self):
"""Persist tokens to file (renamed for compatibility)"""
await self._save_tokens()
async def _save_tokens(self):
"""Save tokens to file"""
if self._tokens:
async with aiofiles.open(self.token_file, "w") as f:
await f.write(json.dumps(self._tokens, indent=2))
def _is_token_expired(self) -> bool:
"""Check if current access token is expired"""
if not self._tokens or 'expiry' not in self._tokens:
return True
expiry_str = self._tokens['expiry']
# Handle different expiry formats
async def load_credentials(self) -> bool:
"""Load existing credentials from token file (async)."""
try:
if expiry_str.endswith('Z'):
expiry_dt = datetime.fromisoformat(expiry_str[:-1])
else:
expiry_dt = datetime.fromisoformat(expiry_str)
# Add 5-minute buffer
import datetime as dt
now = datetime.now()
return now >= (expiry_dt - dt.timedelta(minutes=5))
except:
return True
logger.debug(f"SharePoint OAuth loading credentials from: {self.token_file}")
if os.path.exists(self.token_file):
logger.debug(f"Token file exists, reading: {self.token_file}")
# Read the token file
async with aiofiles.open(self.token_file, "r") as f:
cache_data = await f.read()
logger.debug(f"Read {len(cache_data)} chars from token file")
if cache_data.strip():
# 1) Try legacy flat JSON first
try:
json_data = json.loads(cache_data)
if isinstance(json_data, dict) and "refresh_token" in json_data:
if self.allow_json_refresh:
logger.debug(
"Found legacy JSON refresh_token and allow_json_refresh=True; attempting migration refresh"
)
return await self._refresh_from_json_token(json_data)
else:
logger.warning(
"Token file contains a legacy JSON refresh_token, but allow_json_refresh=False. "
"Delete the file and re-auth."
)
return False
except json.JSONDecodeError:
logger.debug("Token file is not flat JSON; attempting MSAL cache format")
# 2) Try MSAL cache format
logger.debug("Attempting MSAL cache deserialization")
self.token_cache.deserialize(cache_data)
# Get accounts from loaded cache
accounts = self.app.get_accounts()
logger.debug(f"Found {len(accounts)} accounts in MSAL cache")
if accounts:
self._current_account = accounts[0]
logger.debug(f"Set current account: {self._current_account.get('username', 'no username')}")
# IMPORTANT: Use RESOURCE_SCOPES (no reserved scopes) for silent acquisition
result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=self._current_account)
logger.debug(f"Silent token acquisition result keys: {list(result.keys()) if result else 'None'}")
if result and "access_token" in result:
logger.debug("Silent token acquisition successful")
await self.save_cache()
return True
else:
error_msg = (result or {}).get("error") or "No result"
logger.warning(f"Silent token acquisition failed: {error_msg}")
else:
logger.debug(f"Token file {self.token_file} is empty")
else:
logger.debug(f"Token file does not exist: {self.token_file}")
async def _refresh_access_token(self) -> bool:
"""Refresh the access token using refresh token"""
if not self._tokens or 'refresh_token' not in self._tokens:
return False
data = {
'client_id': self.client_id,
'client_secret': self.client_secret,
'refresh_token': self._tokens['refresh_token'],
'grant_type': 'refresh_token',
'scope': ' '.join(self.SCOPES)
}
except Exception as e:
logger.error(f"Failed to load SharePoint credentials: {e}")
import traceback
traceback.print_exc()
return False
async with httpx.AsyncClient() as client:
try:
response = await client.post(self.TOKEN_ENDPOINT, data=data)
response.raise_for_status()
token_data = response.json()
async def _refresh_from_json_token(self, token_data: dict) -> bool:
"""
Use refresh token from a legacy JSON file to get new tokens (one-time migration path).
# Update tokens
self._tokens['token'] = token_data['access_token']
if 'refresh_token' in token_data:
self._tokens['refresh_token'] = token_data['refresh_token']
# Calculate expiry
expires_in = token_data.get('expires_in', 3600)
import datetime as dt
expiry = datetime.now() + dt.timedelta(seconds=expires_in)
self._tokens['expiry'] = expiry.isoformat()
await self._save_tokens()
print("Access token refreshed successfully")
return True
except Exception as e:
print(f"Failed to refresh token: {e}")
Notes:
- Prefer using an MSAL cache file and acquire_token_silent().
- This path is only for migrating older refresh_token JSON files.
"""
try:
refresh_token = token_data.get("refresh_token")
if not refresh_token:
logger.error("No refresh_token found in JSON file - cannot refresh")
logger.error("You must re-authenticate interactively to obtain a valid token")
return False
def create_authorization_url(self, redirect_uri: str) -> str:
"""Create authorization URL for OAuth flow"""
from urllib.parse import urlencode
params = {
'client_id': self.client_id,
'response_type': 'code',
'redirect_uri': redirect_uri,
'scope': ' '.join(self.SCOPES),
'response_mode': 'query'
# Use only RESOURCE_SCOPES when refreshing (no reserved scopes)
refresh_scopes = [s for s in self.RESOURCE_SCOPES if s not in self.RESERVED_SCOPES]
logger.debug(f"Using refresh token; refresh scopes = {refresh_scopes}")
result = self.app.acquire_token_by_refresh_token(
refresh_token=refresh_token,
scopes=refresh_scopes,
)
if result and "access_token" in result:
logger.debug("Successfully refreshed token via legacy JSON path")
await self.save_cache()
accounts = self.app.get_accounts()
logger.debug(f"After refresh, found {len(accounts)} accounts")
if accounts:
self._current_account = accounts[0]
logger.debug(f"Set current account after refresh: {self._current_account.get('username', 'no username')}")
return True
# Error handling
err = (result or {}).get("error_description") or (result or {}).get("error") or "Unknown error"
logger.error(f"Refresh token failed: {err}")
if any(code in err for code in ("AADSTS70000", "invalid_grant", "interaction_required")):
logger.warning(
"Refresh denied due to unauthorized/expired scopes or invalid grant. "
"Delete the token file and perform interactive sign-in with correct scopes."
)
return False
except Exception as e:
logger.error(f"Exception during refresh from JSON token: {e}")
import traceback
traceback.print_exc()
return False
async def save_cache(self):
"""Persist the token cache to file."""
try:
# Ensure parent directory exists
parent = os.path.dirname(os.path.abspath(self.token_file))
if parent and not os.path.exists(parent):
os.makedirs(parent, exist_ok=True)
cache_data = self.token_cache.serialize()
if cache_data:
async with aiofiles.open(self.token_file, "w") as f:
await f.write(cache_data)
logger.debug(f"Token cache saved to {self.token_file}")
except Exception as e:
logger.error(f"Failed to save token cache: {e}")
def create_authorization_url(self, redirect_uri: str, state: Optional[str] = None) -> str:
"""Create authorization URL for OAuth flow."""
# Store redirect URI for later use in callback
self._redirect_uri = redirect_uri
kwargs: Dict[str, Any] = {
# IMPORTANT: interactive auth includes offline_access
"scopes": self.AUTH_SCOPES,
"redirect_uri": redirect_uri,
"prompt": "consent", # ensure refresh token on first run
}
auth_url = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize"
return f"{auth_url}?{urlencode(params)}"
if state:
kwargs["state"] = state # Optional CSRF protection
auth_url = self.app.get_authorization_request_url(**kwargs)
logger.debug(f"Generated auth URL: {auth_url}")
logger.debug(f"Auth scopes: {self.AUTH_SCOPES}")
return auth_url
async def handle_authorization_callback(
self, authorization_code: str, redirect_uri: str
) -> bool:
"""Handle OAuth callback and exchange code for tokens"""
data = {
'client_id': self.client_id,
'client_secret': self.client_secret,
'code': authorization_code,
'grant_type': 'authorization_code',
'redirect_uri': redirect_uri,
'scope': ' '.join(self.SCOPES)
}
"""Handle OAuth callback and exchange code for tokens."""
try:
# For code exchange, we pass the same auth scopes as used in the authorize step
result = self.app.acquire_token_by_authorization_code(
authorization_code,
scopes=self.AUTH_SCOPES,
redirect_uri=redirect_uri,
)
async with httpx.AsyncClient() as client:
try:
response = await client.post(self.TOKEN_ENDPOINT, data=data)
response.raise_for_status()
token_data = response.json()
if result and "access_token" in result:
# Store the account for future use
accounts = self.app.get_accounts()
if accounts:
self._current_account = accounts[0]
# Store tokens in our format
import datetime as dt
expires_in = token_data.get('expires_in', 3600)
expiry = datetime.now() + dt.timedelta(seconds=expires_in)
self._tokens = {
'token': token_data['access_token'],
'refresh_token': token_data['refresh_token'],
'scopes': self.SCOPES,
'expiry': expiry.isoformat()
}
await self._save_tokens()
print("Authorization successful, tokens saved")
await self.save_cache()
logger.info("SharePoint OAuth authorization successful")
return True
except Exception as e:
print(f"Authorization failed: {e}")
return False
async def is_authenticated(self) -> bool:
"""Check if we have valid credentials"""
if not self._tokens:
error_msg = (result or {}).get("error_description") or (result or {}).get("error") or "Unknown error"
logger.error(f"SharePoint OAuth authorization failed: {error_msg}")
return False
# If token is expired, try to refresh
if self._is_token_expired():
print("Token expired, attempting refresh...")
if await self._refresh_access_token():
except Exception as e:
logger.error(f"Exception during SharePoint OAuth authorization: {e}")
return False
async def is_authenticated(self) -> bool:
"""Check if we have valid credentials (simplified like Google Drive)."""
try:
# First try to load credentials if we haven't already
if not self._current_account:
await self.load_credentials()
# If we have an account, try to get a token (MSAL will refresh if needed)
if self._current_account:
# IMPORTANT: use RESOURCE_SCOPES here
result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=self._current_account)
if result and "access_token" in result:
return True
else:
error_msg = (result or {}).get("error") or "No result returned"
logger.debug(f"Token acquisition failed for current account: {error_msg}")
# Fallback: try without specific account
result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=None)
if result and "access_token" in result:
# Update current account if this worked
accounts = self.app.get_accounts()
if accounts:
self._current_account = accounts[0]
return True
else:
return False
return True
return False
except Exception as e:
logger.error(f"Authentication check failed: {e}")
return False
def get_access_token(self) -> str:
"""Get current access token"""
if not self._tokens or 'token' not in self._tokens:
raise ValueError("No access token available")
if self._is_token_expired():
raise ValueError("Access token expired and refresh failed")
return self._tokens['token']
"""Get an access token for Microsoft Graph (simplified like Google Drive)."""
try:
# Try with current account first
if self._current_account:
result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=self._current_account)
if result and "access_token" in result:
return result["access_token"]
# Fallback: try without specific account
result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=None)
if result and "access_token" in result:
return result["access_token"]
# If we get here, authentication has failed
error_msg = (result or {}).get("error_description") or (result or {}).get("error") or "No valid authentication"
raise ValueError(f"Failed to acquire access token: {error_msg}")
except Exception as e:
logger.error(f"Failed to get access token: {e}")
raise
async def revoke_credentials(self):
"""Clear tokens"""
self._tokens = None
if os.path.exists(self.token_file):
os.remove(self.token_file)
"""Clear token cache and remove token file (like Google Drive)."""
try:
# Clear in-memory state
self._current_account = None
self.token_cache = msal.SerializableTokenCache()
# Recreate MSAL app with fresh cache
self.app = msal.ConfidentialClientApplication(
client_id=self.client_id,
client_credential=self.client_secret,
authority=self.authority,
token_cache=self.token_cache,
)
# Remove token file
if os.path.exists(self.token_file):
os.remove(self.token_file)
logger.info(f"Removed SharePoint token file: {self.token_file}")
except Exception as e:
logger.error(f"Failed to revoke SharePoint credentials: {e}")
def get_service(self) -> str:
"""Return an access token (Graph doesn't need a generated client like Google Drive)."""
return self.get_access_token()