diff --git a/.env.example b/.env.example index 9fda0544..cbb994e4 100644 --- a/.env.example +++ b/.env.example @@ -7,6 +7,9 @@ OPENSEARCH_PASSWORD=OSisgendb1! # make here https://console.cloud.google.com/apis/credentials GOOGLE_OAUTH_CLIENT_ID= GOOGLE_OAUTH_CLIENT_SECRET= +# Azure app registration credentials for SharePoint/OneDrive +MICROSOFT_GRAPH_OAUTH_CLIENT_ID= +MICROSOFT_GRAPH_OAUTH_CLIENT_SECRET= # Optional dns routable from google (etc.) to handle continous ingest (ngrok works) WEBHOOK_BASE_URL= OPENAI_API_KEY= diff --git a/frontend/src/app/connectors/page.tsx b/frontend/src/app/connectors/page.tsx index a7600f94..3516338d 100644 --- a/frontend/src/app/connectors/page.tsx +++ b/frontend/src/app/connectors/page.tsx @@ -1,6 +1,6 @@ "use client" -import { useState, useEffect, Suspense } from "react" +import { useState, useEffect, useCallback, Suspense } from "react" import { useSearchParams } from "next/navigation" import { Button } from "@/components/ui/button" import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card" @@ -52,23 +52,43 @@ function ConnectorsPage() { const [syncResults, setSyncResults] = useState<{[key: string]: SyncResult | null}>({}) const [maxFiles, setMaxFiles] = useState(10) - // Function definitions first - const checkConnectorStatuses = async () => { - // Initialize connectors list - setConnectors([ - { - id: "google_drive", - name: "Google Drive", - description: "Connect your Google Drive to automatically sync documents", - icon:
G
, - status: "not_connected", - type: "google_drive" - }, - ]) + // Helper function to get connector icon + const getConnectorIcon = (iconName: string) => { + const iconMap: { [key: string]: React.ReactElement } = { + 'google-drive':
G
, + 'sharepoint':
SP
, + 'onedrive':
OD
, + } + return iconMap[iconName] ||
?
+ } + // Function definitions first + const checkConnectorStatuses = useCallback(async () => { try { + // Fetch available connectors from backend + const connectorsResponse = await fetch('/api/connectors') + if (!connectorsResponse.ok) { + throw new Error('Failed to load connectors') + } + + const connectorsResult = await connectorsResponse.json() + const connectorTypes = Object.keys(connectorsResult.connectors) + + // Initialize connectors list with metadata from backend + const initialConnectors = connectorTypes + .filter(type => connectorsResult.connectors[type].available) // Only show available connectors + .map(type => ({ + id: type, + name: connectorsResult.connectors[type].name, + description: connectorsResult.connectors[type].description, + icon: getConnectorIcon(connectorsResult.connectors[type].icon), + status: "not_connected" as const, + type: type + })) + + setConnectors(initialConnectors) + // Check status for each connector type - const connectorTypes = ["google_drive"] for (const connectorType of connectorTypes) { const response = await fetch(`/api/connectors/${connectorType}/status`) @@ -92,7 +112,7 @@ function ConnectorsPage() { } catch (error) { console.error('Failed to check connector statuses:', error) } - } + }, [setConnectors]) const handleConnect = async (connector: Connector) => { setIsConnecting(connector.id) @@ -110,8 +130,8 @@ function ConnectorsPage() { 'Content-Type': 'application/json', }, body: JSON.stringify({ - provider: connector.type.replace('_drive', ''), // "google_drive" -> "google" - purpose: "data_source", + connector_type: connector.type, + purpose: "data_source", name: `${connector.name} Connection`, redirect_uri: redirectUri }), @@ -262,7 +282,7 @@ function ConnectorsPage() { url.searchParams.delete('oauth_success') window.history.replaceState({}, '', url.toString()) } - }, [searchParams, isAuthenticated]) + }, [searchParams, isAuthenticated, checkConnectorStatuses]) return (
diff --git a/frontend/src/app/knowledge-sources/page.tsx b/frontend/src/app/knowledge-sources/page.tsx index 30f8ce38..efc68d45 100644 --- a/frontend/src/app/knowledge-sources/page.tsx +++ b/frontend/src/app/knowledge-sources/page.tsx @@ -1,6 +1,6 @@ "use client" -import { useState, useEffect, Suspense } from "react" +import { useState, useEffect, useCallback, Suspense } from "react" import { useSearchParams } from "next/navigation" import { Button } from "@/components/ui/button" import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card" @@ -150,27 +150,59 @@ function KnowledgeSourcesPage() { } } - // Connector functions - const checkConnectorStatuses = async () => { - setConnectors([ - { - id: "google_drive", - name: "Google Drive", - description: "Connect your Google Drive to automatically sync documents", - icon: ( -
- G -
- ), - status: "not_connected", - type: "google_drive" - }, - ]) + // Helper function to get connector icon + const getConnectorIcon = (iconName: string) => { + const iconMap: { [key: string]: React.ReactElement } = { + 'google-drive': ( +
+ G +
+ ), + 'sharepoint': ( +
+ SP +
+ ), + 'onedrive': ( +
+ OD +
+ ), + } + return iconMap[iconName] || ( +
+ ? +
+ ) + } + // Connector functions + const checkConnectorStatuses = useCallback(async () => { try { - const connectorTypes = ["google_drive"] + // Fetch available connectors from backend + const connectorsResponse = await fetch('/api/connectors') + if (!connectorsResponse.ok) { + throw new Error('Failed to load connectors') + } + + const connectorsResult = await connectorsResponse.json() + const connectorTypes = Object.keys(connectorsResult.connectors) + + // Initialize connectors list with metadata from backend + const initialConnectors = connectorTypes + .filter(type => connectorsResult.connectors[type].available) // Only show available connectors + .map(type => ({ + id: type, + name: connectorsResult.connectors[type].name, + description: connectorsResult.connectors[type].description, + icon: getConnectorIcon(connectorsResult.connectors[type].icon), + status: "not_connected" as const, + type: type + })) + + setConnectors(initialConnectors) + + // Check status for each connector type for (const connectorType of connectorTypes) { const response = await fetch(`/api/connectors/${connectorType}/status`) @@ -194,18 +226,27 @@ function KnowledgeSourcesPage() { } catch (error) { console.error('Failed to check connector statuses:', error) } - } + }, []) const handleConnect = async (connector: Connector) => { setIsConnecting(connector.id) setSyncResults(prev => ({ ...prev, [connector.id]: null })) try { - const response = await fetch(`/api/connectors/${connector.type}/connect`, { + // Use the shared auth callback URL, same as connectors page + const redirectUri = `${window.location.origin}/auth/callback` + + const response = await fetch('/api/auth/init', { method: 'POST', headers: { 'Content-Type': 'application/json', }, + body: JSON.stringify({ + connector_type: connector.type, + purpose: "data_source", + name: `${connector.name} Connection`, + redirect_uri: redirectUri + }), }) if (response.ok) { @@ -305,7 +346,7 @@ function KnowledgeSourcesPage() { url.searchParams.delete('oauth_success') window.history.replaceState({}, '', url.toString()) } - }, [searchParams, isAuthenticated]) + }, [searchParams, isAuthenticated, checkConnectorStatuses]) // Fetch global stats using match-all wildcard const fetchStats = async () => { diff --git a/frontend/src/contexts/auth-context.tsx b/frontend/src/contexts/auth-context.tsx index aafb80b2..597dda99 100644 --- a/frontend/src/contexts/auth-context.tsx +++ b/frontend/src/contexts/auth-context.tsx @@ -68,7 +68,7 @@ export function AuthProvider({ children }: AuthProviderProps) { 'Content-Type': 'application/json', }, body: JSON.stringify({ - provider: 'google', + connector_type: 'google_drive', purpose: 'app_auth', name: 'App Authentication', redirect_uri: redirectUri diff --git a/frontend/src/contexts/knowledge-filter-context.tsx b/frontend/src/contexts/knowledge-filter-context.tsx index c78b34df..2b4068e5 100644 --- a/frontend/src/contexts/knowledge-filter-context.tsx +++ b/frontend/src/contexts/knowledge-filter-context.tsx @@ -1,6 +1,6 @@ "use client" -import React, { createContext, useContext, useState, useEffect, ReactNode } from 'react' +import React, { createContext, useContext, useState, ReactNode } from 'react' interface KnowledgeFilter { id: string @@ -61,9 +61,6 @@ export function KnowledgeFilterProvider({ children }: KnowledgeFilterProviderPro const parsed = JSON.parse(filter.query_data) as ParsedQueryData setParsedFilterData(parsed) - // Store in localStorage for persistence across page reloads - localStorage.setItem('selectedKnowledgeFilter', JSON.stringify(filter)) - // Auto-open panel when filter is selected setIsPanelOpen(true) } catch (error) { @@ -72,7 +69,6 @@ export function KnowledgeFilterProvider({ children }: KnowledgeFilterProviderPro } } else { setParsedFilterData(null) - localStorage.removeItem('selectedKnowledgeFilter') setIsPanelOpen(false) } } @@ -93,19 +89,6 @@ export function KnowledgeFilterProvider({ children }: KnowledgeFilterProviderPro setIsPanelOpen(false) // Close panel but keep filter selected } - // Load persisted filter on mount - useEffect(() => { - try { - const saved = localStorage.getItem('selectedKnowledgeFilter') - if (saved) { - const filter = JSON.parse(saved) as KnowledgeFilter - setSelectedFilter(filter) - } - } catch (error) { - console.error('Error loading persisted filter:', error) - localStorage.removeItem('selectedKnowledgeFilter') - } - }, []) const value: KnowledgeFilterContextType = { selectedFilter, diff --git a/pyproject.toml b/pyproject.toml index 1f109b88..437e96f0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,6 +12,7 @@ dependencies = [ "google-api-python-client>=2.143.0", "google-auth-httplib2>=0.2.0", "google-auth-oauthlib>=1.2.0", + "msal>=1.29.0", "httpx>=0.27.0", "opensearch-py[async]>=3.0.0", "pyjwt>=2.8.0", diff --git a/src/api/auth.py b/src/api/auth.py index fb9943dd..a56d4d67 100644 --- a/src/api/auth.py +++ b/src/api/auth.py @@ -5,16 +5,16 @@ async def auth_init(request: Request, auth_service, session_manager): """Initialize OAuth flow for authentication or data source connection""" try: data = await request.json() - provider = data.get("provider") + connector_type = data.get("connector_type") purpose = data.get("purpose", "data_source") - connection_name = data.get("name", f"{provider}_{purpose}") + connection_name = data.get("name", f"{connector_type}_{purpose}") redirect_uri = data.get("redirect_uri") user = getattr(request.state, 'user', None) user_id = user.user_id if user else None result = await auth_service.init_oauth( - provider, purpose, connection_name, redirect_uri, user_id + connector_type, purpose, connection_name, redirect_uri, user_id ) return JSONResponse(result) diff --git a/src/api/connectors.py b/src/api/connectors.py index 0c04b020..ba511f19 100644 --- a/src/api/connectors.py +++ b/src/api/connectors.py @@ -1,5 +1,14 @@ from starlette.requests import Request -from starlette.responses import JSONResponse +from starlette.responses import JSONResponse, PlainTextResponse + +async def list_connectors(request: Request, connector_service, session_manager): + """List available connector types with metadata""" + try: + connector_types = connector_service.connection_manager.get_available_connector_types() + return JSONResponse({"connectors": connector_types}) + except Exception as e: + print(f"Error listing connectors: {e}") + return JSONResponse({"error": str(e)}, status_code=500) async def connector_sync(request: Request, connector_service, session_manager): """Sync files from all active connections of a connector type""" @@ -85,7 +94,29 @@ async def connector_status(request: Request, connector_service, session_manager) async def connector_webhook(request: Request, connector_service, session_manager): """Handle webhook notifications from any connector type""" connector_type = request.path_params.get("connector_type") - + + # Handle webhook validation (connector-specific) + temp_config = {"token_file": "temp.json"} + from connectors.connection_manager import ConnectionConfig + temp_connection = ConnectionConfig( + connection_id="temp", + connector_type=connector_type, + name="temp", + config=temp_config + ) + try: + temp_connector = connector_service.connection_manager._create_connector(temp_connection) + validation_response = temp_connector.handle_webhook_validation( + request.method, + dict(request.headers), + dict(request.query_params) + ) + if validation_response: + return PlainTextResponse(validation_response) + except (NotImplementedError, ValueError): + # Connector type not found or validation not needed + pass + try: # Get the raw payload and headers payload = {} @@ -109,8 +140,13 @@ async def connector_webhook(request: Request, connector_service, session_manager print(f"[WEBHOOK] {connector_type} notification received") - # Extract channel/subscription ID from headers (Google Drive specific) - channel_id = headers.get('x-goog-channel-id') + # Extract channel/subscription ID using connector-specific method + try: + temp_connector = connector_service.connection_manager._create_connector(temp_connection) + channel_id = temp_connector.extract_webhook_channel_id(payload, headers) + except (NotImplementedError, ValueError): + channel_id = None + if not channel_id: print(f"[WEBHOOK] No channel ID found in {connector_type} webhook") return JSONResponse({"status": "ignored", "reason": "no_channel_id"}) @@ -132,7 +168,6 @@ async def connector_webhook(request: Request, connector_service, session_manager connector = await connector_service._get_connector(active_connections[0].connection_id) if connector: print(f"[WEBHOOK] Cancelling unknown subscription {channel_id}") - resource_id = headers.get('x-goog-resource-id') await connector.cleanup_subscription(channel_id, resource_id) print(f"[WEBHOOK] Successfully cancelled unknown subscription {channel_id}") diff --git a/src/api/search.py b/src/api/search.py index b3035b64..b97adecd 100644 --- a/src/api/search.py +++ b/src/api/search.py @@ -3,27 +3,25 @@ from starlette.responses import JSONResponse async def search(request: Request, search_service, session_manager): """Search for documents""" - payload = await request.json() - query = payload.get("query") - if not query: - return JSONResponse({"error": "Query is required"}, status_code=400) - - filters = payload.get("filters", {}) # Optional filters, defaults to empty dict - limit = payload.get("limit", 10) # Optional limit, defaults to 10 - score_threshold = payload.get("scoreThreshold", 0) # Optional score threshold, defaults to 0 - - user = request.state.user - # Extract JWT token from cookie for OpenSearch OIDC auth - jwt_token = request.cookies.get("auth_token") - - result = await search_service.search(query, user_id=user.user_id, jwt_token=jwt_token, filters=filters, limit=limit, score_threshold=score_threshold) - - # Return appropriate HTTP status codes - if result.get("success"): + try: + payload = await request.json() + query = payload.get("query") + if not query: + return JSONResponse({"error": "Query is required"}, status_code=400) + + filters = payload.get("filters", {}) # Optional filters, defaults to empty dict + limit = payload.get("limit", 10) # Optional limit, defaults to 10 + score_threshold = payload.get("scoreThreshold", 0) # Optional score threshold, defaults to 0 + + user = request.state.user + # Extract JWT token from cookie for OpenSearch OIDC auth + jwt_token = request.cookies.get("auth_token") + + result = await search_service.search(query, user_id=user.user_id, jwt_token=jwt_token, filters=filters, limit=limit, score_threshold=score_threshold) return JSONResponse(result, status_code=200) - else: - error_msg = result.get("error", "") + except Exception as e: + error_msg = str(e) if "AuthenticationException" in error_msg or "access denied" in error_msg.lower(): - return JSONResponse(result, status_code=403) + return JSONResponse({"error": error_msg}, status_code=403) else: - return JSONResponse(result, status_code=500) \ No newline at end of file + return JSONResponse({"error": error_msg}, status_code=500) \ No newline at end of file diff --git a/src/api/upload.py b/src/api/upload.py index 5b99861d..e15ad48a 100644 --- a/src/api/upload.py +++ b/src/api/upload.py @@ -4,22 +4,20 @@ from starlette.responses import JSONResponse async def upload(request: Request, document_service, session_manager): """Upload a single file""" - form = await request.form() - upload_file = form["file"] - user = request.state.user - jwt_token = request.cookies.get("auth_token") - - result = await document_service.process_upload_file(upload_file, owner_user_id=user.user_id, jwt_token=jwt_token) - - # Return appropriate HTTP status codes - if result.get("success"): + try: + form = await request.form() + upload_file = form["file"] + user = request.state.user + jwt_token = request.cookies.get("auth_token") + + result = await document_service.process_upload_file(upload_file, owner_user_id=user.user_id, jwt_token=jwt_token) return JSONResponse(result, status_code=201) # Created - else: - error_msg = result.get("error", "") + except Exception as e: + error_msg = str(e) if "AuthenticationException" in error_msg or "access denied" in error_msg.lower(): - return JSONResponse(result, status_code=403) + return JSONResponse({"error": error_msg}, status_code=403) else: - return JSONResponse(result, status_code=500) + return JSONResponse({"error": error_msg}, status_code=500) async def upload_path(request: Request, task_service, session_manager): """Upload all files from a directory path""" diff --git a/src/connectors/__init__.py b/src/connectors/__init__.py index 4f25233d..87d929ff 100644 --- a/src/connectors/__init__.py +++ b/src/connectors/__init__.py @@ -1,4 +1,6 @@ from .base import BaseConnector from .google_drive import GoogleDriveConnector +from .sharepoint import SharePointConnector +from .onedrive import OneDriveConnector -__all__ = ["BaseConnector", "GoogleDriveConnector"] \ No newline at end of file +__all__ = ["BaseConnector", "GoogleDriveConnector", "SharePointConnector", "OneDriveConnector"] diff --git a/src/connectors/base.py b/src/connectors/base.py index 44e01860..45f39257 100644 --- a/src/connectors/base.py +++ b/src/connectors/base.py @@ -54,6 +54,11 @@ class BaseConnector(ABC): CLIENT_ID_ENV_VAR: str = None CLIENT_SECRET_ENV_VAR: str = None + # Connector metadata for UI + CONNECTOR_NAME: str = None + CONNECTOR_DESCRIPTION: str = None + CONNECTOR_ICON: str = None # Icon identifier or emoji + def __init__(self, config: Dict[str, Any]): self.config = config self._authenticated = False @@ -105,6 +110,17 @@ class BaseConnector(ABC): """Handle webhook notification. Returns list of affected file IDs.""" pass + def handle_webhook_validation(self, request_method: str, headers: Dict[str, str], query_params: Dict[str, str]) -> Optional[str]: + """Handle webhook validation (e.g., for subscription setup). + Returns validation response if applicable, None otherwise. + Default implementation returns None (no validation needed).""" + return None + + def extract_webhook_channel_id(self, payload: Dict[str, Any], headers: Dict[str, str]) -> Optional[str]: + """Extract channel/subscription ID from webhook payload/headers. + Must be implemented by each connector.""" + raise NotImplementedError(f"{self.__class__.__name__} must implement extract_webhook_channel_id") + @abstractmethod async def cleanup_subscription(self, subscription_id: str) -> bool: """Clean up subscription""" diff --git a/src/connectors/connection_manager.py b/src/connectors/connection_manager.py index 208903b6..19799cb3 100644 --- a/src/connectors/connection_manager.py +++ b/src/connectors/connection_manager.py @@ -9,6 +9,8 @@ from pathlib import Path from .base import BaseConnector from .google_drive import GoogleDriveConnector +from .sharepoint import SharePointConnector +from .onedrive import OneDriveConnector @dataclass @@ -186,10 +188,54 @@ class ConnectionManager: return None + def get_available_connector_types(self) -> Dict[str, Dict[str, str]]: + """Get available connector types with their metadata""" + return { + "google_drive": { + "name": GoogleDriveConnector.CONNECTOR_NAME, + "description": GoogleDriveConnector.CONNECTOR_DESCRIPTION, + "icon": GoogleDriveConnector.CONNECTOR_ICON, + "available": self._is_connector_available("google_drive") + }, + "sharepoint": { + "name": SharePointConnector.CONNECTOR_NAME, + "description": SharePointConnector.CONNECTOR_DESCRIPTION, + "icon": SharePointConnector.CONNECTOR_ICON, + "available": self._is_connector_available("sharepoint") + }, + "onedrive": { + "name": OneDriveConnector.CONNECTOR_NAME, + "description": OneDriveConnector.CONNECTOR_DESCRIPTION, + "icon": OneDriveConnector.CONNECTOR_ICON, + "available": self._is_connector_available("onedrive") + } + } + + def _is_connector_available(self, connector_type: str) -> bool: + """Check if a connector type is available (has required env vars)""" + try: + temp_config = ConnectionConfig( + connection_id="temp", + connector_type=connector_type, + name="temp", + config={} + ) + connector = self._create_connector(temp_config) + # Try to get credentials to check if env vars are set + connector.get_client_id() + connector.get_client_secret() + return True + except (ValueError, NotImplementedError): + return False + def _create_connector(self, config: ConnectionConfig) -> BaseConnector: """Factory method to create connector instances""" if config.connector_type == "google_drive": return GoogleDriveConnector(config.config) + elif config.connector_type == "sharepoint": + return SharePointConnector(config.config) + elif config.connector_type == "onedrive": + return OneDriveConnector(config.config) elif config.connector_type == "box": # Future: BoxConnector(config.config) raise NotImplementedError("Box connector not implemented yet") diff --git a/src/connectors/google_drive/connector.py b/src/connectors/google_drive/connector.py index 912bf630..6ff0590e 100644 --- a/src/connectors/google_drive/connector.py +++ b/src/connectors/google_drive/connector.py @@ -133,6 +133,11 @@ class GoogleDriveConnector(BaseConnector): CLIENT_ID_ENV_VAR = "GOOGLE_OAUTH_CLIENT_ID" CLIENT_SECRET_ENV_VAR = "GOOGLE_OAUTH_CLIENT_SECRET" + # Connector metadata + CONNECTOR_NAME = "Google Drive" + CONNECTOR_DESCRIPTION = "Connect your Google Drive to automatically sync documents" + CONNECTOR_ICON = "google-drive" + # Supported file types that can be processed by docling SUPPORTED_MIMETYPES = { 'application/pdf', @@ -363,6 +368,10 @@ class GoogleDriveConnector(BaseConnector): group_permissions=group_permissions ) + def extract_webhook_channel_id(self, payload: Dict[str, Any], headers: Dict[str, str]) -> Optional[str]: + """Extract Google Drive channel ID from webhook headers""" + return headers.get('x-goog-channel-id') + async def handle_webhook(self, payload: Dict[str, Any]) -> List[str]: """Handle Google Drive webhook notification""" if not self._authenticated: diff --git a/src/connectors/google_drive/oauth.py b/src/connectors/google_drive/oauth.py index bcd013d1..bd22176c 100644 --- a/src/connectors/google_drive/oauth.py +++ b/src/connectors/google_drive/oauth.py @@ -13,10 +13,16 @@ class GoogleDriveOAuth: """Handles Google Drive OAuth authentication flow""" SCOPES = [ + 'openid', + 'email', + 'profile', 'https://www.googleapis.com/auth/drive.readonly', 'https://www.googleapis.com/auth/drive.metadata.readonly' ] + AUTH_ENDPOINT = "https://accounts.google.com/o/oauth2/v2/auth" + TOKEN_ENDPOINT = "https://oauth2.googleapis.com/token" + def __init__(self, client_id: str = None, client_secret: str = None, token_file: str = "token.json"): self.client_id = client_id self.client_secret = client_secret diff --git a/src/connectors/onedrive/__init__.py b/src/connectors/onedrive/__init__.py new file mode 100644 index 00000000..6d18c294 --- /dev/null +++ b/src/connectors/onedrive/__init__.py @@ -0,0 +1,4 @@ +from .connector import OneDriveConnector +from .oauth import OneDriveOAuth + +__all__ = ["OneDriveConnector", "OneDriveOAuth"] \ No newline at end of file diff --git a/src/connectors/onedrive/connector.py b/src/connectors/onedrive/connector.py new file mode 100644 index 00000000..f9ff27ef --- /dev/null +++ b/src/connectors/onedrive/connector.py @@ -0,0 +1,193 @@ +import httpx +import uuid +from datetime import datetime, timedelta +from typing import Dict, List, Any, Optional + +from ..base import BaseConnector, ConnectorDocument, DocumentACL +from .oauth import OneDriveOAuth + + +class OneDriveConnector(BaseConnector): + """OneDrive connector using Microsoft Graph API""" + + CLIENT_ID_ENV_VAR = "MICROSOFT_GRAPH_OAUTH_CLIENT_ID" + CLIENT_SECRET_ENV_VAR = "MICROSOFT_GRAPH_OAUTH_CLIENT_SECRET" + + # Connector metadata + CONNECTOR_NAME = "OneDrive" + CONNECTOR_DESCRIPTION = "Connect your personal OneDrive to sync documents" + CONNECTOR_ICON = "onedrive" + + def __init__(self, config: Dict[str, Any]): + super().__init__(config) + self.oauth = OneDriveOAuth( + client_id=self.get_client_id(), + client_secret=self.get_client_secret(), + token_file=config.get("token_file", "onedrive_token.json"), + ) + self.subscription_id = config.get("subscription_id") or config.get("webhook_channel_id") + self.base_url = "https://graph.microsoft.com/v1.0" + + async def authenticate(self) -> bool: + if await self.oauth.is_authenticated(): + self._authenticated = True + return True + return False + + async def setup_subscription(self) -> str: + if not self._authenticated: + raise ValueError("Not authenticated") + + webhook_url = self.config.get("webhook_url") + if not webhook_url: + raise ValueError("webhook_url required in config for subscriptions") + + expiration = (datetime.utcnow() + timedelta(days=2)).isoformat() + "Z" + body = { + "changeType": "created,updated,deleted", + "notificationUrl": webhook_url, + "resource": "/me/drive/root", + "expirationDateTime": expiration, + "clientState": str(uuid.uuid4()), + } + + token = self.oauth.get_access_token() + async with httpx.AsyncClient() as client: + resp = await client.post( + f"{self.base_url}/subscriptions", + json=body, + headers={"Authorization": f"Bearer {token}"}, + ) + resp.raise_for_status() + data = resp.json() + + self.subscription_id = data["id"] + return self.subscription_id + + async def list_files(self, page_token: Optional[str] = None, limit: int = 100) -> Dict[str, Any]: + if not self._authenticated: + raise ValueError("Not authenticated") + + params = {"$top": str(limit)} + if page_token: + params["$skiptoken"] = page_token + + token = self.oauth.get_access_token() + async with httpx.AsyncClient() as client: + resp = await client.get( + f"{self.base_url}/me/drive/root/children", + params=params, + headers={"Authorization": f"Bearer {token}"}, + ) + resp.raise_for_status() + data = resp.json() + + files = [] + for item in data.get("value", []): + if item.get("file"): + files.append({ + "id": item["id"], + "name": item["name"], + "mimeType": item.get("file", {}).get("mimeType", "application/octet-stream"), + "webViewLink": item.get("webUrl"), + "createdTime": item.get("createdDateTime"), + "modifiedTime": item.get("lastModifiedDateTime"), + }) + + next_token = None + next_link = data.get("@odata.nextLink") + if next_link: + from urllib.parse import urlparse, parse_qs + + parsed = urlparse(next_link) + next_token = parse_qs(parsed.query).get("$skiptoken", [None])[0] + + return {"files": files, "nextPageToken": next_token} + + async def get_file_content(self, file_id: str) -> ConnectorDocument: + if not self._authenticated: + raise ValueError("Not authenticated") + + token = self.oauth.get_access_token() + headers = {"Authorization": f"Bearer {token}"} + async with httpx.AsyncClient() as client: + meta_resp = await client.get(f"{self.base_url}/me/drive/items/{file_id}", headers=headers) + meta_resp.raise_for_status() + metadata = meta_resp.json() + + content_resp = await client.get(f"{self.base_url}/me/drive/items/{file_id}/content", headers=headers) + content_resp.raise_for_status() + content = content_resp.content + + perm_resp = await client.get(f"{self.base_url}/me/drive/items/{file_id}/permissions", headers=headers) + perm_resp.raise_for_status() + permissions = perm_resp.json() + + acl = self._parse_permissions(metadata, permissions) + modified = datetime.fromisoformat(metadata["lastModifiedDateTime"].replace("Z", "+00:00")).replace(tzinfo=None) + created = datetime.fromisoformat(metadata["createdDateTime"].replace("Z", "+00:00")).replace(tzinfo=None) + + document = ConnectorDocument( + id=metadata["id"], + filename=metadata["name"], + mimetype=metadata.get("file", {}).get("mimeType", "application/octet-stream"), + content=content, + source_url=metadata.get("webUrl"), + acl=acl, + modified_time=modified, + created_time=created, + metadata={"size": metadata.get("size")}, + ) + return document + + def _parse_permissions(self, metadata: Dict[str, Any], permissions: Dict[str, Any]) -> DocumentACL: + acl = DocumentACL() + owner = metadata.get("createdBy", {}).get("user", {}).get("email") + if owner: + acl.owner = owner + for perm in permissions.get("value", []): + role = perm.get("roles", ["read"])[0] + grantee = perm.get("grantedToV2") or perm.get("grantedTo") + if not grantee: + continue + user = grantee.get("user") + if user and user.get("email"): + acl.user_permissions[user["email"]] = role + group = grantee.get("group") + if group and group.get("email"): + acl.group_permissions[group["email"]] = role + return acl + + def handle_webhook_validation(self, request_method: str, headers: Dict[str, str], query_params: Dict[str, str]) -> Optional[str]: + """Handle Microsoft Graph webhook validation""" + if request_method == "GET": + validation_token = query_params.get("validationtoken") or query_params.get("validationToken") + if validation_token: + return validation_token + return None + + def extract_webhook_channel_id(self, payload: Dict[str, Any], headers: Dict[str, str]) -> Optional[str]: + """Extract SharePoint subscription ID from webhook payload""" + values = payload.get('value', []) + return values[0].get('subscriptionId') if values else None + + async def handle_webhook(self, payload: Dict[str, Any]) -> List[str]: + values = payload.get("value", []) + file_ids = [] + for item in values: + resource_data = item.get("resourceData", {}) + file_id = resource_data.get("id") + if file_id: + file_ids.append(file_id) + return file_ids + + async def cleanup_subscription(self, subscription_id: str, resource_id: str = None) -> bool: + if not self._authenticated: + return False + token = self.oauth.get_access_token() + async with httpx.AsyncClient() as client: + resp = await client.delete( + f"{self.base_url}/subscriptions/{subscription_id}", + headers={"Authorization": f"Bearer {token}"}, + ) + return resp.status_code in (200, 204) diff --git a/src/connectors/onedrive/oauth.py b/src/connectors/onedrive/oauth.py new file mode 100644 index 00000000..ff4f080c --- /dev/null +++ b/src/connectors/onedrive/oauth.py @@ -0,0 +1,83 @@ +import os +import aiofiles +from typing import Optional +import msal + + +class OneDriveOAuth: + """Handles Microsoft Graph OAuth authentication flow""" + + SCOPES = [ + "offline_access", + "Files.Read.All", + ] + + AUTH_ENDPOINT = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize" + TOKEN_ENDPOINT = "https://login.microsoftonline.com/common/oauth2/v2.0/token" + + def __init__(self, client_id: str, client_secret: str, token_file: str = "onedrive_token.json", authority: str = "https://login.microsoftonline.com/common"): + self.client_id = client_id + self.client_secret = client_secret + self.token_file = token_file + self.authority = authority + self.token_cache = msal.SerializableTokenCache() + + # Load existing cache if available + if os.path.exists(self.token_file): + with open(self.token_file, "r") as f: + self.token_cache.deserialize(f.read()) + + self.app = msal.ConfidentialClientApplication( + client_id=self.client_id, + client_credential=self.client_secret, + authority=self.authority, + token_cache=self.token_cache, + ) + + async def save_cache(self): + """Persist the token cache to file""" + async with aiofiles.open(self.token_file, "w") as f: + await f.write(self.token_cache.serialize()) + + def create_authorization_url(self, redirect_uri: str) -> str: + """Create authorization URL for OAuth flow""" + return self.app.get_authorization_request_url(self.SCOPES, redirect_uri=redirect_uri) + + async def handle_authorization_callback(self, authorization_code: str, redirect_uri: str) -> bool: + """Handle OAuth callback and exchange code for tokens""" + result = self.app.acquire_token_by_authorization_code( + authorization_code, + scopes=self.SCOPES, + redirect_uri=redirect_uri, + ) + if "access_token" in result: + await self.save_cache() + return True + raise ValueError(result.get("error_description") or "Authorization failed") + + async def is_authenticated(self) -> bool: + """Check if we have valid credentials""" + accounts = self.app.get_accounts() + if not accounts: + return False + result = self.app.acquire_token_silent(self.SCOPES, account=accounts[0]) + if "access_token" in result: + await self.save_cache() + return True + return False + + def get_access_token(self) -> str: + """Get an access token for Microsoft Graph""" + accounts = self.app.get_accounts() + if not accounts: + raise ValueError("Not authenticated") + result = self.app.acquire_token_silent(self.SCOPES, account=accounts[0]) + if "access_token" not in result: + raise ValueError(result.get("error_description") or "Failed to acquire access token") + return result["access_token"] + + async def revoke_credentials(self): + """Clear token cache and remove token file""" + self.token_cache.clear() + if os.path.exists(self.token_file): + os.remove(self.token_file) diff --git a/src/connectors/service.py b/src/connectors/service.py index c38a4c24..a8262217 100644 --- a/src/connectors/service.py +++ b/src/connectors/service.py @@ -5,6 +5,8 @@ from typing import Dict, Any, List, Optional from .base import BaseConnector, ConnectorDocument from .google_drive import GoogleDriveConnector +from .sharepoint import SharePointConnector +from .onedrive import OneDriveConnector from .connection_manager import ConnectionManager @@ -28,7 +30,7 @@ class ConnectorService: """Get a connector by connection ID""" return await self.connection_manager.get_connector(connection_id) - async def process_connector_document(self, document: ConnectorDocument, owner_user_id: str, jwt_token: str = None) -> Dict[str, Any]: + async def process_connector_document(self, document: ConnectorDocument, owner_user_id: str, connector_type: str, jwt_token: str = None) -> Dict[str, Any]: """Process a document from a connector using existing processing pipeline""" # Create temporary file from document content @@ -54,7 +56,7 @@ class ConnectorService: # If successfully indexed, update the indexed documents with connector metadata if result["status"] == "indexed": # Update all chunks with connector-specific metadata - await self._update_connector_metadata(document, owner_user_id, jwt_token) + await self._update_connector_metadata(document, owner_user_id, connector_type, jwt_token) return { **result, @@ -66,7 +68,7 @@ class ConnectorService: # Clean up temporary file os.unlink(tmp_file.name) - async def _update_connector_metadata(self, document: ConnectorDocument, owner_user_id: str, jwt_token: str = None): + async def _update_connector_metadata(self, document: ConnectorDocument, owner_user_id: str, connector_type: str, jwt_token: str = None): """Update indexed chunks with connector-specific metadata""" # Find all chunks for this document query = { @@ -86,7 +88,7 @@ class ConnectorService: update_body = { "doc": { "source_url": document.source_url, - "connector_type": "google_drive", # Could be passed as parameter + "connector_type": connector_type, # Additional ACL info beyond owner (already set by process_file_common) "allowed_users": document.acl.allowed_users, "allowed_groups": document.acl.allowed_groups, diff --git a/src/connectors/sharepoint/__init__.py b/src/connectors/sharepoint/__init__.py new file mode 100644 index 00000000..dfd631e5 --- /dev/null +++ b/src/connectors/sharepoint/__init__.py @@ -0,0 +1,4 @@ +from .connector import SharePointConnector +from .oauth import SharePointOAuth + +__all__ = ["SharePointConnector", "SharePointOAuth"] diff --git a/src/connectors/sharepoint/connector.py b/src/connectors/sharepoint/connector.py new file mode 100644 index 00000000..7b4731d6 --- /dev/null +++ b/src/connectors/sharepoint/connector.py @@ -0,0 +1,196 @@ +import httpx +import uuid +from datetime import datetime, timedelta +from typing import Dict, List, Any, Optional + +from ..base import BaseConnector, ConnectorDocument, DocumentACL +from .oauth import SharePointOAuth + + +class SharePointConnector(BaseConnector): + """SharePoint Sites connector using Microsoft Graph API""" + + CLIENT_ID_ENV_VAR = "MICROSOFT_GRAPH_OAUTH_CLIENT_ID" + CLIENT_SECRET_ENV_VAR = "MICROSOFT_GRAPH_OAUTH_CLIENT_SECRET" + + # Connector metadata + CONNECTOR_NAME = "SharePoint" + CONNECTOR_DESCRIPTION = "Connect to SharePoint sites to sync team documents" + CONNECTOR_ICON = "sharepoint" + + def __init__(self, config: Dict[str, Any]): + super().__init__(config) + self.oauth = SharePointOAuth( + client_id=self.get_client_id(), + client_secret=self.get_client_secret(), + token_file=config.get("token_file", "sharepoint_token.json"), + ) + self.subscription_id = config.get("subscription_id") or config.get("webhook_channel_id") + self.base_url = "https://graph.microsoft.com/v1.0" + + # SharePoint site configuration + self.site_id = config.get("site_id") # Required for SharePoint + + async def authenticate(self) -> bool: + if await self.oauth.is_authenticated(): + self._authenticated = True + return True + return False + + async def setup_subscription(self) -> str: + if not self._authenticated: + raise ValueError("Not authenticated") + + webhook_url = self.config.get("webhook_url") + if not webhook_url: + raise ValueError("webhook_url required in config for subscriptions") + + expiration = (datetime.utcnow() + timedelta(days=2)).isoformat() + "Z" + body = { + "changeType": "created,updated,deleted", + "notificationUrl": webhook_url, + "resource": f"/sites/{self.site_id}/drive/root", + "expirationDateTime": expiration, + "clientState": str(uuid.uuid4()), + } + + token = self.oauth.get_access_token() + async with httpx.AsyncClient() as client: + resp = await client.post( + f"{self.base_url}/subscriptions", + json=body, + headers={"Authorization": f"Bearer {token}"}, + ) + resp.raise_for_status() + data = resp.json() + + self.subscription_id = data["id"] + return self.subscription_id + + async def list_files(self, page_token: Optional[str] = None, limit: int = 100) -> Dict[str, Any]: + if not self._authenticated: + raise ValueError("Not authenticated") + + params = {"$top": str(limit)} + if page_token: + params["$skiptoken"] = page_token + + token = self.oauth.get_access_token() + async with httpx.AsyncClient() as client: + resp = await client.get( + f"{self.base_url}/sites/{self.site_id}/drive/root/children", + params=params, + headers={"Authorization": f"Bearer {token}"}, + ) + resp.raise_for_status() + data = resp.json() + + files = [] + for item in data.get("value", []): + if item.get("file"): + files.append({ + "id": item["id"], + "name": item["name"], + "mimeType": item.get("file", {}).get("mimeType", "application/octet-stream"), + "webViewLink": item.get("webUrl"), + "createdTime": item.get("createdDateTime"), + "modifiedTime": item.get("lastModifiedDateTime"), + }) + + next_token = None + next_link = data.get("@odata.nextLink") + if next_link: + from urllib.parse import urlparse, parse_qs + + parsed = urlparse(next_link) + next_token = parse_qs(parsed.query).get("$skiptoken", [None])[0] + + return {"files": files, "nextPageToken": next_token} + + async def get_file_content(self, file_id: str) -> ConnectorDocument: + if not self._authenticated: + raise ValueError("Not authenticated") + + token = self.oauth.get_access_token() + headers = {"Authorization": f"Bearer {token}"} + async with httpx.AsyncClient() as client: + meta_resp = await client.get(f"{self.base_url}/sites/{self.site_id}/drive/items/{file_id}", headers=headers) + meta_resp.raise_for_status() + metadata = meta_resp.json() + + content_resp = await client.get(f"{self.base_url}/sites/{self.site_id}/drive/items/{file_id}/content", headers=headers) + content_resp.raise_for_status() + content = content_resp.content + + perm_resp = await client.get(f"{self.base_url}/sites/{self.site_id}/drive/items/{file_id}/permissions", headers=headers) + perm_resp.raise_for_status() + permissions = perm_resp.json() + + acl = self._parse_permissions(metadata, permissions) + modified = datetime.fromisoformat(metadata["lastModifiedDateTime"].replace("Z", "+00:00")).replace(tzinfo=None) + created = datetime.fromisoformat(metadata["createdDateTime"].replace("Z", "+00:00")).replace(tzinfo=None) + + document = ConnectorDocument( + id=metadata["id"], + filename=metadata["name"], + mimetype=metadata.get("file", {}).get("mimeType", "application/octet-stream"), + content=content, + source_url=metadata.get("webUrl"), + acl=acl, + modified_time=modified, + created_time=created, + metadata={"size": metadata.get("size")}, + ) + return document + + def _parse_permissions(self, metadata: Dict[str, Any], permissions: Dict[str, Any]) -> DocumentACL: + acl = DocumentACL() + owner = metadata.get("createdBy", {}).get("user", {}).get("email") + if owner: + acl.owner = owner + for perm in permissions.get("value", []): + role = perm.get("roles", ["read"])[0] + grantee = perm.get("grantedToV2") or perm.get("grantedTo") + if not grantee: + continue + user = grantee.get("user") + if user and user.get("email"): + acl.user_permissions[user["email"]] = role + group = grantee.get("group") + if group and group.get("email"): + acl.group_permissions[group["email"]] = role + return acl + + def handle_webhook_validation(self, request_method: str, headers: Dict[str, str], query_params: Dict[str, str]) -> Optional[str]: + """Handle Microsoft Graph webhook validation""" + if request_method == "GET": + validation_token = query_params.get("validationtoken") or query_params.get("validationToken") + if validation_token: + return validation_token + return None + + def extract_webhook_channel_id(self, payload: Dict[str, Any], headers: Dict[str, str]) -> Optional[str]: + """Extract SharePoint subscription ID from webhook payload""" + values = payload.get('value', []) + return values[0].get('subscriptionId') if values else None + + async def handle_webhook(self, payload: Dict[str, Any]) -> List[str]: + values = payload.get("value", []) + file_ids = [] + for item in values: + resource_data = item.get("resourceData", {}) + file_id = resource_data.get("id") + if file_id: + file_ids.append(file_id) + return file_ids + + async def cleanup_subscription(self, subscription_id: str, resource_id: str = None) -> bool: + if not self._authenticated: + return False + token = self.oauth.get_access_token() + async with httpx.AsyncClient() as client: + resp = await client.delete( + f"{self.base_url}/subscriptions/{subscription_id}", + headers={"Authorization": f"Bearer {token}"}, + ) + return resp.status_code in (200, 204) diff --git a/src/connectors/sharepoint/oauth.py b/src/connectors/sharepoint/oauth.py new file mode 100644 index 00000000..fdb566f8 --- /dev/null +++ b/src/connectors/sharepoint/oauth.py @@ -0,0 +1,84 @@ +import os +import aiofiles +from typing import Optional +import msal + + +class SharePointOAuth: + """Handles Microsoft Graph OAuth authentication flow""" + + SCOPES = [ + "offline_access", + "Files.Read.All", + "Sites.Read.All", + ] + + AUTH_ENDPOINT = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize" + TOKEN_ENDPOINT = "https://login.microsoftonline.com/common/oauth2/v2.0/token" + + def __init__(self, client_id: str, client_secret: str, token_file: str = "sharepoint_token.json", authority: str = "https://login.microsoftonline.com/common"): + self.client_id = client_id + self.client_secret = client_secret + self.token_file = token_file + self.authority = authority + self.token_cache = msal.SerializableTokenCache() + + # Load existing cache if available + if os.path.exists(self.token_file): + with open(self.token_file, "r") as f: + self.token_cache.deserialize(f.read()) + + self.app = msal.ConfidentialClientApplication( + client_id=self.client_id, + client_credential=self.client_secret, + authority=self.authority, + token_cache=self.token_cache, + ) + + async def save_cache(self): + """Persist the token cache to file""" + async with aiofiles.open(self.token_file, "w") as f: + await f.write(self.token_cache.serialize()) + + def create_authorization_url(self, redirect_uri: str) -> str: + """Create authorization URL for OAuth flow""" + return self.app.get_authorization_request_url(self.SCOPES, redirect_uri=redirect_uri) + + async def handle_authorization_callback(self, authorization_code: str, redirect_uri: str) -> bool: + """Handle OAuth callback and exchange code for tokens""" + result = self.app.acquire_token_by_authorization_code( + authorization_code, + scopes=self.SCOPES, + redirect_uri=redirect_uri, + ) + if "access_token" in result: + await self.save_cache() + return True + raise ValueError(result.get("error_description") or "Authorization failed") + + async def is_authenticated(self) -> bool: + """Check if we have valid credentials""" + accounts = self.app.get_accounts() + if not accounts: + return False + result = self.app.acquire_token_silent(self.SCOPES, account=accounts[0]) + if "access_token" in result: + await self.save_cache() + return True + return False + + def get_access_token(self) -> str: + """Get an access token for Microsoft Graph""" + accounts = self.app.get_accounts() + if not accounts: + raise ValueError("Not authenticated") + result = self.app.acquire_token_silent(self.SCOPES, account=accounts[0]) + if "access_token" not in result: + raise ValueError(result.get("error_description") or "Failed to acquire access token") + return result["access_token"] + + async def revoke_credentials(self): + """Clear token cache and remove token file""" + self.token_cache.clear() + if os.path.exists(self.token_file): + os.remove(self.token_file) diff --git a/src/main.py b/src/main.py index 518ba7a6..3d9c5524 100644 --- a/src/main.py +++ b/src/main.py @@ -278,6 +278,13 @@ def create_app(): ), methods=["POST"]), # Connector endpoints + Route("/connectors", + require_auth(services['session_manager'])( + partial(connectors.list_connectors, + connector_service=services['connector_service'], + session_manager=services['session_manager']) + ), methods=["GET"]), + Route("/connectors/{connector_type}/sync", require_auth(services['session_manager'])( partial(connectors.connector_sync, diff --git a/src/models/processors.py b/src/models/processors.py index f8326c72..ef33984f 100644 --- a/src/models/processors.py +++ b/src/models/processors.py @@ -63,9 +63,10 @@ class ConnectorFileProcessor(TaskProcessor): file_id = item # item is the connector file ID file_info = self.file_info_map.get(file_id) - # Get the connector + # Get the connector and connection info connector = await self.connector_service.get_connector(self.connection_id) - if not connector: + connection = await self.connector_service.connection_manager.get_connection(self.connection_id) + if not connector or not connection: raise ValueError(f"Connection '{self.connection_id}' not found") # Get file content from connector (the connector will fetch metadata if needed) @@ -76,7 +77,7 @@ class ConnectorFileProcessor(TaskProcessor): raise ValueError("user_id not provided to ConnectorFileProcessor") # Process using existing pipeline - result = await self.connector_service.process_connector_document(document, self.user_id) + result = await self.connector_service.process_connector_document(document, self.user_id, connection.connector_type) file_task.status = TaskStatus.COMPLETED file_task.result = result diff --git a/src/services/auth_service.py b/src/services/auth_service.py index c22a60c9..70c1d8b7 100644 --- a/src/services/auth_service.py +++ b/src/services/auth_service.py @@ -8,6 +8,12 @@ from typing import Optional from config.settings import WEBHOOK_BASE_URL from session_manager import SessionManager +from connectors.google_drive.oauth import GoogleDriveOAuth +from connectors.onedrive.oauth import OneDriveOAuth +from connectors.sharepoint.oauth import SharePointOAuth +from connectors.google_drive import GoogleDriveConnector +from connectors.onedrive import OneDriveConnector +from connectors.sharepoint import SharePointConnector class AuthService: def __init__(self, session_manager: SessionManager, connector_service=None): @@ -15,11 +21,16 @@ class AuthService: self.connector_service = connector_service self.used_auth_codes = set() # Track used authorization codes - async def init_oauth(self, provider: str, purpose: str, connection_name: str, + async def init_oauth(self, connector_type: str, purpose: str, connection_name: str, redirect_uri: str, user_id: str = None) -> dict: """Initialize OAuth flow for authentication or data source connection""" - if provider != "google": - raise ValueError("Unsupported provider") + # Validate connector_type based on purpose + if purpose == "app_auth" and connector_type != "google_drive": + raise ValueError("Only Google login supported for app authentication") + elif purpose == "data_source" and connector_type not in ["google_drive", "onedrive", "sharepoint"]: + raise ValueError(f"Unsupported connector type: {connector_type}") + elif purpose not in ["app_auth", "data_source"]: + raise ValueError(f"Unsupported purpose: {purpose}") if not redirect_uri: raise ValueError("redirect_uri is required") @@ -27,20 +38,19 @@ class AuthService: # We'll validate client credentials when creating the connector # Create connection configuration - token_file = f"{provider}_{purpose}_{uuid.uuid4().hex[:8]}.json" + token_file = f"{connector_type}_{purpose}_{uuid.uuid4().hex[:8]}.json" config = { "token_file": token_file, - "provider": provider, + "connector_type": connector_type, "purpose": purpose, "redirect_uri": redirect_uri } # Only add webhook URL if WEBHOOK_BASE_URL is configured if WEBHOOK_BASE_URL: - config["webhook_url"] = f"{WEBHOOK_BASE_URL}/connectors/{provider}_drive/webhook" + config["webhook_url"] = f"{WEBHOOK_BASE_URL}/connectors/{connector_type}/webhook" - # Create connection in manager (always use _drive connector type as it handles OAuth) - connector_type = f"{provider}_drive" + # Create connection in manager connection_id = await self.connector_service.connection_manager.create_connection( connector_type=connector_type, name=connection_name, @@ -48,25 +58,38 @@ class AuthService: user_id=user_id ) - # Return OAuth configuration for client-side flow - scopes = [ - 'openid', 'email', 'profile', - 'https://www.googleapis.com/auth/drive.readonly', - 'https://www.googleapis.com/auth/drive.metadata.readonly' - ] - - # Get client_id from environment variable (same as connector would do) + # Get OAuth configuration from connector and OAuth classes import os - client_id = os.getenv("GOOGLE_OAUTH_CLIENT_ID") + + # Map connector types to their connector and OAuth classes + connector_class_map = { + "google_drive": (GoogleDriveConnector, GoogleDriveOAuth), + "onedrive": (OneDriveConnector, OneDriveOAuth), + "sharepoint": (SharePointConnector, SharePointOAuth) + } + + connector_class, oauth_class = connector_class_map.get(connector_type, (None, None)) + if not connector_class or not oauth_class: + raise ValueError(f"No classes found for connector type: {connector_type}") + + # Get scopes from OAuth class + scopes = oauth_class.SCOPES + + # Get endpoints from OAuth class + auth_endpoint = oauth_class.AUTH_ENDPOINT + token_endpoint = oauth_class.TOKEN_ENDPOINT + + # Get client_id from environment variable using connector's env var name + client_id = os.getenv(connector_class.CLIENT_ID_ENV_VAR) if not client_id: - raise ValueError("GOOGLE_OAUTH_CLIENT_ID environment variable not set") + raise ValueError(f"{connector_class.CLIENT_ID_ENV_VAR} environment variable not set") oauth_config = { "client_id": client_id, "scopes": scopes, "redirect_uri": redirect_uri, - "authorization_endpoint": "https://accounts.google.com/o/oauth2/v2/auth", - "token_endpoint": "https://oauth2.googleapis.com/token" + "authorization_endpoint": auth_endpoint, + "token_endpoint": token_endpoint } return { @@ -98,10 +121,23 @@ class AuthService: if not redirect_uri: raise ValueError("Redirect URI not found in connection config") - token_url = "https://oauth2.googleapis.com/token" - # Get connector to access client credentials + # Get connector to access client credentials and endpoints connector = self.connector_service.connection_manager._create_connector(connection_config) + # Get token endpoint from connector type + connector_type = connection_config.connector_type + connector_class_map = { + "google_drive": (GoogleDriveConnector, GoogleDriveOAuth), + "onedrive": (OneDriveConnector, OneDriveOAuth), + "sharepoint": (SharePointConnector, SharePointOAuth) + } + + connector_class, oauth_class = connector_class_map.get(connector_type, (None, None)) + if not connector_class or not oauth_class: + raise ValueError(f"No classes found for connector type: {connector_type}") + + token_url = oauth_class.TOKEN_ENDPOINT + token_payload = { "code": authorization_code, "client_id": connector.get_client_id(), @@ -119,14 +155,18 @@ class AuthService: token_data = token_response.json() # Store tokens in the token file (without client_secret) + # Use actual scopes from OAuth response + granted_scopes = token_data.get("scope") + if not granted_scopes: + raise ValueError(f"OAuth provider for {connector_type} did not return granted scopes in token response") + + # OAuth providers typically return scopes as a space-separated string + scopes = granted_scopes.split(" ") if isinstance(granted_scopes, str) else granted_scopes + token_file_data = { "token": token_data["access_token"], "refresh_token": token_data.get("refresh_token"), - "scopes": [ - "openid", "email", "profile", - "https://www.googleapis.com/auth/drive.readonly", - "https://www.googleapis.com/auth/drive.metadata.readonly" - ] + "scopes": scopes } # Add expiry if provided diff --git a/uv.lock b/uv.lock index 9505fee7..275fd012 100644 --- a/uv.lock +++ b/uv.lock @@ -989,6 +989,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/43/e3/7d92a15f894aa0c9c4b49b8ee9ac9850d6e63b03c9c32c0367a13ae62209/mpmath-1.3.0-py3-none-any.whl", hash = "sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c", size = 536198, upload-time = "2023-03-07T16:47:09.197Z" }, ] +[[package]] +name = "msal" +version = "1.33.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cryptography" }, + { name = "pyjwt", extra = ["crypto"] }, + { name = "requests" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d5/da/81acbe0c1fd7e9e4ec35f55dadeba9833a847b9a6ba2e2d1e4432da901dd/msal-1.33.0.tar.gz", hash = "sha256:836ad80faa3e25a7d71015c990ce61f704a87328b1e73bcbb0623a18cbf17510", size = 153801, upload-time = "2025-07-22T19:36:33.693Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/86/5b/fbc73e91f7727ae1e79b21ed833308e99dc11cc1cd3d4717f579775de5e9/msal-1.33.0-py3-none-any.whl", hash = "sha256:c0cd41cecf8eaed733ee7e3be9e040291eba53b0f262d3ae9c58f38b04244273", size = 116853, upload-time = "2025-07-22T19:36:32.403Z" }, +] + [[package]] name = "multidict" version = "6.6.3" @@ -1328,6 +1342,7 @@ dependencies = [ { name = "google-auth-httplib2" }, { name = "google-auth-oauthlib" }, { name = "httpx" }, + { name = "msal" }, { name = "opensearch-py", extra = ["async"] }, { name = "pyjwt" }, { name = "python-multipart" }, @@ -1346,6 +1361,7 @@ requires-dist = [ { name = "google-auth-httplib2", specifier = ">=0.2.0" }, { name = "google-auth-oauthlib", specifier = ">=1.2.0" }, { name = "httpx", specifier = ">=0.27.0" }, + { name = "msal", specifier = ">=1.29.0" }, { name = "opensearch-py", extras = ["async"], specifier = ">=3.0.0" }, { name = "pyjwt", specifier = ">=2.8.0" }, { name = "python-multipart", specifier = ">=0.0.20" }, @@ -1661,6 +1677,11 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/61/ad/689f02752eeec26aed679477e80e632ef1b682313be70793d798c1d5fc8f/PyJWT-2.10.1-py3-none-any.whl", hash = "sha256:dcdd193e30abefd5debf142f9adfcdd2b58004e644f25406ffaebd50bd98dacb", size = 22997, upload-time = "2024-11-28T03:43:27.893Z" }, ] +[package.optional-dependencies] +crypto = [ + { name = "cryptography" }, +] + [[package]] name = "pylatexenc" version = "2.10"