Finally fix MSAL in onedrive/sharepoint
This commit is contained in:
parent
5e14c7f100
commit
f03889a2b3
6 changed files with 1649 additions and 750 deletions
|
|
@ -132,7 +132,10 @@ async def connector_status(request: Request, connector_service, session_manager)
|
||||||
for connection in connections:
|
for connection in connections:
|
||||||
try:
|
try:
|
||||||
connector = await connector_service._get_connector(connection.connection_id)
|
connector = await connector_service._get_connector(connection.connection_id)
|
||||||
connection_client_ids[connection.connection_id] = connector.get_client_id()
|
if connector is not None:
|
||||||
|
connection_client_ids[connection.connection_id] = connector.get_client_id()
|
||||||
|
else:
|
||||||
|
connection_client_ids[connection.connection_id] = None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Could not get connector for connection",
|
"Could not get connector for connection",
|
||||||
|
|
@ -338,8 +341,8 @@ async def connector_webhook(request: Request, connector_service, session_manager
|
||||||
)
|
)
|
||||||
|
|
||||||
async def connector_token(request: Request, connector_service, session_manager):
|
async def connector_token(request: Request, connector_service, session_manager):
|
||||||
"""Get access token for connector API calls (e.g., Google Picker)"""
|
"""Get access token for connector API calls (e.g., Pickers)."""
|
||||||
connector_type = request.path_params.get("connector_type")
|
url_connector_type = request.path_params.get("connector_type")
|
||||||
connection_id = request.query_params.get("connection_id")
|
connection_id = request.query_params.get("connection_id")
|
||||||
|
|
||||||
if not connection_id:
|
if not connection_id:
|
||||||
|
|
@ -348,37 +351,81 @@ async def connector_token(request: Request, connector_service, session_manager):
|
||||||
user = request.state.user
|
user = request.state.user
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Get the connection and verify it belongs to the user
|
# 1) Load the connection and verify ownership
|
||||||
connection = await connector_service.connection_manager.get_connection(connection_id)
|
connection = await connector_service.connection_manager.get_connection(connection_id)
|
||||||
if not connection or connection.user_id != user.user_id:
|
if not connection or connection.user_id != user.user_id:
|
||||||
return JSONResponse({"error": "Connection not found"}, status_code=404)
|
return JSONResponse({"error": "Connection not found"}, status_code=404)
|
||||||
|
|
||||||
# Get the connector instance
|
# 2) Get the ACTUAL connector instance/type for this connection_id
|
||||||
connector = await connector_service._get_connector(connection_id)
|
connector = await connector_service._get_connector(connection_id)
|
||||||
if not connector:
|
if not connector:
|
||||||
return JSONResponse({"error": f"Connector not available - authentication may have failed for {connector_type}"}, status_code=404)
|
return JSONResponse(
|
||||||
|
{"error": f"Connector not available - authentication may have failed for {url_connector_type}"},
|
||||||
|
status_code=404,
|
||||||
|
)
|
||||||
|
|
||||||
# For Google Drive, get the access token
|
real_type = getattr(connector, "type", None) or getattr(connection, "connector_type", None)
|
||||||
if connector_type == "google_drive" and hasattr(connector, 'oauth'):
|
if real_type is None:
|
||||||
|
return JSONResponse({"error": "Unable to determine connector type"}, status_code=500)
|
||||||
|
|
||||||
|
# Optional: warn if URL path type disagrees with real type
|
||||||
|
if url_connector_type and url_connector_type != real_type:
|
||||||
|
# You can downgrade this to debug if you expect cross-routing.
|
||||||
|
return JSONResponse(
|
||||||
|
{
|
||||||
|
"error": "Connector type mismatch",
|
||||||
|
"detail": {
|
||||||
|
"requested_type": url_connector_type,
|
||||||
|
"actual_type": real_type,
|
||||||
|
"hint": "Call the token endpoint using the correct connector_type for this connection_id.",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
status_code=400,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3) Branch by the actual connector type
|
||||||
|
# GOOGLE DRIVE (google-auth)
|
||||||
|
if real_type == "google_drive" and hasattr(connector, "oauth"):
|
||||||
await connector.oauth.load_credentials()
|
await connector.oauth.load_credentials()
|
||||||
if connector.oauth.creds and connector.oauth.creds.valid:
|
if connector.oauth.creds and connector.oauth.creds.valid:
|
||||||
return JSONResponse({
|
expires_in = None
|
||||||
"access_token": connector.oauth.creds.token,
|
try:
|
||||||
"expires_in": (connector.oauth.creds.expiry.timestamp() -
|
if connector.oauth.creds.expiry:
|
||||||
__import__('time').time()) if connector.oauth.creds.expiry else None
|
import time
|
||||||
})
|
expires_in = max(0, int(connector.oauth.creds.expiry.timestamp() - time.time()))
|
||||||
else:
|
except Exception:
|
||||||
return JSONResponse({"error": "Invalid or expired credentials"}, status_code=401)
|
expires_in = None
|
||||||
|
|
||||||
# For OneDrive and SharePoint, get the access token
|
return JSONResponse(
|
||||||
elif connector_type in ["onedrive", "sharepoint"] and hasattr(connector, 'oauth'):
|
{
|
||||||
|
"access_token": connector.oauth.creds.token,
|
||||||
|
"expires_in": expires_in,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return JSONResponse({"error": "Invalid or expired credentials"}, status_code=401)
|
||||||
|
|
||||||
|
# ONEDRIVE / SHAREPOINT (MSAL or custom)
|
||||||
|
if real_type in ("onedrive", "sharepoint") and hasattr(connector, "oauth"):
|
||||||
|
# Ensure cache/credentials are loaded before trying to use them
|
||||||
try:
|
try:
|
||||||
|
# Prefer a dedicated is_authenticated() that loads cache internally
|
||||||
|
if hasattr(connector.oauth, "is_authenticated"):
|
||||||
|
ok = await connector.oauth.is_authenticated()
|
||||||
|
else:
|
||||||
|
# Fallback: try to load credentials explicitly if available
|
||||||
|
ok = True
|
||||||
|
if hasattr(connector.oauth, "load_credentials"):
|
||||||
|
ok = await connector.oauth.load_credentials()
|
||||||
|
|
||||||
|
if not ok:
|
||||||
|
return JSONResponse({"error": "Not authenticated"}, status_code=401)
|
||||||
|
|
||||||
|
# Now safe to fetch access token
|
||||||
access_token = connector.oauth.get_access_token()
|
access_token = connector.oauth.get_access_token()
|
||||||
return JSONResponse({
|
# MSAL result has expiry, but we’re returning a raw token; keep expires_in None for simplicity
|
||||||
"access_token": access_token,
|
return JSONResponse({"access_token": access_token, "expires_in": None})
|
||||||
"expires_in": None # MSAL handles token expiry internally
|
|
||||||
})
|
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
|
# Typical when acquire_token_silent fails (e.g., needs re-auth)
|
||||||
return JSONResponse({"error": f"Failed to get access token: {str(e)}"}, status_code=401)
|
return JSONResponse({"error": f"Failed to get access token: {str(e)}"}, status_code=401)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return JSONResponse({"error": f"Authentication error: {str(e)}"}, status_code=500)
|
return JSONResponse({"error": f"Authentication error: {str(e)}"}, status_code=500)
|
||||||
|
|
@ -386,7 +433,5 @@ async def connector_token(request: Request, connector_service, session_manager):
|
||||||
return JSONResponse({"error": "Token not available for this connector type"}, status_code=400)
|
return JSONResponse({"error": "Token not available for this connector type"}, status_code=400)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Error getting connector token", error=str(e))
|
logger.error("Error getting connector token", exc_info=True)
|
||||||
return JSONResponse({"error": str(e)}, status_code=500)
|
return JSONResponse({"error": str(e)}, status_code=500)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -294,32 +294,39 @@ class ConnectionManager:
|
||||||
|
|
||||||
async def get_connector(self, connection_id: str) -> Optional[BaseConnector]:
|
async def get_connector(self, connection_id: str) -> Optional[BaseConnector]:
|
||||||
"""Get an active connector instance"""
|
"""Get an active connector instance"""
|
||||||
|
logger.debug(f"Getting connector for connection_id: {connection_id}")
|
||||||
|
|
||||||
# Return cached connector if available
|
# Return cached connector if available
|
||||||
if connection_id in self.active_connectors:
|
if connection_id in self.active_connectors:
|
||||||
connector = self.active_connectors[connection_id]
|
connector = self.active_connectors[connection_id]
|
||||||
if connector.is_authenticated:
|
if connector.is_authenticated:
|
||||||
|
logger.debug(f"Returning cached authenticated connector for {connection_id}")
|
||||||
return connector
|
return connector
|
||||||
else:
|
else:
|
||||||
# Remove unauthenticated connector from cache
|
# Remove unauthenticated connector from cache
|
||||||
|
logger.debug(f"Removing unauthenticated connector from cache for {connection_id}")
|
||||||
del self.active_connectors[connection_id]
|
del self.active_connectors[connection_id]
|
||||||
|
|
||||||
# Try to create and authenticate connector
|
# Try to create and authenticate connector
|
||||||
connection_config = self.connections.get(connection_id)
|
connection_config = self.connections.get(connection_id)
|
||||||
if not connection_config or not connection_config.is_active:
|
if not connection_config or not connection_config.is_active:
|
||||||
|
logger.debug(f"No active connection config found for {connection_id}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
logger.debug(f"Creating connector for {connection_config.connector_type}")
|
||||||
connector = self._create_connector(connection_config)
|
connector = self._create_connector(connection_config)
|
||||||
if await connector.authenticate():
|
|
||||||
|
logger.debug(f"Attempting authentication for {connection_id}")
|
||||||
|
auth_result = await connector.authenticate()
|
||||||
|
logger.debug(f"Authentication result for {connection_id}: {auth_result}")
|
||||||
|
|
||||||
|
if auth_result:
|
||||||
self.active_connectors[connection_id] = connector
|
self.active_connectors[connection_id] = connector
|
||||||
|
# ... rest of the method
|
||||||
# Setup webhook subscription if not already set up
|
|
||||||
await self._setup_webhook_if_needed(
|
|
||||||
connection_id, connection_config, connector
|
|
||||||
)
|
|
||||||
|
|
||||||
return connector
|
return connector
|
||||||
|
else:
|
||||||
return None
|
logger.warning(f"Authentication failed for {connection_id}")
|
||||||
|
return None
|
||||||
|
|
||||||
def get_available_connector_types(self) -> Dict[str, Dict[str, Any]]:
|
def get_available_connector_types(self) -> Dict[str, Dict[str, Any]]:
|
||||||
"""Get available connector types with their metadata"""
|
"""Get available connector types with their metadata"""
|
||||||
|
|
@ -363,20 +370,23 @@ class ConnectionManager:
|
||||||
|
|
||||||
def _create_connector(self, config: ConnectionConfig) -> BaseConnector:
|
def _create_connector(self, config: ConnectionConfig) -> BaseConnector:
|
||||||
"""Factory method to create connector instances"""
|
"""Factory method to create connector instances"""
|
||||||
if config.connector_type == "google_drive":
|
try:
|
||||||
return GoogleDriveConnector(config.config)
|
if config.connector_type == "google_drive":
|
||||||
elif config.connector_type == "sharepoint":
|
return GoogleDriveConnector(config.config)
|
||||||
return SharePointConnector(config.config)
|
elif config.connector_type == "sharepoint":
|
||||||
elif config.connector_type == "onedrive":
|
return SharePointConnector(config.config)
|
||||||
return OneDriveConnector(config.config)
|
elif config.connector_type == "onedrive":
|
||||||
elif config.connector_type == "box":
|
return OneDriveConnector(config.config)
|
||||||
# Future: BoxConnector(config.config)
|
elif config.connector_type == "box":
|
||||||
raise NotImplementedError("Box connector not implemented yet")
|
raise NotImplementedError("Box connector not implemented yet")
|
||||||
elif config.connector_type == "dropbox":
|
elif config.connector_type == "dropbox":
|
||||||
# Future: DropboxConnector(config.config)
|
raise NotImplementedError("Dropbox connector not implemented yet")
|
||||||
raise NotImplementedError("Dropbox connector not implemented yet")
|
else:
|
||||||
else:
|
raise ValueError(f"Unknown connector type: {config.connector_type}")
|
||||||
raise ValueError(f"Unknown connector type: {config.connector_type}")
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to create {config.connector_type} connector: {e}")
|
||||||
|
# Re-raise the exception so caller can handle appropriately
|
||||||
|
raise
|
||||||
|
|
||||||
async def update_last_sync(self, connection_id: str):
|
async def update_last_sync(self, connection_id: str):
|
||||||
"""Update the last sync timestamp for a connection"""
|
"""Update the last sync timestamp for a connection"""
|
||||||
|
|
|
||||||
|
|
@ -1,235 +1,487 @@
|
||||||
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import List, Dict, Any, Optional
|
||||||
|
from datetime import datetime
|
||||||
import httpx
|
import httpx
|
||||||
import uuid
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from typing import Dict, List, Any, Optional
|
|
||||||
|
|
||||||
from ..base import BaseConnector, ConnectorDocument, DocumentACL
|
from ..base import BaseConnector, ConnectorDocument, DocumentACL
|
||||||
from .oauth import OneDriveOAuth
|
from .oauth import OneDriveOAuth
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class OneDriveConnector(BaseConnector):
|
class OneDriveConnector(BaseConnector):
|
||||||
"""OneDrive connector using Microsoft Graph API"""
|
"""OneDrive connector using MSAL-based OAuth for authentication."""
|
||||||
|
|
||||||
|
# Required BaseConnector class attributes
|
||||||
CLIENT_ID_ENV_VAR = "MICROSOFT_GRAPH_OAUTH_CLIENT_ID"
|
CLIENT_ID_ENV_VAR = "MICROSOFT_GRAPH_OAUTH_CLIENT_ID"
|
||||||
CLIENT_SECRET_ENV_VAR = "MICROSOFT_GRAPH_OAUTH_CLIENT_SECRET"
|
CLIENT_SECRET_ENV_VAR = "MICROSOFT_GRAPH_OAUTH_CLIENT_SECRET"
|
||||||
|
|
||||||
# Connector metadata
|
# Connector metadata
|
||||||
CONNECTOR_NAME = "OneDrive"
|
CONNECTOR_NAME = "OneDrive"
|
||||||
CONNECTOR_DESCRIPTION = "Connect your personal OneDrive to sync documents"
|
CONNECTOR_DESCRIPTION = "Connect to OneDrive (personal) to sync documents and files"
|
||||||
CONNECTOR_ICON = "onedrive"
|
CONNECTOR_ICON = "onedrive"
|
||||||
|
|
||||||
def __init__(self, config: Dict[str, Any]):
|
def __init__(self, config: Dict[str, Any]):
|
||||||
super().__init__(config)
|
logger.debug(f"OneDrive connector __init__ called with config type: {type(config)}")
|
||||||
|
logger.debug(f"OneDrive connector __init__ config value: {config}")
|
||||||
|
|
||||||
|
if config is None:
|
||||||
|
logger.debug("Config was None, using empty dict")
|
||||||
|
config = {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.debug("Calling super().__init__")
|
||||||
|
super().__init__(config)
|
||||||
|
logger.debug("super().__init__ completed successfully")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"super().__init__ failed: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Initialize with defaults that allow the connector to be listed
|
||||||
|
self.client_id = None
|
||||||
|
self.client_secret = None
|
||||||
|
self.redirect_uri = config.get("redirect_uri", "http://localhost") # must match your app registration
|
||||||
|
|
||||||
|
# Try to get credentials, but don't fail if they're missing
|
||||||
|
try:
|
||||||
|
self.client_id = self.get_client_id()
|
||||||
|
logger.debug(f"Got client_id: {self.client_id is not None}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Failed to get client_id: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.client_secret = self.get_client_secret()
|
||||||
|
logger.debug(f"Got client_secret: {self.client_secret is not None}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Failed to get client_secret: {e}")
|
||||||
|
|
||||||
|
# Token file setup
|
||||||
project_root = Path(__file__).resolve().parent.parent.parent.parent
|
project_root = Path(__file__).resolve().parent.parent.parent.parent
|
||||||
token_file = config.get("token_file") or str(project_root / "onedrive_token.json")
|
token_file = config.get("token_file") or str(project_root / "onedrive_token.json")
|
||||||
self.oauth = OneDriveOAuth(
|
Path(token_file).parent.mkdir(parents=True, exist_ok=True)
|
||||||
client_id=self.get_client_id(),
|
|
||||||
client_secret=self.get_client_secret(),
|
|
||||||
token_file=token_file,
|
|
||||||
)
|
|
||||||
self.subscription_id = config.get("subscription_id") or config.get(
|
|
||||||
"webhook_channel_id"
|
|
||||||
)
|
|
||||||
self.base_url = "https://graph.microsoft.com/v1.0"
|
|
||||||
|
|
||||||
async def authenticate(self) -> bool:
|
# Only initialize OAuth if we have credentials
|
||||||
if await self.oauth.is_authenticated():
|
if self.client_id and self.client_secret:
|
||||||
self._authenticated = True
|
connection_id = config.get("connection_id", "default")
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def setup_subscription(self) -> str:
|
# Use token_file from config if provided, otherwise generate one
|
||||||
if not self._authenticated:
|
if config.get("token_file"):
|
||||||
raise ValueError("Not authenticated")
|
oauth_token_file = config["token_file"]
|
||||||
|
else:
|
||||||
|
# Use a per-connection cache file to avoid collisions with other connectors
|
||||||
|
oauth_token_file = f"onedrive_token_{connection_id}.json"
|
||||||
|
|
||||||
webhook_url = self.config.get("webhook_url")
|
# MSA & org both work via /common for OneDrive personal testing
|
||||||
if not webhook_url:
|
authority = "https://login.microsoftonline.com/common"
|
||||||
raise ValueError("webhook_url required in config for subscriptions")
|
|
||||||
|
|
||||||
expiration = (datetime.utcnow() + timedelta(days=2)).isoformat() + "Z"
|
self.oauth = OneDriveOAuth(
|
||||||
body = {
|
client_id=self.client_id,
|
||||||
"changeType": "created,updated,deleted",
|
client_secret=self.client_secret,
|
||||||
"notificationUrl": webhook_url,
|
token_file=oauth_token_file,
|
||||||
"resource": "/me/drive/root",
|
authority=authority,
|
||||||
"expirationDateTime": expiration,
|
allow_json_refresh=True, # allows one-time migration from legacy JSON if present
|
||||||
"clientState": str(uuid.uuid4()),
|
)
|
||||||
|
else:
|
||||||
|
self.oauth = None
|
||||||
|
|
||||||
|
# Track subscription ID for webhooks (note: change notifications might not be available for personal accounts)
|
||||||
|
self._subscription_id: Optional[str] = None
|
||||||
|
|
||||||
|
# Graph API defaults
|
||||||
|
self._graph_api_version = "v1.0"
|
||||||
|
self._default_params = {
|
||||||
|
"$select": "id,name,size,lastModifiedDateTime,createdDateTime,webUrl,file,folder,@microsoft.graph.downloadUrl"
|
||||||
}
|
}
|
||||||
|
|
||||||
token = self.oauth.get_access_token()
|
@property
|
||||||
async with httpx.AsyncClient() as client:
|
def _graph_base_url(self) -> str:
|
||||||
resp = await client.post(
|
"""Base URL for Microsoft Graph API calls."""
|
||||||
f"{self.base_url}/subscriptions",
|
return f"https://graph.microsoft.com/{self._graph_api_version}"
|
||||||
json=body,
|
|
||||||
headers={"Authorization": f"Bearer {token}"},
|
|
||||||
)
|
|
||||||
resp.raise_for_status()
|
|
||||||
data = resp.json()
|
|
||||||
|
|
||||||
self.subscription_id = data["id"]
|
def emit(self, doc: ConnectorDocument) -> None:
|
||||||
return self.subscription_id
|
"""Emit a ConnectorDocument instance (integrate with your pipeline here)."""
|
||||||
|
logger.debug(f"Emitting OneDrive document: {doc.id} ({doc.filename})")
|
||||||
|
|
||||||
async def list_files(
|
async def authenticate(self) -> bool:
|
||||||
self, page_token: Optional[str] = None, limit: int = 100
|
"""Test authentication - BaseConnector interface."""
|
||||||
) -> Dict[str, Any]:
|
logger.debug(f"OneDrive authenticate() called, oauth is None: {self.oauth is None}")
|
||||||
if not self._authenticated:
|
try:
|
||||||
raise ValueError("Not authenticated")
|
if not self.oauth:
|
||||||
|
logger.debug("OneDrive authentication failed: OAuth not initialized")
|
||||||
|
self._authenticated = False
|
||||||
|
return False
|
||||||
|
|
||||||
params = {"$top": str(limit)}
|
logger.debug("Loading OneDrive credentials...")
|
||||||
if page_token:
|
load_result = await self.oauth.load_credentials()
|
||||||
params["$skiptoken"] = page_token
|
logger.debug(f"Load credentials result: {load_result}")
|
||||||
|
|
||||||
token = self.oauth.get_access_token()
|
logger.debug("Checking OneDrive authentication status...")
|
||||||
async with httpx.AsyncClient() as client:
|
authenticated = await self.oauth.is_authenticated()
|
||||||
resp = await client.get(
|
logger.debug(f"OneDrive is_authenticated result: {authenticated}")
|
||||||
f"{self.base_url}/me/drive/root/children",
|
|
||||||
params=params,
|
|
||||||
headers={"Authorization": f"Bearer {token}"},
|
|
||||||
)
|
|
||||||
resp.raise_for_status()
|
|
||||||
data = resp.json()
|
|
||||||
|
|
||||||
files = []
|
self._authenticated = authenticated
|
||||||
for item in data.get("value", []):
|
return authenticated
|
||||||
if item.get("file"):
|
except Exception as e:
|
||||||
files.append(
|
logger.error(f"OneDrive authentication failed: {e}")
|
||||||
{
|
import traceback
|
||||||
"id": item["id"],
|
traceback.print_exc()
|
||||||
"name": item["name"],
|
self._authenticated = False
|
||||||
"mimeType": item.get("file", {}).get(
|
return False
|
||||||
"mimeType", "application/octet-stream"
|
|
||||||
),
|
|
||||||
"webViewLink": item.get("webUrl"),
|
|
||||||
"createdTime": item.get("createdDateTime"),
|
|
||||||
"modifiedTime": item.get("lastModifiedDateTime"),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
next_token = None
|
def get_auth_url(self) -> str:
|
||||||
next_link = data.get("@odata.nextLink")
|
"""Get OAuth authorization URL."""
|
||||||
if next_link:
|
if not self.oauth:
|
||||||
from urllib.parse import urlparse, parse_qs
|
raise RuntimeError("OneDrive OAuth not initialized - missing credentials")
|
||||||
|
return self.oauth.create_authorization_url(self.redirect_uri)
|
||||||
|
|
||||||
parsed = urlparse(next_link)
|
async def handle_oauth_callback(self, auth_code: str) -> Dict[str, Any]:
|
||||||
next_token = parse_qs(parsed.query).get("$skiptoken", [None])[0]
|
"""Handle OAuth callback."""
|
||||||
|
if not self.oauth:
|
||||||
|
raise RuntimeError("OneDrive OAuth not initialized - missing credentials")
|
||||||
|
try:
|
||||||
|
success = await self.oauth.handle_authorization_callback(auth_code, self.redirect_uri)
|
||||||
|
if success:
|
||||||
|
self._authenticated = True
|
||||||
|
return {"status": "success"}
|
||||||
|
else:
|
||||||
|
raise ValueError("OAuth callback failed")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"OAuth callback failed: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
return {"files": files, "nextPageToken": next_token}
|
def sync_once(self) -> None:
|
||||||
|
"""
|
||||||
|
Perform a one-shot sync of OneDrive files and emit documents.
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
async def _async_sync():
|
||||||
|
try:
|
||||||
|
file_list = await self.list_files(max_files=1000)
|
||||||
|
files = file_list.get("files", [])
|
||||||
|
for file_info in files:
|
||||||
|
try:
|
||||||
|
file_id = file_info.get("id")
|
||||||
|
if not file_id:
|
||||||
|
continue
|
||||||
|
doc = await self.get_file_content(file_id)
|
||||||
|
self.emit(doc)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to sync OneDrive file {file_info.get('name', 'unknown')}: {e}")
|
||||||
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"OneDrive sync_once failed: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
if hasattr(asyncio, 'run'):
|
||||||
|
asyncio.run(_async_sync())
|
||||||
|
else:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
loop.run_until_complete(_async_sync())
|
||||||
|
|
||||||
|
async def setup_subscription(self) -> str:
|
||||||
|
"""
|
||||||
|
Set up real-time subscription for file changes.
|
||||||
|
NOTE: Change notifications may not be available for personal OneDrive accounts.
|
||||||
|
"""
|
||||||
|
webhook_url = self.config.get('webhook_url')
|
||||||
|
if not webhook_url:
|
||||||
|
logger.warning("No webhook URL configured, skipping OneDrive subscription setup")
|
||||||
|
return "no-webhook-configured"
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not await self.authenticate():
|
||||||
|
raise RuntimeError("OneDrive authentication failed during subscription setup")
|
||||||
|
|
||||||
|
token = self.oauth.get_access_token()
|
||||||
|
|
||||||
|
# For OneDrive personal we target the user's drive
|
||||||
|
resource = "/me/drive/root"
|
||||||
|
|
||||||
|
subscription_data = {
|
||||||
|
"changeType": "created,updated,deleted",
|
||||||
|
"notificationUrl": f"{webhook_url}/webhook/onedrive",
|
||||||
|
"resource": resource,
|
||||||
|
"expirationDateTime": self._get_subscription_expiry(),
|
||||||
|
"clientState": "onedrive_personal",
|
||||||
|
}
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {token}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
url = f"{self._graph_base_url}/subscriptions"
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.post(url, json=subscription_data, headers=headers, timeout=30)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
result = response.json()
|
||||||
|
subscription_id = result.get("id")
|
||||||
|
|
||||||
|
if subscription_id:
|
||||||
|
self._subscription_id = subscription_id
|
||||||
|
logger.info(f"OneDrive subscription created: {subscription_id}")
|
||||||
|
return subscription_id
|
||||||
|
else:
|
||||||
|
raise ValueError("No subscription ID returned from Microsoft Graph")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to setup OneDrive subscription: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _get_subscription_expiry(self) -> str:
|
||||||
|
"""Get subscription expiry time (Graph caps duration; often <= 3 days)."""
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
expiry = datetime.utcnow() + timedelta(days=3)
|
||||||
|
return expiry.strftime("%Y-%m-%dT%H:%M:%S.%fZ")
|
||||||
|
|
||||||
|
async def list_files(self, page_token: Optional[str] = None, max_files: Optional[int] = None) -> Dict[str, Any]:
|
||||||
|
"""List files from OneDrive using Microsoft Graph."""
|
||||||
|
try:
|
||||||
|
if not await self.authenticate():
|
||||||
|
raise RuntimeError("OneDrive authentication failed during file listing")
|
||||||
|
|
||||||
|
files: List[Dict[str, Any]] = []
|
||||||
|
max_files_value = max_files if max_files is not None else 100
|
||||||
|
|
||||||
|
base_url = f"{self._graph_base_url}/me/drive/root/children"
|
||||||
|
|
||||||
|
params = dict(self._default_params)
|
||||||
|
params["$top"] = max_files_value
|
||||||
|
|
||||||
|
if page_token:
|
||||||
|
params["$skiptoken"] = page_token
|
||||||
|
|
||||||
|
response = await self._make_graph_request(base_url, params=params)
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
items = data.get("value", [])
|
||||||
|
for item in items:
|
||||||
|
if item.get("file"): # include files only
|
||||||
|
files.append({
|
||||||
|
"id": item.get("id", ""),
|
||||||
|
"name": item.get("name", ""),
|
||||||
|
"path": f"/drive/items/{item.get('id')}",
|
||||||
|
"size": int(item.get("size", 0)),
|
||||||
|
"modified": item.get("lastModifiedDateTime"),
|
||||||
|
"created": item.get("createdDateTime"),
|
||||||
|
"mime_type": item.get("file", {}).get("mimeType", self._get_mime_type(item.get("name", ""))),
|
||||||
|
"url": item.get("webUrl", ""),
|
||||||
|
"download_url": item.get("@microsoft.graph.downloadUrl"),
|
||||||
|
})
|
||||||
|
|
||||||
|
# Next page
|
||||||
|
next_page_token = None
|
||||||
|
next_link = data.get("@odata.nextLink")
|
||||||
|
if next_link:
|
||||||
|
from urllib.parse import urlparse, parse_qs
|
||||||
|
parsed = urlparse(next_link)
|
||||||
|
query_params = parse_qs(parsed.query)
|
||||||
|
if "$skiptoken" in query_params:
|
||||||
|
next_page_token = query_params["$skiptoken"][0]
|
||||||
|
|
||||||
|
return {"files": files, "next_page_token": next_page_token}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to list OneDrive files: {e}")
|
||||||
|
return {"files": [], "next_page_token": None}
|
||||||
|
|
||||||
async def get_file_content(self, file_id: str) -> ConnectorDocument:
|
async def get_file_content(self, file_id: str) -> ConnectorDocument:
|
||||||
if not self._authenticated:
|
"""Get file content and metadata."""
|
||||||
raise ValueError("Not authenticated")
|
try:
|
||||||
|
if not await self.authenticate():
|
||||||
|
raise RuntimeError("OneDrive authentication failed during file content retrieval")
|
||||||
|
|
||||||
token = self.oauth.get_access_token()
|
file_metadata = await self._get_file_metadata_by_id(file_id)
|
||||||
headers = {"Authorization": f"Bearer {token}"}
|
if not file_metadata:
|
||||||
async with httpx.AsyncClient() as client:
|
raise ValueError(f"File not found: {file_id}")
|
||||||
meta_resp = await client.get(
|
|
||||||
f"{self.base_url}/me/drive/items/{file_id}", headers=headers
|
|
||||||
)
|
|
||||||
meta_resp.raise_for_status()
|
|
||||||
metadata = meta_resp.json()
|
|
||||||
|
|
||||||
content_resp = await client.get(
|
download_url = file_metadata.get("download_url")
|
||||||
f"{self.base_url}/me/drive/items/{file_id}/content", headers=headers
|
if download_url:
|
||||||
)
|
content = await self._download_file_from_url(download_url)
|
||||||
content = content_resp.content
|
|
||||||
|
|
||||||
# Handle the possibility of this being a redirect
|
|
||||||
if content_resp.status_code in (301, 302, 303, 307, 308):
|
|
||||||
redirect_url = content_resp.headers.get("Location")
|
|
||||||
if redirect_url:
|
|
||||||
content_resp = await client.get(redirect_url)
|
|
||||||
content_resp.raise_for_status()
|
|
||||||
content = content_resp.content
|
|
||||||
else:
|
else:
|
||||||
content_resp.raise_for_status()
|
content = await self._download_file_content(file_id)
|
||||||
|
|
||||||
perm_resp = await client.get(
|
acl = DocumentACL(
|
||||||
f"{self.base_url}/me/drive/items/{file_id}/permissions", headers=headers
|
owner="",
|
||||||
|
user_permissions={},
|
||||||
|
group_permissions={},
|
||||||
)
|
)
|
||||||
perm_resp.raise_for_status()
|
|
||||||
permissions = perm_resp.json()
|
|
||||||
|
|
||||||
acl = self._parse_permissions(metadata, permissions)
|
modified_time = self._parse_graph_date(file_metadata.get("modified"))
|
||||||
modified = datetime.fromisoformat(
|
created_time = self._parse_graph_date(file_metadata.get("created"))
|
||||||
metadata["lastModifiedDateTime"].replace("Z", "+00:00")
|
|
||||||
).replace(tzinfo=None)
|
|
||||||
created = datetime.fromisoformat(
|
|
||||||
metadata["createdDateTime"].replace("Z", "+00:00")
|
|
||||||
).replace(tzinfo=None)
|
|
||||||
|
|
||||||
document = ConnectorDocument(
|
return ConnectorDocument(
|
||||||
id=metadata["id"],
|
id=file_id,
|
||||||
filename=metadata["name"],
|
filename=file_metadata.get("name", ""),
|
||||||
mimetype=metadata.get("file", {}).get(
|
mimetype=file_metadata.get("mime_type", "application/octet-stream"),
|
||||||
"mimeType", "application/octet-stream"
|
content=content,
|
||||||
),
|
source_url=file_metadata.get("url", ""),
|
||||||
content=content,
|
acl=acl,
|
||||||
source_url=metadata.get("webUrl"),
|
modified_time=modified_time,
|
||||||
acl=acl,
|
created_time=created_time,
|
||||||
modified_time=modified,
|
metadata={
|
||||||
created_time=created,
|
"onedrive_path": file_metadata.get("path", ""),
|
||||||
metadata={"size": metadata.get("size")},
|
"size": file_metadata.get("size", 0),
|
||||||
)
|
},
|
||||||
return document
|
|
||||||
|
|
||||||
def _parse_permissions(
|
|
||||||
self, metadata: Dict[str, Any], permissions: Dict[str, Any]
|
|
||||||
) -> DocumentACL:
|
|
||||||
acl = DocumentACL()
|
|
||||||
owner = metadata.get("createdBy", {}).get("user", {}).get("email")
|
|
||||||
if owner:
|
|
||||||
acl.owner = owner
|
|
||||||
for perm in permissions.get("value", []):
|
|
||||||
role = perm.get("roles", ["read"])[0]
|
|
||||||
grantee = perm.get("grantedToV2") or perm.get("grantedTo")
|
|
||||||
if not grantee:
|
|
||||||
continue
|
|
||||||
user = grantee.get("user")
|
|
||||||
if user and user.get("email"):
|
|
||||||
acl.user_permissions[user["email"]] = role
|
|
||||||
group = grantee.get("group")
|
|
||||||
if group and group.get("email"):
|
|
||||||
acl.group_permissions[group["email"]] = role
|
|
||||||
return acl
|
|
||||||
|
|
||||||
def handle_webhook_validation(
|
|
||||||
self, request_method: str, headers: Dict[str, str], query_params: Dict[str, str]
|
|
||||||
) -> Optional[str]:
|
|
||||||
"""Handle Microsoft Graph webhook validation"""
|
|
||||||
if request_method == "GET":
|
|
||||||
validation_token = query_params.get("validationtoken") or query_params.get(
|
|
||||||
"validationToken"
|
|
||||||
)
|
)
|
||||||
if validation_token:
|
|
||||||
return validation_token
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to get OneDrive file content {file_id}: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def _get_file_metadata_by_id(self, file_id: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Get file metadata by ID using Graph API."""
|
||||||
|
try:
|
||||||
|
url = f"{self._graph_base_url}/me/drive/items/{file_id}"
|
||||||
|
params = dict(self._default_params)
|
||||||
|
|
||||||
|
response = await self._make_graph_request(url, params=params)
|
||||||
|
item = response.json()
|
||||||
|
|
||||||
|
if item.get("file"):
|
||||||
|
return {
|
||||||
|
"id": file_id,
|
||||||
|
"name": item.get("name", ""),
|
||||||
|
"path": f"/drive/items/{file_id}",
|
||||||
|
"size": int(item.get("size", 0)),
|
||||||
|
"modified": item.get("lastModifiedDateTime"),
|
||||||
|
"created": item.get("createdDateTime"),
|
||||||
|
"mime_type": item.get("file", {}).get("mimeType", self._get_mime_type(item.get("name", ""))),
|
||||||
|
"url": item.get("webUrl", ""),
|
||||||
|
"download_url": item.get("@microsoft.graph.downloadUrl"),
|
||||||
|
}
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to get file metadata for {file_id}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _download_file_content(self, file_id: str) -> bytes:
|
||||||
|
"""Download file content by file ID using Graph API."""
|
||||||
|
try:
|
||||||
|
url = f"{self._graph_base_url}/me/drive/items/{file_id}/content"
|
||||||
|
token = self.oauth.get_access_token()
|
||||||
|
headers = {"Authorization": f"Bearer {token}"}
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.get(url, headers=headers, timeout=60)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.content
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to download file content for {file_id}: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def _download_file_from_url(self, download_url: str) -> bytes:
|
||||||
|
"""Download file content from direct download URL."""
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.get(download_url, timeout=60)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.content
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to download from URL {download_url}: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _parse_graph_date(self, date_str: Optional[str]) -> datetime:
|
||||||
|
"""Parse Microsoft Graph date string to datetime."""
|
||||||
|
if not date_str:
|
||||||
|
return datetime.now()
|
||||||
|
try:
|
||||||
|
if date_str.endswith('Z'):
|
||||||
|
return datetime.fromisoformat(date_str[:-1]).replace(tzinfo=None)
|
||||||
|
else:
|
||||||
|
return datetime.fromisoformat(date_str.replace('T', ' '))
|
||||||
|
except (ValueError, AttributeError):
|
||||||
|
return datetime.now()
|
||||||
|
|
||||||
|
async def _make_graph_request(self, url: str, method: str = "GET",
|
||||||
|
data: Optional[Dict] = None, params: Optional[Dict] = None) -> httpx.Response:
|
||||||
|
"""Make authenticated API request to Microsoft Graph."""
|
||||||
|
token = self.oauth.get_access_token()
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {token}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
if method.upper() == "GET":
|
||||||
|
response = await client.get(url, headers=headers, params=params, timeout=30)
|
||||||
|
elif method.upper() == "POST":
|
||||||
|
response = await client.post(url, headers=headers, json=data, timeout=30)
|
||||||
|
elif method.upper() == "DELETE":
|
||||||
|
response = await client.delete(url, headers=headers, timeout=30)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported HTTP method: {method}")
|
||||||
|
|
||||||
|
response.raise_for_status()
|
||||||
|
return response
|
||||||
|
|
||||||
|
def _get_mime_type(self, filename: str) -> str:
|
||||||
|
"""Get MIME type based on file extension."""
|
||||||
|
import mimetypes
|
||||||
|
mime_type, _ = mimetypes.guess_type(filename)
|
||||||
|
return mime_type or "application/octet-stream"
|
||||||
|
|
||||||
|
# Webhook methods - BaseConnector interface
|
||||||
|
def handle_webhook_validation(self, request_method: str,
|
||||||
|
headers: Dict[str, str],
|
||||||
|
query_params: Dict[str, str]) -> Optional[str]:
|
||||||
|
"""Handle webhook validation (Graph API specific)."""
|
||||||
|
if request_method == "POST" and "validationToken" in query_params:
|
||||||
|
return query_params["validationToken"]
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def extract_webhook_channel_id(
|
def extract_webhook_channel_id(self, payload: Dict[str, Any],
|
||||||
self, payload: Dict[str, Any], headers: Dict[str, str]
|
headers: Dict[str, str]) -> Optional[str]:
|
||||||
) -> Optional[str]:
|
"""Extract channel/subscription ID from webhook payload."""
|
||||||
"""Extract SharePoint subscription ID from webhook payload"""
|
notifications = payload.get("value", [])
|
||||||
values = payload.get("value", [])
|
if notifications:
|
||||||
return values[0].get("subscriptionId") if values else None
|
return notifications[0].get("subscriptionId")
|
||||||
|
return None
|
||||||
|
|
||||||
async def handle_webhook(self, payload: Dict[str, Any]) -> List[str]:
|
async def handle_webhook(self, payload: Dict[str, Any]) -> List[str]:
|
||||||
values = payload.get("value", [])
|
"""Handle webhook notification and return affected file IDs."""
|
||||||
file_ids = []
|
affected_files: List[str] = []
|
||||||
for item in values:
|
notifications = payload.get("value", [])
|
||||||
resource_data = item.get("resourceData", {})
|
for notification in notifications:
|
||||||
file_id = resource_data.get("id")
|
resource = notification.get("resource")
|
||||||
if file_id:
|
if resource and "/drive/items/" in resource:
|
||||||
file_ids.append(file_id)
|
file_id = resource.split("/drive/items/")[-1]
|
||||||
return file_ids
|
affected_files.append(file_id)
|
||||||
|
return affected_files
|
||||||
|
|
||||||
async def cleanup_subscription(
|
async def cleanup_subscription(self, subscription_id: str) -> bool:
|
||||||
self, subscription_id: str, resource_id: str = None
|
"""Clean up subscription - BaseConnector interface."""
|
||||||
) -> bool:
|
if subscription_id == "no-webhook-configured":
|
||||||
if not self._authenticated:
|
logger.info("No subscription to cleanup (webhook was not configured)")
|
||||||
|
return True
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not await self.authenticate():
|
||||||
|
logger.error("OneDrive authentication failed during subscription cleanup")
|
||||||
|
return False
|
||||||
|
|
||||||
|
token = self.oauth.get_access_token()
|
||||||
|
headers = {"Authorization": f"Bearer {token}"}
|
||||||
|
|
||||||
|
url = f"{self._graph_base_url}/subscriptions/{subscription_id}"
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.delete(url, headers=headers, timeout=30)
|
||||||
|
|
||||||
|
if response.status_code in [200, 204, 404]:
|
||||||
|
logger.info(f"OneDrive subscription {subscription_id} cleaned up successfully")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
logger.warning(f"Unexpected response cleaning up subscription: {response.status_code}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to cleanup OneDrive subscription {subscription_id}: {e}")
|
||||||
return False
|
return False
|
||||||
token = self.oauth.get_access_token()
|
|
||||||
async with httpx.AsyncClient() as client:
|
|
||||||
resp = await client.delete(
|
|
||||||
f"{self.base_url}/subscriptions/{subscription_id}",
|
|
||||||
headers={"Authorization": f"Bearer {token}"},
|
|
||||||
)
|
|
||||||
return resp.status_code in (200, 204)
|
|
||||||
|
|
|
||||||
|
|
@ -1,18 +1,28 @@
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Optional, Dict, Any
|
||||||
|
|
||||||
import aiofiles
|
import aiofiles
|
||||||
from datetime import datetime
|
import msal
|
||||||
import httpx
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class OneDriveOAuth:
|
class OneDriveOAuth:
|
||||||
"""Direct token management for OneDrive, bypassing MSAL cache format"""
|
"""Handles Microsoft Graph OAuth for OneDrive (personal Microsoft accounts by default)."""
|
||||||
|
|
||||||
SCOPES = [
|
# Reserved scopes that must NOT be sent on token or silent calls
|
||||||
"offline_access",
|
RESERVED_SCOPES = {"openid", "profile", "offline_access"}
|
||||||
"Files.Read.All",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
# For PERSONAL Microsoft Accounts (OneDrive consumer):
|
||||||
|
# - Use AUTH_SCOPES for interactive auth (consent + refresh token issuance)
|
||||||
|
# - Use RESOURCE_SCOPES for acquire_token_silent / refresh paths
|
||||||
|
AUTH_SCOPES = ["User.Read", "Files.Read.All", "offline_access"]
|
||||||
|
RESOURCE_SCOPES = ["User.Read", "Files.Read.All"]
|
||||||
|
SCOPES = AUTH_SCOPES # Backward-compat alias if something references .SCOPES
|
||||||
|
|
||||||
|
# Kept for reference; MSAL derives endpoints from `authority`
|
||||||
AUTH_ENDPOINT = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize"
|
AUTH_ENDPOINT = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize"
|
||||||
TOKEN_ENDPOINT = "https://login.microsoftonline.com/common/oauth2/v2.0/token"
|
TOKEN_ENDPOINT = "https://login.microsoftonline.com/common/oauth2/v2.0/token"
|
||||||
|
|
||||||
|
|
@ -21,168 +31,292 @@ class OneDriveOAuth:
|
||||||
client_id: str,
|
client_id: str,
|
||||||
client_secret: str,
|
client_secret: str,
|
||||||
token_file: str = "onedrive_token.json",
|
token_file: str = "onedrive_token.json",
|
||||||
|
authority: str = "https://login.microsoftonline.com/common",
|
||||||
|
allow_json_refresh: bool = True,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Initialize OneDriveOAuth.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client_id: Azure AD application (client) ID.
|
||||||
|
client_secret: Azure AD application client secret.
|
||||||
|
token_file: Path to persisted token cache file (MSAL cache format).
|
||||||
|
authority: Usually "https://login.microsoftonline.com/common" for MSA + org,
|
||||||
|
or tenant-specific for work/school.
|
||||||
|
allow_json_refresh: If True, permit one-time migration from legacy flat JSON
|
||||||
|
{"access_token","refresh_token",...}. Otherwise refuse it.
|
||||||
|
"""
|
||||||
self.client_id = client_id
|
self.client_id = client_id
|
||||||
self.client_secret = client_secret
|
self.client_secret = client_secret
|
||||||
self.token_file = token_file
|
self.token_file = token_file
|
||||||
self._tokens = None
|
self.authority = authority
|
||||||
self._load_tokens()
|
self.allow_json_refresh = allow_json_refresh
|
||||||
|
self.token_cache = msal.SerializableTokenCache()
|
||||||
|
self._current_account = None
|
||||||
|
|
||||||
def _load_tokens(self):
|
# Initialize MSAL Confidential Client
|
||||||
"""Load tokens from file"""
|
self.app = msal.ConfidentialClientApplication(
|
||||||
if os.path.exists(self.token_file):
|
client_id=self.client_id,
|
||||||
with open(self.token_file, "r") as f:
|
client_credential=self.client_secret,
|
||||||
self._tokens = json.loads(f.read())
|
authority=self.authority,
|
||||||
print(f"Loaded tokens from {self.token_file}")
|
token_cache=self.token_cache,
|
||||||
else:
|
)
|
||||||
print(f"No token file found at {self.token_file}")
|
|
||||||
|
|
||||||
async def _save_tokens(self):
|
async def load_credentials(self) -> bool:
|
||||||
"""Save tokens to file"""
|
"""Load existing credentials from token file (async)."""
|
||||||
if self._tokens:
|
|
||||||
async with aiofiles.open(self.token_file, "w") as f:
|
|
||||||
await f.write(json.dumps(self._tokens, indent=2))
|
|
||||||
|
|
||||||
def _is_token_expired(self) -> bool:
|
|
||||||
"""Check if current access token is expired"""
|
|
||||||
if not self._tokens or 'expiry' not in self._tokens:
|
|
||||||
return True
|
|
||||||
|
|
||||||
expiry_str = self._tokens['expiry']
|
|
||||||
# Handle different expiry formats
|
|
||||||
try:
|
try:
|
||||||
if expiry_str.endswith('Z'):
|
logger.debug(f"OneDrive OAuth loading credentials from: {self.token_file}")
|
||||||
expiry_dt = datetime.fromisoformat(expiry_str[:-1])
|
if os.path.exists(self.token_file):
|
||||||
|
logger.debug(f"Token file exists, reading: {self.token_file}")
|
||||||
|
|
||||||
|
# Read the token file
|
||||||
|
async with aiofiles.open(self.token_file, "r") as f:
|
||||||
|
cache_data = await f.read()
|
||||||
|
logger.debug(f"Read {len(cache_data)} chars from token file")
|
||||||
|
|
||||||
|
if cache_data.strip():
|
||||||
|
# 1) Try legacy flat JSON first
|
||||||
|
try:
|
||||||
|
json_data = json.loads(cache_data)
|
||||||
|
if isinstance(json_data, dict) and "refresh_token" in json_data:
|
||||||
|
if self.allow_json_refresh:
|
||||||
|
logger.debug(
|
||||||
|
"Found legacy JSON refresh_token and allow_json_refresh=True; attempting migration refresh"
|
||||||
|
)
|
||||||
|
return await self._refresh_from_json_token(json_data)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"Token file contains a legacy JSON refresh_token, but allow_json_refresh=False. "
|
||||||
|
"Delete the file and re-auth."
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.debug("Token file is not flat JSON; attempting MSAL cache format")
|
||||||
|
|
||||||
|
# 2) Try MSAL cache format
|
||||||
|
logger.debug("Attempting MSAL cache deserialization")
|
||||||
|
self.token_cache.deserialize(cache_data)
|
||||||
|
|
||||||
|
# Get accounts from loaded cache
|
||||||
|
accounts = self.app.get_accounts()
|
||||||
|
logger.debug(f"Found {len(accounts)} accounts in MSAL cache")
|
||||||
|
if accounts:
|
||||||
|
self._current_account = accounts[0]
|
||||||
|
logger.debug(f"Set current account: {self._current_account.get('username', 'no username')}")
|
||||||
|
|
||||||
|
# Use RESOURCE_SCOPES (no reserved scopes) for silent acquisition
|
||||||
|
result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=self._current_account)
|
||||||
|
logger.debug(f"Silent token acquisition result keys: {list(result.keys()) if result else 'None'}")
|
||||||
|
if result and "access_token" in result:
|
||||||
|
logger.debug("Silent token acquisition successful")
|
||||||
|
await self.save_cache()
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
error_msg = (result or {}).get("error") or "No result"
|
||||||
|
logger.warning(f"Silent token acquisition failed: {error_msg}")
|
||||||
|
else:
|
||||||
|
logger.debug(f"Token file {self.token_file} is empty")
|
||||||
else:
|
else:
|
||||||
expiry_dt = datetime.fromisoformat(expiry_str)
|
logger.debug(f"Token file does not exist: {self.token_file}")
|
||||||
|
|
||||||
# Add 5-minute buffer
|
|
||||||
import datetime as dt
|
|
||||||
now = datetime.now()
|
|
||||||
return now >= (expiry_dt - dt.timedelta(minutes=5))
|
|
||||||
except:
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def _refresh_access_token(self) -> bool:
|
|
||||||
"""Refresh the access token using refresh token"""
|
|
||||||
if not self._tokens or 'refresh_token' not in self._tokens:
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
data = {
|
except Exception as e:
|
||||||
'client_id': self.client_id,
|
logger.error(f"Failed to load OneDrive credentials: {e}")
|
||||||
'client_secret': self.client_secret,
|
import traceback
|
||||||
'refresh_token': self._tokens['refresh_token'],
|
traceback.print_exc()
|
||||||
'grant_type': 'refresh_token',
|
|
||||||
'scope': ' '.join(self.SCOPES)
|
|
||||||
}
|
|
||||||
|
|
||||||
async with httpx.AsyncClient() as client:
|
|
||||||
try:
|
|
||||||
response = await client.post(self.TOKEN_ENDPOINT, data=data)
|
|
||||||
response.raise_for_status()
|
|
||||||
token_data = response.json()
|
|
||||||
|
|
||||||
# Update tokens
|
|
||||||
self._tokens['token'] = token_data['access_token']
|
|
||||||
if 'refresh_token' in token_data:
|
|
||||||
self._tokens['refresh_token'] = token_data['refresh_token']
|
|
||||||
|
|
||||||
# Calculate expiry
|
|
||||||
expires_in = token_data.get('expires_in', 3600)
|
|
||||||
import datetime as dt
|
|
||||||
expiry = datetime.now() + dt.timedelta(seconds=expires_in)
|
|
||||||
self._tokens['expiry'] = expiry.isoformat()
|
|
||||||
|
|
||||||
await self._save_tokens()
|
|
||||||
print("Access token refreshed successfully")
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Failed to refresh token: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def is_authenticated(self) -> bool:
|
|
||||||
"""Check if we have valid credentials"""
|
|
||||||
if not self._tokens:
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# If token is expired, try to refresh
|
async def _refresh_from_json_token(self, token_data: dict) -> bool:
|
||||||
if self._is_token_expired():
|
"""
|
||||||
print("Token expired, attempting refresh...")
|
Use refresh token from a legacy JSON file to get new tokens (one-time migration path).
|
||||||
if await self._refresh_access_token():
|
Prefer using an MSAL cache file and acquire_token_silent(); this path is only for migrating older files.
|
||||||
return True
|
"""
|
||||||
else:
|
try:
|
||||||
|
refresh_token = token_data.get("refresh_token")
|
||||||
|
if not refresh_token:
|
||||||
|
logger.error("No refresh_token found in JSON file - cannot refresh")
|
||||||
|
logger.error("You must re-authenticate interactively to obtain a valid token")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return True
|
# Use only RESOURCE_SCOPES when refreshing (no reserved scopes)
|
||||||
|
refresh_scopes = [s for s in self.RESOURCE_SCOPES if s not in self.RESERVED_SCOPES]
|
||||||
|
logger.debug(f"Using refresh token; refresh scopes = {refresh_scopes}")
|
||||||
|
|
||||||
def get_access_token(self) -> str:
|
result = self.app.acquire_token_by_refresh_token(
|
||||||
"""Get current access token"""
|
refresh_token=refresh_token,
|
||||||
if not self._tokens or 'token' not in self._tokens:
|
scopes=refresh_scopes,
|
||||||
raise ValueError("No access token available")
|
)
|
||||||
|
|
||||||
if self._is_token_expired():
|
if result and "access_token" in result:
|
||||||
raise ValueError("Access token expired and refresh failed")
|
logger.debug("Successfully refreshed token via legacy JSON path")
|
||||||
|
await self.save_cache()
|
||||||
|
|
||||||
return self._tokens['token']
|
accounts = self.app.get_accounts()
|
||||||
|
logger.debug(f"After refresh, found {len(accounts)} accounts")
|
||||||
|
if accounts:
|
||||||
|
self._current_account = accounts[0]
|
||||||
|
logger.debug(f"Set current account after refresh: {self._current_account.get('username', 'no username')}")
|
||||||
|
return True
|
||||||
|
|
||||||
async def revoke_credentials(self):
|
# Error handling
|
||||||
"""Clear tokens"""
|
err = (result or {}).get("error_description") or (result or {}).get("error") or "Unknown error"
|
||||||
self._tokens = None
|
logger.error(f"Refresh token failed: {err}")
|
||||||
if os.path.exists(self.token_file):
|
|
||||||
os.remove(self.token_file)
|
|
||||||
|
|
||||||
# Keep these methods for compatibility with your existing OAuth flow
|
if any(code in err for code in ("AADSTS70000", "invalid_grant", "interaction_required")):
|
||||||
def create_authorization_url(self, redirect_uri: str) -> str:
|
logger.warning(
|
||||||
"""Create authorization URL for OAuth flow"""
|
"Refresh denied due to unauthorized/expired scopes or invalid grant. "
|
||||||
from urllib.parse import urlencode
|
"Delete the token file and perform interactive sign-in with correct scopes."
|
||||||
|
)
|
||||||
|
|
||||||
params = {
|
return False
|
||||||
'client_id': self.client_id,
|
|
||||||
'response_type': 'code',
|
except Exception as e:
|
||||||
'redirect_uri': redirect_uri,
|
logger.error(f"Exception during refresh from JSON token: {e}")
|
||||||
'scope': ' '.join(self.SCOPES),
|
import traceback
|
||||||
'response_mode': 'query'
|
traceback.print_exc()
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def save_cache(self):
|
||||||
|
"""Persist the token cache to file."""
|
||||||
|
try:
|
||||||
|
# Ensure parent directory exists
|
||||||
|
parent = os.path.dirname(os.path.abspath(self.token_file))
|
||||||
|
if parent and not os.path.exists(parent):
|
||||||
|
os.makedirs(parent, exist_ok=True)
|
||||||
|
|
||||||
|
cache_data = self.token_cache.serialize()
|
||||||
|
if cache_data:
|
||||||
|
async with aiofiles.open(self.token_file, "w") as f:
|
||||||
|
await f.write(cache_data)
|
||||||
|
logger.debug(f"Token cache saved to {self.token_file}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to save token cache: {e}")
|
||||||
|
|
||||||
|
def create_authorization_url(self, redirect_uri: str, state: Optional[str] = None) -> str:
|
||||||
|
"""Create authorization URL for OAuth flow."""
|
||||||
|
# Store redirect URI for later use in callback
|
||||||
|
self._redirect_uri = redirect_uri
|
||||||
|
|
||||||
|
kwargs: Dict[str, Any] = {
|
||||||
|
# Interactive auth includes offline_access
|
||||||
|
"scopes": self.AUTH_SCOPES,
|
||||||
|
"redirect_uri": redirect_uri,
|
||||||
|
"prompt": "consent", # ensure refresh token on first run
|
||||||
}
|
}
|
||||||
|
if state:
|
||||||
|
kwargs["state"] = state # Optional CSRF protection
|
||||||
|
|
||||||
auth_url = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize"
|
auth_url = self.app.get_authorization_request_url(**kwargs)
|
||||||
return f"{auth_url}?{urlencode(params)}"
|
|
||||||
|
logger.debug(f"Generated auth URL: {auth_url}")
|
||||||
|
logger.debug(f"Auth scopes: {self.AUTH_SCOPES}")
|
||||||
|
|
||||||
|
return auth_url
|
||||||
|
|
||||||
async def handle_authorization_callback(
|
async def handle_authorization_callback(
|
||||||
self, authorization_code: str, redirect_uri: str
|
self, authorization_code: str, redirect_uri: str
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Handle OAuth callback and exchange code for tokens"""
|
"""Handle OAuth callback and exchange code for tokens."""
|
||||||
data = {
|
try:
|
||||||
'client_id': self.client_id,
|
result = self.app.acquire_token_by_authorization_code(
|
||||||
'client_secret': self.client_secret,
|
authorization_code,
|
||||||
'code': authorization_code,
|
scopes=self.AUTH_SCOPES, # same as authorize step
|
||||||
'grant_type': 'authorization_code',
|
redirect_uri=redirect_uri,
|
||||||
'redirect_uri': redirect_uri,
|
)
|
||||||
'scope': ' '.join(self.SCOPES)
|
|
||||||
}
|
|
||||||
|
|
||||||
async with httpx.AsyncClient() as client:
|
if result and "access_token" in result:
|
||||||
try:
|
accounts = self.app.get_accounts()
|
||||||
response = await client.post(self.TOKEN_ENDPOINT, data=data)
|
if accounts:
|
||||||
response.raise_for_status()
|
self._current_account = accounts[0]
|
||||||
token_data = response.json()
|
|
||||||
|
|
||||||
# Store tokens in our format
|
await self.save_cache()
|
||||||
import datetime as dt
|
logger.info("OneDrive OAuth authorization successful")
|
||||||
expires_in = token_data.get('expires_in', 3600)
|
|
||||||
expiry = datetime.now() + dt.timedelta(seconds=expires_in)
|
|
||||||
|
|
||||||
self._tokens = {
|
|
||||||
'token': token_data['access_token'],
|
|
||||||
'refresh_token': token_data['refresh_token'],
|
|
||||||
'scopes': self.SCOPES,
|
|
||||||
'expiry': expiry.isoformat()
|
|
||||||
}
|
|
||||||
|
|
||||||
await self._save_tokens()
|
|
||||||
print("Authorization successful, tokens saved")
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
except Exception as e:
|
error_msg = (result or {}).get("error_description") or (result or {}).get("error") or "Unknown error"
|
||||||
print(f"Authorization failed: {e}")
|
logger.error(f"OneDrive OAuth authorization failed: {error_msg}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Exception during OneDrive OAuth authorization: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def is_authenticated(self) -> bool:
|
||||||
|
"""Check if we have valid credentials."""
|
||||||
|
try:
|
||||||
|
# First try to load credentials if we haven't already
|
||||||
|
if not self._current_account:
|
||||||
|
await self.load_credentials()
|
||||||
|
|
||||||
|
# Try to get a token (MSAL will refresh if needed)
|
||||||
|
if self._current_account:
|
||||||
|
result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=self._current_account)
|
||||||
|
if result and "access_token" in result:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
error_msg = (result or {}).get("error") or "No result returned"
|
||||||
|
logger.debug(f"Token acquisition failed for current account: {error_msg}")
|
||||||
|
|
||||||
|
# Fallback: try without specific account
|
||||||
|
result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=None)
|
||||||
|
if result and "access_token" in result:
|
||||||
|
accounts = self.app.get_accounts()
|
||||||
|
if accounts:
|
||||||
|
self._current_account = accounts[0]
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Authentication check failed: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_access_token(self) -> str:
|
||||||
|
"""Get an access token for Microsoft Graph."""
|
||||||
|
try:
|
||||||
|
# Try with current account first
|
||||||
|
if self._current_account:
|
||||||
|
result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=self._current_account)
|
||||||
|
if result and "access_token" in result:
|
||||||
|
return result["access_token"]
|
||||||
|
|
||||||
|
# Fallback: try without specific account
|
||||||
|
result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=None)
|
||||||
|
if result and "access_token" in result:
|
||||||
|
return result["access_token"]
|
||||||
|
|
||||||
|
# If we get here, authentication has failed
|
||||||
|
error_msg = (result or {}).get("error_description") or (result or {}).get("error") or "No valid authentication"
|
||||||
|
raise ValueError(f"Failed to acquire access token: {error_msg}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to get access token: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def revoke_credentials(self):
|
||||||
|
"""Clear token cache and remove token file."""
|
||||||
|
try:
|
||||||
|
# Clear in-memory state
|
||||||
|
self._current_account = None
|
||||||
|
self.token_cache = msal.SerializableTokenCache()
|
||||||
|
|
||||||
|
# Recreate MSAL app with fresh cache
|
||||||
|
self.app = msal.ConfidentialClientApplication(
|
||||||
|
client_id=self.client_id,
|
||||||
|
client_credential=self.client_secret,
|
||||||
|
authority=self.authority,
|
||||||
|
token_cache=self.token_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Remove token file
|
||||||
|
if os.path.exists(self.token_file):
|
||||||
|
os.remove(self.token_file)
|
||||||
|
logger.info(f"Removed OneDrive token file: {self.token_file}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to revoke OneDrive credentials: {e}")
|
||||||
|
|
||||||
|
def get_service(self) -> str:
|
||||||
|
"""Return an access token (Graph client is just the bearer)."""
|
||||||
|
return self.get_access_token()
|
||||||
|
|
|
||||||
|
|
@ -1,241 +1,564 @@
|
||||||
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import List, Dict, Any, Optional
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
from datetime import datetime
|
||||||
import httpx
|
import httpx
|
||||||
import uuid
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from typing import Dict, List, Any, Optional
|
|
||||||
|
|
||||||
from ..base import BaseConnector, ConnectorDocument, DocumentACL
|
from ..base import BaseConnector, ConnectorDocument, DocumentACL
|
||||||
from .oauth import SharePointOAuth
|
from .oauth import SharePointOAuth
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class SharePointConnector(BaseConnector):
|
class SharePointConnector(BaseConnector):
|
||||||
"""SharePoint Sites connector using Microsoft Graph API"""
|
"""SharePoint connector using MSAL-based OAuth for authentication"""
|
||||||
|
|
||||||
|
# Required BaseConnector class attributes
|
||||||
CLIENT_ID_ENV_VAR = "MICROSOFT_GRAPH_OAUTH_CLIENT_ID"
|
CLIENT_ID_ENV_VAR = "MICROSOFT_GRAPH_OAUTH_CLIENT_ID"
|
||||||
CLIENT_SECRET_ENV_VAR = "MICROSOFT_GRAPH_OAUTH_CLIENT_SECRET"
|
CLIENT_SECRET_ENV_VAR = "MICROSOFT_GRAPH_OAUTH_CLIENT_SECRET"
|
||||||
|
|
||||||
# Connector metadata
|
# Connector metadata
|
||||||
CONNECTOR_NAME = "SharePoint"
|
CONNECTOR_NAME = "SharePoint"
|
||||||
CONNECTOR_DESCRIPTION = "Connect to SharePoint sites to sync team documents"
|
CONNECTOR_DESCRIPTION = "Connect to SharePoint to sync documents and files"
|
||||||
CONNECTOR_ICON = "sharepoint"
|
CONNECTOR_ICON = "sharepoint"
|
||||||
|
|
||||||
def __init__(self, config: Dict[str, Any]):
|
def __init__(self, config: Dict[str, Any]):
|
||||||
super().__init__(config)
|
super().__init__(config) # Fix: Call parent init first
|
||||||
|
|
||||||
|
def __init__(self, config: Dict[str, Any]):
|
||||||
|
logger.debug(f"SharePoint connector __init__ called with config type: {type(config)}")
|
||||||
|
logger.debug(f"SharePoint connector __init__ config value: {config}")
|
||||||
|
|
||||||
|
# Ensure we always pass a valid config to the base class
|
||||||
|
if config is None:
|
||||||
|
logger.debug("Config was None, using empty dict")
|
||||||
|
config = {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.debug("Calling super().__init__")
|
||||||
|
super().__init__(config) # Now safe to call with empty dict instead of None
|
||||||
|
logger.debug("super().__init__ completed successfully")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"super().__init__ failed: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Initialize with defaults that allow the connector to be listed
|
||||||
|
self.client_id = None
|
||||||
|
self.client_secret = None
|
||||||
|
self.tenant_id = config.get("tenant_id", "common")
|
||||||
|
self.sharepoint_url = config.get("sharepoint_url")
|
||||||
|
self.redirect_uri = config.get("redirect_uri", "http://localhost")
|
||||||
|
|
||||||
|
# Try to get credentials, but don't fail if they're missing
|
||||||
|
try:
|
||||||
|
logger.debug("Attempting to get client_id")
|
||||||
|
self.client_id = self.get_client_id()
|
||||||
|
logger.debug(f"Got client_id: {self.client_id is not None}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Failed to get client_id: {e}")
|
||||||
|
pass # Credentials not available, that's OK for listing
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.debug("Attempting to get client_secret")
|
||||||
|
self.client_secret = self.get_client_secret()
|
||||||
|
logger.debug(f"Got client_secret: {self.client_secret is not None}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Failed to get client_secret: {e}")
|
||||||
|
pass # Credentials not available, that's OK for listing
|
||||||
|
|
||||||
|
# Token file setup
|
||||||
project_root = Path(__file__).resolve().parent.parent.parent.parent
|
project_root = Path(__file__).resolve().parent.parent.parent.parent
|
||||||
token_file = config.get("token_file") or str(project_root / "onedrive_token.json")
|
token_file = config.get("token_file") or str(project_root / "sharepoint_token.json")
|
||||||
self.oauth = SharePointOAuth(
|
Path(token_file).parent.mkdir(parents=True, exist_ok=True)
|
||||||
client_id=self.get_client_id(),
|
|
||||||
client_secret=self.get_client_secret(),
|
|
||||||
token_file=token_file,
|
|
||||||
)
|
|
||||||
self.subscription_id = config.get("subscription_id") or config.get(
|
|
||||||
"webhook_channel_id"
|
|
||||||
)
|
|
||||||
self.base_url = "https://graph.microsoft.com/v1.0"
|
|
||||||
|
|
||||||
# SharePoint site configuration
|
# Only initialize OAuth if we have credentials
|
||||||
self.site_id = config.get("site_id") # Required for SharePoint
|
if self.client_id and self.client_secret:
|
||||||
|
connection_id = config.get("connection_id", "default")
|
||||||
|
|
||||||
async def authenticate(self) -> bool:
|
# Use token_file from config if provided, otherwise generate one
|
||||||
if await self.oauth.is_authenticated():
|
if config.get("token_file"):
|
||||||
self._authenticated = True
|
oauth_token_file = config["token_file"]
|
||||||
return True
|
else:
|
||||||
return False
|
oauth_token_file = f"sharepoint_token_{connection_id}.json"
|
||||||
|
|
||||||
async def setup_subscription(self) -> str:
|
authority = f"https://login.microsoftonline.com/{self.tenant_id}" if self.tenant_id != "common" else "https://login.microsoftonline.com/common"
|
||||||
if not self._authenticated:
|
|
||||||
raise ValueError("Not authenticated")
|
|
||||||
|
|
||||||
webhook_url = self.config.get("webhook_url")
|
self.oauth = SharePointOAuth(
|
||||||
if not webhook_url:
|
client_id=self.client_id,
|
||||||
raise ValueError("webhook_url required in config for subscriptions")
|
client_secret=self.client_secret,
|
||||||
|
token_file=oauth_token_file,
|
||||||
|
authority=authority
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.oauth = None
|
||||||
|
|
||||||
expiration = (datetime.utcnow() + timedelta(days=2)).isoformat() + "Z"
|
# Track subscription ID for webhooks
|
||||||
body = {
|
self._subscription_id: Optional[str] = None
|
||||||
"changeType": "created,updated,deleted",
|
|
||||||
"notificationUrl": webhook_url,
|
# Add Graph API defaults similar to Google Drive flags
|
||||||
"resource": f"/sites/{self.site_id}/drive/root",
|
self._graph_api_version = "v1.0"
|
||||||
"expirationDateTime": expiration,
|
self._default_params = {
|
||||||
"clientState": str(uuid.uuid4()),
|
"$select": "id,name,size,lastModifiedDateTime,createdDateTime,webUrl,file,folder,@microsoft.graph.downloadUrl"
|
||||||
}
|
}
|
||||||
|
|
||||||
token = self.oauth.get_access_token()
|
@property
|
||||||
async with httpx.AsyncClient() as client:
|
def _graph_base_url(self) -> str:
|
||||||
resp = await client.post(
|
"""Base URL for Microsoft Graph API calls"""
|
||||||
f"{self.base_url}/subscriptions",
|
return f"https://graph.microsoft.com/{self._graph_api_version}"
|
||||||
json=body,
|
|
||||||
headers={"Authorization": f"Bearer {token}"},
|
|
||||||
)
|
|
||||||
resp.raise_for_status()
|
|
||||||
data = resp.json()
|
|
||||||
|
|
||||||
self.subscription_id = data["id"]
|
def emit(self, doc: ConnectorDocument) -> None:
|
||||||
return self.subscription_id
|
"""
|
||||||
|
Emit a ConnectorDocument instance.
|
||||||
|
Override this method to integrate with your ingestion pipeline.
|
||||||
|
"""
|
||||||
|
logger.debug(f"Emitting SharePoint document: {doc.id} ({doc.filename})")
|
||||||
|
|
||||||
async def list_files(
|
async def authenticate(self) -> bool:
|
||||||
self, page_token: Optional[str] = None, limit: int = 100
|
"""Test authentication - BaseConnector interface"""
|
||||||
) -> Dict[str, Any]:
|
logger.debug(f"SharePoint authenticate() called, oauth is None: {self.oauth is None}")
|
||||||
if not self._authenticated:
|
try:
|
||||||
raise ValueError("Not authenticated")
|
if not self.oauth:
|
||||||
|
logger.debug("SharePoint authentication failed: OAuth not initialized")
|
||||||
|
self._authenticated = False
|
||||||
|
return False
|
||||||
|
|
||||||
params = {"$top": str(limit)}
|
logger.debug("Loading SharePoint credentials...")
|
||||||
if page_token:
|
# Try to load existing credentials first
|
||||||
params["$skiptoken"] = page_token
|
load_result = await self.oauth.load_credentials()
|
||||||
|
logger.debug(f"Load credentials result: {load_result}")
|
||||||
|
|
||||||
token = self.oauth.get_access_token()
|
logger.debug("Checking SharePoint authentication status...")
|
||||||
async with httpx.AsyncClient() as client:
|
authenticated = await self.oauth.is_authenticated()
|
||||||
resp = await client.get(
|
logger.debug(f"SharePoint is_authenticated result: {authenticated}")
|
||||||
f"{self.base_url}/sites/{self.site_id}/drive/root/children",
|
|
||||||
params=params,
|
|
||||||
headers={"Authorization": f"Bearer {token}"},
|
|
||||||
)
|
|
||||||
resp.raise_for_status()
|
|
||||||
data = resp.json()
|
|
||||||
|
|
||||||
files = []
|
self._authenticated = authenticated
|
||||||
for item in data.get("value", []):
|
return authenticated
|
||||||
if item.get("file"):
|
except Exception as e:
|
||||||
files.append(
|
logger.error(f"SharePoint authentication failed: {e}")
|
||||||
{
|
import traceback
|
||||||
"id": item["id"],
|
traceback.print_exc()
|
||||||
"name": item["name"],
|
self._authenticated = False
|
||||||
"mimeType": item.get("file", {}).get(
|
return False
|
||||||
"mimeType", "application/octet-stream"
|
|
||||||
),
|
|
||||||
"webViewLink": item.get("webUrl"),
|
|
||||||
"createdTime": item.get("createdDateTime"),
|
|
||||||
"modifiedTime": item.get("lastModifiedDateTime"),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
next_token = None
|
def get_auth_url(self) -> str:
|
||||||
next_link = data.get("@odata.nextLink")
|
"""Get OAuth authorization URL"""
|
||||||
if next_link:
|
if not self.oauth:
|
||||||
from urllib.parse import urlparse, parse_qs
|
raise RuntimeError("SharePoint OAuth not initialized - missing credentials")
|
||||||
|
return self.oauth.create_authorization_url(self.redirect_uri)
|
||||||
|
|
||||||
parsed = urlparse(next_link)
|
async def handle_oauth_callback(self, auth_code: str) -> Dict[str, Any]:
|
||||||
next_token = parse_qs(parsed.query).get("$skiptoken", [None])[0]
|
"""Handle OAuth callback"""
|
||||||
|
if not self.oauth:
|
||||||
return {"files": files, "nextPageToken": next_token}
|
raise RuntimeError("SharePoint OAuth not initialized - missing credentials")
|
||||||
|
try:
|
||||||
async def get_file_content(self, file_id: str) -> ConnectorDocument:
|
success = await self.oauth.handle_authorization_callback(auth_code, self.redirect_uri)
|
||||||
if not self._authenticated:
|
if success:
|
||||||
raise ValueError("Not authenticated")
|
self._authenticated = True
|
||||||
|
return {"status": "success"}
|
||||||
token = self.oauth.get_access_token()
|
|
||||||
headers = {"Authorization": f"Bearer {token}"}
|
|
||||||
async with httpx.AsyncClient() as client:
|
|
||||||
meta_resp = await client.get(
|
|
||||||
f"{self.base_url}/sites/{self.site_id}/drive/items/{file_id}",
|
|
||||||
headers=headers,
|
|
||||||
)
|
|
||||||
meta_resp.raise_for_status()
|
|
||||||
metadata = meta_resp.json()
|
|
||||||
|
|
||||||
content_resp = await client.get(
|
|
||||||
f"{self.base_url}/sites/{self.site_id}/drive/items/{file_id}/content",
|
|
||||||
headers=headers,
|
|
||||||
)
|
|
||||||
content = content_resp.content
|
|
||||||
|
|
||||||
# Handle the possibility of this being a redirect
|
|
||||||
if content_resp.status_code in (301, 302, 303, 307, 308):
|
|
||||||
redirect_url = content_resp.headers.get("Location")
|
|
||||||
if redirect_url:
|
|
||||||
content_resp = await client.get(redirect_url)
|
|
||||||
content_resp.raise_for_status()
|
|
||||||
content = content_resp.content
|
|
||||||
else:
|
else:
|
||||||
content_resp.raise_for_status()
|
raise ValueError("OAuth callback failed")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"OAuth callback failed: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
perm_resp = await client.get(
|
def sync_once(self) -> None:
|
||||||
f"{self.base_url}/sites/{self.site_id}/drive/items/{file_id}/permissions",
|
"""
|
||||||
headers=headers,
|
Perform a one-shot sync of SharePoint files and emit documents.
|
||||||
)
|
This method mirrors the Google Drive connector's sync_once functionality.
|
||||||
perm_resp.raise_for_status()
|
"""
|
||||||
permissions = perm_resp.json()
|
import asyncio
|
||||||
|
|
||||||
acl = self._parse_permissions(metadata, permissions)
|
async def _async_sync():
|
||||||
modified = datetime.fromisoformat(
|
try:
|
||||||
metadata["lastModifiedDateTime"].replace("Z", "+00:00")
|
# Get list of files
|
||||||
).replace(tzinfo=None)
|
file_list = await self.list_files(max_files=1000) # Adjust as needed
|
||||||
created = datetime.fromisoformat(
|
files = file_list.get("files", [])
|
||||||
metadata["createdDateTime"].replace("Z", "+00:00")
|
|
||||||
).replace(tzinfo=None)
|
|
||||||
|
|
||||||
document = ConnectorDocument(
|
for file_info in files:
|
||||||
id=metadata["id"],
|
try:
|
||||||
filename=metadata["name"],
|
file_id = file_info.get("id")
|
||||||
mimetype=metadata.get("file", {}).get(
|
if not file_id:
|
||||||
"mimeType", "application/octet-stream"
|
continue
|
||||||
),
|
|
||||||
content=content,
|
|
||||||
source_url=metadata.get("webUrl"),
|
|
||||||
acl=acl,
|
|
||||||
modified_time=modified,
|
|
||||||
created_time=created,
|
|
||||||
metadata={"size": metadata.get("size")},
|
|
||||||
)
|
|
||||||
return document
|
|
||||||
|
|
||||||
def _parse_permissions(
|
# Get full document content
|
||||||
self, metadata: Dict[str, Any], permissions: Dict[str, Any]
|
doc = await self.get_file_content(file_id)
|
||||||
) -> DocumentACL:
|
self.emit(doc)
|
||||||
acl = DocumentACL()
|
|
||||||
owner = metadata.get("createdBy", {}).get("user", {}).get("email")
|
except Exception as e:
|
||||||
if owner:
|
logger.error(f"Failed to sync SharePoint file {file_info.get('name', 'unknown')}: {e}")
|
||||||
acl.owner = owner
|
continue
|
||||||
for perm in permissions.get("value", []):
|
|
||||||
role = perm.get("roles", ["read"])[0]
|
except Exception as e:
|
||||||
grantee = perm.get("grantedToV2") or perm.get("grantedTo")
|
logger.error(f"SharePoint sync_once failed: {e}")
|
||||||
if not grantee:
|
raise
|
||||||
continue
|
|
||||||
user = grantee.get("user")
|
# Run the async sync
|
||||||
if user and user.get("email"):
|
if hasattr(asyncio, 'run'):
|
||||||
acl.user_permissions[user["email"]] = role
|
asyncio.run(_async_sync())
|
||||||
group = grantee.get("group")
|
else:
|
||||||
if group and group.get("email"):
|
# Python < 3.7 compatibility
|
||||||
acl.group_permissions[group["email"]] = role
|
loop = asyncio.get_event_loop()
|
||||||
return acl
|
loop.run_until_complete(_async_sync())
|
||||||
|
|
||||||
|
async def setup_subscription(self) -> str:
|
||||||
|
"""Set up real-time subscription for file changes - BaseConnector interface"""
|
||||||
|
webhook_url = self.config.get('webhook_url')
|
||||||
|
if not webhook_url:
|
||||||
|
logger.warning("No webhook URL configured, skipping SharePoint subscription setup")
|
||||||
|
return "no-webhook-configured"
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Ensure we're authenticated
|
||||||
|
if not await self.authenticate():
|
||||||
|
raise RuntimeError("SharePoint authentication failed during subscription setup")
|
||||||
|
|
||||||
|
token = self.oauth.get_access_token()
|
||||||
|
|
||||||
|
# Microsoft Graph subscription for SharePoint site
|
||||||
|
site_info = self._parse_sharepoint_url()
|
||||||
|
if site_info:
|
||||||
|
resource = f"sites/{site_info['host_name']}:/sites/{site_info['site_name']}:/drive/root"
|
||||||
|
else:
|
||||||
|
resource = "/me/drive/root"
|
||||||
|
|
||||||
|
subscription_data = {
|
||||||
|
"changeType": "created,updated,deleted",
|
||||||
|
"notificationUrl": f"{webhook_url}/webhook/sharepoint",
|
||||||
|
"resource": resource,
|
||||||
|
"expirationDateTime": self._get_subscription_expiry(),
|
||||||
|
"clientState": f"sharepoint_{self.tenant_id}"
|
||||||
|
}
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {token}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
url = f"{self._graph_base_url}/subscriptions"
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.post(url, json=subscription_data, headers=headers, timeout=30)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
result = response.json()
|
||||||
|
subscription_id = result.get("id")
|
||||||
|
|
||||||
|
if subscription_id:
|
||||||
|
self._subscription_id = subscription_id
|
||||||
|
logger.info(f"SharePoint subscription created: {subscription_id}")
|
||||||
|
return subscription_id
|
||||||
|
else:
|
||||||
|
raise ValueError("No subscription ID returned from Microsoft Graph")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to setup SharePoint subscription: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _get_subscription_expiry(self) -> str:
|
||||||
|
"""Get subscription expiry time (max 3 days for Graph API)"""
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
expiry = datetime.utcnow() + timedelta(days=3) # 3 days max for Graph
|
||||||
|
return expiry.strftime("%Y-%m-%dT%H:%M:%S.%fZ")
|
||||||
|
|
||||||
|
def _parse_sharepoint_url(self) -> Optional[Dict[str, str]]:
|
||||||
|
"""Parse SharePoint URL to extract site information for Graph API"""
|
||||||
|
if not self.sharepoint_url:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
parsed = urlparse(self.sharepoint_url)
|
||||||
|
# Extract hostname and site name from URL like: https://contoso.sharepoint.com/sites/teamsite
|
||||||
|
host_name = parsed.netloc
|
||||||
|
path_parts = parsed.path.strip('/').split('/')
|
||||||
|
|
||||||
|
if len(path_parts) >= 2 and path_parts[0] == 'sites':
|
||||||
|
site_name = path_parts[1]
|
||||||
|
return {
|
||||||
|
"host_name": host_name,
|
||||||
|
"site_name": site_name
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Could not parse SharePoint URL {self.sharepoint_url}: {e}")
|
||||||
|
|
||||||
def handle_webhook_validation(
|
|
||||||
self, request_method: str, headers: Dict[str, str], query_params: Dict[str, str]
|
|
||||||
) -> Optional[str]:
|
|
||||||
"""Handle Microsoft Graph webhook validation"""
|
|
||||||
if request_method == "GET":
|
|
||||||
validation_token = query_params.get("validationtoken") or query_params.get(
|
|
||||||
"validationToken"
|
|
||||||
)
|
|
||||||
if validation_token:
|
|
||||||
return validation_token
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def extract_webhook_channel_id(
|
async def list_files(self, page_token: Optional[str] = None, max_files: Optional[int] = None) -> Dict[str, Any]:
|
||||||
self, payload: Dict[str, Any], headers: Dict[str, str]
|
"""List all files using Microsoft Graph API - BaseConnector interface"""
|
||||||
) -> Optional[str]:
|
try:
|
||||||
"""Extract SharePoint subscription ID from webhook payload"""
|
# Ensure authentication
|
||||||
values = payload.get("value", [])
|
if not await self.authenticate():
|
||||||
return values[0].get("subscriptionId") if values else None
|
raise RuntimeError("SharePoint authentication failed during file listing")
|
||||||
|
|
||||||
|
files = []
|
||||||
|
max_files_value = max_files if max_files is not None else 100
|
||||||
|
|
||||||
|
# Build Graph API URL for the site or fallback to user's OneDrive
|
||||||
|
site_info = self._parse_sharepoint_url()
|
||||||
|
if site_info:
|
||||||
|
base_url = f"{self._graph_base_url}/sites/{site_info['host_name']}:/sites/{site_info['site_name']}:/drive/root/children"
|
||||||
|
else:
|
||||||
|
base_url = f"{self._graph_base_url}/me/drive/root/children"
|
||||||
|
|
||||||
|
params = dict(self._default_params)
|
||||||
|
params["$top"] = max_files_value
|
||||||
|
|
||||||
|
if page_token:
|
||||||
|
params["$skiptoken"] = page_token
|
||||||
|
|
||||||
|
response = await self._make_graph_request(base_url, params=params)
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
items = data.get("value", [])
|
||||||
|
for item in items:
|
||||||
|
# Only include files, not folders
|
||||||
|
if item.get("file"):
|
||||||
|
files.append({
|
||||||
|
"id": item.get("id", ""),
|
||||||
|
"name": item.get("name", ""),
|
||||||
|
"path": f"/drive/items/{item.get('id')}",
|
||||||
|
"size": int(item.get("size", 0)),
|
||||||
|
"modified": item.get("lastModifiedDateTime"),
|
||||||
|
"created": item.get("createdDateTime"),
|
||||||
|
"mime_type": item.get("file", {}).get("mimeType", self._get_mime_type(item.get("name", ""))),
|
||||||
|
"url": item.get("webUrl", ""),
|
||||||
|
"download_url": item.get("@microsoft.graph.downloadUrl")
|
||||||
|
})
|
||||||
|
|
||||||
|
# Check for next page
|
||||||
|
next_page_token = None
|
||||||
|
next_link = data.get("@odata.nextLink")
|
||||||
|
if next_link:
|
||||||
|
from urllib.parse import urlparse, parse_qs
|
||||||
|
parsed = urlparse(next_link)
|
||||||
|
query_params = parse_qs(parsed.query)
|
||||||
|
if "$skiptoken" in query_params:
|
||||||
|
next_page_token = query_params["$skiptoken"][0]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"files": files,
|
||||||
|
"next_page_token": next_page_token
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to list SharePoint files: {e}")
|
||||||
|
return {"files": [], "next_page_token": None} # Return empty result instead of raising
|
||||||
|
|
||||||
|
async def get_file_content(self, file_id: str) -> ConnectorDocument:
|
||||||
|
"""Get file content and metadata - BaseConnector interface"""
|
||||||
|
try:
|
||||||
|
# Ensure authentication
|
||||||
|
if not await self.authenticate():
|
||||||
|
raise RuntimeError("SharePoint authentication failed during file content retrieval")
|
||||||
|
|
||||||
|
# First get file metadata using Graph API
|
||||||
|
file_metadata = await self._get_file_metadata_by_id(file_id)
|
||||||
|
|
||||||
|
if not file_metadata:
|
||||||
|
raise ValueError(f"File not found: {file_id}")
|
||||||
|
|
||||||
|
# Download file content
|
||||||
|
download_url = file_metadata.get("download_url")
|
||||||
|
if download_url:
|
||||||
|
content = await self._download_file_from_url(download_url)
|
||||||
|
else:
|
||||||
|
content = await self._download_file_content(file_id)
|
||||||
|
|
||||||
|
# Create ACL from metadata
|
||||||
|
acl = DocumentACL(
|
||||||
|
owner="", # Graph API requires additional calls for detailed permissions
|
||||||
|
user_permissions={},
|
||||||
|
group_permissions={}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parse dates
|
||||||
|
modified_time = self._parse_graph_date(file_metadata.get("modified"))
|
||||||
|
created_time = self._parse_graph_date(file_metadata.get("created"))
|
||||||
|
|
||||||
|
return ConnectorDocument(
|
||||||
|
id=file_id,
|
||||||
|
filename=file_metadata.get("name", ""),
|
||||||
|
mimetype=file_metadata.get("mime_type", "application/octet-stream"),
|
||||||
|
content=content,
|
||||||
|
source_url=file_metadata.get("url", ""),
|
||||||
|
acl=acl,
|
||||||
|
modified_time=modified_time,
|
||||||
|
created_time=created_time,
|
||||||
|
metadata={
|
||||||
|
"sharepoint_path": file_metadata.get("path", ""),
|
||||||
|
"sharepoint_url": self.sharepoint_url,
|
||||||
|
"size": file_metadata.get("size", 0)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to get SharePoint file content {file_id}: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def _get_file_metadata_by_id(self, file_id: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Get file metadata by ID using Graph API"""
|
||||||
|
try:
|
||||||
|
# Try site-specific path first, then fallback to user drive
|
||||||
|
site_info = self._parse_sharepoint_url()
|
||||||
|
if site_info:
|
||||||
|
url = f"{self._graph_base_url}/sites/{site_info['host_name']}:/sites/{site_info['site_name']}:/drive/items/{file_id}"
|
||||||
|
else:
|
||||||
|
url = f"{self._graph_base_url}/me/drive/items/{file_id}"
|
||||||
|
|
||||||
|
params = dict(self._default_params)
|
||||||
|
|
||||||
|
response = await self._make_graph_request(url, params=params)
|
||||||
|
item = response.json()
|
||||||
|
|
||||||
|
if item.get("file"):
|
||||||
|
return {
|
||||||
|
"id": file_id,
|
||||||
|
"name": item.get("name", ""),
|
||||||
|
"path": f"/drive/items/{file_id}",
|
||||||
|
"size": int(item.get("size", 0)),
|
||||||
|
"modified": item.get("lastModifiedDateTime"),
|
||||||
|
"created": item.get("createdDateTime"),
|
||||||
|
"mime_type": item.get("file", {}).get("mimeType", self._get_mime_type(item.get("name", ""))),
|
||||||
|
"url": item.get("webUrl", ""),
|
||||||
|
"download_url": item.get("@microsoft.graph.downloadUrl")
|
||||||
|
}
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to get file metadata for {file_id}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _download_file_content(self, file_id: str) -> bytes:
|
||||||
|
"""Download file content by file ID using Graph API"""
|
||||||
|
try:
|
||||||
|
site_info = self._parse_sharepoint_url()
|
||||||
|
if site_info:
|
||||||
|
url = f"{self._graph_base_url}/sites/{site_info['host_name']}:/sites/{site_info['site_name']}:/drive/items/{file_id}/content"
|
||||||
|
else:
|
||||||
|
url = f"{self._graph_base_url}/me/drive/items/{file_id}/content"
|
||||||
|
|
||||||
|
token = self.oauth.get_access_token()
|
||||||
|
headers = {"Authorization": f"Bearer {token}"}
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.get(url, headers=headers, timeout=60)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.content
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to download file content for {file_id}: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def _download_file_from_url(self, download_url: str) -> bytes:
|
||||||
|
"""Download file content from direct download URL"""
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.get(download_url, timeout=60)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.content
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to download from URL {download_url}: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _parse_graph_date(self, date_str: Optional[str]) -> datetime:
|
||||||
|
"""Parse Microsoft Graph date string to datetime"""
|
||||||
|
if not date_str:
|
||||||
|
return datetime.now()
|
||||||
|
|
||||||
|
try:
|
||||||
|
if date_str.endswith('Z'):
|
||||||
|
return datetime.fromisoformat(date_str[:-1]).replace(tzinfo=None)
|
||||||
|
else:
|
||||||
|
return datetime.fromisoformat(date_str.replace('T', ' '))
|
||||||
|
except (ValueError, AttributeError):
|
||||||
|
return datetime.now()
|
||||||
|
|
||||||
|
async def _make_graph_request(self, url: str, method: str = "GET",
|
||||||
|
data: Optional[Dict] = None, params: Optional[Dict] = None) -> httpx.Response:
|
||||||
|
"""Make authenticated API request to Microsoft Graph"""
|
||||||
|
token = self.oauth.get_access_token()
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {token}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
if method.upper() == "GET":
|
||||||
|
response = await client.get(url, headers=headers, params=params, timeout=30)
|
||||||
|
elif method.upper() == "POST":
|
||||||
|
response = await client.post(url, headers=headers, json=data, timeout=30)
|
||||||
|
elif method.upper() == "DELETE":
|
||||||
|
response = await client.delete(url, headers=headers, timeout=30)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported HTTP method: {method}")
|
||||||
|
|
||||||
|
response.raise_for_status()
|
||||||
|
return response
|
||||||
|
|
||||||
|
def _get_mime_type(self, filename: str) -> str:
|
||||||
|
"""Get MIME type based on file extension"""
|
||||||
|
import mimetypes
|
||||||
|
mime_type, _ = mimetypes.guess_type(filename)
|
||||||
|
return mime_type or "application/octet-stream"
|
||||||
|
|
||||||
|
# Webhook methods - BaseConnector interface
|
||||||
|
def handle_webhook_validation(self, request_method: str, headers: Dict[str, str],
|
||||||
|
query_params: Dict[str, str]) -> Optional[str]:
|
||||||
|
"""Handle webhook validation (Graph API specific)"""
|
||||||
|
if request_method == "POST" and "validationToken" in query_params:
|
||||||
|
return query_params["validationToken"]
|
||||||
|
return None
|
||||||
|
|
||||||
|
def extract_webhook_channel_id(self, payload: Dict[str, Any],
|
||||||
|
headers: Dict[str, str]) -> Optional[str]:
|
||||||
|
"""Extract channel/subscription ID from webhook payload"""
|
||||||
|
notifications = payload.get("value", [])
|
||||||
|
if notifications:
|
||||||
|
return notifications[0].get("subscriptionId")
|
||||||
|
return None
|
||||||
|
|
||||||
async def handle_webhook(self, payload: Dict[str, Any]) -> List[str]:
|
async def handle_webhook(self, payload: Dict[str, Any]) -> List[str]:
|
||||||
values = payload.get("value", [])
|
"""Handle webhook notification and return affected file IDs"""
|
||||||
file_ids = []
|
affected_files = []
|
||||||
for item in values:
|
|
||||||
resource_data = item.get("resourceData", {})
|
|
||||||
file_id = resource_data.get("id")
|
|
||||||
if file_id:
|
|
||||||
file_ids.append(file_id)
|
|
||||||
return file_ids
|
|
||||||
|
|
||||||
async def cleanup_subscription(
|
# Process Microsoft Graph webhook payload
|
||||||
self, subscription_id: str, resource_id: str = None
|
notifications = payload.get("value", [])
|
||||||
) -> bool:
|
for notification in notifications:
|
||||||
if not self._authenticated:
|
resource = notification.get("resource")
|
||||||
|
if resource and "/drive/items/" in resource:
|
||||||
|
file_id = resource.split("/drive/items/")[-1]
|
||||||
|
affected_files.append(file_id)
|
||||||
|
|
||||||
|
return affected_files
|
||||||
|
|
||||||
|
async def cleanup_subscription(self, subscription_id: str) -> bool:
|
||||||
|
"""Clean up subscription - BaseConnector interface"""
|
||||||
|
if subscription_id == "no-webhook-configured":
|
||||||
|
logger.info("No subscription to cleanup (webhook was not configured)")
|
||||||
|
return True
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Ensure authentication
|
||||||
|
if not await self.authenticate():
|
||||||
|
logger.error("SharePoint authentication failed during subscription cleanup")
|
||||||
|
return False
|
||||||
|
|
||||||
|
token = self.oauth.get_access_token()
|
||||||
|
headers = {"Authorization": f"Bearer {token}"}
|
||||||
|
|
||||||
|
url = f"{self._graph_base_url}/subscriptions/{subscription_id}"
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.delete(url, headers=headers, timeout=30)
|
||||||
|
|
||||||
|
if response.status_code in [200, 204, 404]:
|
||||||
|
logger.info(f"SharePoint subscription {subscription_id} cleaned up successfully")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
logger.warning(f"Unexpected response cleaning up subscription: {response.status_code}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to cleanup SharePoint subscription {subscription_id}: {e}")
|
||||||
return False
|
return False
|
||||||
token = self.oauth.get_access_token()
|
|
||||||
async with httpx.AsyncClient() as client:
|
|
||||||
resp = await client.delete(
|
|
||||||
f"{self.base_url}/subscriptions/{subscription_id}",
|
|
||||||
headers={"Authorization": f"Bearer {token}"},
|
|
||||||
)
|
|
||||||
return resp.status_code in (200, 204)
|
|
||||||
|
|
|
||||||
|
|
@ -1,19 +1,28 @@
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Optional, Dict, Any
|
||||||
|
|
||||||
import aiofiles
|
import aiofiles
|
||||||
from datetime import datetime
|
import msal
|
||||||
import httpx
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class SharePointOAuth:
|
class SharePointOAuth:
|
||||||
"""Direct token management for SharePoint, bypassing MSAL cache format"""
|
"""Handles Microsoft Graph OAuth authentication flow following Google Drive pattern."""
|
||||||
|
|
||||||
SCOPES = [
|
# Reserved scopes that must NOT be sent on token or silent calls
|
||||||
"offline_access",
|
RESERVED_SCOPES = {"openid", "profile", "offline_access"}
|
||||||
"Files.Read.All",
|
|
||||||
"Sites.Read.All",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
# For PERSONAL Microsoft Accounts (OneDrive consumer):
|
||||||
|
# - Use AUTH_SCOPES for interactive auth (consent + refresh token issuance)
|
||||||
|
# - Use RESOURCE_SCOPES for acquire_token_silent / refresh paths
|
||||||
|
AUTH_SCOPES = ["User.Read", "Files.Read.All", "offline_access"]
|
||||||
|
RESOURCE_SCOPES = ["User.Read", "Files.Read.All"]
|
||||||
|
SCOPES = AUTH_SCOPES # Backward compatibility alias
|
||||||
|
|
||||||
|
# Kept for reference; MSAL derives endpoints from `authority`
|
||||||
AUTH_ENDPOINT = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize"
|
AUTH_ENDPOINT = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize"
|
||||||
TOKEN_ENDPOINT = "https://login.microsoftonline.com/common/oauth2/v2.0/token"
|
TOKEN_ENDPOINT = "https://login.microsoftonline.com/common/oauth2/v2.0/token"
|
||||||
|
|
||||||
|
|
@ -22,173 +31,299 @@ class SharePointOAuth:
|
||||||
client_id: str,
|
client_id: str,
|
||||||
client_secret: str,
|
client_secret: str,
|
||||||
token_file: str = "sharepoint_token.json",
|
token_file: str = "sharepoint_token.json",
|
||||||
authority: str = "https://login.microsoftonline.com/common", # Keep for compatibility
|
authority: str = "https://login.microsoftonline.com/common",
|
||||||
|
allow_json_refresh: bool = True,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Initialize SharePointOAuth.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client_id: Azure AD application (client) ID.
|
||||||
|
client_secret: Azure AD application client secret.
|
||||||
|
token_file: Path to persisted token cache file (MSAL cache format).
|
||||||
|
authority: Usually "https://login.microsoftonline.com/common" for MSA + org,
|
||||||
|
or tenant-specific for work/school.
|
||||||
|
allow_json_refresh: If True, permit one-time migration from legacy flat JSON
|
||||||
|
{"access_token","refresh_token",...}. Otherwise refuse it.
|
||||||
|
"""
|
||||||
self.client_id = client_id
|
self.client_id = client_id
|
||||||
self.client_secret = client_secret
|
self.client_secret = client_secret
|
||||||
self.token_file = token_file
|
self.token_file = token_file
|
||||||
self.authority = authority # Keep for compatibility but not used
|
self.authority = authority
|
||||||
self._tokens = None
|
self.allow_json_refresh = allow_json_refresh
|
||||||
self._load_tokens()
|
self.token_cache = msal.SerializableTokenCache()
|
||||||
|
self._current_account = None
|
||||||
|
|
||||||
def _load_tokens(self):
|
# Initialize MSAL Confidential Client
|
||||||
"""Load tokens from file"""
|
self.app = msal.ConfidentialClientApplication(
|
||||||
if os.path.exists(self.token_file):
|
client_id=self.client_id,
|
||||||
with open(self.token_file, "r") as f:
|
client_credential=self.client_secret,
|
||||||
self._tokens = json.loads(f.read())
|
authority=self.authority,
|
||||||
print(f"Loaded tokens from {self.token_file}")
|
token_cache=self.token_cache,
|
||||||
else:
|
)
|
||||||
print(f"No token file found at {self.token_file}")
|
|
||||||
|
|
||||||
async def save_cache(self):
|
async def load_credentials(self) -> bool:
|
||||||
"""Persist tokens to file (renamed for compatibility)"""
|
"""Load existing credentials from token file (async)."""
|
||||||
await self._save_tokens()
|
|
||||||
|
|
||||||
async def _save_tokens(self):
|
|
||||||
"""Save tokens to file"""
|
|
||||||
if self._tokens:
|
|
||||||
async with aiofiles.open(self.token_file, "w") as f:
|
|
||||||
await f.write(json.dumps(self._tokens, indent=2))
|
|
||||||
|
|
||||||
def _is_token_expired(self) -> bool:
|
|
||||||
"""Check if current access token is expired"""
|
|
||||||
if not self._tokens or 'expiry' not in self._tokens:
|
|
||||||
return True
|
|
||||||
|
|
||||||
expiry_str = self._tokens['expiry']
|
|
||||||
# Handle different expiry formats
|
|
||||||
try:
|
try:
|
||||||
if expiry_str.endswith('Z'):
|
logger.debug(f"SharePoint OAuth loading credentials from: {self.token_file}")
|
||||||
expiry_dt = datetime.fromisoformat(expiry_str[:-1])
|
if os.path.exists(self.token_file):
|
||||||
|
logger.debug(f"Token file exists, reading: {self.token_file}")
|
||||||
|
|
||||||
|
# Read the token file
|
||||||
|
async with aiofiles.open(self.token_file, "r") as f:
|
||||||
|
cache_data = await f.read()
|
||||||
|
logger.debug(f"Read {len(cache_data)} chars from token file")
|
||||||
|
|
||||||
|
if cache_data.strip():
|
||||||
|
# 1) Try legacy flat JSON first
|
||||||
|
try:
|
||||||
|
json_data = json.loads(cache_data)
|
||||||
|
if isinstance(json_data, dict) and "refresh_token" in json_data:
|
||||||
|
if self.allow_json_refresh:
|
||||||
|
logger.debug(
|
||||||
|
"Found legacy JSON refresh_token and allow_json_refresh=True; attempting migration refresh"
|
||||||
|
)
|
||||||
|
return await self._refresh_from_json_token(json_data)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"Token file contains a legacy JSON refresh_token, but allow_json_refresh=False. "
|
||||||
|
"Delete the file and re-auth."
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.debug("Token file is not flat JSON; attempting MSAL cache format")
|
||||||
|
|
||||||
|
# 2) Try MSAL cache format
|
||||||
|
logger.debug("Attempting MSAL cache deserialization")
|
||||||
|
self.token_cache.deserialize(cache_data)
|
||||||
|
|
||||||
|
# Get accounts from loaded cache
|
||||||
|
accounts = self.app.get_accounts()
|
||||||
|
logger.debug(f"Found {len(accounts)} accounts in MSAL cache")
|
||||||
|
if accounts:
|
||||||
|
self._current_account = accounts[0]
|
||||||
|
logger.debug(f"Set current account: {self._current_account.get('username', 'no username')}")
|
||||||
|
|
||||||
|
# IMPORTANT: Use RESOURCE_SCOPES (no reserved scopes) for silent acquisition
|
||||||
|
result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=self._current_account)
|
||||||
|
logger.debug(f"Silent token acquisition result keys: {list(result.keys()) if result else 'None'}")
|
||||||
|
if result and "access_token" in result:
|
||||||
|
logger.debug("Silent token acquisition successful")
|
||||||
|
await self.save_cache()
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
error_msg = (result or {}).get("error") or "No result"
|
||||||
|
logger.warning(f"Silent token acquisition failed: {error_msg}")
|
||||||
|
else:
|
||||||
|
logger.debug(f"Token file {self.token_file} is empty")
|
||||||
else:
|
else:
|
||||||
expiry_dt = datetime.fromisoformat(expiry_str)
|
logger.debug(f"Token file does not exist: {self.token_file}")
|
||||||
|
|
||||||
# Add 5-minute buffer
|
|
||||||
import datetime as dt
|
|
||||||
now = datetime.now()
|
|
||||||
return now >= (expiry_dt - dt.timedelta(minutes=5))
|
|
||||||
except:
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def _refresh_access_token(self) -> bool:
|
|
||||||
"""Refresh the access token using refresh token"""
|
|
||||||
if not self._tokens or 'refresh_token' not in self._tokens:
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
data = {
|
except Exception as e:
|
||||||
'client_id': self.client_id,
|
logger.error(f"Failed to load SharePoint credentials: {e}")
|
||||||
'client_secret': self.client_secret,
|
import traceback
|
||||||
'refresh_token': self._tokens['refresh_token'],
|
traceback.print_exc()
|
||||||
'grant_type': 'refresh_token',
|
return False
|
||||||
'scope': ' '.join(self.SCOPES)
|
|
||||||
}
|
|
||||||
|
|
||||||
async with httpx.AsyncClient() as client:
|
async def _refresh_from_json_token(self, token_data: dict) -> bool:
|
||||||
try:
|
"""
|
||||||
response = await client.post(self.TOKEN_ENDPOINT, data=data)
|
Use refresh token from a legacy JSON file to get new tokens (one-time migration path).
|
||||||
response.raise_for_status()
|
|
||||||
token_data = response.json()
|
|
||||||
|
|
||||||
# Update tokens
|
Notes:
|
||||||
self._tokens['token'] = token_data['access_token']
|
- Prefer using an MSAL cache file and acquire_token_silent().
|
||||||
if 'refresh_token' in token_data:
|
- This path is only for migrating older refresh_token JSON files.
|
||||||
self._tokens['refresh_token'] = token_data['refresh_token']
|
"""
|
||||||
|
try:
|
||||||
# Calculate expiry
|
refresh_token = token_data.get("refresh_token")
|
||||||
expires_in = token_data.get('expires_in', 3600)
|
if not refresh_token:
|
||||||
import datetime as dt
|
logger.error("No refresh_token found in JSON file - cannot refresh")
|
||||||
expiry = datetime.now() + dt.timedelta(seconds=expires_in)
|
logger.error("You must re-authenticate interactively to obtain a valid token")
|
||||||
self._tokens['expiry'] = expiry.isoformat()
|
|
||||||
|
|
||||||
await self._save_tokens()
|
|
||||||
print("Access token refreshed successfully")
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Failed to refresh token: {e}")
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def create_authorization_url(self, redirect_uri: str) -> str:
|
# Use only RESOURCE_SCOPES when refreshing (no reserved scopes)
|
||||||
"""Create authorization URL for OAuth flow"""
|
refresh_scopes = [s for s in self.RESOURCE_SCOPES if s not in self.RESERVED_SCOPES]
|
||||||
from urllib.parse import urlencode
|
logger.debug(f"Using refresh token; refresh scopes = {refresh_scopes}")
|
||||||
|
|
||||||
params = {
|
result = self.app.acquire_token_by_refresh_token(
|
||||||
'client_id': self.client_id,
|
refresh_token=refresh_token,
|
||||||
'response_type': 'code',
|
scopes=refresh_scopes,
|
||||||
'redirect_uri': redirect_uri,
|
)
|
||||||
'scope': ' '.join(self.SCOPES),
|
|
||||||
'response_mode': 'query'
|
if result and "access_token" in result:
|
||||||
|
logger.debug("Successfully refreshed token via legacy JSON path")
|
||||||
|
await self.save_cache()
|
||||||
|
|
||||||
|
accounts = self.app.get_accounts()
|
||||||
|
logger.debug(f"After refresh, found {len(accounts)} accounts")
|
||||||
|
if accounts:
|
||||||
|
self._current_account = accounts[0]
|
||||||
|
logger.debug(f"Set current account after refresh: {self._current_account.get('username', 'no username')}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Error handling
|
||||||
|
err = (result or {}).get("error_description") or (result or {}).get("error") or "Unknown error"
|
||||||
|
logger.error(f"Refresh token failed: {err}")
|
||||||
|
|
||||||
|
if any(code in err for code in ("AADSTS70000", "invalid_grant", "interaction_required")):
|
||||||
|
logger.warning(
|
||||||
|
"Refresh denied due to unauthorized/expired scopes or invalid grant. "
|
||||||
|
"Delete the token file and perform interactive sign-in with correct scopes."
|
||||||
|
)
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Exception during refresh from JSON token: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def save_cache(self):
|
||||||
|
"""Persist the token cache to file."""
|
||||||
|
try:
|
||||||
|
# Ensure parent directory exists
|
||||||
|
parent = os.path.dirname(os.path.abspath(self.token_file))
|
||||||
|
if parent and not os.path.exists(parent):
|
||||||
|
os.makedirs(parent, exist_ok=True)
|
||||||
|
|
||||||
|
cache_data = self.token_cache.serialize()
|
||||||
|
if cache_data:
|
||||||
|
async with aiofiles.open(self.token_file, "w") as f:
|
||||||
|
await f.write(cache_data)
|
||||||
|
logger.debug(f"Token cache saved to {self.token_file}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to save token cache: {e}")
|
||||||
|
|
||||||
|
def create_authorization_url(self, redirect_uri: str, state: Optional[str] = None) -> str:
|
||||||
|
"""Create authorization URL for OAuth flow."""
|
||||||
|
# Store redirect URI for later use in callback
|
||||||
|
self._redirect_uri = redirect_uri
|
||||||
|
|
||||||
|
kwargs: Dict[str, Any] = {
|
||||||
|
# IMPORTANT: interactive auth includes offline_access
|
||||||
|
"scopes": self.AUTH_SCOPES,
|
||||||
|
"redirect_uri": redirect_uri,
|
||||||
|
"prompt": "consent", # ensure refresh token on first run
|
||||||
}
|
}
|
||||||
|
if state:
|
||||||
|
kwargs["state"] = state # Optional CSRF protection
|
||||||
|
|
||||||
auth_url = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize"
|
auth_url = self.app.get_authorization_request_url(**kwargs)
|
||||||
return f"{auth_url}?{urlencode(params)}"
|
|
||||||
|
logger.debug(f"Generated auth URL: {auth_url}")
|
||||||
|
logger.debug(f"Auth scopes: {self.AUTH_SCOPES}")
|
||||||
|
|
||||||
|
return auth_url
|
||||||
|
|
||||||
async def handle_authorization_callback(
|
async def handle_authorization_callback(
|
||||||
self, authorization_code: str, redirect_uri: str
|
self, authorization_code: str, redirect_uri: str
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Handle OAuth callback and exchange code for tokens"""
|
"""Handle OAuth callback and exchange code for tokens."""
|
||||||
data = {
|
try:
|
||||||
'client_id': self.client_id,
|
# For code exchange, we pass the same auth scopes as used in the authorize step
|
||||||
'client_secret': self.client_secret,
|
result = self.app.acquire_token_by_authorization_code(
|
||||||
'code': authorization_code,
|
authorization_code,
|
||||||
'grant_type': 'authorization_code',
|
scopes=self.AUTH_SCOPES,
|
||||||
'redirect_uri': redirect_uri,
|
redirect_uri=redirect_uri,
|
||||||
'scope': ' '.join(self.SCOPES)
|
)
|
||||||
}
|
|
||||||
|
|
||||||
async with httpx.AsyncClient() as client:
|
if result and "access_token" in result:
|
||||||
try:
|
# Store the account for future use
|
||||||
response = await client.post(self.TOKEN_ENDPOINT, data=data)
|
accounts = self.app.get_accounts()
|
||||||
response.raise_for_status()
|
if accounts:
|
||||||
token_data = response.json()
|
self._current_account = accounts[0]
|
||||||
|
|
||||||
# Store tokens in our format
|
await self.save_cache()
|
||||||
import datetime as dt
|
logger.info("SharePoint OAuth authorization successful")
|
||||||
expires_in = token_data.get('expires_in', 3600)
|
|
||||||
expiry = datetime.now() + dt.timedelta(seconds=expires_in)
|
|
||||||
|
|
||||||
self._tokens = {
|
|
||||||
'token': token_data['access_token'],
|
|
||||||
'refresh_token': token_data['refresh_token'],
|
|
||||||
'scopes': self.SCOPES,
|
|
||||||
'expiry': expiry.isoformat()
|
|
||||||
}
|
|
||||||
|
|
||||||
await self._save_tokens()
|
|
||||||
print("Authorization successful, tokens saved")
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
except Exception as e:
|
error_msg = (result or {}).get("error_description") or (result or {}).get("error") or "Unknown error"
|
||||||
print(f"Authorization failed: {e}")
|
logger.error(f"SharePoint OAuth authorization failed: {error_msg}")
|
||||||
return False
|
|
||||||
|
|
||||||
async def is_authenticated(self) -> bool:
|
|
||||||
"""Check if we have valid credentials"""
|
|
||||||
if not self._tokens:
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# If token is expired, try to refresh
|
except Exception as e:
|
||||||
if self._is_token_expired():
|
logger.error(f"Exception during SharePoint OAuth authorization: {e}")
|
||||||
print("Token expired, attempting refresh...")
|
return False
|
||||||
if await self._refresh_access_token():
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
async def is_authenticated(self) -> bool:
|
||||||
|
"""Check if we have valid credentials (simplified like Google Drive)."""
|
||||||
|
try:
|
||||||
|
# First try to load credentials if we haven't already
|
||||||
|
if not self._current_account:
|
||||||
|
await self.load_credentials()
|
||||||
|
|
||||||
|
# If we have an account, try to get a token (MSAL will refresh if needed)
|
||||||
|
if self._current_account:
|
||||||
|
# IMPORTANT: use RESOURCE_SCOPES here
|
||||||
|
result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=self._current_account)
|
||||||
|
if result and "access_token" in result:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
error_msg = (result or {}).get("error") or "No result returned"
|
||||||
|
logger.debug(f"Token acquisition failed for current account: {error_msg}")
|
||||||
|
|
||||||
|
# Fallback: try without specific account
|
||||||
|
result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=None)
|
||||||
|
if result and "access_token" in result:
|
||||||
|
# Update current account if this worked
|
||||||
|
accounts = self.app.get_accounts()
|
||||||
|
if accounts:
|
||||||
|
self._current_account = accounts[0]
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Authentication check failed: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
def get_access_token(self) -> str:
|
def get_access_token(self) -> str:
|
||||||
"""Get current access token"""
|
"""Get an access token for Microsoft Graph (simplified like Google Drive)."""
|
||||||
if not self._tokens or 'token' not in self._tokens:
|
try:
|
||||||
raise ValueError("No access token available")
|
# Try with current account first
|
||||||
|
if self._current_account:
|
||||||
|
result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=self._current_account)
|
||||||
|
if result and "access_token" in result:
|
||||||
|
return result["access_token"]
|
||||||
|
|
||||||
if self._is_token_expired():
|
# Fallback: try without specific account
|
||||||
raise ValueError("Access token expired and refresh failed")
|
result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=None)
|
||||||
|
if result and "access_token" in result:
|
||||||
|
return result["access_token"]
|
||||||
|
|
||||||
return self._tokens['token']
|
# If we get here, authentication has failed
|
||||||
|
error_msg = (result or {}).get("error_description") or (result or {}).get("error") or "No valid authentication"
|
||||||
|
raise ValueError(f"Failed to acquire access token: {error_msg}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to get access token: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
async def revoke_credentials(self):
|
async def revoke_credentials(self):
|
||||||
"""Clear tokens"""
|
"""Clear token cache and remove token file (like Google Drive)."""
|
||||||
self._tokens = None
|
try:
|
||||||
if os.path.exists(self.token_file):
|
# Clear in-memory state
|
||||||
os.remove(self.token_file)
|
self._current_account = None
|
||||||
|
self.token_cache = msal.SerializableTokenCache()
|
||||||
|
|
||||||
|
# Recreate MSAL app with fresh cache
|
||||||
|
self.app = msal.ConfidentialClientApplication(
|
||||||
|
client_id=self.client_id,
|
||||||
|
client_credential=self.client_secret,
|
||||||
|
authority=self.authority,
|
||||||
|
token_cache=self.token_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Remove token file
|
||||||
|
if os.path.exists(self.token_file):
|
||||||
|
os.remove(self.token_file)
|
||||||
|
logger.info(f"Removed SharePoint token file: {self.token_file}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to revoke SharePoint credentials: {e}")
|
||||||
|
|
||||||
|
def get_service(self) -> str:
|
||||||
|
"""Return an access token (Graph doesn't need a generated client like Google Drive)."""
|
||||||
|
return self.get_access_token()
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue