Finally fix MSAL in onedrive/sharepoint

This commit is contained in:
Eric Hare 2025-09-25 13:12:27 -07:00
parent 5e14c7f100
commit f03889a2b3
No known key found for this signature in database
GPG key ID: A73DF73724270AB7
6 changed files with 1649 additions and 750 deletions

View file

@ -132,7 +132,10 @@ async def connector_status(request: Request, connector_service, session_manager)
for connection in connections: for connection in connections:
try: try:
connector = await connector_service._get_connector(connection.connection_id) connector = await connector_service._get_connector(connection.connection_id)
connection_client_ids[connection.connection_id] = connector.get_client_id() if connector is not None:
connection_client_ids[connection.connection_id] = connector.get_client_id()
else:
connection_client_ids[connection.connection_id] = None
except Exception as e: except Exception as e:
logger.warning( logger.warning(
"Could not get connector for connection", "Could not get connector for connection",
@ -338,8 +341,8 @@ async def connector_webhook(request: Request, connector_service, session_manager
) )
async def connector_token(request: Request, connector_service, session_manager): async def connector_token(request: Request, connector_service, session_manager):
"""Get access token for connector API calls (e.g., Google Picker)""" """Get access token for connector API calls (e.g., Pickers)."""
connector_type = request.path_params.get("connector_type") url_connector_type = request.path_params.get("connector_type")
connection_id = request.query_params.get("connection_id") connection_id = request.query_params.get("connection_id")
if not connection_id: if not connection_id:
@ -348,37 +351,81 @@ async def connector_token(request: Request, connector_service, session_manager):
user = request.state.user user = request.state.user
try: try:
# Get the connection and verify it belongs to the user # 1) Load the connection and verify ownership
connection = await connector_service.connection_manager.get_connection(connection_id) connection = await connector_service.connection_manager.get_connection(connection_id)
if not connection or connection.user_id != user.user_id: if not connection or connection.user_id != user.user_id:
return JSONResponse({"error": "Connection not found"}, status_code=404) return JSONResponse({"error": "Connection not found"}, status_code=404)
# Get the connector instance # 2) Get the ACTUAL connector instance/type for this connection_id
connector = await connector_service._get_connector(connection_id) connector = await connector_service._get_connector(connection_id)
if not connector: if not connector:
return JSONResponse({"error": f"Connector not available - authentication may have failed for {connector_type}"}, status_code=404) return JSONResponse(
{"error": f"Connector not available - authentication may have failed for {url_connector_type}"},
status_code=404,
)
# For Google Drive, get the access token real_type = getattr(connector, "type", None) or getattr(connection, "connector_type", None)
if connector_type == "google_drive" and hasattr(connector, 'oauth'): if real_type is None:
return JSONResponse({"error": "Unable to determine connector type"}, status_code=500)
# Optional: warn if URL path type disagrees with real type
if url_connector_type and url_connector_type != real_type:
# You can downgrade this to debug if you expect cross-routing.
return JSONResponse(
{
"error": "Connector type mismatch",
"detail": {
"requested_type": url_connector_type,
"actual_type": real_type,
"hint": "Call the token endpoint using the correct connector_type for this connection_id.",
},
},
status_code=400,
)
# 3) Branch by the actual connector type
# GOOGLE DRIVE (google-auth)
if real_type == "google_drive" and hasattr(connector, "oauth"):
await connector.oauth.load_credentials() await connector.oauth.load_credentials()
if connector.oauth.creds and connector.oauth.creds.valid: if connector.oauth.creds and connector.oauth.creds.valid:
return JSONResponse({ expires_in = None
"access_token": connector.oauth.creds.token, try:
"expires_in": (connector.oauth.creds.expiry.timestamp() - if connector.oauth.creds.expiry:
__import__('time').time()) if connector.oauth.creds.expiry else None import time
}) expires_in = max(0, int(connector.oauth.creds.expiry.timestamp() - time.time()))
else: except Exception:
return JSONResponse({"error": "Invalid or expired credentials"}, status_code=401) expires_in = None
# For OneDrive and SharePoint, get the access token return JSONResponse(
elif connector_type in ["onedrive", "sharepoint"] and hasattr(connector, 'oauth'): {
"access_token": connector.oauth.creds.token,
"expires_in": expires_in,
}
)
return JSONResponse({"error": "Invalid or expired credentials"}, status_code=401)
# ONEDRIVE / SHAREPOINT (MSAL or custom)
if real_type in ("onedrive", "sharepoint") and hasattr(connector, "oauth"):
# Ensure cache/credentials are loaded before trying to use them
try: try:
# Prefer a dedicated is_authenticated() that loads cache internally
if hasattr(connector.oauth, "is_authenticated"):
ok = await connector.oauth.is_authenticated()
else:
# Fallback: try to load credentials explicitly if available
ok = True
if hasattr(connector.oauth, "load_credentials"):
ok = await connector.oauth.load_credentials()
if not ok:
return JSONResponse({"error": "Not authenticated"}, status_code=401)
# Now safe to fetch access token
access_token = connector.oauth.get_access_token() access_token = connector.oauth.get_access_token()
return JSONResponse({ # MSAL result has expiry, but were returning a raw token; keep expires_in None for simplicity
"access_token": access_token, return JSONResponse({"access_token": access_token, "expires_in": None})
"expires_in": None # MSAL handles token expiry internally
})
except ValueError as e: except ValueError as e:
# Typical when acquire_token_silent fails (e.g., needs re-auth)
return JSONResponse({"error": f"Failed to get access token: {str(e)}"}, status_code=401) return JSONResponse({"error": f"Failed to get access token: {str(e)}"}, status_code=401)
except Exception as e: except Exception as e:
return JSONResponse({"error": f"Authentication error: {str(e)}"}, status_code=500) return JSONResponse({"error": f"Authentication error: {str(e)}"}, status_code=500)
@ -386,7 +433,5 @@ async def connector_token(request: Request, connector_service, session_manager):
return JSONResponse({"error": "Token not available for this connector type"}, status_code=400) return JSONResponse({"error": "Token not available for this connector type"}, status_code=400)
except Exception as e: except Exception as e:
logger.error("Error getting connector token", error=str(e)) logger.error("Error getting connector token", exc_info=True)
return JSONResponse({"error": str(e)}, status_code=500) return JSONResponse({"error": str(e)}, status_code=500)

View file

@ -294,32 +294,39 @@ class ConnectionManager:
async def get_connector(self, connection_id: str) -> Optional[BaseConnector]: async def get_connector(self, connection_id: str) -> Optional[BaseConnector]:
"""Get an active connector instance""" """Get an active connector instance"""
logger.debug(f"Getting connector for connection_id: {connection_id}")
# Return cached connector if available # Return cached connector if available
if connection_id in self.active_connectors: if connection_id in self.active_connectors:
connector = self.active_connectors[connection_id] connector = self.active_connectors[connection_id]
if connector.is_authenticated: if connector.is_authenticated:
logger.debug(f"Returning cached authenticated connector for {connection_id}")
return connector return connector
else: else:
# Remove unauthenticated connector from cache # Remove unauthenticated connector from cache
logger.debug(f"Removing unauthenticated connector from cache for {connection_id}")
del self.active_connectors[connection_id] del self.active_connectors[connection_id]
# Try to create and authenticate connector # Try to create and authenticate connector
connection_config = self.connections.get(connection_id) connection_config = self.connections.get(connection_id)
if not connection_config or not connection_config.is_active: if not connection_config or not connection_config.is_active:
logger.debug(f"No active connection config found for {connection_id}")
return None return None
logger.debug(f"Creating connector for {connection_config.connector_type}")
connector = self._create_connector(connection_config) connector = self._create_connector(connection_config)
if await connector.authenticate():
logger.debug(f"Attempting authentication for {connection_id}")
auth_result = await connector.authenticate()
logger.debug(f"Authentication result for {connection_id}: {auth_result}")
if auth_result:
self.active_connectors[connection_id] = connector self.active_connectors[connection_id] = connector
# ... rest of the method
# Setup webhook subscription if not already set up
await self._setup_webhook_if_needed(
connection_id, connection_config, connector
)
return connector return connector
else:
return None logger.warning(f"Authentication failed for {connection_id}")
return None
def get_available_connector_types(self) -> Dict[str, Dict[str, Any]]: def get_available_connector_types(self) -> Dict[str, Dict[str, Any]]:
"""Get available connector types with their metadata""" """Get available connector types with their metadata"""
@ -363,20 +370,23 @@ class ConnectionManager:
def _create_connector(self, config: ConnectionConfig) -> BaseConnector: def _create_connector(self, config: ConnectionConfig) -> BaseConnector:
"""Factory method to create connector instances""" """Factory method to create connector instances"""
if config.connector_type == "google_drive": try:
return GoogleDriveConnector(config.config) if config.connector_type == "google_drive":
elif config.connector_type == "sharepoint": return GoogleDriveConnector(config.config)
return SharePointConnector(config.config) elif config.connector_type == "sharepoint":
elif config.connector_type == "onedrive": return SharePointConnector(config.config)
return OneDriveConnector(config.config) elif config.connector_type == "onedrive":
elif config.connector_type == "box": return OneDriveConnector(config.config)
# Future: BoxConnector(config.config) elif config.connector_type == "box":
raise NotImplementedError("Box connector not implemented yet") raise NotImplementedError("Box connector not implemented yet")
elif config.connector_type == "dropbox": elif config.connector_type == "dropbox":
# Future: DropboxConnector(config.config) raise NotImplementedError("Dropbox connector not implemented yet")
raise NotImplementedError("Dropbox connector not implemented yet") else:
else: raise ValueError(f"Unknown connector type: {config.connector_type}")
raise ValueError(f"Unknown connector type: {config.connector_type}") except Exception as e:
logger.error(f"Failed to create {config.connector_type} connector: {e}")
# Re-raise the exception so caller can handle appropriately
raise
async def update_last_sync(self, connection_id: str): async def update_last_sync(self, connection_id: str):
"""Update the last sync timestamp for a connection""" """Update the last sync timestamp for a connection"""

View file

@ -1,235 +1,487 @@
import logging
from pathlib import Path from pathlib import Path
from typing import List, Dict, Any, Optional
from datetime import datetime
import httpx import httpx
import uuid
from datetime import datetime, timedelta
from typing import Dict, List, Any, Optional
from ..base import BaseConnector, ConnectorDocument, DocumentACL from ..base import BaseConnector, ConnectorDocument, DocumentACL
from .oauth import OneDriveOAuth from .oauth import OneDriveOAuth
logger = logging.getLogger(__name__)
class OneDriveConnector(BaseConnector): class OneDriveConnector(BaseConnector):
"""OneDrive connector using Microsoft Graph API""" """OneDrive connector using MSAL-based OAuth for authentication."""
# Required BaseConnector class attributes
CLIENT_ID_ENV_VAR = "MICROSOFT_GRAPH_OAUTH_CLIENT_ID" CLIENT_ID_ENV_VAR = "MICROSOFT_GRAPH_OAUTH_CLIENT_ID"
CLIENT_SECRET_ENV_VAR = "MICROSOFT_GRAPH_OAUTH_CLIENT_SECRET" CLIENT_SECRET_ENV_VAR = "MICROSOFT_GRAPH_OAUTH_CLIENT_SECRET"
# Connector metadata # Connector metadata
CONNECTOR_NAME = "OneDrive" CONNECTOR_NAME = "OneDrive"
CONNECTOR_DESCRIPTION = "Connect your personal OneDrive to sync documents" CONNECTOR_DESCRIPTION = "Connect to OneDrive (personal) to sync documents and files"
CONNECTOR_ICON = "onedrive" CONNECTOR_ICON = "onedrive"
def __init__(self, config: Dict[str, Any]): def __init__(self, config: Dict[str, Any]):
super().__init__(config) logger.debug(f"OneDrive connector __init__ called with config type: {type(config)}")
logger.debug(f"OneDrive connector __init__ config value: {config}")
if config is None:
logger.debug("Config was None, using empty dict")
config = {}
try:
logger.debug("Calling super().__init__")
super().__init__(config)
logger.debug("super().__init__ completed successfully")
except Exception as e:
logger.error(f"super().__init__ failed: {e}")
raise
# Initialize with defaults that allow the connector to be listed
self.client_id = None
self.client_secret = None
self.redirect_uri = config.get("redirect_uri", "http://localhost") # must match your app registration
# Try to get credentials, but don't fail if they're missing
try:
self.client_id = self.get_client_id()
logger.debug(f"Got client_id: {self.client_id is not None}")
except Exception as e:
logger.debug(f"Failed to get client_id: {e}")
try:
self.client_secret = self.get_client_secret()
logger.debug(f"Got client_secret: {self.client_secret is not None}")
except Exception as e:
logger.debug(f"Failed to get client_secret: {e}")
# Token file setup
project_root = Path(__file__).resolve().parent.parent.parent.parent project_root = Path(__file__).resolve().parent.parent.parent.parent
token_file = config.get("token_file") or str(project_root / "onedrive_token.json") token_file = config.get("token_file") or str(project_root / "onedrive_token.json")
self.oauth = OneDriveOAuth( Path(token_file).parent.mkdir(parents=True, exist_ok=True)
client_id=self.get_client_id(),
client_secret=self.get_client_secret(),
token_file=token_file,
)
self.subscription_id = config.get("subscription_id") or config.get(
"webhook_channel_id"
)
self.base_url = "https://graph.microsoft.com/v1.0"
async def authenticate(self) -> bool: # Only initialize OAuth if we have credentials
if await self.oauth.is_authenticated(): if self.client_id and self.client_secret:
self._authenticated = True connection_id = config.get("connection_id", "default")
return True
return False
async def setup_subscription(self) -> str: # Use token_file from config if provided, otherwise generate one
if not self._authenticated: if config.get("token_file"):
raise ValueError("Not authenticated") oauth_token_file = config["token_file"]
else:
# Use a per-connection cache file to avoid collisions with other connectors
oauth_token_file = f"onedrive_token_{connection_id}.json"
webhook_url = self.config.get("webhook_url") # MSA & org both work via /common for OneDrive personal testing
if not webhook_url: authority = "https://login.microsoftonline.com/common"
raise ValueError("webhook_url required in config for subscriptions")
expiration = (datetime.utcnow() + timedelta(days=2)).isoformat() + "Z" self.oauth = OneDriveOAuth(
body = { client_id=self.client_id,
"changeType": "created,updated,deleted", client_secret=self.client_secret,
"notificationUrl": webhook_url, token_file=oauth_token_file,
"resource": "/me/drive/root", authority=authority,
"expirationDateTime": expiration, allow_json_refresh=True, # allows one-time migration from legacy JSON if present
"clientState": str(uuid.uuid4()), )
else:
self.oauth = None
# Track subscription ID for webhooks (note: change notifications might not be available for personal accounts)
self._subscription_id: Optional[str] = None
# Graph API defaults
self._graph_api_version = "v1.0"
self._default_params = {
"$select": "id,name,size,lastModifiedDateTime,createdDateTime,webUrl,file,folder,@microsoft.graph.downloadUrl"
} }
token = self.oauth.get_access_token() @property
async with httpx.AsyncClient() as client: def _graph_base_url(self) -> str:
resp = await client.post( """Base URL for Microsoft Graph API calls."""
f"{self.base_url}/subscriptions", return f"https://graph.microsoft.com/{self._graph_api_version}"
json=body,
headers={"Authorization": f"Bearer {token}"},
)
resp.raise_for_status()
data = resp.json()
self.subscription_id = data["id"] def emit(self, doc: ConnectorDocument) -> None:
return self.subscription_id """Emit a ConnectorDocument instance (integrate with your pipeline here)."""
logger.debug(f"Emitting OneDrive document: {doc.id} ({doc.filename})")
async def list_files( async def authenticate(self) -> bool:
self, page_token: Optional[str] = None, limit: int = 100 """Test authentication - BaseConnector interface."""
) -> Dict[str, Any]: logger.debug(f"OneDrive authenticate() called, oauth is None: {self.oauth is None}")
if not self._authenticated: try:
raise ValueError("Not authenticated") if not self.oauth:
logger.debug("OneDrive authentication failed: OAuth not initialized")
self._authenticated = False
return False
params = {"$top": str(limit)} logger.debug("Loading OneDrive credentials...")
if page_token: load_result = await self.oauth.load_credentials()
params["$skiptoken"] = page_token logger.debug(f"Load credentials result: {load_result}")
token = self.oauth.get_access_token() logger.debug("Checking OneDrive authentication status...")
async with httpx.AsyncClient() as client: authenticated = await self.oauth.is_authenticated()
resp = await client.get( logger.debug(f"OneDrive is_authenticated result: {authenticated}")
f"{self.base_url}/me/drive/root/children",
params=params,
headers={"Authorization": f"Bearer {token}"},
)
resp.raise_for_status()
data = resp.json()
files = [] self._authenticated = authenticated
for item in data.get("value", []): return authenticated
if item.get("file"): except Exception as e:
files.append( logger.error(f"OneDrive authentication failed: {e}")
{ import traceback
"id": item["id"], traceback.print_exc()
"name": item["name"], self._authenticated = False
"mimeType": item.get("file", {}).get( return False
"mimeType", "application/octet-stream"
),
"webViewLink": item.get("webUrl"),
"createdTime": item.get("createdDateTime"),
"modifiedTime": item.get("lastModifiedDateTime"),
}
)
next_token = None def get_auth_url(self) -> str:
next_link = data.get("@odata.nextLink") """Get OAuth authorization URL."""
if next_link: if not self.oauth:
from urllib.parse import urlparse, parse_qs raise RuntimeError("OneDrive OAuth not initialized - missing credentials")
return self.oauth.create_authorization_url(self.redirect_uri)
parsed = urlparse(next_link) async def handle_oauth_callback(self, auth_code: str) -> Dict[str, Any]:
next_token = parse_qs(parsed.query).get("$skiptoken", [None])[0] """Handle OAuth callback."""
if not self.oauth:
raise RuntimeError("OneDrive OAuth not initialized - missing credentials")
try:
success = await self.oauth.handle_authorization_callback(auth_code, self.redirect_uri)
if success:
self._authenticated = True
return {"status": "success"}
else:
raise ValueError("OAuth callback failed")
except Exception as e:
logger.error(f"OAuth callback failed: {e}")
raise
return {"files": files, "nextPageToken": next_token} def sync_once(self) -> None:
"""
Perform a one-shot sync of OneDrive files and emit documents.
"""
import asyncio
async def _async_sync():
try:
file_list = await self.list_files(max_files=1000)
files = file_list.get("files", [])
for file_info in files:
try:
file_id = file_info.get("id")
if not file_id:
continue
doc = await self.get_file_content(file_id)
self.emit(doc)
except Exception as e:
logger.error(f"Failed to sync OneDrive file {file_info.get('name', 'unknown')}: {e}")
continue
except Exception as e:
logger.error(f"OneDrive sync_once failed: {e}")
raise
if hasattr(asyncio, 'run'):
asyncio.run(_async_sync())
else:
loop = asyncio.get_event_loop()
loop.run_until_complete(_async_sync())
async def setup_subscription(self) -> str:
"""
Set up real-time subscription for file changes.
NOTE: Change notifications may not be available for personal OneDrive accounts.
"""
webhook_url = self.config.get('webhook_url')
if not webhook_url:
logger.warning("No webhook URL configured, skipping OneDrive subscription setup")
return "no-webhook-configured"
try:
if not await self.authenticate():
raise RuntimeError("OneDrive authentication failed during subscription setup")
token = self.oauth.get_access_token()
# For OneDrive personal we target the user's drive
resource = "/me/drive/root"
subscription_data = {
"changeType": "created,updated,deleted",
"notificationUrl": f"{webhook_url}/webhook/onedrive",
"resource": resource,
"expirationDateTime": self._get_subscription_expiry(),
"clientState": "onedrive_personal",
}
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
}
url = f"{self._graph_base_url}/subscriptions"
async with httpx.AsyncClient() as client:
response = await client.post(url, json=subscription_data, headers=headers, timeout=30)
response.raise_for_status()
result = response.json()
subscription_id = result.get("id")
if subscription_id:
self._subscription_id = subscription_id
logger.info(f"OneDrive subscription created: {subscription_id}")
return subscription_id
else:
raise ValueError("No subscription ID returned from Microsoft Graph")
except Exception as e:
logger.error(f"Failed to setup OneDrive subscription: {e}")
raise
def _get_subscription_expiry(self) -> str:
"""Get subscription expiry time (Graph caps duration; often <= 3 days)."""
from datetime import datetime, timedelta
expiry = datetime.utcnow() + timedelta(days=3)
return expiry.strftime("%Y-%m-%dT%H:%M:%S.%fZ")
async def list_files(self, page_token: Optional[str] = None, max_files: Optional[int] = None) -> Dict[str, Any]:
"""List files from OneDrive using Microsoft Graph."""
try:
if not await self.authenticate():
raise RuntimeError("OneDrive authentication failed during file listing")
files: List[Dict[str, Any]] = []
max_files_value = max_files if max_files is not None else 100
base_url = f"{self._graph_base_url}/me/drive/root/children"
params = dict(self._default_params)
params["$top"] = max_files_value
if page_token:
params["$skiptoken"] = page_token
response = await self._make_graph_request(base_url, params=params)
data = response.json()
items = data.get("value", [])
for item in items:
if item.get("file"): # include files only
files.append({
"id": item.get("id", ""),
"name": item.get("name", ""),
"path": f"/drive/items/{item.get('id')}",
"size": int(item.get("size", 0)),
"modified": item.get("lastModifiedDateTime"),
"created": item.get("createdDateTime"),
"mime_type": item.get("file", {}).get("mimeType", self._get_mime_type(item.get("name", ""))),
"url": item.get("webUrl", ""),
"download_url": item.get("@microsoft.graph.downloadUrl"),
})
# Next page
next_page_token = None
next_link = data.get("@odata.nextLink")
if next_link:
from urllib.parse import urlparse, parse_qs
parsed = urlparse(next_link)
query_params = parse_qs(parsed.query)
if "$skiptoken" in query_params:
next_page_token = query_params["$skiptoken"][0]
return {"files": files, "next_page_token": next_page_token}
except Exception as e:
logger.error(f"Failed to list OneDrive files: {e}")
return {"files": [], "next_page_token": None}
async def get_file_content(self, file_id: str) -> ConnectorDocument: async def get_file_content(self, file_id: str) -> ConnectorDocument:
if not self._authenticated: """Get file content and metadata."""
raise ValueError("Not authenticated") try:
if not await self.authenticate():
raise RuntimeError("OneDrive authentication failed during file content retrieval")
token = self.oauth.get_access_token() file_metadata = await self._get_file_metadata_by_id(file_id)
headers = {"Authorization": f"Bearer {token}"} if not file_metadata:
async with httpx.AsyncClient() as client: raise ValueError(f"File not found: {file_id}")
meta_resp = await client.get(
f"{self.base_url}/me/drive/items/{file_id}", headers=headers
)
meta_resp.raise_for_status()
metadata = meta_resp.json()
content_resp = await client.get( download_url = file_metadata.get("download_url")
f"{self.base_url}/me/drive/items/{file_id}/content", headers=headers if download_url:
) content = await self._download_file_from_url(download_url)
content = content_resp.content
# Handle the possibility of this being a redirect
if content_resp.status_code in (301, 302, 303, 307, 308):
redirect_url = content_resp.headers.get("Location")
if redirect_url:
content_resp = await client.get(redirect_url)
content_resp.raise_for_status()
content = content_resp.content
else: else:
content_resp.raise_for_status() content = await self._download_file_content(file_id)
perm_resp = await client.get( acl = DocumentACL(
f"{self.base_url}/me/drive/items/{file_id}/permissions", headers=headers owner="",
user_permissions={},
group_permissions={},
) )
perm_resp.raise_for_status()
permissions = perm_resp.json()
acl = self._parse_permissions(metadata, permissions) modified_time = self._parse_graph_date(file_metadata.get("modified"))
modified = datetime.fromisoformat( created_time = self._parse_graph_date(file_metadata.get("created"))
metadata["lastModifiedDateTime"].replace("Z", "+00:00")
).replace(tzinfo=None)
created = datetime.fromisoformat(
metadata["createdDateTime"].replace("Z", "+00:00")
).replace(tzinfo=None)
document = ConnectorDocument( return ConnectorDocument(
id=metadata["id"], id=file_id,
filename=metadata["name"], filename=file_metadata.get("name", ""),
mimetype=metadata.get("file", {}).get( mimetype=file_metadata.get("mime_type", "application/octet-stream"),
"mimeType", "application/octet-stream" content=content,
), source_url=file_metadata.get("url", ""),
content=content, acl=acl,
source_url=metadata.get("webUrl"), modified_time=modified_time,
acl=acl, created_time=created_time,
modified_time=modified, metadata={
created_time=created, "onedrive_path": file_metadata.get("path", ""),
metadata={"size": metadata.get("size")}, "size": file_metadata.get("size", 0),
) },
return document
def _parse_permissions(
self, metadata: Dict[str, Any], permissions: Dict[str, Any]
) -> DocumentACL:
acl = DocumentACL()
owner = metadata.get("createdBy", {}).get("user", {}).get("email")
if owner:
acl.owner = owner
for perm in permissions.get("value", []):
role = perm.get("roles", ["read"])[0]
grantee = perm.get("grantedToV2") or perm.get("grantedTo")
if not grantee:
continue
user = grantee.get("user")
if user and user.get("email"):
acl.user_permissions[user["email"]] = role
group = grantee.get("group")
if group and group.get("email"):
acl.group_permissions[group["email"]] = role
return acl
def handle_webhook_validation(
self, request_method: str, headers: Dict[str, str], query_params: Dict[str, str]
) -> Optional[str]:
"""Handle Microsoft Graph webhook validation"""
if request_method == "GET":
validation_token = query_params.get("validationtoken") or query_params.get(
"validationToken"
) )
if validation_token:
return validation_token except Exception as e:
logger.error(f"Failed to get OneDrive file content {file_id}: {e}")
raise
async def _get_file_metadata_by_id(self, file_id: str) -> Optional[Dict[str, Any]]:
"""Get file metadata by ID using Graph API."""
try:
url = f"{self._graph_base_url}/me/drive/items/{file_id}"
params = dict(self._default_params)
response = await self._make_graph_request(url, params=params)
item = response.json()
if item.get("file"):
return {
"id": file_id,
"name": item.get("name", ""),
"path": f"/drive/items/{file_id}",
"size": int(item.get("size", 0)),
"modified": item.get("lastModifiedDateTime"),
"created": item.get("createdDateTime"),
"mime_type": item.get("file", {}).get("mimeType", self._get_mime_type(item.get("name", ""))),
"url": item.get("webUrl", ""),
"download_url": item.get("@microsoft.graph.downloadUrl"),
}
return None
except Exception as e:
logger.error(f"Failed to get file metadata for {file_id}: {e}")
return None
async def _download_file_content(self, file_id: str) -> bytes:
"""Download file content by file ID using Graph API."""
try:
url = f"{self._graph_base_url}/me/drive/items/{file_id}/content"
token = self.oauth.get_access_token()
headers = {"Authorization": f"Bearer {token}"}
async with httpx.AsyncClient() as client:
response = await client.get(url, headers=headers, timeout=60)
response.raise_for_status()
return response.content
except Exception as e:
logger.error(f"Failed to download file content for {file_id}: {e}")
raise
async def _download_file_from_url(self, download_url: str) -> bytes:
"""Download file content from direct download URL."""
try:
async with httpx.AsyncClient() as client:
response = await client.get(download_url, timeout=60)
response.raise_for_status()
return response.content
except Exception as e:
logger.error(f"Failed to download from URL {download_url}: {e}")
raise
def _parse_graph_date(self, date_str: Optional[str]) -> datetime:
"""Parse Microsoft Graph date string to datetime."""
if not date_str:
return datetime.now()
try:
if date_str.endswith('Z'):
return datetime.fromisoformat(date_str[:-1]).replace(tzinfo=None)
else:
return datetime.fromisoformat(date_str.replace('T', ' '))
except (ValueError, AttributeError):
return datetime.now()
async def _make_graph_request(self, url: str, method: str = "GET",
data: Optional[Dict] = None, params: Optional[Dict] = None) -> httpx.Response:
"""Make authenticated API request to Microsoft Graph."""
token = self.oauth.get_access_token()
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
}
async with httpx.AsyncClient() as client:
if method.upper() == "GET":
response = await client.get(url, headers=headers, params=params, timeout=30)
elif method.upper() == "POST":
response = await client.post(url, headers=headers, json=data, timeout=30)
elif method.upper() == "DELETE":
response = await client.delete(url, headers=headers, timeout=30)
else:
raise ValueError(f"Unsupported HTTP method: {method}")
response.raise_for_status()
return response
def _get_mime_type(self, filename: str) -> str:
"""Get MIME type based on file extension."""
import mimetypes
mime_type, _ = mimetypes.guess_type(filename)
return mime_type or "application/octet-stream"
# Webhook methods - BaseConnector interface
def handle_webhook_validation(self, request_method: str,
headers: Dict[str, str],
query_params: Dict[str, str]) -> Optional[str]:
"""Handle webhook validation (Graph API specific)."""
if request_method == "POST" and "validationToken" in query_params:
return query_params["validationToken"]
return None return None
def extract_webhook_channel_id( def extract_webhook_channel_id(self, payload: Dict[str, Any],
self, payload: Dict[str, Any], headers: Dict[str, str] headers: Dict[str, str]) -> Optional[str]:
) -> Optional[str]: """Extract channel/subscription ID from webhook payload."""
"""Extract SharePoint subscription ID from webhook payload""" notifications = payload.get("value", [])
values = payload.get("value", []) if notifications:
return values[0].get("subscriptionId") if values else None return notifications[0].get("subscriptionId")
return None
async def handle_webhook(self, payload: Dict[str, Any]) -> List[str]: async def handle_webhook(self, payload: Dict[str, Any]) -> List[str]:
values = payload.get("value", []) """Handle webhook notification and return affected file IDs."""
file_ids = [] affected_files: List[str] = []
for item in values: notifications = payload.get("value", [])
resource_data = item.get("resourceData", {}) for notification in notifications:
file_id = resource_data.get("id") resource = notification.get("resource")
if file_id: if resource and "/drive/items/" in resource:
file_ids.append(file_id) file_id = resource.split("/drive/items/")[-1]
return file_ids affected_files.append(file_id)
return affected_files
async def cleanup_subscription( async def cleanup_subscription(self, subscription_id: str) -> bool:
self, subscription_id: str, resource_id: str = None """Clean up subscription - BaseConnector interface."""
) -> bool: if subscription_id == "no-webhook-configured":
if not self._authenticated: logger.info("No subscription to cleanup (webhook was not configured)")
return True
try:
if not await self.authenticate():
logger.error("OneDrive authentication failed during subscription cleanup")
return False
token = self.oauth.get_access_token()
headers = {"Authorization": f"Bearer {token}"}
url = f"{self._graph_base_url}/subscriptions/{subscription_id}"
async with httpx.AsyncClient() as client:
response = await client.delete(url, headers=headers, timeout=30)
if response.status_code in [200, 204, 404]:
logger.info(f"OneDrive subscription {subscription_id} cleaned up successfully")
return True
else:
logger.warning(f"Unexpected response cleaning up subscription: {response.status_code}")
return False
except Exception as e:
logger.error(f"Failed to cleanup OneDrive subscription {subscription_id}: {e}")
return False return False
token = self.oauth.get_access_token()
async with httpx.AsyncClient() as client:
resp = await client.delete(
f"{self.base_url}/subscriptions/{subscription_id}",
headers={"Authorization": f"Bearer {token}"},
)
return resp.status_code in (200, 204)

View file

@ -1,18 +1,28 @@
import os import os
import json import json
import logging
from typing import Optional, Dict, Any
import aiofiles import aiofiles
from datetime import datetime import msal
import httpx
logger = logging.getLogger(__name__)
class OneDriveOAuth: class OneDriveOAuth:
"""Direct token management for OneDrive, bypassing MSAL cache format""" """Handles Microsoft Graph OAuth for OneDrive (personal Microsoft accounts by default)."""
SCOPES = [ # Reserved scopes that must NOT be sent on token or silent calls
"offline_access", RESERVED_SCOPES = {"openid", "profile", "offline_access"}
"Files.Read.All",
]
# For PERSONAL Microsoft Accounts (OneDrive consumer):
# - Use AUTH_SCOPES for interactive auth (consent + refresh token issuance)
# - Use RESOURCE_SCOPES for acquire_token_silent / refresh paths
AUTH_SCOPES = ["User.Read", "Files.Read.All", "offline_access"]
RESOURCE_SCOPES = ["User.Read", "Files.Read.All"]
SCOPES = AUTH_SCOPES # Backward-compat alias if something references .SCOPES
# Kept for reference; MSAL derives endpoints from `authority`
AUTH_ENDPOINT = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize" AUTH_ENDPOINT = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize"
TOKEN_ENDPOINT = "https://login.microsoftonline.com/common/oauth2/v2.0/token" TOKEN_ENDPOINT = "https://login.microsoftonline.com/common/oauth2/v2.0/token"
@ -21,168 +31,292 @@ class OneDriveOAuth:
client_id: str, client_id: str,
client_secret: str, client_secret: str,
token_file: str = "onedrive_token.json", token_file: str = "onedrive_token.json",
authority: str = "https://login.microsoftonline.com/common",
allow_json_refresh: bool = True,
): ):
"""
Initialize OneDriveOAuth.
Args:
client_id: Azure AD application (client) ID.
client_secret: Azure AD application client secret.
token_file: Path to persisted token cache file (MSAL cache format).
authority: Usually "https://login.microsoftonline.com/common" for MSA + org,
or tenant-specific for work/school.
allow_json_refresh: If True, permit one-time migration from legacy flat JSON
{"access_token","refresh_token",...}. Otherwise refuse it.
"""
self.client_id = client_id self.client_id = client_id
self.client_secret = client_secret self.client_secret = client_secret
self.token_file = token_file self.token_file = token_file
self._tokens = None self.authority = authority
self._load_tokens() self.allow_json_refresh = allow_json_refresh
self.token_cache = msal.SerializableTokenCache()
self._current_account = None
def _load_tokens(self): # Initialize MSAL Confidential Client
"""Load tokens from file""" self.app = msal.ConfidentialClientApplication(
if os.path.exists(self.token_file): client_id=self.client_id,
with open(self.token_file, "r") as f: client_credential=self.client_secret,
self._tokens = json.loads(f.read()) authority=self.authority,
print(f"Loaded tokens from {self.token_file}") token_cache=self.token_cache,
else: )
print(f"No token file found at {self.token_file}")
async def _save_tokens(self): async def load_credentials(self) -> bool:
"""Save tokens to file""" """Load existing credentials from token file (async)."""
if self._tokens:
async with aiofiles.open(self.token_file, "w") as f:
await f.write(json.dumps(self._tokens, indent=2))
def _is_token_expired(self) -> bool:
"""Check if current access token is expired"""
if not self._tokens or 'expiry' not in self._tokens:
return True
expiry_str = self._tokens['expiry']
# Handle different expiry formats
try: try:
if expiry_str.endswith('Z'): logger.debug(f"OneDrive OAuth loading credentials from: {self.token_file}")
expiry_dt = datetime.fromisoformat(expiry_str[:-1]) if os.path.exists(self.token_file):
else: logger.debug(f"Token file exists, reading: {self.token_file}")
expiry_dt = datetime.fromisoformat(expiry_str)
# Read the token file
# Add 5-minute buffer async with aiofiles.open(self.token_file, "r") as f:
import datetime as dt cache_data = await f.read()
now = datetime.now() logger.debug(f"Read {len(cache_data)} chars from token file")
return now >= (expiry_dt - dt.timedelta(minutes=5))
except: if cache_data.strip():
return True # 1) Try legacy flat JSON first
try:
json_data = json.loads(cache_data)
if isinstance(json_data, dict) and "refresh_token" in json_data:
if self.allow_json_refresh:
logger.debug(
"Found legacy JSON refresh_token and allow_json_refresh=True; attempting migration refresh"
)
return await self._refresh_from_json_token(json_data)
else:
logger.warning(
"Token file contains a legacy JSON refresh_token, but allow_json_refresh=False. "
"Delete the file and re-auth."
)
return False
except json.JSONDecodeError:
logger.debug("Token file is not flat JSON; attempting MSAL cache format")
# 2) Try MSAL cache format
logger.debug("Attempting MSAL cache deserialization")
self.token_cache.deserialize(cache_data)
# Get accounts from loaded cache
accounts = self.app.get_accounts()
logger.debug(f"Found {len(accounts)} accounts in MSAL cache")
if accounts:
self._current_account = accounts[0]
logger.debug(f"Set current account: {self._current_account.get('username', 'no username')}")
# Use RESOURCE_SCOPES (no reserved scopes) for silent acquisition
result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=self._current_account)
logger.debug(f"Silent token acquisition result keys: {list(result.keys()) if result else 'None'}")
if result and "access_token" in result:
logger.debug("Silent token acquisition successful")
await self.save_cache()
return True
else:
error_msg = (result or {}).get("error") or "No result"
logger.warning(f"Silent token acquisition failed: {error_msg}")
else:
logger.debug(f"Token file {self.token_file} is empty")
else:
logger.debug(f"Token file does not exist: {self.token_file}")
async def _refresh_access_token(self) -> bool:
"""Refresh the access token using refresh token"""
if not self._tokens or 'refresh_token' not in self._tokens:
return False return False
data = { except Exception as e:
'client_id': self.client_id, logger.error(f"Failed to load OneDrive credentials: {e}")
'client_secret': self.client_secret, import traceback
'refresh_token': self._tokens['refresh_token'], traceback.print_exc()
'grant_type': 'refresh_token',
'scope': ' '.join(self.SCOPES)
}
async with httpx.AsyncClient() as client:
try:
response = await client.post(self.TOKEN_ENDPOINT, data=data)
response.raise_for_status()
token_data = response.json()
# Update tokens
self._tokens['token'] = token_data['access_token']
if 'refresh_token' in token_data:
self._tokens['refresh_token'] = token_data['refresh_token']
# Calculate expiry
expires_in = token_data.get('expires_in', 3600)
import datetime as dt
expiry = datetime.now() + dt.timedelta(seconds=expires_in)
self._tokens['expiry'] = expiry.isoformat()
await self._save_tokens()
print("Access token refreshed successfully")
return True
except Exception as e:
print(f"Failed to refresh token: {e}")
return False
async def is_authenticated(self) -> bool:
"""Check if we have valid credentials"""
if not self._tokens:
return False return False
# If token is expired, try to refresh async def _refresh_from_json_token(self, token_data: dict) -> bool:
if self._is_token_expired(): """
print("Token expired, attempting refresh...") Use refresh token from a legacy JSON file to get new tokens (one-time migration path).
if await self._refresh_access_token(): Prefer using an MSAL cache file and acquire_token_silent(); this path is only for migrating older files.
return True """
else: try:
refresh_token = token_data.get("refresh_token")
if not refresh_token:
logger.error("No refresh_token found in JSON file - cannot refresh")
logger.error("You must re-authenticate interactively to obtain a valid token")
return False return False
return True
def get_access_token(self) -> str: # Use only RESOURCE_SCOPES when refreshing (no reserved scopes)
"""Get current access token""" refresh_scopes = [s for s in self.RESOURCE_SCOPES if s not in self.RESERVED_SCOPES]
if not self._tokens or 'token' not in self._tokens: logger.debug(f"Using refresh token; refresh scopes = {refresh_scopes}")
raise ValueError("No access token available")
if self._is_token_expired():
raise ValueError("Access token expired and refresh failed")
return self._tokens['token']
async def revoke_credentials(self): result = self.app.acquire_token_by_refresh_token(
"""Clear tokens""" refresh_token=refresh_token,
self._tokens = None scopes=refresh_scopes,
if os.path.exists(self.token_file): )
os.remove(self.token_file)
# Keep these methods for compatibility with your existing OAuth flow if result and "access_token" in result:
def create_authorization_url(self, redirect_uri: str) -> str: logger.debug("Successfully refreshed token via legacy JSON path")
"""Create authorization URL for OAuth flow""" await self.save_cache()
from urllib.parse import urlencode
accounts = self.app.get_accounts()
params = { logger.debug(f"After refresh, found {len(accounts)} accounts")
'client_id': self.client_id, if accounts:
'response_type': 'code', self._current_account = accounts[0]
'redirect_uri': redirect_uri, logger.debug(f"Set current account after refresh: {self._current_account.get('username', 'no username')}")
'scope': ' '.join(self.SCOPES), return True
'response_mode': 'query'
# Error handling
err = (result or {}).get("error_description") or (result or {}).get("error") or "Unknown error"
logger.error(f"Refresh token failed: {err}")
if any(code in err for code in ("AADSTS70000", "invalid_grant", "interaction_required")):
logger.warning(
"Refresh denied due to unauthorized/expired scopes or invalid grant. "
"Delete the token file and perform interactive sign-in with correct scopes."
)
return False
except Exception as e:
logger.error(f"Exception during refresh from JSON token: {e}")
import traceback
traceback.print_exc()
return False
async def save_cache(self):
"""Persist the token cache to file."""
try:
# Ensure parent directory exists
parent = os.path.dirname(os.path.abspath(self.token_file))
if parent and not os.path.exists(parent):
os.makedirs(parent, exist_ok=True)
cache_data = self.token_cache.serialize()
if cache_data:
async with aiofiles.open(self.token_file, "w") as f:
await f.write(cache_data)
logger.debug(f"Token cache saved to {self.token_file}")
except Exception as e:
logger.error(f"Failed to save token cache: {e}")
def create_authorization_url(self, redirect_uri: str, state: Optional[str] = None) -> str:
"""Create authorization URL for OAuth flow."""
# Store redirect URI for later use in callback
self._redirect_uri = redirect_uri
kwargs: Dict[str, Any] = {
# Interactive auth includes offline_access
"scopes": self.AUTH_SCOPES,
"redirect_uri": redirect_uri,
"prompt": "consent", # ensure refresh token on first run
} }
if state:
auth_url = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize" kwargs["state"] = state # Optional CSRF protection
return f"{auth_url}?{urlencode(params)}"
auth_url = self.app.get_authorization_request_url(**kwargs)
logger.debug(f"Generated auth URL: {auth_url}")
logger.debug(f"Auth scopes: {self.AUTH_SCOPES}")
return auth_url
async def handle_authorization_callback( async def handle_authorization_callback(
self, authorization_code: str, redirect_uri: str self, authorization_code: str, redirect_uri: str
) -> bool: ) -> bool:
"""Handle OAuth callback and exchange code for tokens""" """Handle OAuth callback and exchange code for tokens."""
data = { try:
'client_id': self.client_id, result = self.app.acquire_token_by_authorization_code(
'client_secret': self.client_secret, authorization_code,
'code': authorization_code, scopes=self.AUTH_SCOPES, # same as authorize step
'grant_type': 'authorization_code', redirect_uri=redirect_uri,
'redirect_uri': redirect_uri, )
'scope': ' '.join(self.SCOPES)
}
async with httpx.AsyncClient() as client: if result and "access_token" in result:
try: accounts = self.app.get_accounts()
response = await client.post(self.TOKEN_ENDPOINT, data=data) if accounts:
response.raise_for_status() self._current_account = accounts[0]
token_data = response.json()
# Store tokens in our format await self.save_cache()
import datetime as dt logger.info("OneDrive OAuth authorization successful")
expires_in = token_data.get('expires_in', 3600)
expiry = datetime.now() + dt.timedelta(seconds=expires_in)
self._tokens = {
'token': token_data['access_token'],
'refresh_token': token_data['refresh_token'],
'scopes': self.SCOPES,
'expiry': expiry.isoformat()
}
await self._save_tokens()
print("Authorization successful, tokens saved")
return True return True
except Exception as e: error_msg = (result or {}).get("error_description") or (result or {}).get("error") or "Unknown error"
print(f"Authorization failed: {e}") logger.error(f"OneDrive OAuth authorization failed: {error_msg}")
return False return False
except Exception as e:
logger.error(f"Exception during OneDrive OAuth authorization: {e}")
return False
async def is_authenticated(self) -> bool:
"""Check if we have valid credentials."""
try:
# First try to load credentials if we haven't already
if not self._current_account:
await self.load_credentials()
# Try to get a token (MSAL will refresh if needed)
if self._current_account:
result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=self._current_account)
if result and "access_token" in result:
return True
else:
error_msg = (result or {}).get("error") or "No result returned"
logger.debug(f"Token acquisition failed for current account: {error_msg}")
# Fallback: try without specific account
result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=None)
if result and "access_token" in result:
accounts = self.app.get_accounts()
if accounts:
self._current_account = accounts[0]
return True
return False
except Exception as e:
logger.error(f"Authentication check failed: {e}")
return False
def get_access_token(self) -> str:
"""Get an access token for Microsoft Graph."""
try:
# Try with current account first
if self._current_account:
result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=self._current_account)
if result and "access_token" in result:
return result["access_token"]
# Fallback: try without specific account
result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=None)
if result and "access_token" in result:
return result["access_token"]
# If we get here, authentication has failed
error_msg = (result or {}).get("error_description") or (result or {}).get("error") or "No valid authentication"
raise ValueError(f"Failed to acquire access token: {error_msg}")
except Exception as e:
logger.error(f"Failed to get access token: {e}")
raise
async def revoke_credentials(self):
"""Clear token cache and remove token file."""
try:
# Clear in-memory state
self._current_account = None
self.token_cache = msal.SerializableTokenCache()
# Recreate MSAL app with fresh cache
self.app = msal.ConfidentialClientApplication(
client_id=self.client_id,
client_credential=self.client_secret,
authority=self.authority,
token_cache=self.token_cache,
)
# Remove token file
if os.path.exists(self.token_file):
os.remove(self.token_file)
logger.info(f"Removed OneDrive token file: {self.token_file}")
except Exception as e:
logger.error(f"Failed to revoke OneDrive credentials: {e}")
def get_service(self) -> str:
"""Return an access token (Graph client is just the bearer)."""
return self.get_access_token()

View file

@ -1,241 +1,564 @@
import logging
from pathlib import Path from pathlib import Path
from typing import List, Dict, Any, Optional
from urllib.parse import urlparse
from datetime import datetime
import httpx import httpx
import uuid
from datetime import datetime, timedelta
from typing import Dict, List, Any, Optional
from ..base import BaseConnector, ConnectorDocument, DocumentACL from ..base import BaseConnector, ConnectorDocument, DocumentACL
from .oauth import SharePointOAuth from .oauth import SharePointOAuth
logger = logging.getLogger(__name__)
class SharePointConnector(BaseConnector): class SharePointConnector(BaseConnector):
"""SharePoint Sites connector using Microsoft Graph API""" """SharePoint connector using MSAL-based OAuth for authentication"""
# Required BaseConnector class attributes
CLIENT_ID_ENV_VAR = "MICROSOFT_GRAPH_OAUTH_CLIENT_ID" CLIENT_ID_ENV_VAR = "MICROSOFT_GRAPH_OAUTH_CLIENT_ID"
CLIENT_SECRET_ENV_VAR = "MICROSOFT_GRAPH_OAUTH_CLIENT_SECRET" CLIENT_SECRET_ENV_VAR = "MICROSOFT_GRAPH_OAUTH_CLIENT_SECRET"
# Connector metadata # Connector metadata
CONNECTOR_NAME = "SharePoint" CONNECTOR_NAME = "SharePoint"
CONNECTOR_DESCRIPTION = "Connect to SharePoint sites to sync team documents" CONNECTOR_DESCRIPTION = "Connect to SharePoint to sync documents and files"
CONNECTOR_ICON = "sharepoint" CONNECTOR_ICON = "sharepoint"
def __init__(self, config: Dict[str, Any]): def __init__(self, config: Dict[str, Any]):
super().__init__(config) super().__init__(config) # Fix: Call parent init first
def __init__(self, config: Dict[str, Any]):
logger.debug(f"SharePoint connector __init__ called with config type: {type(config)}")
logger.debug(f"SharePoint connector __init__ config value: {config}")
# Ensure we always pass a valid config to the base class
if config is None:
logger.debug("Config was None, using empty dict")
config = {}
try:
logger.debug("Calling super().__init__")
super().__init__(config) # Now safe to call with empty dict instead of None
logger.debug("super().__init__ completed successfully")
except Exception as e:
logger.error(f"super().__init__ failed: {e}")
raise
# Initialize with defaults that allow the connector to be listed
self.client_id = None
self.client_secret = None
self.tenant_id = config.get("tenant_id", "common")
self.sharepoint_url = config.get("sharepoint_url")
self.redirect_uri = config.get("redirect_uri", "http://localhost")
# Try to get credentials, but don't fail if they're missing
try:
logger.debug("Attempting to get client_id")
self.client_id = self.get_client_id()
logger.debug(f"Got client_id: {self.client_id is not None}")
except Exception as e:
logger.debug(f"Failed to get client_id: {e}")
pass # Credentials not available, that's OK for listing
try:
logger.debug("Attempting to get client_secret")
self.client_secret = self.get_client_secret()
logger.debug(f"Got client_secret: {self.client_secret is not None}")
except Exception as e:
logger.debug(f"Failed to get client_secret: {e}")
pass # Credentials not available, that's OK for listing
# Token file setup
project_root = Path(__file__).resolve().parent.parent.parent.parent project_root = Path(__file__).resolve().parent.parent.parent.parent
token_file = config.get("token_file") or str(project_root / "onedrive_token.json") token_file = config.get("token_file") or str(project_root / "sharepoint_token.json")
self.oauth = SharePointOAuth( Path(token_file).parent.mkdir(parents=True, exist_ok=True)
client_id=self.get_client_id(),
client_secret=self.get_client_secret(), # Only initialize OAuth if we have credentials
token_file=token_file, if self.client_id and self.client_secret:
) connection_id = config.get("connection_id", "default")
self.subscription_id = config.get("subscription_id") or config.get(
"webhook_channel_id" # Use token_file from config if provided, otherwise generate one
) if config.get("token_file"):
self.base_url = "https://graph.microsoft.com/v1.0" oauth_token_file = config["token_file"]
# SharePoint site configuration
self.site_id = config.get("site_id") # Required for SharePoint
async def authenticate(self) -> bool:
if await self.oauth.is_authenticated():
self._authenticated = True
return True
return False
async def setup_subscription(self) -> str:
if not self._authenticated:
raise ValueError("Not authenticated")
webhook_url = self.config.get("webhook_url")
if not webhook_url:
raise ValueError("webhook_url required in config for subscriptions")
expiration = (datetime.utcnow() + timedelta(days=2)).isoformat() + "Z"
body = {
"changeType": "created,updated,deleted",
"notificationUrl": webhook_url,
"resource": f"/sites/{self.site_id}/drive/root",
"expirationDateTime": expiration,
"clientState": str(uuid.uuid4()),
}
token = self.oauth.get_access_token()
async with httpx.AsyncClient() as client:
resp = await client.post(
f"{self.base_url}/subscriptions",
json=body,
headers={"Authorization": f"Bearer {token}"},
)
resp.raise_for_status()
data = resp.json()
self.subscription_id = data["id"]
return self.subscription_id
async def list_files(
self, page_token: Optional[str] = None, limit: int = 100
) -> Dict[str, Any]:
if not self._authenticated:
raise ValueError("Not authenticated")
params = {"$top": str(limit)}
if page_token:
params["$skiptoken"] = page_token
token = self.oauth.get_access_token()
async with httpx.AsyncClient() as client:
resp = await client.get(
f"{self.base_url}/sites/{self.site_id}/drive/root/children",
params=params,
headers={"Authorization": f"Bearer {token}"},
)
resp.raise_for_status()
data = resp.json()
files = []
for item in data.get("value", []):
if item.get("file"):
files.append(
{
"id": item["id"],
"name": item["name"],
"mimeType": item.get("file", {}).get(
"mimeType", "application/octet-stream"
),
"webViewLink": item.get("webUrl"),
"createdTime": item.get("createdDateTime"),
"modifiedTime": item.get("lastModifiedDateTime"),
}
)
next_token = None
next_link = data.get("@odata.nextLink")
if next_link:
from urllib.parse import urlparse, parse_qs
parsed = urlparse(next_link)
next_token = parse_qs(parsed.query).get("$skiptoken", [None])[0]
return {"files": files, "nextPageToken": next_token}
async def get_file_content(self, file_id: str) -> ConnectorDocument:
if not self._authenticated:
raise ValueError("Not authenticated")
token = self.oauth.get_access_token()
headers = {"Authorization": f"Bearer {token}"}
async with httpx.AsyncClient() as client:
meta_resp = await client.get(
f"{self.base_url}/sites/{self.site_id}/drive/items/{file_id}",
headers=headers,
)
meta_resp.raise_for_status()
metadata = meta_resp.json()
content_resp = await client.get(
f"{self.base_url}/sites/{self.site_id}/drive/items/{file_id}/content",
headers=headers,
)
content = content_resp.content
# Handle the possibility of this being a redirect
if content_resp.status_code in (301, 302, 303, 307, 308):
redirect_url = content_resp.headers.get("Location")
if redirect_url:
content_resp = await client.get(redirect_url)
content_resp.raise_for_status()
content = content_resp.content
else: else:
content_resp.raise_for_status() oauth_token_file = f"sharepoint_token_{connection_id}.json"
perm_resp = await client.get( authority = f"https://login.microsoftonline.com/{self.tenant_id}" if self.tenant_id != "common" else "https://login.microsoftonline.com/common"
f"{self.base_url}/sites/{self.site_id}/drive/items/{file_id}/permissions",
headers=headers, self.oauth = SharePointOAuth(
client_id=self.client_id,
client_secret=self.client_secret,
token_file=oauth_token_file,
authority=authority
) )
perm_resp.raise_for_status() else:
permissions = perm_resp.json() self.oauth = None
acl = self._parse_permissions(metadata, permissions) # Track subscription ID for webhooks
modified = datetime.fromisoformat( self._subscription_id: Optional[str] = None
metadata["lastModifiedDateTime"].replace("Z", "+00:00")
).replace(tzinfo=None) # Add Graph API defaults similar to Google Drive flags
created = datetime.fromisoformat( self._graph_api_version = "v1.0"
metadata["createdDateTime"].replace("Z", "+00:00") self._default_params = {
).replace(tzinfo=None) "$select": "id,name,size,lastModifiedDateTime,createdDateTime,webUrl,file,folder,@microsoft.graph.downloadUrl"
}
document = ConnectorDocument(
id=metadata["id"], @property
filename=metadata["name"], def _graph_base_url(self) -> str:
mimetype=metadata.get("file", {}).get( """Base URL for Microsoft Graph API calls"""
"mimeType", "application/octet-stream" return f"https://graph.microsoft.com/{self._graph_api_version}"
),
content=content, def emit(self, doc: ConnectorDocument) -> None:
source_url=metadata.get("webUrl"), """
acl=acl, Emit a ConnectorDocument instance.
modified_time=modified, Override this method to integrate with your ingestion pipeline.
created_time=created, """
metadata={"size": metadata.get("size")}, logger.debug(f"Emitting SharePoint document: {doc.id} ({doc.filename})")
)
return document async def authenticate(self) -> bool:
"""Test authentication - BaseConnector interface"""
def _parse_permissions( logger.debug(f"SharePoint authenticate() called, oauth is None: {self.oauth is None}")
self, metadata: Dict[str, Any], permissions: Dict[str, Any] try:
) -> DocumentACL: if not self.oauth:
acl = DocumentACL() logger.debug("SharePoint authentication failed: OAuth not initialized")
owner = metadata.get("createdBy", {}).get("user", {}).get("email") self._authenticated = False
if owner: return False
acl.owner = owner
for perm in permissions.get("value", []): logger.debug("Loading SharePoint credentials...")
role = perm.get("roles", ["read"])[0] # Try to load existing credentials first
grantee = perm.get("grantedToV2") or perm.get("grantedTo") load_result = await self.oauth.load_credentials()
if not grantee: logger.debug(f"Load credentials result: {load_result}")
continue
user = grantee.get("user") logger.debug("Checking SharePoint authentication status...")
if user and user.get("email"): authenticated = await self.oauth.is_authenticated()
acl.user_permissions[user["email"]] = role logger.debug(f"SharePoint is_authenticated result: {authenticated}")
group = grantee.get("group")
if group and group.get("email"): self._authenticated = authenticated
acl.group_permissions[group["email"]] = role return authenticated
return acl except Exception as e:
logger.error(f"SharePoint authentication failed: {e}")
def handle_webhook_validation( import traceback
self, request_method: str, headers: Dict[str, str], query_params: Dict[str, str] traceback.print_exc()
) -> Optional[str]: self._authenticated = False
"""Handle Microsoft Graph webhook validation"""
if request_method == "GET":
validation_token = query_params.get("validationtoken") or query_params.get(
"validationToken"
)
if validation_token:
return validation_token
return None
def extract_webhook_channel_id(
self, payload: Dict[str, Any], headers: Dict[str, str]
) -> Optional[str]:
"""Extract SharePoint subscription ID from webhook payload"""
values = payload.get("value", [])
return values[0].get("subscriptionId") if values else None
async def handle_webhook(self, payload: Dict[str, Any]) -> List[str]:
values = payload.get("value", [])
file_ids = []
for item in values:
resource_data = item.get("resourceData", {})
file_id = resource_data.get("id")
if file_id:
file_ids.append(file_id)
return file_ids
async def cleanup_subscription(
self, subscription_id: str, resource_id: str = None
) -> bool:
if not self._authenticated:
return False return False
token = self.oauth.get_access_token()
async with httpx.AsyncClient() as client: def get_auth_url(self) -> str:
resp = await client.delete( """Get OAuth authorization URL"""
f"{self.base_url}/subscriptions/{subscription_id}", if not self.oauth:
headers={"Authorization": f"Bearer {token}"}, raise RuntimeError("SharePoint OAuth not initialized - missing credentials")
return self.oauth.create_authorization_url(self.redirect_uri)
async def handle_oauth_callback(self, auth_code: str) -> Dict[str, Any]:
"""Handle OAuth callback"""
if not self.oauth:
raise RuntimeError("SharePoint OAuth not initialized - missing credentials")
try:
success = await self.oauth.handle_authorization_callback(auth_code, self.redirect_uri)
if success:
self._authenticated = True
return {"status": "success"}
else:
raise ValueError("OAuth callback failed")
except Exception as e:
logger.error(f"OAuth callback failed: {e}")
raise
def sync_once(self) -> None:
"""
Perform a one-shot sync of SharePoint files and emit documents.
This method mirrors the Google Drive connector's sync_once functionality.
"""
import asyncio
async def _async_sync():
try:
# Get list of files
file_list = await self.list_files(max_files=1000) # Adjust as needed
files = file_list.get("files", [])
for file_info in files:
try:
file_id = file_info.get("id")
if not file_id:
continue
# Get full document content
doc = await self.get_file_content(file_id)
self.emit(doc)
except Exception as e:
logger.error(f"Failed to sync SharePoint file {file_info.get('name', 'unknown')}: {e}")
continue
except Exception as e:
logger.error(f"SharePoint sync_once failed: {e}")
raise
# Run the async sync
if hasattr(asyncio, 'run'):
asyncio.run(_async_sync())
else:
# Python < 3.7 compatibility
loop = asyncio.get_event_loop()
loop.run_until_complete(_async_sync())
async def setup_subscription(self) -> str:
"""Set up real-time subscription for file changes - BaseConnector interface"""
webhook_url = self.config.get('webhook_url')
if not webhook_url:
logger.warning("No webhook URL configured, skipping SharePoint subscription setup")
return "no-webhook-configured"
try:
# Ensure we're authenticated
if not await self.authenticate():
raise RuntimeError("SharePoint authentication failed during subscription setup")
token = self.oauth.get_access_token()
# Microsoft Graph subscription for SharePoint site
site_info = self._parse_sharepoint_url()
if site_info:
resource = f"sites/{site_info['host_name']}:/sites/{site_info['site_name']}:/drive/root"
else:
resource = "/me/drive/root"
subscription_data = {
"changeType": "created,updated,deleted",
"notificationUrl": f"{webhook_url}/webhook/sharepoint",
"resource": resource,
"expirationDateTime": self._get_subscription_expiry(),
"clientState": f"sharepoint_{self.tenant_id}"
}
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json"
}
url = f"{self._graph_base_url}/subscriptions"
async with httpx.AsyncClient() as client:
response = await client.post(url, json=subscription_data, headers=headers, timeout=30)
response.raise_for_status()
result = response.json()
subscription_id = result.get("id")
if subscription_id:
self._subscription_id = subscription_id
logger.info(f"SharePoint subscription created: {subscription_id}")
return subscription_id
else:
raise ValueError("No subscription ID returned from Microsoft Graph")
except Exception as e:
logger.error(f"Failed to setup SharePoint subscription: {e}")
raise
def _get_subscription_expiry(self) -> str:
"""Get subscription expiry time (max 3 days for Graph API)"""
from datetime import datetime, timedelta
expiry = datetime.utcnow() + timedelta(days=3) # 3 days max for Graph
return expiry.strftime("%Y-%m-%dT%H:%M:%S.%fZ")
def _parse_sharepoint_url(self) -> Optional[Dict[str, str]]:
"""Parse SharePoint URL to extract site information for Graph API"""
if not self.sharepoint_url:
return None
try:
parsed = urlparse(self.sharepoint_url)
# Extract hostname and site name from URL like: https://contoso.sharepoint.com/sites/teamsite
host_name = parsed.netloc
path_parts = parsed.path.strip('/').split('/')
if len(path_parts) >= 2 and path_parts[0] == 'sites':
site_name = path_parts[1]
return {
"host_name": host_name,
"site_name": site_name
}
except Exception as e:
logger.warning(f"Could not parse SharePoint URL {self.sharepoint_url}: {e}")
return None
async def list_files(self, page_token: Optional[str] = None, max_files: Optional[int] = None) -> Dict[str, Any]:
"""List all files using Microsoft Graph API - BaseConnector interface"""
try:
# Ensure authentication
if not await self.authenticate():
raise RuntimeError("SharePoint authentication failed during file listing")
files = []
max_files_value = max_files if max_files is not None else 100
# Build Graph API URL for the site or fallback to user's OneDrive
site_info = self._parse_sharepoint_url()
if site_info:
base_url = f"{self._graph_base_url}/sites/{site_info['host_name']}:/sites/{site_info['site_name']}:/drive/root/children"
else:
base_url = f"{self._graph_base_url}/me/drive/root/children"
params = dict(self._default_params)
params["$top"] = max_files_value
if page_token:
params["$skiptoken"] = page_token
response = await self._make_graph_request(base_url, params=params)
data = response.json()
items = data.get("value", [])
for item in items:
# Only include files, not folders
if item.get("file"):
files.append({
"id": item.get("id", ""),
"name": item.get("name", ""),
"path": f"/drive/items/{item.get('id')}",
"size": int(item.get("size", 0)),
"modified": item.get("lastModifiedDateTime"),
"created": item.get("createdDateTime"),
"mime_type": item.get("file", {}).get("mimeType", self._get_mime_type(item.get("name", ""))),
"url": item.get("webUrl", ""),
"download_url": item.get("@microsoft.graph.downloadUrl")
})
# Check for next page
next_page_token = None
next_link = data.get("@odata.nextLink")
if next_link:
from urllib.parse import urlparse, parse_qs
parsed = urlparse(next_link)
query_params = parse_qs(parsed.query)
if "$skiptoken" in query_params:
next_page_token = query_params["$skiptoken"][0]
return {
"files": files,
"next_page_token": next_page_token
}
except Exception as e:
logger.error(f"Failed to list SharePoint files: {e}")
return {"files": [], "next_page_token": None} # Return empty result instead of raising
async def get_file_content(self, file_id: str) -> ConnectorDocument:
"""Get file content and metadata - BaseConnector interface"""
try:
# Ensure authentication
if not await self.authenticate():
raise RuntimeError("SharePoint authentication failed during file content retrieval")
# First get file metadata using Graph API
file_metadata = await self._get_file_metadata_by_id(file_id)
if not file_metadata:
raise ValueError(f"File not found: {file_id}")
# Download file content
download_url = file_metadata.get("download_url")
if download_url:
content = await self._download_file_from_url(download_url)
else:
content = await self._download_file_content(file_id)
# Create ACL from metadata
acl = DocumentACL(
owner="", # Graph API requires additional calls for detailed permissions
user_permissions={},
group_permissions={}
) )
return resp.status_code in (200, 204)
# Parse dates
modified_time = self._parse_graph_date(file_metadata.get("modified"))
created_time = self._parse_graph_date(file_metadata.get("created"))
return ConnectorDocument(
id=file_id,
filename=file_metadata.get("name", ""),
mimetype=file_metadata.get("mime_type", "application/octet-stream"),
content=content,
source_url=file_metadata.get("url", ""),
acl=acl,
modified_time=modified_time,
created_time=created_time,
metadata={
"sharepoint_path": file_metadata.get("path", ""),
"sharepoint_url": self.sharepoint_url,
"size": file_metadata.get("size", 0)
}
)
except Exception as e:
logger.error(f"Failed to get SharePoint file content {file_id}: {e}")
raise
async def _get_file_metadata_by_id(self, file_id: str) -> Optional[Dict[str, Any]]:
"""Get file metadata by ID using Graph API"""
try:
# Try site-specific path first, then fallback to user drive
site_info = self._parse_sharepoint_url()
if site_info:
url = f"{self._graph_base_url}/sites/{site_info['host_name']}:/sites/{site_info['site_name']}:/drive/items/{file_id}"
else:
url = f"{self._graph_base_url}/me/drive/items/{file_id}"
params = dict(self._default_params)
response = await self._make_graph_request(url, params=params)
item = response.json()
if item.get("file"):
return {
"id": file_id,
"name": item.get("name", ""),
"path": f"/drive/items/{file_id}",
"size": int(item.get("size", 0)),
"modified": item.get("lastModifiedDateTime"),
"created": item.get("createdDateTime"),
"mime_type": item.get("file", {}).get("mimeType", self._get_mime_type(item.get("name", ""))),
"url": item.get("webUrl", ""),
"download_url": item.get("@microsoft.graph.downloadUrl")
}
return None
except Exception as e:
logger.error(f"Failed to get file metadata for {file_id}: {e}")
return None
async def _download_file_content(self, file_id: str) -> bytes:
"""Download file content by file ID using Graph API"""
try:
site_info = self._parse_sharepoint_url()
if site_info:
url = f"{self._graph_base_url}/sites/{site_info['host_name']}:/sites/{site_info['site_name']}:/drive/items/{file_id}/content"
else:
url = f"{self._graph_base_url}/me/drive/items/{file_id}/content"
token = self.oauth.get_access_token()
headers = {"Authorization": f"Bearer {token}"}
async with httpx.AsyncClient() as client:
response = await client.get(url, headers=headers, timeout=60)
response.raise_for_status()
return response.content
except Exception as e:
logger.error(f"Failed to download file content for {file_id}: {e}")
raise
async def _download_file_from_url(self, download_url: str) -> bytes:
"""Download file content from direct download URL"""
try:
async with httpx.AsyncClient() as client:
response = await client.get(download_url, timeout=60)
response.raise_for_status()
return response.content
except Exception as e:
logger.error(f"Failed to download from URL {download_url}: {e}")
raise
def _parse_graph_date(self, date_str: Optional[str]) -> datetime:
"""Parse Microsoft Graph date string to datetime"""
if not date_str:
return datetime.now()
try:
if date_str.endswith('Z'):
return datetime.fromisoformat(date_str[:-1]).replace(tzinfo=None)
else:
return datetime.fromisoformat(date_str.replace('T', ' '))
except (ValueError, AttributeError):
return datetime.now()
async def _make_graph_request(self, url: str, method: str = "GET",
data: Optional[Dict] = None, params: Optional[Dict] = None) -> httpx.Response:
"""Make authenticated API request to Microsoft Graph"""
token = self.oauth.get_access_token()
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json"
}
async with httpx.AsyncClient() as client:
if method.upper() == "GET":
response = await client.get(url, headers=headers, params=params, timeout=30)
elif method.upper() == "POST":
response = await client.post(url, headers=headers, json=data, timeout=30)
elif method.upper() == "DELETE":
response = await client.delete(url, headers=headers, timeout=30)
else:
raise ValueError(f"Unsupported HTTP method: {method}")
response.raise_for_status()
return response
def _get_mime_type(self, filename: str) -> str:
"""Get MIME type based on file extension"""
import mimetypes
mime_type, _ = mimetypes.guess_type(filename)
return mime_type or "application/octet-stream"
# Webhook methods - BaseConnector interface
def handle_webhook_validation(self, request_method: str, headers: Dict[str, str],
query_params: Dict[str, str]) -> Optional[str]:
"""Handle webhook validation (Graph API specific)"""
if request_method == "POST" and "validationToken" in query_params:
return query_params["validationToken"]
return None
def extract_webhook_channel_id(self, payload: Dict[str, Any],
headers: Dict[str, str]) -> Optional[str]:
"""Extract channel/subscription ID from webhook payload"""
notifications = payload.get("value", [])
if notifications:
return notifications[0].get("subscriptionId")
return None
async def handle_webhook(self, payload: Dict[str, Any]) -> List[str]:
"""Handle webhook notification and return affected file IDs"""
affected_files = []
# Process Microsoft Graph webhook payload
notifications = payload.get("value", [])
for notification in notifications:
resource = notification.get("resource")
if resource and "/drive/items/" in resource:
file_id = resource.split("/drive/items/")[-1]
affected_files.append(file_id)
return affected_files
async def cleanup_subscription(self, subscription_id: str) -> bool:
"""Clean up subscription - BaseConnector interface"""
if subscription_id == "no-webhook-configured":
logger.info("No subscription to cleanup (webhook was not configured)")
return True
try:
# Ensure authentication
if not await self.authenticate():
logger.error("SharePoint authentication failed during subscription cleanup")
return False
token = self.oauth.get_access_token()
headers = {"Authorization": f"Bearer {token}"}
url = f"{self._graph_base_url}/subscriptions/{subscription_id}"
async with httpx.AsyncClient() as client:
response = await client.delete(url, headers=headers, timeout=30)
if response.status_code in [200, 204, 404]:
logger.info(f"SharePoint subscription {subscription_id} cleaned up successfully")
return True
else:
logger.warning(f"Unexpected response cleaning up subscription: {response.status_code}")
return False
except Exception as e:
logger.error(f"Failed to cleanup SharePoint subscription {subscription_id}: {e}")
return False

View file

@ -1,19 +1,28 @@
import os import os
import json import json
import logging
from typing import Optional, Dict, Any
import aiofiles import aiofiles
from datetime import datetime import msal
import httpx
logger = logging.getLogger(__name__)
class SharePointOAuth: class SharePointOAuth:
"""Direct token management for SharePoint, bypassing MSAL cache format""" """Handles Microsoft Graph OAuth authentication flow following Google Drive pattern."""
SCOPES = [ # Reserved scopes that must NOT be sent on token or silent calls
"offline_access", RESERVED_SCOPES = {"openid", "profile", "offline_access"}
"Files.Read.All",
"Sites.Read.All",
]
# For PERSONAL Microsoft Accounts (OneDrive consumer):
# - Use AUTH_SCOPES for interactive auth (consent + refresh token issuance)
# - Use RESOURCE_SCOPES for acquire_token_silent / refresh paths
AUTH_SCOPES = ["User.Read", "Files.Read.All", "offline_access"]
RESOURCE_SCOPES = ["User.Read", "Files.Read.All"]
SCOPES = AUTH_SCOPES # Backward compatibility alias
# Kept for reference; MSAL derives endpoints from `authority`
AUTH_ENDPOINT = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize" AUTH_ENDPOINT = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize"
TOKEN_ENDPOINT = "https://login.microsoftonline.com/common/oauth2/v2.0/token" TOKEN_ENDPOINT = "https://login.microsoftonline.com/common/oauth2/v2.0/token"
@ -22,173 +31,299 @@ class SharePointOAuth:
client_id: str, client_id: str,
client_secret: str, client_secret: str,
token_file: str = "sharepoint_token.json", token_file: str = "sharepoint_token.json",
authority: str = "https://login.microsoftonline.com/common", # Keep for compatibility authority: str = "https://login.microsoftonline.com/common",
allow_json_refresh: bool = True,
): ):
"""
Initialize SharePointOAuth.
Args:
client_id: Azure AD application (client) ID.
client_secret: Azure AD application client secret.
token_file: Path to persisted token cache file (MSAL cache format).
authority: Usually "https://login.microsoftonline.com/common" for MSA + org,
or tenant-specific for work/school.
allow_json_refresh: If True, permit one-time migration from legacy flat JSON
{"access_token","refresh_token",...}. Otherwise refuse it.
"""
self.client_id = client_id self.client_id = client_id
self.client_secret = client_secret self.client_secret = client_secret
self.token_file = token_file self.token_file = token_file
self.authority = authority # Keep for compatibility but not used self.authority = authority
self._tokens = None self.allow_json_refresh = allow_json_refresh
self._load_tokens() self.token_cache = msal.SerializableTokenCache()
self._current_account = None
def _load_tokens(self): # Initialize MSAL Confidential Client
"""Load tokens from file""" self.app = msal.ConfidentialClientApplication(
if os.path.exists(self.token_file): client_id=self.client_id,
with open(self.token_file, "r") as f: client_credential=self.client_secret,
self._tokens = json.loads(f.read()) authority=self.authority,
print(f"Loaded tokens from {self.token_file}") token_cache=self.token_cache,
else: )
print(f"No token file found at {self.token_file}")
async def save_cache(self): async def load_credentials(self) -> bool:
"""Persist tokens to file (renamed for compatibility)""" """Load existing credentials from token file (async)."""
await self._save_tokens()
async def _save_tokens(self):
"""Save tokens to file"""
if self._tokens:
async with aiofiles.open(self.token_file, "w") as f:
await f.write(json.dumps(self._tokens, indent=2))
def _is_token_expired(self) -> bool:
"""Check if current access token is expired"""
if not self._tokens or 'expiry' not in self._tokens:
return True
expiry_str = self._tokens['expiry']
# Handle different expiry formats
try: try:
if expiry_str.endswith('Z'): logger.debug(f"SharePoint OAuth loading credentials from: {self.token_file}")
expiry_dt = datetime.fromisoformat(expiry_str[:-1]) if os.path.exists(self.token_file):
else: logger.debug(f"Token file exists, reading: {self.token_file}")
expiry_dt = datetime.fromisoformat(expiry_str)
# Read the token file
# Add 5-minute buffer async with aiofiles.open(self.token_file, "r") as f:
import datetime as dt cache_data = await f.read()
now = datetime.now() logger.debug(f"Read {len(cache_data)} chars from token file")
return now >= (expiry_dt - dt.timedelta(minutes=5))
except: if cache_data.strip():
return True # 1) Try legacy flat JSON first
try:
json_data = json.loads(cache_data)
if isinstance(json_data, dict) and "refresh_token" in json_data:
if self.allow_json_refresh:
logger.debug(
"Found legacy JSON refresh_token and allow_json_refresh=True; attempting migration refresh"
)
return await self._refresh_from_json_token(json_data)
else:
logger.warning(
"Token file contains a legacy JSON refresh_token, but allow_json_refresh=False. "
"Delete the file and re-auth."
)
return False
except json.JSONDecodeError:
logger.debug("Token file is not flat JSON; attempting MSAL cache format")
# 2) Try MSAL cache format
logger.debug("Attempting MSAL cache deserialization")
self.token_cache.deserialize(cache_data)
# Get accounts from loaded cache
accounts = self.app.get_accounts()
logger.debug(f"Found {len(accounts)} accounts in MSAL cache")
if accounts:
self._current_account = accounts[0]
logger.debug(f"Set current account: {self._current_account.get('username', 'no username')}")
# IMPORTANT: Use RESOURCE_SCOPES (no reserved scopes) for silent acquisition
result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=self._current_account)
logger.debug(f"Silent token acquisition result keys: {list(result.keys()) if result else 'None'}")
if result and "access_token" in result:
logger.debug("Silent token acquisition successful")
await self.save_cache()
return True
else:
error_msg = (result or {}).get("error") or "No result"
logger.warning(f"Silent token acquisition failed: {error_msg}")
else:
logger.debug(f"Token file {self.token_file} is empty")
else:
logger.debug(f"Token file does not exist: {self.token_file}")
async def _refresh_access_token(self) -> bool:
"""Refresh the access token using refresh token"""
if not self._tokens or 'refresh_token' not in self._tokens:
return False return False
data = { except Exception as e:
'client_id': self.client_id, logger.error(f"Failed to load SharePoint credentials: {e}")
'client_secret': self.client_secret, import traceback
'refresh_token': self._tokens['refresh_token'], traceback.print_exc()
'grant_type': 'refresh_token', return False
'scope': ' '.join(self.SCOPES)
}
async with httpx.AsyncClient() as client: async def _refresh_from_json_token(self, token_data: dict) -> bool:
try: """
response = await client.post(self.TOKEN_ENDPOINT, data=data) Use refresh token from a legacy JSON file to get new tokens (one-time migration path).
response.raise_for_status()
token_data = response.json()
# Update tokens Notes:
self._tokens['token'] = token_data['access_token'] - Prefer using an MSAL cache file and acquire_token_silent().
if 'refresh_token' in token_data: - This path is only for migrating older refresh_token JSON files.
self._tokens['refresh_token'] = token_data['refresh_token'] """
try:
# Calculate expiry refresh_token = token_data.get("refresh_token")
expires_in = token_data.get('expires_in', 3600) if not refresh_token:
import datetime as dt logger.error("No refresh_token found in JSON file - cannot refresh")
expiry = datetime.now() + dt.timedelta(seconds=expires_in) logger.error("You must re-authenticate interactively to obtain a valid token")
self._tokens['expiry'] = expiry.isoformat()
await self._save_tokens()
print("Access token refreshed successfully")
return True
except Exception as e:
print(f"Failed to refresh token: {e}")
return False return False
def create_authorization_url(self, redirect_uri: str) -> str: # Use only RESOURCE_SCOPES when refreshing (no reserved scopes)
"""Create authorization URL for OAuth flow""" refresh_scopes = [s for s in self.RESOURCE_SCOPES if s not in self.RESERVED_SCOPES]
from urllib.parse import urlencode logger.debug(f"Using refresh token; refresh scopes = {refresh_scopes}")
params = { result = self.app.acquire_token_by_refresh_token(
'client_id': self.client_id, refresh_token=refresh_token,
'response_type': 'code', scopes=refresh_scopes,
'redirect_uri': redirect_uri, )
'scope': ' '.join(self.SCOPES),
'response_mode': 'query' if result and "access_token" in result:
logger.debug("Successfully refreshed token via legacy JSON path")
await self.save_cache()
accounts = self.app.get_accounts()
logger.debug(f"After refresh, found {len(accounts)} accounts")
if accounts:
self._current_account = accounts[0]
logger.debug(f"Set current account after refresh: {self._current_account.get('username', 'no username')}")
return True
# Error handling
err = (result or {}).get("error_description") or (result or {}).get("error") or "Unknown error"
logger.error(f"Refresh token failed: {err}")
if any(code in err for code in ("AADSTS70000", "invalid_grant", "interaction_required")):
logger.warning(
"Refresh denied due to unauthorized/expired scopes or invalid grant. "
"Delete the token file and perform interactive sign-in with correct scopes."
)
return False
except Exception as e:
logger.error(f"Exception during refresh from JSON token: {e}")
import traceback
traceback.print_exc()
return False
async def save_cache(self):
"""Persist the token cache to file."""
try:
# Ensure parent directory exists
parent = os.path.dirname(os.path.abspath(self.token_file))
if parent and not os.path.exists(parent):
os.makedirs(parent, exist_ok=True)
cache_data = self.token_cache.serialize()
if cache_data:
async with aiofiles.open(self.token_file, "w") as f:
await f.write(cache_data)
logger.debug(f"Token cache saved to {self.token_file}")
except Exception as e:
logger.error(f"Failed to save token cache: {e}")
def create_authorization_url(self, redirect_uri: str, state: Optional[str] = None) -> str:
"""Create authorization URL for OAuth flow."""
# Store redirect URI for later use in callback
self._redirect_uri = redirect_uri
kwargs: Dict[str, Any] = {
# IMPORTANT: interactive auth includes offline_access
"scopes": self.AUTH_SCOPES,
"redirect_uri": redirect_uri,
"prompt": "consent", # ensure refresh token on first run
} }
if state:
auth_url = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize" kwargs["state"] = state # Optional CSRF protection
return f"{auth_url}?{urlencode(params)}"
auth_url = self.app.get_authorization_request_url(**kwargs)
logger.debug(f"Generated auth URL: {auth_url}")
logger.debug(f"Auth scopes: {self.AUTH_SCOPES}")
return auth_url
async def handle_authorization_callback( async def handle_authorization_callback(
self, authorization_code: str, redirect_uri: str self, authorization_code: str, redirect_uri: str
) -> bool: ) -> bool:
"""Handle OAuth callback and exchange code for tokens""" """Handle OAuth callback and exchange code for tokens."""
data = { try:
'client_id': self.client_id, # For code exchange, we pass the same auth scopes as used in the authorize step
'client_secret': self.client_secret, result = self.app.acquire_token_by_authorization_code(
'code': authorization_code, authorization_code,
'grant_type': 'authorization_code', scopes=self.AUTH_SCOPES,
'redirect_uri': redirect_uri, redirect_uri=redirect_uri,
'scope': ' '.join(self.SCOPES) )
}
async with httpx.AsyncClient() as client: if result and "access_token" in result:
try: # Store the account for future use
response = await client.post(self.TOKEN_ENDPOINT, data=data) accounts = self.app.get_accounts()
response.raise_for_status() if accounts:
token_data = response.json() self._current_account = accounts[0]
# Store tokens in our format await self.save_cache()
import datetime as dt logger.info("SharePoint OAuth authorization successful")
expires_in = token_data.get('expires_in', 3600)
expiry = datetime.now() + dt.timedelta(seconds=expires_in)
self._tokens = {
'token': token_data['access_token'],
'refresh_token': token_data['refresh_token'],
'scopes': self.SCOPES,
'expiry': expiry.isoformat()
}
await self._save_tokens()
print("Authorization successful, tokens saved")
return True return True
except Exception as e: error_msg = (result or {}).get("error_description") or (result or {}).get("error") or "Unknown error"
print(f"Authorization failed: {e}") logger.error(f"SharePoint OAuth authorization failed: {error_msg}")
return False
async def is_authenticated(self) -> bool:
"""Check if we have valid credentials"""
if not self._tokens:
return False return False
# If token is expired, try to refresh except Exception as e:
if self._is_token_expired(): logger.error(f"Exception during SharePoint OAuth authorization: {e}")
print("Token expired, attempting refresh...") return False
if await self._refresh_access_token():
async def is_authenticated(self) -> bool:
"""Check if we have valid credentials (simplified like Google Drive)."""
try:
# First try to load credentials if we haven't already
if not self._current_account:
await self.load_credentials()
# If we have an account, try to get a token (MSAL will refresh if needed)
if self._current_account:
# IMPORTANT: use RESOURCE_SCOPES here
result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=self._current_account)
if result and "access_token" in result:
return True
else:
error_msg = (result or {}).get("error") or "No result returned"
logger.debug(f"Token acquisition failed for current account: {error_msg}")
# Fallback: try without specific account
result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=None)
if result and "access_token" in result:
# Update current account if this worked
accounts = self.app.get_accounts()
if accounts:
self._current_account = accounts[0]
return True return True
else:
return False return False
return True except Exception as e:
logger.error(f"Authentication check failed: {e}")
return False
def get_access_token(self) -> str: def get_access_token(self) -> str:
"""Get current access token""" """Get an access token for Microsoft Graph (simplified like Google Drive)."""
if not self._tokens or 'token' not in self._tokens: try:
raise ValueError("No access token available") # Try with current account first
if self._current_account:
if self._is_token_expired(): result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=self._current_account)
raise ValueError("Access token expired and refresh failed") if result and "access_token" in result:
return result["access_token"]
return self._tokens['token']
# Fallback: try without specific account
result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=None)
if result and "access_token" in result:
return result["access_token"]
# If we get here, authentication has failed
error_msg = (result or {}).get("error_description") or (result or {}).get("error") or "No valid authentication"
raise ValueError(f"Failed to acquire access token: {error_msg}")
except Exception as e:
logger.error(f"Failed to get access token: {e}")
raise
async def revoke_credentials(self): async def revoke_credentials(self):
"""Clear tokens""" """Clear token cache and remove token file (like Google Drive)."""
self._tokens = None try:
if os.path.exists(self.token_file): # Clear in-memory state
os.remove(self.token_file) self._current_account = None
self.token_cache = msal.SerializableTokenCache()
# Recreate MSAL app with fresh cache
self.app = msal.ConfidentialClientApplication(
client_id=self.client_id,
client_credential=self.client_secret,
authority=self.authority,
token_cache=self.token_cache,
)
# Remove token file
if os.path.exists(self.token_file):
os.remove(self.token_file)
logger.info(f"Removed SharePoint token file: {self.token_file}")
except Exception as e:
logger.error(f"Failed to revoke SharePoint credentials: {e}")
def get_service(self) -> str:
"""Return an access token (Graph doesn't need a generated client like Google Drive)."""
return self.get_access_token()