Finally fix MSAL in onedrive/sharepoint
This commit is contained in:
parent
5e14c7f100
commit
f03889a2b3
6 changed files with 1649 additions and 750 deletions
|
|
@ -132,7 +132,10 @@ async def connector_status(request: Request, connector_service, session_manager)
|
|||
for connection in connections:
|
||||
try:
|
||||
connector = await connector_service._get_connector(connection.connection_id)
|
||||
connection_client_ids[connection.connection_id] = connector.get_client_id()
|
||||
if connector is not None:
|
||||
connection_client_ids[connection.connection_id] = connector.get_client_id()
|
||||
else:
|
||||
connection_client_ids[connection.connection_id] = None
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Could not get connector for connection",
|
||||
|
|
@ -338,8 +341,8 @@ async def connector_webhook(request: Request, connector_service, session_manager
|
|||
)
|
||||
|
||||
async def connector_token(request: Request, connector_service, session_manager):
|
||||
"""Get access token for connector API calls (e.g., Google Picker)"""
|
||||
connector_type = request.path_params.get("connector_type")
|
||||
"""Get access token for connector API calls (e.g., Pickers)."""
|
||||
url_connector_type = request.path_params.get("connector_type")
|
||||
connection_id = request.query_params.get("connection_id")
|
||||
|
||||
if not connection_id:
|
||||
|
|
@ -348,37 +351,81 @@ async def connector_token(request: Request, connector_service, session_manager):
|
|||
user = request.state.user
|
||||
|
||||
try:
|
||||
# Get the connection and verify it belongs to the user
|
||||
# 1) Load the connection and verify ownership
|
||||
connection = await connector_service.connection_manager.get_connection(connection_id)
|
||||
if not connection or connection.user_id != user.user_id:
|
||||
return JSONResponse({"error": "Connection not found"}, status_code=404)
|
||||
|
||||
# Get the connector instance
|
||||
# 2) Get the ACTUAL connector instance/type for this connection_id
|
||||
connector = await connector_service._get_connector(connection_id)
|
||||
if not connector:
|
||||
return JSONResponse({"error": f"Connector not available - authentication may have failed for {connector_type}"}, status_code=404)
|
||||
return JSONResponse(
|
||||
{"error": f"Connector not available - authentication may have failed for {url_connector_type}"},
|
||||
status_code=404,
|
||||
)
|
||||
|
||||
# For Google Drive, get the access token
|
||||
if connector_type == "google_drive" and hasattr(connector, 'oauth'):
|
||||
real_type = getattr(connector, "type", None) or getattr(connection, "connector_type", None)
|
||||
if real_type is None:
|
||||
return JSONResponse({"error": "Unable to determine connector type"}, status_code=500)
|
||||
|
||||
# Optional: warn if URL path type disagrees with real type
|
||||
if url_connector_type and url_connector_type != real_type:
|
||||
# You can downgrade this to debug if you expect cross-routing.
|
||||
return JSONResponse(
|
||||
{
|
||||
"error": "Connector type mismatch",
|
||||
"detail": {
|
||||
"requested_type": url_connector_type,
|
||||
"actual_type": real_type,
|
||||
"hint": "Call the token endpoint using the correct connector_type for this connection_id.",
|
||||
},
|
||||
},
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
# 3) Branch by the actual connector type
|
||||
# GOOGLE DRIVE (google-auth)
|
||||
if real_type == "google_drive" and hasattr(connector, "oauth"):
|
||||
await connector.oauth.load_credentials()
|
||||
if connector.oauth.creds and connector.oauth.creds.valid:
|
||||
return JSONResponse({
|
||||
"access_token": connector.oauth.creds.token,
|
||||
"expires_in": (connector.oauth.creds.expiry.timestamp() -
|
||||
__import__('time').time()) if connector.oauth.creds.expiry else None
|
||||
})
|
||||
else:
|
||||
return JSONResponse({"error": "Invalid or expired credentials"}, status_code=401)
|
||||
|
||||
# For OneDrive and SharePoint, get the access token
|
||||
elif connector_type in ["onedrive", "sharepoint"] and hasattr(connector, 'oauth'):
|
||||
expires_in = None
|
||||
try:
|
||||
if connector.oauth.creds.expiry:
|
||||
import time
|
||||
expires_in = max(0, int(connector.oauth.creds.expiry.timestamp() - time.time()))
|
||||
except Exception:
|
||||
expires_in = None
|
||||
|
||||
return JSONResponse(
|
||||
{
|
||||
"access_token": connector.oauth.creds.token,
|
||||
"expires_in": expires_in,
|
||||
}
|
||||
)
|
||||
return JSONResponse({"error": "Invalid or expired credentials"}, status_code=401)
|
||||
|
||||
# ONEDRIVE / SHAREPOINT (MSAL or custom)
|
||||
if real_type in ("onedrive", "sharepoint") and hasattr(connector, "oauth"):
|
||||
# Ensure cache/credentials are loaded before trying to use them
|
||||
try:
|
||||
# Prefer a dedicated is_authenticated() that loads cache internally
|
||||
if hasattr(connector.oauth, "is_authenticated"):
|
||||
ok = await connector.oauth.is_authenticated()
|
||||
else:
|
||||
# Fallback: try to load credentials explicitly if available
|
||||
ok = True
|
||||
if hasattr(connector.oauth, "load_credentials"):
|
||||
ok = await connector.oauth.load_credentials()
|
||||
|
||||
if not ok:
|
||||
return JSONResponse({"error": "Not authenticated"}, status_code=401)
|
||||
|
||||
# Now safe to fetch access token
|
||||
access_token = connector.oauth.get_access_token()
|
||||
return JSONResponse({
|
||||
"access_token": access_token,
|
||||
"expires_in": None # MSAL handles token expiry internally
|
||||
})
|
||||
# MSAL result has expiry, but we’re returning a raw token; keep expires_in None for simplicity
|
||||
return JSONResponse({"access_token": access_token, "expires_in": None})
|
||||
except ValueError as e:
|
||||
# Typical when acquire_token_silent fails (e.g., needs re-auth)
|
||||
return JSONResponse({"error": f"Failed to get access token: {str(e)}"}, status_code=401)
|
||||
except Exception as e:
|
||||
return JSONResponse({"error": f"Authentication error: {str(e)}"}, status_code=500)
|
||||
|
|
@ -386,7 +433,5 @@ async def connector_token(request: Request, connector_service, session_manager):
|
|||
return JSONResponse({"error": "Token not available for this connector type"}, status_code=400)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error getting connector token", error=str(e))
|
||||
logger.error("Error getting connector token", exc_info=True)
|
||||
return JSONResponse({"error": str(e)}, status_code=500)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -294,32 +294,39 @@ class ConnectionManager:
|
|||
|
||||
async def get_connector(self, connection_id: str) -> Optional[BaseConnector]:
|
||||
"""Get an active connector instance"""
|
||||
logger.debug(f"Getting connector for connection_id: {connection_id}")
|
||||
|
||||
# Return cached connector if available
|
||||
if connection_id in self.active_connectors:
|
||||
connector = self.active_connectors[connection_id]
|
||||
if connector.is_authenticated:
|
||||
logger.debug(f"Returning cached authenticated connector for {connection_id}")
|
||||
return connector
|
||||
else:
|
||||
# Remove unauthenticated connector from cache
|
||||
logger.debug(f"Removing unauthenticated connector from cache for {connection_id}")
|
||||
del self.active_connectors[connection_id]
|
||||
|
||||
# Try to create and authenticate connector
|
||||
connection_config = self.connections.get(connection_id)
|
||||
if not connection_config or not connection_config.is_active:
|
||||
logger.debug(f"No active connection config found for {connection_id}")
|
||||
return None
|
||||
|
||||
logger.debug(f"Creating connector for {connection_config.connector_type}")
|
||||
connector = self._create_connector(connection_config)
|
||||
if await connector.authenticate():
|
||||
|
||||
logger.debug(f"Attempting authentication for {connection_id}")
|
||||
auth_result = await connector.authenticate()
|
||||
logger.debug(f"Authentication result for {connection_id}: {auth_result}")
|
||||
|
||||
if auth_result:
|
||||
self.active_connectors[connection_id] = connector
|
||||
|
||||
# Setup webhook subscription if not already set up
|
||||
await self._setup_webhook_if_needed(
|
||||
connection_id, connection_config, connector
|
||||
)
|
||||
|
||||
# ... rest of the method
|
||||
return connector
|
||||
|
||||
return None
|
||||
else:
|
||||
logger.warning(f"Authentication failed for {connection_id}")
|
||||
return None
|
||||
|
||||
def get_available_connector_types(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""Get available connector types with their metadata"""
|
||||
|
|
@ -363,20 +370,23 @@ class ConnectionManager:
|
|||
|
||||
def _create_connector(self, config: ConnectionConfig) -> BaseConnector:
|
||||
"""Factory method to create connector instances"""
|
||||
if config.connector_type == "google_drive":
|
||||
return GoogleDriveConnector(config.config)
|
||||
elif config.connector_type == "sharepoint":
|
||||
return SharePointConnector(config.config)
|
||||
elif config.connector_type == "onedrive":
|
||||
return OneDriveConnector(config.config)
|
||||
elif config.connector_type == "box":
|
||||
# Future: BoxConnector(config.config)
|
||||
raise NotImplementedError("Box connector not implemented yet")
|
||||
elif config.connector_type == "dropbox":
|
||||
# Future: DropboxConnector(config.config)
|
||||
raise NotImplementedError("Dropbox connector not implemented yet")
|
||||
else:
|
||||
raise ValueError(f"Unknown connector type: {config.connector_type}")
|
||||
try:
|
||||
if config.connector_type == "google_drive":
|
||||
return GoogleDriveConnector(config.config)
|
||||
elif config.connector_type == "sharepoint":
|
||||
return SharePointConnector(config.config)
|
||||
elif config.connector_type == "onedrive":
|
||||
return OneDriveConnector(config.config)
|
||||
elif config.connector_type == "box":
|
||||
raise NotImplementedError("Box connector not implemented yet")
|
||||
elif config.connector_type == "dropbox":
|
||||
raise NotImplementedError("Dropbox connector not implemented yet")
|
||||
else:
|
||||
raise ValueError(f"Unknown connector type: {config.connector_type}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create {config.connector_type} connector: {e}")
|
||||
# Re-raise the exception so caller can handle appropriately
|
||||
raise
|
||||
|
||||
async def update_last_sync(self, connection_id: str):
|
||||
"""Update the last sync timestamp for a connection"""
|
||||
|
|
|
|||
|
|
@ -1,235 +1,487 @@
|
|||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Optional
|
||||
from datetime import datetime
|
||||
import httpx
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Any, Optional
|
||||
|
||||
from ..base import BaseConnector, ConnectorDocument, DocumentACL
|
||||
from .oauth import OneDriveOAuth
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OneDriveConnector(BaseConnector):
|
||||
"""OneDrive connector using Microsoft Graph API"""
|
||||
"""OneDrive connector using MSAL-based OAuth for authentication."""
|
||||
|
||||
# Required BaseConnector class attributes
|
||||
CLIENT_ID_ENV_VAR = "MICROSOFT_GRAPH_OAUTH_CLIENT_ID"
|
||||
CLIENT_SECRET_ENV_VAR = "MICROSOFT_GRAPH_OAUTH_CLIENT_SECRET"
|
||||
|
||||
# Connector metadata
|
||||
CONNECTOR_NAME = "OneDrive"
|
||||
CONNECTOR_DESCRIPTION = "Connect your personal OneDrive to sync documents"
|
||||
CONNECTOR_DESCRIPTION = "Connect to OneDrive (personal) to sync documents and files"
|
||||
CONNECTOR_ICON = "onedrive"
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
super().__init__(config)
|
||||
logger.debug(f"OneDrive connector __init__ called with config type: {type(config)}")
|
||||
logger.debug(f"OneDrive connector __init__ config value: {config}")
|
||||
|
||||
if config is None:
|
||||
logger.debug("Config was None, using empty dict")
|
||||
config = {}
|
||||
|
||||
try:
|
||||
logger.debug("Calling super().__init__")
|
||||
super().__init__(config)
|
||||
logger.debug("super().__init__ completed successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"super().__init__ failed: {e}")
|
||||
raise
|
||||
|
||||
# Initialize with defaults that allow the connector to be listed
|
||||
self.client_id = None
|
||||
self.client_secret = None
|
||||
self.redirect_uri = config.get("redirect_uri", "http://localhost") # must match your app registration
|
||||
|
||||
# Try to get credentials, but don't fail if they're missing
|
||||
try:
|
||||
self.client_id = self.get_client_id()
|
||||
logger.debug(f"Got client_id: {self.client_id is not None}")
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to get client_id: {e}")
|
||||
|
||||
try:
|
||||
self.client_secret = self.get_client_secret()
|
||||
logger.debug(f"Got client_secret: {self.client_secret is not None}")
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to get client_secret: {e}")
|
||||
|
||||
# Token file setup
|
||||
project_root = Path(__file__).resolve().parent.parent.parent.parent
|
||||
token_file = config.get("token_file") or str(project_root / "onedrive_token.json")
|
||||
self.oauth = OneDriveOAuth(
|
||||
client_id=self.get_client_id(),
|
||||
client_secret=self.get_client_secret(),
|
||||
token_file=token_file,
|
||||
)
|
||||
self.subscription_id = config.get("subscription_id") or config.get(
|
||||
"webhook_channel_id"
|
||||
)
|
||||
self.base_url = "https://graph.microsoft.com/v1.0"
|
||||
Path(token_file).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
async def authenticate(self) -> bool:
|
||||
if await self.oauth.is_authenticated():
|
||||
self._authenticated = True
|
||||
return True
|
||||
return False
|
||||
# Only initialize OAuth if we have credentials
|
||||
if self.client_id and self.client_secret:
|
||||
connection_id = config.get("connection_id", "default")
|
||||
|
||||
async def setup_subscription(self) -> str:
|
||||
if not self._authenticated:
|
||||
raise ValueError("Not authenticated")
|
||||
# Use token_file from config if provided, otherwise generate one
|
||||
if config.get("token_file"):
|
||||
oauth_token_file = config["token_file"]
|
||||
else:
|
||||
# Use a per-connection cache file to avoid collisions with other connectors
|
||||
oauth_token_file = f"onedrive_token_{connection_id}.json"
|
||||
|
||||
webhook_url = self.config.get("webhook_url")
|
||||
if not webhook_url:
|
||||
raise ValueError("webhook_url required in config for subscriptions")
|
||||
# MSA & org both work via /common for OneDrive personal testing
|
||||
authority = "https://login.microsoftonline.com/common"
|
||||
|
||||
expiration = (datetime.utcnow() + timedelta(days=2)).isoformat() + "Z"
|
||||
body = {
|
||||
"changeType": "created,updated,deleted",
|
||||
"notificationUrl": webhook_url,
|
||||
"resource": "/me/drive/root",
|
||||
"expirationDateTime": expiration,
|
||||
"clientState": str(uuid.uuid4()),
|
||||
self.oauth = OneDriveOAuth(
|
||||
client_id=self.client_id,
|
||||
client_secret=self.client_secret,
|
||||
token_file=oauth_token_file,
|
||||
authority=authority,
|
||||
allow_json_refresh=True, # allows one-time migration from legacy JSON if present
|
||||
)
|
||||
else:
|
||||
self.oauth = None
|
||||
|
||||
# Track subscription ID for webhooks (note: change notifications might not be available for personal accounts)
|
||||
self._subscription_id: Optional[str] = None
|
||||
|
||||
# Graph API defaults
|
||||
self._graph_api_version = "v1.0"
|
||||
self._default_params = {
|
||||
"$select": "id,name,size,lastModifiedDateTime,createdDateTime,webUrl,file,folder,@microsoft.graph.downloadUrl"
|
||||
}
|
||||
|
||||
token = self.oauth.get_access_token()
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.post(
|
||||
f"{self.base_url}/subscriptions",
|
||||
json=body,
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
@property
|
||||
def _graph_base_url(self) -> str:
|
||||
"""Base URL for Microsoft Graph API calls."""
|
||||
return f"https://graph.microsoft.com/{self._graph_api_version}"
|
||||
|
||||
self.subscription_id = data["id"]
|
||||
return self.subscription_id
|
||||
def emit(self, doc: ConnectorDocument) -> None:
|
||||
"""Emit a ConnectorDocument instance (integrate with your pipeline here)."""
|
||||
logger.debug(f"Emitting OneDrive document: {doc.id} ({doc.filename})")
|
||||
|
||||
async def list_files(
|
||||
self, page_token: Optional[str] = None, limit: int = 100
|
||||
) -> Dict[str, Any]:
|
||||
if not self._authenticated:
|
||||
raise ValueError("Not authenticated")
|
||||
async def authenticate(self) -> bool:
|
||||
"""Test authentication - BaseConnector interface."""
|
||||
logger.debug(f"OneDrive authenticate() called, oauth is None: {self.oauth is None}")
|
||||
try:
|
||||
if not self.oauth:
|
||||
logger.debug("OneDrive authentication failed: OAuth not initialized")
|
||||
self._authenticated = False
|
||||
return False
|
||||
|
||||
params = {"$top": str(limit)}
|
||||
if page_token:
|
||||
params["$skiptoken"] = page_token
|
||||
logger.debug("Loading OneDrive credentials...")
|
||||
load_result = await self.oauth.load_credentials()
|
||||
logger.debug(f"Load credentials result: {load_result}")
|
||||
|
||||
token = self.oauth.get_access_token()
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.get(
|
||||
f"{self.base_url}/me/drive/root/children",
|
||||
params=params,
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
logger.debug("Checking OneDrive authentication status...")
|
||||
authenticated = await self.oauth.is_authenticated()
|
||||
logger.debug(f"OneDrive is_authenticated result: {authenticated}")
|
||||
|
||||
files = []
|
||||
for item in data.get("value", []):
|
||||
if item.get("file"):
|
||||
files.append(
|
||||
{
|
||||
"id": item["id"],
|
||||
"name": item["name"],
|
||||
"mimeType": item.get("file", {}).get(
|
||||
"mimeType", "application/octet-stream"
|
||||
),
|
||||
"webViewLink": item.get("webUrl"),
|
||||
"createdTime": item.get("createdDateTime"),
|
||||
"modifiedTime": item.get("lastModifiedDateTime"),
|
||||
}
|
||||
)
|
||||
self._authenticated = authenticated
|
||||
return authenticated
|
||||
except Exception as e:
|
||||
logger.error(f"OneDrive authentication failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
self._authenticated = False
|
||||
return False
|
||||
|
||||
next_token = None
|
||||
next_link = data.get("@odata.nextLink")
|
||||
if next_link:
|
||||
from urllib.parse import urlparse, parse_qs
|
||||
def get_auth_url(self) -> str:
|
||||
"""Get OAuth authorization URL."""
|
||||
if not self.oauth:
|
||||
raise RuntimeError("OneDrive OAuth not initialized - missing credentials")
|
||||
return self.oauth.create_authorization_url(self.redirect_uri)
|
||||
|
||||
parsed = urlparse(next_link)
|
||||
next_token = parse_qs(parsed.query).get("$skiptoken", [None])[0]
|
||||
async def handle_oauth_callback(self, auth_code: str) -> Dict[str, Any]:
|
||||
"""Handle OAuth callback."""
|
||||
if not self.oauth:
|
||||
raise RuntimeError("OneDrive OAuth not initialized - missing credentials")
|
||||
try:
|
||||
success = await self.oauth.handle_authorization_callback(auth_code, self.redirect_uri)
|
||||
if success:
|
||||
self._authenticated = True
|
||||
return {"status": "success"}
|
||||
else:
|
||||
raise ValueError("OAuth callback failed")
|
||||
except Exception as e:
|
||||
logger.error(f"OAuth callback failed: {e}")
|
||||
raise
|
||||
|
||||
return {"files": files, "nextPageToken": next_token}
|
||||
def sync_once(self) -> None:
|
||||
"""
|
||||
Perform a one-shot sync of OneDrive files and emit documents.
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
async def _async_sync():
|
||||
try:
|
||||
file_list = await self.list_files(max_files=1000)
|
||||
files = file_list.get("files", [])
|
||||
for file_info in files:
|
||||
try:
|
||||
file_id = file_info.get("id")
|
||||
if not file_id:
|
||||
continue
|
||||
doc = await self.get_file_content(file_id)
|
||||
self.emit(doc)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to sync OneDrive file {file_info.get('name', 'unknown')}: {e}")
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"OneDrive sync_once failed: {e}")
|
||||
raise
|
||||
|
||||
if hasattr(asyncio, 'run'):
|
||||
asyncio.run(_async_sync())
|
||||
else:
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(_async_sync())
|
||||
|
||||
async def setup_subscription(self) -> str:
|
||||
"""
|
||||
Set up real-time subscription for file changes.
|
||||
NOTE: Change notifications may not be available for personal OneDrive accounts.
|
||||
"""
|
||||
webhook_url = self.config.get('webhook_url')
|
||||
if not webhook_url:
|
||||
logger.warning("No webhook URL configured, skipping OneDrive subscription setup")
|
||||
return "no-webhook-configured"
|
||||
|
||||
try:
|
||||
if not await self.authenticate():
|
||||
raise RuntimeError("OneDrive authentication failed during subscription setup")
|
||||
|
||||
token = self.oauth.get_access_token()
|
||||
|
||||
# For OneDrive personal we target the user's drive
|
||||
resource = "/me/drive/root"
|
||||
|
||||
subscription_data = {
|
||||
"changeType": "created,updated,deleted",
|
||||
"notificationUrl": f"{webhook_url}/webhook/onedrive",
|
||||
"resource": resource,
|
||||
"expirationDateTime": self._get_subscription_expiry(),
|
||||
"clientState": "onedrive_personal",
|
||||
}
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
url = f"{self._graph_base_url}/subscriptions"
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(url, json=subscription_data, headers=headers, timeout=30)
|
||||
response.raise_for_status()
|
||||
|
||||
result = response.json()
|
||||
subscription_id = result.get("id")
|
||||
|
||||
if subscription_id:
|
||||
self._subscription_id = subscription_id
|
||||
logger.info(f"OneDrive subscription created: {subscription_id}")
|
||||
return subscription_id
|
||||
else:
|
||||
raise ValueError("No subscription ID returned from Microsoft Graph")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to setup OneDrive subscription: {e}")
|
||||
raise
|
||||
|
||||
def _get_subscription_expiry(self) -> str:
|
||||
"""Get subscription expiry time (Graph caps duration; often <= 3 days)."""
|
||||
from datetime import datetime, timedelta
|
||||
expiry = datetime.utcnow() + timedelta(days=3)
|
||||
return expiry.strftime("%Y-%m-%dT%H:%M:%S.%fZ")
|
||||
|
||||
async def list_files(self, page_token: Optional[str] = None, max_files: Optional[int] = None) -> Dict[str, Any]:
|
||||
"""List files from OneDrive using Microsoft Graph."""
|
||||
try:
|
||||
if not await self.authenticate():
|
||||
raise RuntimeError("OneDrive authentication failed during file listing")
|
||||
|
||||
files: List[Dict[str, Any]] = []
|
||||
max_files_value = max_files if max_files is not None else 100
|
||||
|
||||
base_url = f"{self._graph_base_url}/me/drive/root/children"
|
||||
|
||||
params = dict(self._default_params)
|
||||
params["$top"] = max_files_value
|
||||
|
||||
if page_token:
|
||||
params["$skiptoken"] = page_token
|
||||
|
||||
response = await self._make_graph_request(base_url, params=params)
|
||||
data = response.json()
|
||||
|
||||
items = data.get("value", [])
|
||||
for item in items:
|
||||
if item.get("file"): # include files only
|
||||
files.append({
|
||||
"id": item.get("id", ""),
|
||||
"name": item.get("name", ""),
|
||||
"path": f"/drive/items/{item.get('id')}",
|
||||
"size": int(item.get("size", 0)),
|
||||
"modified": item.get("lastModifiedDateTime"),
|
||||
"created": item.get("createdDateTime"),
|
||||
"mime_type": item.get("file", {}).get("mimeType", self._get_mime_type(item.get("name", ""))),
|
||||
"url": item.get("webUrl", ""),
|
||||
"download_url": item.get("@microsoft.graph.downloadUrl"),
|
||||
})
|
||||
|
||||
# Next page
|
||||
next_page_token = None
|
||||
next_link = data.get("@odata.nextLink")
|
||||
if next_link:
|
||||
from urllib.parse import urlparse, parse_qs
|
||||
parsed = urlparse(next_link)
|
||||
query_params = parse_qs(parsed.query)
|
||||
if "$skiptoken" in query_params:
|
||||
next_page_token = query_params["$skiptoken"][0]
|
||||
|
||||
return {"files": files, "next_page_token": next_page_token}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list OneDrive files: {e}")
|
||||
return {"files": [], "next_page_token": None}
|
||||
|
||||
async def get_file_content(self, file_id: str) -> ConnectorDocument:
|
||||
if not self._authenticated:
|
||||
raise ValueError("Not authenticated")
|
||||
"""Get file content and metadata."""
|
||||
try:
|
||||
if not await self.authenticate():
|
||||
raise RuntimeError("OneDrive authentication failed during file content retrieval")
|
||||
|
||||
token = self.oauth.get_access_token()
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
async with httpx.AsyncClient() as client:
|
||||
meta_resp = await client.get(
|
||||
f"{self.base_url}/me/drive/items/{file_id}", headers=headers
|
||||
)
|
||||
meta_resp.raise_for_status()
|
||||
metadata = meta_resp.json()
|
||||
file_metadata = await self._get_file_metadata_by_id(file_id)
|
||||
if not file_metadata:
|
||||
raise ValueError(f"File not found: {file_id}")
|
||||
|
||||
content_resp = await client.get(
|
||||
f"{self.base_url}/me/drive/items/{file_id}/content", headers=headers
|
||||
)
|
||||
content = content_resp.content
|
||||
|
||||
# Handle the possibility of this being a redirect
|
||||
if content_resp.status_code in (301, 302, 303, 307, 308):
|
||||
redirect_url = content_resp.headers.get("Location")
|
||||
if redirect_url:
|
||||
content_resp = await client.get(redirect_url)
|
||||
content_resp.raise_for_status()
|
||||
content = content_resp.content
|
||||
download_url = file_metadata.get("download_url")
|
||||
if download_url:
|
||||
content = await self._download_file_from_url(download_url)
|
||||
else:
|
||||
content_resp.raise_for_status()
|
||||
content = await self._download_file_content(file_id)
|
||||
|
||||
perm_resp = await client.get(
|
||||
f"{self.base_url}/me/drive/items/{file_id}/permissions", headers=headers
|
||||
acl = DocumentACL(
|
||||
owner="",
|
||||
user_permissions={},
|
||||
group_permissions={},
|
||||
)
|
||||
perm_resp.raise_for_status()
|
||||
permissions = perm_resp.json()
|
||||
|
||||
acl = self._parse_permissions(metadata, permissions)
|
||||
modified = datetime.fromisoformat(
|
||||
metadata["lastModifiedDateTime"].replace("Z", "+00:00")
|
||||
).replace(tzinfo=None)
|
||||
created = datetime.fromisoformat(
|
||||
metadata["createdDateTime"].replace("Z", "+00:00")
|
||||
).replace(tzinfo=None)
|
||||
modified_time = self._parse_graph_date(file_metadata.get("modified"))
|
||||
created_time = self._parse_graph_date(file_metadata.get("created"))
|
||||
|
||||
document = ConnectorDocument(
|
||||
id=metadata["id"],
|
||||
filename=metadata["name"],
|
||||
mimetype=metadata.get("file", {}).get(
|
||||
"mimeType", "application/octet-stream"
|
||||
),
|
||||
content=content,
|
||||
source_url=metadata.get("webUrl"),
|
||||
acl=acl,
|
||||
modified_time=modified,
|
||||
created_time=created,
|
||||
metadata={"size": metadata.get("size")},
|
||||
)
|
||||
return document
|
||||
|
||||
def _parse_permissions(
|
||||
self, metadata: Dict[str, Any], permissions: Dict[str, Any]
|
||||
) -> DocumentACL:
|
||||
acl = DocumentACL()
|
||||
owner = metadata.get("createdBy", {}).get("user", {}).get("email")
|
||||
if owner:
|
||||
acl.owner = owner
|
||||
for perm in permissions.get("value", []):
|
||||
role = perm.get("roles", ["read"])[0]
|
||||
grantee = perm.get("grantedToV2") or perm.get("grantedTo")
|
||||
if not grantee:
|
||||
continue
|
||||
user = grantee.get("user")
|
||||
if user and user.get("email"):
|
||||
acl.user_permissions[user["email"]] = role
|
||||
group = grantee.get("group")
|
||||
if group and group.get("email"):
|
||||
acl.group_permissions[group["email"]] = role
|
||||
return acl
|
||||
|
||||
def handle_webhook_validation(
|
||||
self, request_method: str, headers: Dict[str, str], query_params: Dict[str, str]
|
||||
) -> Optional[str]:
|
||||
"""Handle Microsoft Graph webhook validation"""
|
||||
if request_method == "GET":
|
||||
validation_token = query_params.get("validationtoken") or query_params.get(
|
||||
"validationToken"
|
||||
return ConnectorDocument(
|
||||
id=file_id,
|
||||
filename=file_metadata.get("name", ""),
|
||||
mimetype=file_metadata.get("mime_type", "application/octet-stream"),
|
||||
content=content,
|
||||
source_url=file_metadata.get("url", ""),
|
||||
acl=acl,
|
||||
modified_time=modified_time,
|
||||
created_time=created_time,
|
||||
metadata={
|
||||
"onedrive_path": file_metadata.get("path", ""),
|
||||
"size": file_metadata.get("size", 0),
|
||||
},
|
||||
)
|
||||
if validation_token:
|
||||
return validation_token
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get OneDrive file content {file_id}: {e}")
|
||||
raise
|
||||
|
||||
async def _get_file_metadata_by_id(self, file_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get file metadata by ID using Graph API."""
|
||||
try:
|
||||
url = f"{self._graph_base_url}/me/drive/items/{file_id}"
|
||||
params = dict(self._default_params)
|
||||
|
||||
response = await self._make_graph_request(url, params=params)
|
||||
item = response.json()
|
||||
|
||||
if item.get("file"):
|
||||
return {
|
||||
"id": file_id,
|
||||
"name": item.get("name", ""),
|
||||
"path": f"/drive/items/{file_id}",
|
||||
"size": int(item.get("size", 0)),
|
||||
"modified": item.get("lastModifiedDateTime"),
|
||||
"created": item.get("createdDateTime"),
|
||||
"mime_type": item.get("file", {}).get("mimeType", self._get_mime_type(item.get("name", ""))),
|
||||
"url": item.get("webUrl", ""),
|
||||
"download_url": item.get("@microsoft.graph.downloadUrl"),
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get file metadata for {file_id}: {e}")
|
||||
return None
|
||||
|
||||
async def _download_file_content(self, file_id: str) -> bytes:
|
||||
"""Download file content by file ID using Graph API."""
|
||||
try:
|
||||
url = f"{self._graph_base_url}/me/drive/items/{file_id}/content"
|
||||
token = self.oauth.get_access_token()
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(url, headers=headers, timeout=60)
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to download file content for {file_id}: {e}")
|
||||
raise
|
||||
|
||||
async def _download_file_from_url(self, download_url: str) -> bytes:
|
||||
"""Download file content from direct download URL."""
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(download_url, timeout=60)
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to download from URL {download_url}: {e}")
|
||||
raise
|
||||
|
||||
def _parse_graph_date(self, date_str: Optional[str]) -> datetime:
|
||||
"""Parse Microsoft Graph date string to datetime."""
|
||||
if not date_str:
|
||||
return datetime.now()
|
||||
try:
|
||||
if date_str.endswith('Z'):
|
||||
return datetime.fromisoformat(date_str[:-1]).replace(tzinfo=None)
|
||||
else:
|
||||
return datetime.fromisoformat(date_str.replace('T', ' '))
|
||||
except (ValueError, AttributeError):
|
||||
return datetime.now()
|
||||
|
||||
async def _make_graph_request(self, url: str, method: str = "GET",
|
||||
data: Optional[Dict] = None, params: Optional[Dict] = None) -> httpx.Response:
|
||||
"""Make authenticated API request to Microsoft Graph."""
|
||||
token = self.oauth.get_access_token()
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
if method.upper() == "GET":
|
||||
response = await client.get(url, headers=headers, params=params, timeout=30)
|
||||
elif method.upper() == "POST":
|
||||
response = await client.post(url, headers=headers, json=data, timeout=30)
|
||||
elif method.upper() == "DELETE":
|
||||
response = await client.delete(url, headers=headers, timeout=30)
|
||||
else:
|
||||
raise ValueError(f"Unsupported HTTP method: {method}")
|
||||
|
||||
response.raise_for_status()
|
||||
return response
|
||||
|
||||
def _get_mime_type(self, filename: str) -> str:
|
||||
"""Get MIME type based on file extension."""
|
||||
import mimetypes
|
||||
mime_type, _ = mimetypes.guess_type(filename)
|
||||
return mime_type or "application/octet-stream"
|
||||
|
||||
# Webhook methods - BaseConnector interface
|
||||
def handle_webhook_validation(self, request_method: str,
|
||||
headers: Dict[str, str],
|
||||
query_params: Dict[str, str]) -> Optional[str]:
|
||||
"""Handle webhook validation (Graph API specific)."""
|
||||
if request_method == "POST" and "validationToken" in query_params:
|
||||
return query_params["validationToken"]
|
||||
return None
|
||||
|
||||
def extract_webhook_channel_id(
|
||||
self, payload: Dict[str, Any], headers: Dict[str, str]
|
||||
) -> Optional[str]:
|
||||
"""Extract SharePoint subscription ID from webhook payload"""
|
||||
values = payload.get("value", [])
|
||||
return values[0].get("subscriptionId") if values else None
|
||||
def extract_webhook_channel_id(self, payload: Dict[str, Any],
|
||||
headers: Dict[str, str]) -> Optional[str]:
|
||||
"""Extract channel/subscription ID from webhook payload."""
|
||||
notifications = payload.get("value", [])
|
||||
if notifications:
|
||||
return notifications[0].get("subscriptionId")
|
||||
return None
|
||||
|
||||
async def handle_webhook(self, payload: Dict[str, Any]) -> List[str]:
|
||||
values = payload.get("value", [])
|
||||
file_ids = []
|
||||
for item in values:
|
||||
resource_data = item.get("resourceData", {})
|
||||
file_id = resource_data.get("id")
|
||||
if file_id:
|
||||
file_ids.append(file_id)
|
||||
return file_ids
|
||||
"""Handle webhook notification and return affected file IDs."""
|
||||
affected_files: List[str] = []
|
||||
notifications = payload.get("value", [])
|
||||
for notification in notifications:
|
||||
resource = notification.get("resource")
|
||||
if resource and "/drive/items/" in resource:
|
||||
file_id = resource.split("/drive/items/")[-1]
|
||||
affected_files.append(file_id)
|
||||
return affected_files
|
||||
|
||||
async def cleanup_subscription(
|
||||
self, subscription_id: str, resource_id: str = None
|
||||
) -> bool:
|
||||
if not self._authenticated:
|
||||
async def cleanup_subscription(self, subscription_id: str) -> bool:
|
||||
"""Clean up subscription - BaseConnector interface."""
|
||||
if subscription_id == "no-webhook-configured":
|
||||
logger.info("No subscription to cleanup (webhook was not configured)")
|
||||
return True
|
||||
|
||||
try:
|
||||
if not await self.authenticate():
|
||||
logger.error("OneDrive authentication failed during subscription cleanup")
|
||||
return False
|
||||
|
||||
token = self.oauth.get_access_token()
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
|
||||
url = f"{self._graph_base_url}/subscriptions/{subscription_id}"
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.delete(url, headers=headers, timeout=30)
|
||||
|
||||
if response.status_code in [200, 204, 404]:
|
||||
logger.info(f"OneDrive subscription {subscription_id} cleaned up successfully")
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"Unexpected response cleaning up subscription: {response.status_code}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to cleanup OneDrive subscription {subscription_id}: {e}")
|
||||
return False
|
||||
token = self.oauth.get_access_token()
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.delete(
|
||||
f"{self.base_url}/subscriptions/{subscription_id}",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
return resp.status_code in (200, 204)
|
||||
|
|
|
|||
|
|
@ -1,18 +1,28 @@
|
|||
import os
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
import aiofiles
|
||||
from datetime import datetime
|
||||
import httpx
|
||||
import msal
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OneDriveOAuth:
|
||||
"""Direct token management for OneDrive, bypassing MSAL cache format"""
|
||||
"""Handles Microsoft Graph OAuth for OneDrive (personal Microsoft accounts by default)."""
|
||||
|
||||
SCOPES = [
|
||||
"offline_access",
|
||||
"Files.Read.All",
|
||||
]
|
||||
# Reserved scopes that must NOT be sent on token or silent calls
|
||||
RESERVED_SCOPES = {"openid", "profile", "offline_access"}
|
||||
|
||||
# For PERSONAL Microsoft Accounts (OneDrive consumer):
|
||||
# - Use AUTH_SCOPES for interactive auth (consent + refresh token issuance)
|
||||
# - Use RESOURCE_SCOPES for acquire_token_silent / refresh paths
|
||||
AUTH_SCOPES = ["User.Read", "Files.Read.All", "offline_access"]
|
||||
RESOURCE_SCOPES = ["User.Read", "Files.Read.All"]
|
||||
SCOPES = AUTH_SCOPES # Backward-compat alias if something references .SCOPES
|
||||
|
||||
# Kept for reference; MSAL derives endpoints from `authority`
|
||||
AUTH_ENDPOINT = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize"
|
||||
TOKEN_ENDPOINT = "https://login.microsoftonline.com/common/oauth2/v2.0/token"
|
||||
|
||||
|
|
@ -21,168 +31,292 @@ class OneDriveOAuth:
|
|||
client_id: str,
|
||||
client_secret: str,
|
||||
token_file: str = "onedrive_token.json",
|
||||
authority: str = "https://login.microsoftonline.com/common",
|
||||
allow_json_refresh: bool = True,
|
||||
):
|
||||
"""
|
||||
Initialize OneDriveOAuth.
|
||||
|
||||
Args:
|
||||
client_id: Azure AD application (client) ID.
|
||||
client_secret: Azure AD application client secret.
|
||||
token_file: Path to persisted token cache file (MSAL cache format).
|
||||
authority: Usually "https://login.microsoftonline.com/common" for MSA + org,
|
||||
or tenant-specific for work/school.
|
||||
allow_json_refresh: If True, permit one-time migration from legacy flat JSON
|
||||
{"access_token","refresh_token",...}. Otherwise refuse it.
|
||||
"""
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
self.token_file = token_file
|
||||
self._tokens = None
|
||||
self._load_tokens()
|
||||
self.authority = authority
|
||||
self.allow_json_refresh = allow_json_refresh
|
||||
self.token_cache = msal.SerializableTokenCache()
|
||||
self._current_account = None
|
||||
|
||||
def _load_tokens(self):
|
||||
"""Load tokens from file"""
|
||||
if os.path.exists(self.token_file):
|
||||
with open(self.token_file, "r") as f:
|
||||
self._tokens = json.loads(f.read())
|
||||
print(f"Loaded tokens from {self.token_file}")
|
||||
else:
|
||||
print(f"No token file found at {self.token_file}")
|
||||
# Initialize MSAL Confidential Client
|
||||
self.app = msal.ConfidentialClientApplication(
|
||||
client_id=self.client_id,
|
||||
client_credential=self.client_secret,
|
||||
authority=self.authority,
|
||||
token_cache=self.token_cache,
|
||||
)
|
||||
|
||||
async def _save_tokens(self):
|
||||
"""Save tokens to file"""
|
||||
if self._tokens:
|
||||
async with aiofiles.open(self.token_file, "w") as f:
|
||||
await f.write(json.dumps(self._tokens, indent=2))
|
||||
|
||||
def _is_token_expired(self) -> bool:
|
||||
"""Check if current access token is expired"""
|
||||
if not self._tokens or 'expiry' not in self._tokens:
|
||||
return True
|
||||
|
||||
expiry_str = self._tokens['expiry']
|
||||
# Handle different expiry formats
|
||||
async def load_credentials(self) -> bool:
|
||||
"""Load existing credentials from token file (async)."""
|
||||
try:
|
||||
if expiry_str.endswith('Z'):
|
||||
expiry_dt = datetime.fromisoformat(expiry_str[:-1])
|
||||
else:
|
||||
expiry_dt = datetime.fromisoformat(expiry_str)
|
||||
|
||||
# Add 5-minute buffer
|
||||
import datetime as dt
|
||||
now = datetime.now()
|
||||
return now >= (expiry_dt - dt.timedelta(minutes=5))
|
||||
except:
|
||||
return True
|
||||
logger.debug(f"OneDrive OAuth loading credentials from: {self.token_file}")
|
||||
if os.path.exists(self.token_file):
|
||||
logger.debug(f"Token file exists, reading: {self.token_file}")
|
||||
|
||||
# Read the token file
|
||||
async with aiofiles.open(self.token_file, "r") as f:
|
||||
cache_data = await f.read()
|
||||
logger.debug(f"Read {len(cache_data)} chars from token file")
|
||||
|
||||
if cache_data.strip():
|
||||
# 1) Try legacy flat JSON first
|
||||
try:
|
||||
json_data = json.loads(cache_data)
|
||||
if isinstance(json_data, dict) and "refresh_token" in json_data:
|
||||
if self.allow_json_refresh:
|
||||
logger.debug(
|
||||
"Found legacy JSON refresh_token and allow_json_refresh=True; attempting migration refresh"
|
||||
)
|
||||
return await self._refresh_from_json_token(json_data)
|
||||
else:
|
||||
logger.warning(
|
||||
"Token file contains a legacy JSON refresh_token, but allow_json_refresh=False. "
|
||||
"Delete the file and re-auth."
|
||||
)
|
||||
return False
|
||||
except json.JSONDecodeError:
|
||||
logger.debug("Token file is not flat JSON; attempting MSAL cache format")
|
||||
|
||||
# 2) Try MSAL cache format
|
||||
logger.debug("Attempting MSAL cache deserialization")
|
||||
self.token_cache.deserialize(cache_data)
|
||||
|
||||
# Get accounts from loaded cache
|
||||
accounts = self.app.get_accounts()
|
||||
logger.debug(f"Found {len(accounts)} accounts in MSAL cache")
|
||||
if accounts:
|
||||
self._current_account = accounts[0]
|
||||
logger.debug(f"Set current account: {self._current_account.get('username', 'no username')}")
|
||||
|
||||
# Use RESOURCE_SCOPES (no reserved scopes) for silent acquisition
|
||||
result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=self._current_account)
|
||||
logger.debug(f"Silent token acquisition result keys: {list(result.keys()) if result else 'None'}")
|
||||
if result and "access_token" in result:
|
||||
logger.debug("Silent token acquisition successful")
|
||||
await self.save_cache()
|
||||
return True
|
||||
else:
|
||||
error_msg = (result or {}).get("error") or "No result"
|
||||
logger.warning(f"Silent token acquisition failed: {error_msg}")
|
||||
else:
|
||||
logger.debug(f"Token file {self.token_file} is empty")
|
||||
else:
|
||||
logger.debug(f"Token file does not exist: {self.token_file}")
|
||||
|
||||
async def _refresh_access_token(self) -> bool:
|
||||
"""Refresh the access token using refresh token"""
|
||||
if not self._tokens or 'refresh_token' not in self._tokens:
|
||||
return False
|
||||
|
||||
data = {
|
||||
'client_id': self.client_id,
|
||||
'client_secret': self.client_secret,
|
||||
'refresh_token': self._tokens['refresh_token'],
|
||||
'grant_type': 'refresh_token',
|
||||
'scope': ' '.join(self.SCOPES)
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
try:
|
||||
response = await client.post(self.TOKEN_ENDPOINT, data=data)
|
||||
response.raise_for_status()
|
||||
token_data = response.json()
|
||||
|
||||
# Update tokens
|
||||
self._tokens['token'] = token_data['access_token']
|
||||
if 'refresh_token' in token_data:
|
||||
self._tokens['refresh_token'] = token_data['refresh_token']
|
||||
|
||||
# Calculate expiry
|
||||
expires_in = token_data.get('expires_in', 3600)
|
||||
import datetime as dt
|
||||
expiry = datetime.now() + dt.timedelta(seconds=expires_in)
|
||||
self._tokens['expiry'] = expiry.isoformat()
|
||||
|
||||
await self._save_tokens()
|
||||
print("Access token refreshed successfully")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to refresh token: {e}")
|
||||
return False
|
||||
|
||||
async def is_authenticated(self) -> bool:
|
||||
"""Check if we have valid credentials"""
|
||||
if not self._tokens:
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load OneDrive credentials: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
# If token is expired, try to refresh
|
||||
if self._is_token_expired():
|
||||
print("Token expired, attempting refresh...")
|
||||
if await self._refresh_access_token():
|
||||
return True
|
||||
else:
|
||||
async def _refresh_from_json_token(self, token_data: dict) -> bool:
|
||||
"""
|
||||
Use refresh token from a legacy JSON file to get new tokens (one-time migration path).
|
||||
Prefer using an MSAL cache file and acquire_token_silent(); this path is only for migrating older files.
|
||||
"""
|
||||
try:
|
||||
refresh_token = token_data.get("refresh_token")
|
||||
if not refresh_token:
|
||||
logger.error("No refresh_token found in JSON file - cannot refresh")
|
||||
logger.error("You must re-authenticate interactively to obtain a valid token")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def get_access_token(self) -> str:
|
||||
"""Get current access token"""
|
||||
if not self._tokens or 'token' not in self._tokens:
|
||||
raise ValueError("No access token available")
|
||||
|
||||
if self._is_token_expired():
|
||||
raise ValueError("Access token expired and refresh failed")
|
||||
|
||||
return self._tokens['token']
|
||||
# Use only RESOURCE_SCOPES when refreshing (no reserved scopes)
|
||||
refresh_scopes = [s for s in self.RESOURCE_SCOPES if s not in self.RESERVED_SCOPES]
|
||||
logger.debug(f"Using refresh token; refresh scopes = {refresh_scopes}")
|
||||
|
||||
async def revoke_credentials(self):
|
||||
"""Clear tokens"""
|
||||
self._tokens = None
|
||||
if os.path.exists(self.token_file):
|
||||
os.remove(self.token_file)
|
||||
result = self.app.acquire_token_by_refresh_token(
|
||||
refresh_token=refresh_token,
|
||||
scopes=refresh_scopes,
|
||||
)
|
||||
|
||||
# Keep these methods for compatibility with your existing OAuth flow
|
||||
def create_authorization_url(self, redirect_uri: str) -> str:
|
||||
"""Create authorization URL for OAuth flow"""
|
||||
from urllib.parse import urlencode
|
||||
|
||||
params = {
|
||||
'client_id': self.client_id,
|
||||
'response_type': 'code',
|
||||
'redirect_uri': redirect_uri,
|
||||
'scope': ' '.join(self.SCOPES),
|
||||
'response_mode': 'query'
|
||||
if result and "access_token" in result:
|
||||
logger.debug("Successfully refreshed token via legacy JSON path")
|
||||
await self.save_cache()
|
||||
|
||||
accounts = self.app.get_accounts()
|
||||
logger.debug(f"After refresh, found {len(accounts)} accounts")
|
||||
if accounts:
|
||||
self._current_account = accounts[0]
|
||||
logger.debug(f"Set current account after refresh: {self._current_account.get('username', 'no username')}")
|
||||
return True
|
||||
|
||||
# Error handling
|
||||
err = (result or {}).get("error_description") or (result or {}).get("error") or "Unknown error"
|
||||
logger.error(f"Refresh token failed: {err}")
|
||||
|
||||
if any(code in err for code in ("AADSTS70000", "invalid_grant", "interaction_required")):
|
||||
logger.warning(
|
||||
"Refresh denied due to unauthorized/expired scopes or invalid grant. "
|
||||
"Delete the token file and perform interactive sign-in with correct scopes."
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Exception during refresh from JSON token: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
async def save_cache(self):
|
||||
"""Persist the token cache to file."""
|
||||
try:
|
||||
# Ensure parent directory exists
|
||||
parent = os.path.dirname(os.path.abspath(self.token_file))
|
||||
if parent and not os.path.exists(parent):
|
||||
os.makedirs(parent, exist_ok=True)
|
||||
|
||||
cache_data = self.token_cache.serialize()
|
||||
if cache_data:
|
||||
async with aiofiles.open(self.token_file, "w") as f:
|
||||
await f.write(cache_data)
|
||||
logger.debug(f"Token cache saved to {self.token_file}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save token cache: {e}")
|
||||
|
||||
def create_authorization_url(self, redirect_uri: str, state: Optional[str] = None) -> str:
|
||||
"""Create authorization URL for OAuth flow."""
|
||||
# Store redirect URI for later use in callback
|
||||
self._redirect_uri = redirect_uri
|
||||
|
||||
kwargs: Dict[str, Any] = {
|
||||
# Interactive auth includes offline_access
|
||||
"scopes": self.AUTH_SCOPES,
|
||||
"redirect_uri": redirect_uri,
|
||||
"prompt": "consent", # ensure refresh token on first run
|
||||
}
|
||||
|
||||
auth_url = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize"
|
||||
return f"{auth_url}?{urlencode(params)}"
|
||||
if state:
|
||||
kwargs["state"] = state # Optional CSRF protection
|
||||
|
||||
auth_url = self.app.get_authorization_request_url(**kwargs)
|
||||
|
||||
logger.debug(f"Generated auth URL: {auth_url}")
|
||||
logger.debug(f"Auth scopes: {self.AUTH_SCOPES}")
|
||||
|
||||
return auth_url
|
||||
|
||||
async def handle_authorization_callback(
|
||||
self, authorization_code: str, redirect_uri: str
|
||||
) -> bool:
|
||||
"""Handle OAuth callback and exchange code for tokens"""
|
||||
data = {
|
||||
'client_id': self.client_id,
|
||||
'client_secret': self.client_secret,
|
||||
'code': authorization_code,
|
||||
'grant_type': 'authorization_code',
|
||||
'redirect_uri': redirect_uri,
|
||||
'scope': ' '.join(self.SCOPES)
|
||||
}
|
||||
"""Handle OAuth callback and exchange code for tokens."""
|
||||
try:
|
||||
result = self.app.acquire_token_by_authorization_code(
|
||||
authorization_code,
|
||||
scopes=self.AUTH_SCOPES, # same as authorize step
|
||||
redirect_uri=redirect_uri,
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
try:
|
||||
response = await client.post(self.TOKEN_ENDPOINT, data=data)
|
||||
response.raise_for_status()
|
||||
token_data = response.json()
|
||||
if result and "access_token" in result:
|
||||
accounts = self.app.get_accounts()
|
||||
if accounts:
|
||||
self._current_account = accounts[0]
|
||||
|
||||
# Store tokens in our format
|
||||
import datetime as dt
|
||||
expires_in = token_data.get('expires_in', 3600)
|
||||
expiry = datetime.now() + dt.timedelta(seconds=expires_in)
|
||||
|
||||
self._tokens = {
|
||||
'token': token_data['access_token'],
|
||||
'refresh_token': token_data['refresh_token'],
|
||||
'scopes': self.SCOPES,
|
||||
'expiry': expiry.isoformat()
|
||||
}
|
||||
|
||||
await self._save_tokens()
|
||||
print("Authorization successful, tokens saved")
|
||||
await self.save_cache()
|
||||
logger.info("OneDrive OAuth authorization successful")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"Authorization failed: {e}")
|
||||
return False
|
||||
error_msg = (result or {}).get("error_description") or (result or {}).get("error") or "Unknown error"
|
||||
logger.error(f"OneDrive OAuth authorization failed: {error_msg}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Exception during OneDrive OAuth authorization: {e}")
|
||||
return False
|
||||
|
||||
async def is_authenticated(self) -> bool:
|
||||
"""Check if we have valid credentials."""
|
||||
try:
|
||||
# First try to load credentials if we haven't already
|
||||
if not self._current_account:
|
||||
await self.load_credentials()
|
||||
|
||||
# Try to get a token (MSAL will refresh if needed)
|
||||
if self._current_account:
|
||||
result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=self._current_account)
|
||||
if result and "access_token" in result:
|
||||
return True
|
||||
else:
|
||||
error_msg = (result or {}).get("error") or "No result returned"
|
||||
logger.debug(f"Token acquisition failed for current account: {error_msg}")
|
||||
|
||||
# Fallback: try without specific account
|
||||
result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=None)
|
||||
if result and "access_token" in result:
|
||||
accounts = self.app.get_accounts()
|
||||
if accounts:
|
||||
self._current_account = accounts[0]
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Authentication check failed: {e}")
|
||||
return False
|
||||
|
||||
def get_access_token(self) -> str:
|
||||
"""Get an access token for Microsoft Graph."""
|
||||
try:
|
||||
# Try with current account first
|
||||
if self._current_account:
|
||||
result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=self._current_account)
|
||||
if result and "access_token" in result:
|
||||
return result["access_token"]
|
||||
|
||||
# Fallback: try without specific account
|
||||
result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=None)
|
||||
if result and "access_token" in result:
|
||||
return result["access_token"]
|
||||
|
||||
# If we get here, authentication has failed
|
||||
error_msg = (result or {}).get("error_description") or (result or {}).get("error") or "No valid authentication"
|
||||
raise ValueError(f"Failed to acquire access token: {error_msg}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get access token: {e}")
|
||||
raise
|
||||
|
||||
async def revoke_credentials(self):
|
||||
"""Clear token cache and remove token file."""
|
||||
try:
|
||||
# Clear in-memory state
|
||||
self._current_account = None
|
||||
self.token_cache = msal.SerializableTokenCache()
|
||||
|
||||
# Recreate MSAL app with fresh cache
|
||||
self.app = msal.ConfidentialClientApplication(
|
||||
client_id=self.client_id,
|
||||
client_credential=self.client_secret,
|
||||
authority=self.authority,
|
||||
token_cache=self.token_cache,
|
||||
)
|
||||
|
||||
# Remove token file
|
||||
if os.path.exists(self.token_file):
|
||||
os.remove(self.token_file)
|
||||
logger.info(f"Removed OneDrive token file: {self.token_file}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to revoke OneDrive credentials: {e}")
|
||||
|
||||
def get_service(self) -> str:
|
||||
"""Return an access token (Graph client is just the bearer)."""
|
||||
return self.get_access_token()
|
||||
|
|
|
|||
|
|
@ -1,241 +1,564 @@
|
|||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Optional
|
||||
from urllib.parse import urlparse
|
||||
from datetime import datetime
|
||||
import httpx
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Any, Optional
|
||||
|
||||
from ..base import BaseConnector, ConnectorDocument, DocumentACL
|
||||
from .oauth import SharePointOAuth
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SharePointConnector(BaseConnector):
|
||||
"""SharePoint Sites connector using Microsoft Graph API"""
|
||||
"""SharePoint connector using MSAL-based OAuth for authentication"""
|
||||
|
||||
# Required BaseConnector class attributes
|
||||
CLIENT_ID_ENV_VAR = "MICROSOFT_GRAPH_OAUTH_CLIENT_ID"
|
||||
CLIENT_SECRET_ENV_VAR = "MICROSOFT_GRAPH_OAUTH_CLIENT_SECRET"
|
||||
|
||||
|
||||
# Connector metadata
|
||||
CONNECTOR_NAME = "SharePoint"
|
||||
CONNECTOR_DESCRIPTION = "Connect to SharePoint sites to sync team documents"
|
||||
CONNECTOR_DESCRIPTION = "Connect to SharePoint to sync documents and files"
|
||||
CONNECTOR_ICON = "sharepoint"
|
||||
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
super().__init__(config)
|
||||
super().__init__(config) # Fix: Call parent init first
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
logger.debug(f"SharePoint connector __init__ called with config type: {type(config)}")
|
||||
logger.debug(f"SharePoint connector __init__ config value: {config}")
|
||||
|
||||
# Ensure we always pass a valid config to the base class
|
||||
if config is None:
|
||||
logger.debug("Config was None, using empty dict")
|
||||
config = {}
|
||||
|
||||
try:
|
||||
logger.debug("Calling super().__init__")
|
||||
super().__init__(config) # Now safe to call with empty dict instead of None
|
||||
logger.debug("super().__init__ completed successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"super().__init__ failed: {e}")
|
||||
raise
|
||||
|
||||
# Initialize with defaults that allow the connector to be listed
|
||||
self.client_id = None
|
||||
self.client_secret = None
|
||||
self.tenant_id = config.get("tenant_id", "common")
|
||||
self.sharepoint_url = config.get("sharepoint_url")
|
||||
self.redirect_uri = config.get("redirect_uri", "http://localhost")
|
||||
|
||||
# Try to get credentials, but don't fail if they're missing
|
||||
try:
|
||||
logger.debug("Attempting to get client_id")
|
||||
self.client_id = self.get_client_id()
|
||||
logger.debug(f"Got client_id: {self.client_id is not None}")
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to get client_id: {e}")
|
||||
pass # Credentials not available, that's OK for listing
|
||||
|
||||
try:
|
||||
logger.debug("Attempting to get client_secret")
|
||||
self.client_secret = self.get_client_secret()
|
||||
logger.debug(f"Got client_secret: {self.client_secret is not None}")
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to get client_secret: {e}")
|
||||
pass # Credentials not available, that's OK for listing
|
||||
|
||||
# Token file setup
|
||||
project_root = Path(__file__).resolve().parent.parent.parent.parent
|
||||
token_file = config.get("token_file") or str(project_root / "onedrive_token.json")
|
||||
self.oauth = SharePointOAuth(
|
||||
client_id=self.get_client_id(),
|
||||
client_secret=self.get_client_secret(),
|
||||
token_file=token_file,
|
||||
)
|
||||
self.subscription_id = config.get("subscription_id") or config.get(
|
||||
"webhook_channel_id"
|
||||
)
|
||||
self.base_url = "https://graph.microsoft.com/v1.0"
|
||||
|
||||
# SharePoint site configuration
|
||||
self.site_id = config.get("site_id") # Required for SharePoint
|
||||
|
||||
async def authenticate(self) -> bool:
|
||||
if await self.oauth.is_authenticated():
|
||||
self._authenticated = True
|
||||
return True
|
||||
return False
|
||||
|
||||
async def setup_subscription(self) -> str:
|
||||
if not self._authenticated:
|
||||
raise ValueError("Not authenticated")
|
||||
|
||||
webhook_url = self.config.get("webhook_url")
|
||||
if not webhook_url:
|
||||
raise ValueError("webhook_url required in config for subscriptions")
|
||||
|
||||
expiration = (datetime.utcnow() + timedelta(days=2)).isoformat() + "Z"
|
||||
body = {
|
||||
"changeType": "created,updated,deleted",
|
||||
"notificationUrl": webhook_url,
|
||||
"resource": f"/sites/{self.site_id}/drive/root",
|
||||
"expirationDateTime": expiration,
|
||||
"clientState": str(uuid.uuid4()),
|
||||
}
|
||||
|
||||
token = self.oauth.get_access_token()
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.post(
|
||||
f"{self.base_url}/subscriptions",
|
||||
json=body,
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
self.subscription_id = data["id"]
|
||||
return self.subscription_id
|
||||
|
||||
async def list_files(
|
||||
self, page_token: Optional[str] = None, limit: int = 100
|
||||
) -> Dict[str, Any]:
|
||||
if not self._authenticated:
|
||||
raise ValueError("Not authenticated")
|
||||
|
||||
params = {"$top": str(limit)}
|
||||
if page_token:
|
||||
params["$skiptoken"] = page_token
|
||||
|
||||
token = self.oauth.get_access_token()
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.get(
|
||||
f"{self.base_url}/sites/{self.site_id}/drive/root/children",
|
||||
params=params,
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
files = []
|
||||
for item in data.get("value", []):
|
||||
if item.get("file"):
|
||||
files.append(
|
||||
{
|
||||
"id": item["id"],
|
||||
"name": item["name"],
|
||||
"mimeType": item.get("file", {}).get(
|
||||
"mimeType", "application/octet-stream"
|
||||
),
|
||||
"webViewLink": item.get("webUrl"),
|
||||
"createdTime": item.get("createdDateTime"),
|
||||
"modifiedTime": item.get("lastModifiedDateTime"),
|
||||
}
|
||||
)
|
||||
|
||||
next_token = None
|
||||
next_link = data.get("@odata.nextLink")
|
||||
if next_link:
|
||||
from urllib.parse import urlparse, parse_qs
|
||||
|
||||
parsed = urlparse(next_link)
|
||||
next_token = parse_qs(parsed.query).get("$skiptoken", [None])[0]
|
||||
|
||||
return {"files": files, "nextPageToken": next_token}
|
||||
|
||||
async def get_file_content(self, file_id: str) -> ConnectorDocument:
|
||||
if not self._authenticated:
|
||||
raise ValueError("Not authenticated")
|
||||
|
||||
token = self.oauth.get_access_token()
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
async with httpx.AsyncClient() as client:
|
||||
meta_resp = await client.get(
|
||||
f"{self.base_url}/sites/{self.site_id}/drive/items/{file_id}",
|
||||
headers=headers,
|
||||
)
|
||||
meta_resp.raise_for_status()
|
||||
metadata = meta_resp.json()
|
||||
|
||||
content_resp = await client.get(
|
||||
f"{self.base_url}/sites/{self.site_id}/drive/items/{file_id}/content",
|
||||
headers=headers,
|
||||
)
|
||||
content = content_resp.content
|
||||
|
||||
# Handle the possibility of this being a redirect
|
||||
if content_resp.status_code in (301, 302, 303, 307, 308):
|
||||
redirect_url = content_resp.headers.get("Location")
|
||||
if redirect_url:
|
||||
content_resp = await client.get(redirect_url)
|
||||
content_resp.raise_for_status()
|
||||
content = content_resp.content
|
||||
token_file = config.get("token_file") or str(project_root / "sharepoint_token.json")
|
||||
Path(token_file).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Only initialize OAuth if we have credentials
|
||||
if self.client_id and self.client_secret:
|
||||
connection_id = config.get("connection_id", "default")
|
||||
|
||||
# Use token_file from config if provided, otherwise generate one
|
||||
if config.get("token_file"):
|
||||
oauth_token_file = config["token_file"]
|
||||
else:
|
||||
content_resp.raise_for_status()
|
||||
|
||||
perm_resp = await client.get(
|
||||
f"{self.base_url}/sites/{self.site_id}/drive/items/{file_id}/permissions",
|
||||
headers=headers,
|
||||
oauth_token_file = f"sharepoint_token_{connection_id}.json"
|
||||
|
||||
authority = f"https://login.microsoftonline.com/{self.tenant_id}" if self.tenant_id != "common" else "https://login.microsoftonline.com/common"
|
||||
|
||||
self.oauth = SharePointOAuth(
|
||||
client_id=self.client_id,
|
||||
client_secret=self.client_secret,
|
||||
token_file=oauth_token_file,
|
||||
authority=authority
|
||||
)
|
||||
perm_resp.raise_for_status()
|
||||
permissions = perm_resp.json()
|
||||
|
||||
acl = self._parse_permissions(metadata, permissions)
|
||||
modified = datetime.fromisoformat(
|
||||
metadata["lastModifiedDateTime"].replace("Z", "+00:00")
|
||||
).replace(tzinfo=None)
|
||||
created = datetime.fromisoformat(
|
||||
metadata["createdDateTime"].replace("Z", "+00:00")
|
||||
).replace(tzinfo=None)
|
||||
|
||||
document = ConnectorDocument(
|
||||
id=metadata["id"],
|
||||
filename=metadata["name"],
|
||||
mimetype=metadata.get("file", {}).get(
|
||||
"mimeType", "application/octet-stream"
|
||||
),
|
||||
content=content,
|
||||
source_url=metadata.get("webUrl"),
|
||||
acl=acl,
|
||||
modified_time=modified,
|
||||
created_time=created,
|
||||
metadata={"size": metadata.get("size")},
|
||||
)
|
||||
return document
|
||||
|
||||
def _parse_permissions(
|
||||
self, metadata: Dict[str, Any], permissions: Dict[str, Any]
|
||||
) -> DocumentACL:
|
||||
acl = DocumentACL()
|
||||
owner = metadata.get("createdBy", {}).get("user", {}).get("email")
|
||||
if owner:
|
||||
acl.owner = owner
|
||||
for perm in permissions.get("value", []):
|
||||
role = perm.get("roles", ["read"])[0]
|
||||
grantee = perm.get("grantedToV2") or perm.get("grantedTo")
|
||||
if not grantee:
|
||||
continue
|
||||
user = grantee.get("user")
|
||||
if user and user.get("email"):
|
||||
acl.user_permissions[user["email"]] = role
|
||||
group = grantee.get("group")
|
||||
if group and group.get("email"):
|
||||
acl.group_permissions[group["email"]] = role
|
||||
return acl
|
||||
|
||||
def handle_webhook_validation(
|
||||
self, request_method: str, headers: Dict[str, str], query_params: Dict[str, str]
|
||||
) -> Optional[str]:
|
||||
"""Handle Microsoft Graph webhook validation"""
|
||||
if request_method == "GET":
|
||||
validation_token = query_params.get("validationtoken") or query_params.get(
|
||||
"validationToken"
|
||||
)
|
||||
if validation_token:
|
||||
return validation_token
|
||||
return None
|
||||
|
||||
def extract_webhook_channel_id(
|
||||
self, payload: Dict[str, Any], headers: Dict[str, str]
|
||||
) -> Optional[str]:
|
||||
"""Extract SharePoint subscription ID from webhook payload"""
|
||||
values = payload.get("value", [])
|
||||
return values[0].get("subscriptionId") if values else None
|
||||
|
||||
async def handle_webhook(self, payload: Dict[str, Any]) -> List[str]:
|
||||
values = payload.get("value", [])
|
||||
file_ids = []
|
||||
for item in values:
|
||||
resource_data = item.get("resourceData", {})
|
||||
file_id = resource_data.get("id")
|
||||
if file_id:
|
||||
file_ids.append(file_id)
|
||||
return file_ids
|
||||
|
||||
async def cleanup_subscription(
|
||||
self, subscription_id: str, resource_id: str = None
|
||||
) -> bool:
|
||||
if not self._authenticated:
|
||||
else:
|
||||
self.oauth = None
|
||||
|
||||
# Track subscription ID for webhooks
|
||||
self._subscription_id: Optional[str] = None
|
||||
|
||||
# Add Graph API defaults similar to Google Drive flags
|
||||
self._graph_api_version = "v1.0"
|
||||
self._default_params = {
|
||||
"$select": "id,name,size,lastModifiedDateTime,createdDateTime,webUrl,file,folder,@microsoft.graph.downloadUrl"
|
||||
}
|
||||
|
||||
@property
|
||||
def _graph_base_url(self) -> str:
|
||||
"""Base URL for Microsoft Graph API calls"""
|
||||
return f"https://graph.microsoft.com/{self._graph_api_version}"
|
||||
|
||||
def emit(self, doc: ConnectorDocument) -> None:
|
||||
"""
|
||||
Emit a ConnectorDocument instance.
|
||||
Override this method to integrate with your ingestion pipeline.
|
||||
"""
|
||||
logger.debug(f"Emitting SharePoint document: {doc.id} ({doc.filename})")
|
||||
|
||||
async def authenticate(self) -> bool:
|
||||
"""Test authentication - BaseConnector interface"""
|
||||
logger.debug(f"SharePoint authenticate() called, oauth is None: {self.oauth is None}")
|
||||
try:
|
||||
if not self.oauth:
|
||||
logger.debug("SharePoint authentication failed: OAuth not initialized")
|
||||
self._authenticated = False
|
||||
return False
|
||||
|
||||
logger.debug("Loading SharePoint credentials...")
|
||||
# Try to load existing credentials first
|
||||
load_result = await self.oauth.load_credentials()
|
||||
logger.debug(f"Load credentials result: {load_result}")
|
||||
|
||||
logger.debug("Checking SharePoint authentication status...")
|
||||
authenticated = await self.oauth.is_authenticated()
|
||||
logger.debug(f"SharePoint is_authenticated result: {authenticated}")
|
||||
|
||||
self._authenticated = authenticated
|
||||
return authenticated
|
||||
except Exception as e:
|
||||
logger.error(f"SharePoint authentication failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
self._authenticated = False
|
||||
return False
|
||||
token = self.oauth.get_access_token()
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.delete(
|
||||
f"{self.base_url}/subscriptions/{subscription_id}",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
|
||||
def get_auth_url(self) -> str:
|
||||
"""Get OAuth authorization URL"""
|
||||
if not self.oauth:
|
||||
raise RuntimeError("SharePoint OAuth not initialized - missing credentials")
|
||||
return self.oauth.create_authorization_url(self.redirect_uri)
|
||||
|
||||
async def handle_oauth_callback(self, auth_code: str) -> Dict[str, Any]:
|
||||
"""Handle OAuth callback"""
|
||||
if not self.oauth:
|
||||
raise RuntimeError("SharePoint OAuth not initialized - missing credentials")
|
||||
try:
|
||||
success = await self.oauth.handle_authorization_callback(auth_code, self.redirect_uri)
|
||||
if success:
|
||||
self._authenticated = True
|
||||
return {"status": "success"}
|
||||
else:
|
||||
raise ValueError("OAuth callback failed")
|
||||
except Exception as e:
|
||||
logger.error(f"OAuth callback failed: {e}")
|
||||
raise
|
||||
|
||||
def sync_once(self) -> None:
|
||||
"""
|
||||
Perform a one-shot sync of SharePoint files and emit documents.
|
||||
This method mirrors the Google Drive connector's sync_once functionality.
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
async def _async_sync():
|
||||
try:
|
||||
# Get list of files
|
||||
file_list = await self.list_files(max_files=1000) # Adjust as needed
|
||||
files = file_list.get("files", [])
|
||||
|
||||
for file_info in files:
|
||||
try:
|
||||
file_id = file_info.get("id")
|
||||
if not file_id:
|
||||
continue
|
||||
|
||||
# Get full document content
|
||||
doc = await self.get_file_content(file_id)
|
||||
self.emit(doc)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to sync SharePoint file {file_info.get('name', 'unknown')}: {e}")
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"SharePoint sync_once failed: {e}")
|
||||
raise
|
||||
|
||||
# Run the async sync
|
||||
if hasattr(asyncio, 'run'):
|
||||
asyncio.run(_async_sync())
|
||||
else:
|
||||
# Python < 3.7 compatibility
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(_async_sync())
|
||||
|
||||
async def setup_subscription(self) -> str:
|
||||
"""Set up real-time subscription for file changes - BaseConnector interface"""
|
||||
webhook_url = self.config.get('webhook_url')
|
||||
if not webhook_url:
|
||||
logger.warning("No webhook URL configured, skipping SharePoint subscription setup")
|
||||
return "no-webhook-configured"
|
||||
|
||||
try:
|
||||
# Ensure we're authenticated
|
||||
if not await self.authenticate():
|
||||
raise RuntimeError("SharePoint authentication failed during subscription setup")
|
||||
|
||||
token = self.oauth.get_access_token()
|
||||
|
||||
# Microsoft Graph subscription for SharePoint site
|
||||
site_info = self._parse_sharepoint_url()
|
||||
if site_info:
|
||||
resource = f"sites/{site_info['host_name']}:/sites/{site_info['site_name']}:/drive/root"
|
||||
else:
|
||||
resource = "/me/drive/root"
|
||||
|
||||
subscription_data = {
|
||||
"changeType": "created,updated,deleted",
|
||||
"notificationUrl": f"{webhook_url}/webhook/sharepoint",
|
||||
"resource": resource,
|
||||
"expirationDateTime": self._get_subscription_expiry(),
|
||||
"clientState": f"sharepoint_{self.tenant_id}"
|
||||
}
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
url = f"{self._graph_base_url}/subscriptions"
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(url, json=subscription_data, headers=headers, timeout=30)
|
||||
response.raise_for_status()
|
||||
|
||||
result = response.json()
|
||||
subscription_id = result.get("id")
|
||||
|
||||
if subscription_id:
|
||||
self._subscription_id = subscription_id
|
||||
logger.info(f"SharePoint subscription created: {subscription_id}")
|
||||
return subscription_id
|
||||
else:
|
||||
raise ValueError("No subscription ID returned from Microsoft Graph")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to setup SharePoint subscription: {e}")
|
||||
raise
|
||||
|
||||
def _get_subscription_expiry(self) -> str:
|
||||
"""Get subscription expiry time (max 3 days for Graph API)"""
|
||||
from datetime import datetime, timedelta
|
||||
expiry = datetime.utcnow() + timedelta(days=3) # 3 days max for Graph
|
||||
return expiry.strftime("%Y-%m-%dT%H:%M:%S.%fZ")
|
||||
|
||||
def _parse_sharepoint_url(self) -> Optional[Dict[str, str]]:
|
||||
"""Parse SharePoint URL to extract site information for Graph API"""
|
||||
if not self.sharepoint_url:
|
||||
return None
|
||||
|
||||
try:
|
||||
parsed = urlparse(self.sharepoint_url)
|
||||
# Extract hostname and site name from URL like: https://contoso.sharepoint.com/sites/teamsite
|
||||
host_name = parsed.netloc
|
||||
path_parts = parsed.path.strip('/').split('/')
|
||||
|
||||
if len(path_parts) >= 2 and path_parts[0] == 'sites':
|
||||
site_name = path_parts[1]
|
||||
return {
|
||||
"host_name": host_name,
|
||||
"site_name": site_name
|
||||
}
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not parse SharePoint URL {self.sharepoint_url}: {e}")
|
||||
|
||||
return None
|
||||
|
||||
async def list_files(self, page_token: Optional[str] = None, max_files: Optional[int] = None) -> Dict[str, Any]:
|
||||
"""List all files using Microsoft Graph API - BaseConnector interface"""
|
||||
try:
|
||||
# Ensure authentication
|
||||
if not await self.authenticate():
|
||||
raise RuntimeError("SharePoint authentication failed during file listing")
|
||||
|
||||
files = []
|
||||
max_files_value = max_files if max_files is not None else 100
|
||||
|
||||
# Build Graph API URL for the site or fallback to user's OneDrive
|
||||
site_info = self._parse_sharepoint_url()
|
||||
if site_info:
|
||||
base_url = f"{self._graph_base_url}/sites/{site_info['host_name']}:/sites/{site_info['site_name']}:/drive/root/children"
|
||||
else:
|
||||
base_url = f"{self._graph_base_url}/me/drive/root/children"
|
||||
|
||||
params = dict(self._default_params)
|
||||
params["$top"] = max_files_value
|
||||
|
||||
if page_token:
|
||||
params["$skiptoken"] = page_token
|
||||
|
||||
response = await self._make_graph_request(base_url, params=params)
|
||||
data = response.json()
|
||||
|
||||
items = data.get("value", [])
|
||||
for item in items:
|
||||
# Only include files, not folders
|
||||
if item.get("file"):
|
||||
files.append({
|
||||
"id": item.get("id", ""),
|
||||
"name": item.get("name", ""),
|
||||
"path": f"/drive/items/{item.get('id')}",
|
||||
"size": int(item.get("size", 0)),
|
||||
"modified": item.get("lastModifiedDateTime"),
|
||||
"created": item.get("createdDateTime"),
|
||||
"mime_type": item.get("file", {}).get("mimeType", self._get_mime_type(item.get("name", ""))),
|
||||
"url": item.get("webUrl", ""),
|
||||
"download_url": item.get("@microsoft.graph.downloadUrl")
|
||||
})
|
||||
|
||||
# Check for next page
|
||||
next_page_token = None
|
||||
next_link = data.get("@odata.nextLink")
|
||||
if next_link:
|
||||
from urllib.parse import urlparse, parse_qs
|
||||
parsed = urlparse(next_link)
|
||||
query_params = parse_qs(parsed.query)
|
||||
if "$skiptoken" in query_params:
|
||||
next_page_token = query_params["$skiptoken"][0]
|
||||
|
||||
return {
|
||||
"files": files,
|
||||
"next_page_token": next_page_token
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list SharePoint files: {e}")
|
||||
return {"files": [], "next_page_token": None} # Return empty result instead of raising
|
||||
|
||||
async def get_file_content(self, file_id: str) -> ConnectorDocument:
|
||||
"""Get file content and metadata - BaseConnector interface"""
|
||||
try:
|
||||
# Ensure authentication
|
||||
if not await self.authenticate():
|
||||
raise RuntimeError("SharePoint authentication failed during file content retrieval")
|
||||
|
||||
# First get file metadata using Graph API
|
||||
file_metadata = await self._get_file_metadata_by_id(file_id)
|
||||
|
||||
if not file_metadata:
|
||||
raise ValueError(f"File not found: {file_id}")
|
||||
|
||||
# Download file content
|
||||
download_url = file_metadata.get("download_url")
|
||||
if download_url:
|
||||
content = await self._download_file_from_url(download_url)
|
||||
else:
|
||||
content = await self._download_file_content(file_id)
|
||||
|
||||
# Create ACL from metadata
|
||||
acl = DocumentACL(
|
||||
owner="", # Graph API requires additional calls for detailed permissions
|
||||
user_permissions={},
|
||||
group_permissions={}
|
||||
)
|
||||
return resp.status_code in (200, 204)
|
||||
|
||||
# Parse dates
|
||||
modified_time = self._parse_graph_date(file_metadata.get("modified"))
|
||||
created_time = self._parse_graph_date(file_metadata.get("created"))
|
||||
|
||||
return ConnectorDocument(
|
||||
id=file_id,
|
||||
filename=file_metadata.get("name", ""),
|
||||
mimetype=file_metadata.get("mime_type", "application/octet-stream"),
|
||||
content=content,
|
||||
source_url=file_metadata.get("url", ""),
|
||||
acl=acl,
|
||||
modified_time=modified_time,
|
||||
created_time=created_time,
|
||||
metadata={
|
||||
"sharepoint_path": file_metadata.get("path", ""),
|
||||
"sharepoint_url": self.sharepoint_url,
|
||||
"size": file_metadata.get("size", 0)
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get SharePoint file content {file_id}: {e}")
|
||||
raise
|
||||
|
||||
async def _get_file_metadata_by_id(self, file_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get file metadata by ID using Graph API"""
|
||||
try:
|
||||
# Try site-specific path first, then fallback to user drive
|
||||
site_info = self._parse_sharepoint_url()
|
||||
if site_info:
|
||||
url = f"{self._graph_base_url}/sites/{site_info['host_name']}:/sites/{site_info['site_name']}:/drive/items/{file_id}"
|
||||
else:
|
||||
url = f"{self._graph_base_url}/me/drive/items/{file_id}"
|
||||
|
||||
params = dict(self._default_params)
|
||||
|
||||
response = await self._make_graph_request(url, params=params)
|
||||
item = response.json()
|
||||
|
||||
if item.get("file"):
|
||||
return {
|
||||
"id": file_id,
|
||||
"name": item.get("name", ""),
|
||||
"path": f"/drive/items/{file_id}",
|
||||
"size": int(item.get("size", 0)),
|
||||
"modified": item.get("lastModifiedDateTime"),
|
||||
"created": item.get("createdDateTime"),
|
||||
"mime_type": item.get("file", {}).get("mimeType", self._get_mime_type(item.get("name", ""))),
|
||||
"url": item.get("webUrl", ""),
|
||||
"download_url": item.get("@microsoft.graph.downloadUrl")
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get file metadata for {file_id}: {e}")
|
||||
return None
|
||||
|
||||
async def _download_file_content(self, file_id: str) -> bytes:
|
||||
"""Download file content by file ID using Graph API"""
|
||||
try:
|
||||
site_info = self._parse_sharepoint_url()
|
||||
if site_info:
|
||||
url = f"{self._graph_base_url}/sites/{site_info['host_name']}:/sites/{site_info['site_name']}:/drive/items/{file_id}/content"
|
||||
else:
|
||||
url = f"{self._graph_base_url}/me/drive/items/{file_id}/content"
|
||||
|
||||
token = self.oauth.get_access_token()
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(url, headers=headers, timeout=60)
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to download file content for {file_id}: {e}")
|
||||
raise
|
||||
|
||||
async def _download_file_from_url(self, download_url: str) -> bytes:
|
||||
"""Download file content from direct download URL"""
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(download_url, timeout=60)
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to download from URL {download_url}: {e}")
|
||||
raise
|
||||
|
||||
def _parse_graph_date(self, date_str: Optional[str]) -> datetime:
|
||||
"""Parse Microsoft Graph date string to datetime"""
|
||||
if not date_str:
|
||||
return datetime.now()
|
||||
|
||||
try:
|
||||
if date_str.endswith('Z'):
|
||||
return datetime.fromisoformat(date_str[:-1]).replace(tzinfo=None)
|
||||
else:
|
||||
return datetime.fromisoformat(date_str.replace('T', ' '))
|
||||
except (ValueError, AttributeError):
|
||||
return datetime.now()
|
||||
|
||||
async def _make_graph_request(self, url: str, method: str = "GET",
|
||||
data: Optional[Dict] = None, params: Optional[Dict] = None) -> httpx.Response:
|
||||
"""Make authenticated API request to Microsoft Graph"""
|
||||
token = self.oauth.get_access_token()
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
if method.upper() == "GET":
|
||||
response = await client.get(url, headers=headers, params=params, timeout=30)
|
||||
elif method.upper() == "POST":
|
||||
response = await client.post(url, headers=headers, json=data, timeout=30)
|
||||
elif method.upper() == "DELETE":
|
||||
response = await client.delete(url, headers=headers, timeout=30)
|
||||
else:
|
||||
raise ValueError(f"Unsupported HTTP method: {method}")
|
||||
|
||||
response.raise_for_status()
|
||||
return response
|
||||
|
||||
def _get_mime_type(self, filename: str) -> str:
|
||||
"""Get MIME type based on file extension"""
|
||||
import mimetypes
|
||||
mime_type, _ = mimetypes.guess_type(filename)
|
||||
return mime_type or "application/octet-stream"
|
||||
|
||||
# Webhook methods - BaseConnector interface
|
||||
def handle_webhook_validation(self, request_method: str, headers: Dict[str, str],
|
||||
query_params: Dict[str, str]) -> Optional[str]:
|
||||
"""Handle webhook validation (Graph API specific)"""
|
||||
if request_method == "POST" and "validationToken" in query_params:
|
||||
return query_params["validationToken"]
|
||||
return None
|
||||
|
||||
def extract_webhook_channel_id(self, payload: Dict[str, Any],
|
||||
headers: Dict[str, str]) -> Optional[str]:
|
||||
"""Extract channel/subscription ID from webhook payload"""
|
||||
notifications = payload.get("value", [])
|
||||
if notifications:
|
||||
return notifications[0].get("subscriptionId")
|
||||
return None
|
||||
|
||||
async def handle_webhook(self, payload: Dict[str, Any]) -> List[str]:
|
||||
"""Handle webhook notification and return affected file IDs"""
|
||||
affected_files = []
|
||||
|
||||
# Process Microsoft Graph webhook payload
|
||||
notifications = payload.get("value", [])
|
||||
for notification in notifications:
|
||||
resource = notification.get("resource")
|
||||
if resource and "/drive/items/" in resource:
|
||||
file_id = resource.split("/drive/items/")[-1]
|
||||
affected_files.append(file_id)
|
||||
|
||||
return affected_files
|
||||
|
||||
async def cleanup_subscription(self, subscription_id: str) -> bool:
|
||||
"""Clean up subscription - BaseConnector interface"""
|
||||
if subscription_id == "no-webhook-configured":
|
||||
logger.info("No subscription to cleanup (webhook was not configured)")
|
||||
return True
|
||||
|
||||
try:
|
||||
# Ensure authentication
|
||||
if not await self.authenticate():
|
||||
logger.error("SharePoint authentication failed during subscription cleanup")
|
||||
return False
|
||||
|
||||
token = self.oauth.get_access_token()
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
|
||||
url = f"{self._graph_base_url}/subscriptions/{subscription_id}"
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.delete(url, headers=headers, timeout=30)
|
||||
|
||||
if response.status_code in [200, 204, 404]:
|
||||
logger.info(f"SharePoint subscription {subscription_id} cleaned up successfully")
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"Unexpected response cleaning up subscription: {response.status_code}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to cleanup SharePoint subscription {subscription_id}: {e}")
|
||||
return False
|
||||
|
|
|
|||
|
|
@ -1,19 +1,28 @@
|
|||
import os
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
import aiofiles
|
||||
from datetime import datetime
|
||||
import httpx
|
||||
import msal
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SharePointOAuth:
|
||||
"""Direct token management for SharePoint, bypassing MSAL cache format"""
|
||||
"""Handles Microsoft Graph OAuth authentication flow following Google Drive pattern."""
|
||||
|
||||
SCOPES = [
|
||||
"offline_access",
|
||||
"Files.Read.All",
|
||||
"Sites.Read.All",
|
||||
]
|
||||
# Reserved scopes that must NOT be sent on token or silent calls
|
||||
RESERVED_SCOPES = {"openid", "profile", "offline_access"}
|
||||
|
||||
# For PERSONAL Microsoft Accounts (OneDrive consumer):
|
||||
# - Use AUTH_SCOPES for interactive auth (consent + refresh token issuance)
|
||||
# - Use RESOURCE_SCOPES for acquire_token_silent / refresh paths
|
||||
AUTH_SCOPES = ["User.Read", "Files.Read.All", "offline_access"]
|
||||
RESOURCE_SCOPES = ["User.Read", "Files.Read.All"]
|
||||
SCOPES = AUTH_SCOPES # Backward compatibility alias
|
||||
|
||||
# Kept for reference; MSAL derives endpoints from `authority`
|
||||
AUTH_ENDPOINT = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize"
|
||||
TOKEN_ENDPOINT = "https://login.microsoftonline.com/common/oauth2/v2.0/token"
|
||||
|
||||
|
|
@ -22,173 +31,299 @@ class SharePointOAuth:
|
|||
client_id: str,
|
||||
client_secret: str,
|
||||
token_file: str = "sharepoint_token.json",
|
||||
authority: str = "https://login.microsoftonline.com/common", # Keep for compatibility
|
||||
authority: str = "https://login.microsoftonline.com/common",
|
||||
allow_json_refresh: bool = True,
|
||||
):
|
||||
"""
|
||||
Initialize SharePointOAuth.
|
||||
|
||||
Args:
|
||||
client_id: Azure AD application (client) ID.
|
||||
client_secret: Azure AD application client secret.
|
||||
token_file: Path to persisted token cache file (MSAL cache format).
|
||||
authority: Usually "https://login.microsoftonline.com/common" for MSA + org,
|
||||
or tenant-specific for work/school.
|
||||
allow_json_refresh: If True, permit one-time migration from legacy flat JSON
|
||||
{"access_token","refresh_token",...}. Otherwise refuse it.
|
||||
"""
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
self.token_file = token_file
|
||||
self.authority = authority # Keep for compatibility but not used
|
||||
self._tokens = None
|
||||
self._load_tokens()
|
||||
self.authority = authority
|
||||
self.allow_json_refresh = allow_json_refresh
|
||||
self.token_cache = msal.SerializableTokenCache()
|
||||
self._current_account = None
|
||||
|
||||
def _load_tokens(self):
|
||||
"""Load tokens from file"""
|
||||
if os.path.exists(self.token_file):
|
||||
with open(self.token_file, "r") as f:
|
||||
self._tokens = json.loads(f.read())
|
||||
print(f"Loaded tokens from {self.token_file}")
|
||||
else:
|
||||
print(f"No token file found at {self.token_file}")
|
||||
# Initialize MSAL Confidential Client
|
||||
self.app = msal.ConfidentialClientApplication(
|
||||
client_id=self.client_id,
|
||||
client_credential=self.client_secret,
|
||||
authority=self.authority,
|
||||
token_cache=self.token_cache,
|
||||
)
|
||||
|
||||
async def save_cache(self):
|
||||
"""Persist tokens to file (renamed for compatibility)"""
|
||||
await self._save_tokens()
|
||||
|
||||
async def _save_tokens(self):
|
||||
"""Save tokens to file"""
|
||||
if self._tokens:
|
||||
async with aiofiles.open(self.token_file, "w") as f:
|
||||
await f.write(json.dumps(self._tokens, indent=2))
|
||||
|
||||
def _is_token_expired(self) -> bool:
|
||||
"""Check if current access token is expired"""
|
||||
if not self._tokens or 'expiry' not in self._tokens:
|
||||
return True
|
||||
|
||||
expiry_str = self._tokens['expiry']
|
||||
# Handle different expiry formats
|
||||
async def load_credentials(self) -> bool:
|
||||
"""Load existing credentials from token file (async)."""
|
||||
try:
|
||||
if expiry_str.endswith('Z'):
|
||||
expiry_dt = datetime.fromisoformat(expiry_str[:-1])
|
||||
else:
|
||||
expiry_dt = datetime.fromisoformat(expiry_str)
|
||||
|
||||
# Add 5-minute buffer
|
||||
import datetime as dt
|
||||
now = datetime.now()
|
||||
return now >= (expiry_dt - dt.timedelta(minutes=5))
|
||||
except:
|
||||
return True
|
||||
logger.debug(f"SharePoint OAuth loading credentials from: {self.token_file}")
|
||||
if os.path.exists(self.token_file):
|
||||
logger.debug(f"Token file exists, reading: {self.token_file}")
|
||||
|
||||
# Read the token file
|
||||
async with aiofiles.open(self.token_file, "r") as f:
|
||||
cache_data = await f.read()
|
||||
logger.debug(f"Read {len(cache_data)} chars from token file")
|
||||
|
||||
if cache_data.strip():
|
||||
# 1) Try legacy flat JSON first
|
||||
try:
|
||||
json_data = json.loads(cache_data)
|
||||
if isinstance(json_data, dict) and "refresh_token" in json_data:
|
||||
if self.allow_json_refresh:
|
||||
logger.debug(
|
||||
"Found legacy JSON refresh_token and allow_json_refresh=True; attempting migration refresh"
|
||||
)
|
||||
return await self._refresh_from_json_token(json_data)
|
||||
else:
|
||||
logger.warning(
|
||||
"Token file contains a legacy JSON refresh_token, but allow_json_refresh=False. "
|
||||
"Delete the file and re-auth."
|
||||
)
|
||||
return False
|
||||
except json.JSONDecodeError:
|
||||
logger.debug("Token file is not flat JSON; attempting MSAL cache format")
|
||||
|
||||
# 2) Try MSAL cache format
|
||||
logger.debug("Attempting MSAL cache deserialization")
|
||||
self.token_cache.deserialize(cache_data)
|
||||
|
||||
# Get accounts from loaded cache
|
||||
accounts = self.app.get_accounts()
|
||||
logger.debug(f"Found {len(accounts)} accounts in MSAL cache")
|
||||
if accounts:
|
||||
self._current_account = accounts[0]
|
||||
logger.debug(f"Set current account: {self._current_account.get('username', 'no username')}")
|
||||
|
||||
# IMPORTANT: Use RESOURCE_SCOPES (no reserved scopes) for silent acquisition
|
||||
result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=self._current_account)
|
||||
logger.debug(f"Silent token acquisition result keys: {list(result.keys()) if result else 'None'}")
|
||||
if result and "access_token" in result:
|
||||
logger.debug("Silent token acquisition successful")
|
||||
await self.save_cache()
|
||||
return True
|
||||
else:
|
||||
error_msg = (result or {}).get("error") or "No result"
|
||||
logger.warning(f"Silent token acquisition failed: {error_msg}")
|
||||
else:
|
||||
logger.debug(f"Token file {self.token_file} is empty")
|
||||
else:
|
||||
logger.debug(f"Token file does not exist: {self.token_file}")
|
||||
|
||||
async def _refresh_access_token(self) -> bool:
|
||||
"""Refresh the access token using refresh token"""
|
||||
if not self._tokens or 'refresh_token' not in self._tokens:
|
||||
return False
|
||||
|
||||
data = {
|
||||
'client_id': self.client_id,
|
||||
'client_secret': self.client_secret,
|
||||
'refresh_token': self._tokens['refresh_token'],
|
||||
'grant_type': 'refresh_token',
|
||||
'scope': ' '.join(self.SCOPES)
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load SharePoint credentials: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
try:
|
||||
response = await client.post(self.TOKEN_ENDPOINT, data=data)
|
||||
response.raise_for_status()
|
||||
token_data = response.json()
|
||||
async def _refresh_from_json_token(self, token_data: dict) -> bool:
|
||||
"""
|
||||
Use refresh token from a legacy JSON file to get new tokens (one-time migration path).
|
||||
|
||||
# Update tokens
|
||||
self._tokens['token'] = token_data['access_token']
|
||||
if 'refresh_token' in token_data:
|
||||
self._tokens['refresh_token'] = token_data['refresh_token']
|
||||
|
||||
# Calculate expiry
|
||||
expires_in = token_data.get('expires_in', 3600)
|
||||
import datetime as dt
|
||||
expiry = datetime.now() + dt.timedelta(seconds=expires_in)
|
||||
self._tokens['expiry'] = expiry.isoformat()
|
||||
|
||||
await self._save_tokens()
|
||||
print("Access token refreshed successfully")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to refresh token: {e}")
|
||||
Notes:
|
||||
- Prefer using an MSAL cache file and acquire_token_silent().
|
||||
- This path is only for migrating older refresh_token JSON files.
|
||||
"""
|
||||
try:
|
||||
refresh_token = token_data.get("refresh_token")
|
||||
if not refresh_token:
|
||||
logger.error("No refresh_token found in JSON file - cannot refresh")
|
||||
logger.error("You must re-authenticate interactively to obtain a valid token")
|
||||
return False
|
||||
|
||||
def create_authorization_url(self, redirect_uri: str) -> str:
|
||||
"""Create authorization URL for OAuth flow"""
|
||||
from urllib.parse import urlencode
|
||||
|
||||
params = {
|
||||
'client_id': self.client_id,
|
||||
'response_type': 'code',
|
||||
'redirect_uri': redirect_uri,
|
||||
'scope': ' '.join(self.SCOPES),
|
||||
'response_mode': 'query'
|
||||
# Use only RESOURCE_SCOPES when refreshing (no reserved scopes)
|
||||
refresh_scopes = [s for s in self.RESOURCE_SCOPES if s not in self.RESERVED_SCOPES]
|
||||
logger.debug(f"Using refresh token; refresh scopes = {refresh_scopes}")
|
||||
|
||||
result = self.app.acquire_token_by_refresh_token(
|
||||
refresh_token=refresh_token,
|
||||
scopes=refresh_scopes,
|
||||
)
|
||||
|
||||
if result and "access_token" in result:
|
||||
logger.debug("Successfully refreshed token via legacy JSON path")
|
||||
await self.save_cache()
|
||||
|
||||
accounts = self.app.get_accounts()
|
||||
logger.debug(f"After refresh, found {len(accounts)} accounts")
|
||||
if accounts:
|
||||
self._current_account = accounts[0]
|
||||
logger.debug(f"Set current account after refresh: {self._current_account.get('username', 'no username')}")
|
||||
return True
|
||||
|
||||
# Error handling
|
||||
err = (result or {}).get("error_description") or (result or {}).get("error") or "Unknown error"
|
||||
logger.error(f"Refresh token failed: {err}")
|
||||
|
||||
if any(code in err for code in ("AADSTS70000", "invalid_grant", "interaction_required")):
|
||||
logger.warning(
|
||||
"Refresh denied due to unauthorized/expired scopes or invalid grant. "
|
||||
"Delete the token file and perform interactive sign-in with correct scopes."
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Exception during refresh from JSON token: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
async def save_cache(self):
|
||||
"""Persist the token cache to file."""
|
||||
try:
|
||||
# Ensure parent directory exists
|
||||
parent = os.path.dirname(os.path.abspath(self.token_file))
|
||||
if parent and not os.path.exists(parent):
|
||||
os.makedirs(parent, exist_ok=True)
|
||||
|
||||
cache_data = self.token_cache.serialize()
|
||||
if cache_data:
|
||||
async with aiofiles.open(self.token_file, "w") as f:
|
||||
await f.write(cache_data)
|
||||
logger.debug(f"Token cache saved to {self.token_file}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save token cache: {e}")
|
||||
|
||||
def create_authorization_url(self, redirect_uri: str, state: Optional[str] = None) -> str:
|
||||
"""Create authorization URL for OAuth flow."""
|
||||
# Store redirect URI for later use in callback
|
||||
self._redirect_uri = redirect_uri
|
||||
|
||||
kwargs: Dict[str, Any] = {
|
||||
# IMPORTANT: interactive auth includes offline_access
|
||||
"scopes": self.AUTH_SCOPES,
|
||||
"redirect_uri": redirect_uri,
|
||||
"prompt": "consent", # ensure refresh token on first run
|
||||
}
|
||||
|
||||
auth_url = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize"
|
||||
return f"{auth_url}?{urlencode(params)}"
|
||||
if state:
|
||||
kwargs["state"] = state # Optional CSRF protection
|
||||
|
||||
auth_url = self.app.get_authorization_request_url(**kwargs)
|
||||
|
||||
logger.debug(f"Generated auth URL: {auth_url}")
|
||||
logger.debug(f"Auth scopes: {self.AUTH_SCOPES}")
|
||||
|
||||
return auth_url
|
||||
|
||||
async def handle_authorization_callback(
|
||||
self, authorization_code: str, redirect_uri: str
|
||||
) -> bool:
|
||||
"""Handle OAuth callback and exchange code for tokens"""
|
||||
data = {
|
||||
'client_id': self.client_id,
|
||||
'client_secret': self.client_secret,
|
||||
'code': authorization_code,
|
||||
'grant_type': 'authorization_code',
|
||||
'redirect_uri': redirect_uri,
|
||||
'scope': ' '.join(self.SCOPES)
|
||||
}
|
||||
"""Handle OAuth callback and exchange code for tokens."""
|
||||
try:
|
||||
# For code exchange, we pass the same auth scopes as used in the authorize step
|
||||
result = self.app.acquire_token_by_authorization_code(
|
||||
authorization_code,
|
||||
scopes=self.AUTH_SCOPES,
|
||||
redirect_uri=redirect_uri,
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
try:
|
||||
response = await client.post(self.TOKEN_ENDPOINT, data=data)
|
||||
response.raise_for_status()
|
||||
token_data = response.json()
|
||||
if result and "access_token" in result:
|
||||
# Store the account for future use
|
||||
accounts = self.app.get_accounts()
|
||||
if accounts:
|
||||
self._current_account = accounts[0]
|
||||
|
||||
# Store tokens in our format
|
||||
import datetime as dt
|
||||
expires_in = token_data.get('expires_in', 3600)
|
||||
expiry = datetime.now() + dt.timedelta(seconds=expires_in)
|
||||
|
||||
self._tokens = {
|
||||
'token': token_data['access_token'],
|
||||
'refresh_token': token_data['refresh_token'],
|
||||
'scopes': self.SCOPES,
|
||||
'expiry': expiry.isoformat()
|
||||
}
|
||||
|
||||
await self._save_tokens()
|
||||
print("Authorization successful, tokens saved")
|
||||
await self.save_cache()
|
||||
logger.info("SharePoint OAuth authorization successful")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"Authorization failed: {e}")
|
||||
return False
|
||||
|
||||
async def is_authenticated(self) -> bool:
|
||||
"""Check if we have valid credentials"""
|
||||
if not self._tokens:
|
||||
error_msg = (result or {}).get("error_description") or (result or {}).get("error") or "Unknown error"
|
||||
logger.error(f"SharePoint OAuth authorization failed: {error_msg}")
|
||||
return False
|
||||
|
||||
# If token is expired, try to refresh
|
||||
if self._is_token_expired():
|
||||
print("Token expired, attempting refresh...")
|
||||
if await self._refresh_access_token():
|
||||
except Exception as e:
|
||||
logger.error(f"Exception during SharePoint OAuth authorization: {e}")
|
||||
return False
|
||||
|
||||
async def is_authenticated(self) -> bool:
|
||||
"""Check if we have valid credentials (simplified like Google Drive)."""
|
||||
try:
|
||||
# First try to load credentials if we haven't already
|
||||
if not self._current_account:
|
||||
await self.load_credentials()
|
||||
|
||||
# If we have an account, try to get a token (MSAL will refresh if needed)
|
||||
if self._current_account:
|
||||
# IMPORTANT: use RESOURCE_SCOPES here
|
||||
result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=self._current_account)
|
||||
if result and "access_token" in result:
|
||||
return True
|
||||
else:
|
||||
error_msg = (result or {}).get("error") or "No result returned"
|
||||
logger.debug(f"Token acquisition failed for current account: {error_msg}")
|
||||
|
||||
# Fallback: try without specific account
|
||||
result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=None)
|
||||
if result and "access_token" in result:
|
||||
# Update current account if this worked
|
||||
accounts = self.app.get_accounts()
|
||||
if accounts:
|
||||
self._current_account = accounts[0]
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Authentication check failed: {e}")
|
||||
return False
|
||||
|
||||
def get_access_token(self) -> str:
|
||||
"""Get current access token"""
|
||||
if not self._tokens or 'token' not in self._tokens:
|
||||
raise ValueError("No access token available")
|
||||
|
||||
if self._is_token_expired():
|
||||
raise ValueError("Access token expired and refresh failed")
|
||||
|
||||
return self._tokens['token']
|
||||
"""Get an access token for Microsoft Graph (simplified like Google Drive)."""
|
||||
try:
|
||||
# Try with current account first
|
||||
if self._current_account:
|
||||
result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=self._current_account)
|
||||
if result and "access_token" in result:
|
||||
return result["access_token"]
|
||||
|
||||
# Fallback: try without specific account
|
||||
result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=None)
|
||||
if result and "access_token" in result:
|
||||
return result["access_token"]
|
||||
|
||||
# If we get here, authentication has failed
|
||||
error_msg = (result or {}).get("error_description") or (result or {}).get("error") or "No valid authentication"
|
||||
raise ValueError(f"Failed to acquire access token: {error_msg}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get access token: {e}")
|
||||
raise
|
||||
|
||||
async def revoke_credentials(self):
|
||||
"""Clear tokens"""
|
||||
self._tokens = None
|
||||
if os.path.exists(self.token_file):
|
||||
os.remove(self.token_file)
|
||||
"""Clear token cache and remove token file (like Google Drive)."""
|
||||
try:
|
||||
# Clear in-memory state
|
||||
self._current_account = None
|
||||
self.token_cache = msal.SerializableTokenCache()
|
||||
|
||||
# Recreate MSAL app with fresh cache
|
||||
self.app = msal.ConfidentialClientApplication(
|
||||
client_id=self.client_id,
|
||||
client_credential=self.client_secret,
|
||||
authority=self.authority,
|
||||
token_cache=self.token_cache,
|
||||
)
|
||||
|
||||
# Remove token file
|
||||
if os.path.exists(self.token_file):
|
||||
os.remove(self.token_file)
|
||||
logger.info(f"Removed SharePoint token file: {self.token_file}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to revoke SharePoint credentials: {e}")
|
||||
|
||||
def get_service(self) -> str:
|
||||
"""Return an access token (Graph doesn't need a generated client like Google Drive)."""
|
||||
return self.get_access_token()
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue