diff --git a/frontend/src/app/connectors/page.tsx b/frontend/src/app/connectors/page.tsx index 26e1a88a..ad70ec90 100644 --- a/frontend/src/app/connectors/page.tsx +++ b/frontend/src/app/connectors/page.tsx @@ -1,20 +1,14 @@ -"use client" +"use client"; import React, { useState } from "react"; -import { GoogleDrivePicker } from "@/components/google-drive-picker" -import { useTask } from "@/contexts/task-context" +import { UnifiedCloudPicker, CloudFile } from "@/components/cloud-picker"; +import { useTask } from "@/contexts/task-context"; -interface GoogleDriveFile { - id: string; - name: string; - mimeType: string; - webViewLink?: string; - iconLink?: string; -} +// CloudFile interface is now imported from the unified cloud picker export default function ConnectorsPage() { - const { addTask } = useTask() - const [selectedFiles, setSelectedFiles] = useState([]); + const { addTask } = useTask(); + const [selectedFiles, setSelectedFiles] = useState([]); const [isSyncing, setIsSyncing] = useState(false); const [syncResult, setSyncResult] = useState<{ processed?: number; @@ -25,16 +19,19 @@ export default function ConnectorsPage() { errors?: number; } | null>(null); - const handleFileSelection = (files: GoogleDriveFile[]) => { + const handleFileSelection = (files: CloudFile[]) => { setSelectedFiles(files); }; - const handleSync = async (connector: { connectionId: string, type: string }) => { - if (!connector.connectionId || selectedFiles.length === 0) return - - setIsSyncing(true) - setSyncResult(null) - + const handleSync = async (connector: { + connectionId: string; + type: string; + }) => { + if (!connector.connectionId || selectedFiles.length === 0) return; + + setIsSyncing(true); + setSyncResult(null); + try { const syncBody: { connection_id: string; @@ -42,54 +39,55 @@ export default function ConnectorsPage() { selected_files?: string[]; } = { connection_id: connector.connectionId, - selected_files: selectedFiles.map(file => file.id) - } - + selected_files: selectedFiles.map(file => file.id), + }; + const response = await fetch(`/api/connectors/${connector.type}/sync`, { - method: 'POST', + method: "POST", headers: { - 'Content-Type': 'application/json', + "Content-Type": "application/json", }, body: JSON.stringify(syncBody), - }) - - const result = await response.json() - + }); + + const result = await response.json(); + if (response.status === 201) { - const taskId = result.task_id + const taskId = result.task_id; if (taskId) { - addTask(taskId) - setSyncResult({ - processed: 0, + addTask(taskId); + setSyncResult({ + processed: 0, total: selectedFiles.length, - status: 'started' - }) + status: "started", + }); } } else if (response.ok) { - setSyncResult(result) + setSyncResult(result); } else { - console.error('Sync failed:', result.error) - setSyncResult({ error: result.error || 'Sync failed' }) + console.error("Sync failed:", result.error); + setSyncResult({ error: result.error || "Sync failed" }); } } catch (error) { - console.error('Sync error:', error) - setSyncResult({ error: 'Network error occurred' }) + console.error("Sync error:", error); + setSyncResult({ error: "Network error occurred" }); } finally { - setIsSyncing(false) + setIsSyncing(false); } }; return (

Connectors

- +

- This is a demo page for the Google Drive picker component. - For full connector functionality, visit the Settings page. + This is a demo page for the Google Drive picker component. For full + connector functionality, visit the Settings page.

- - 0 && (
- - + {syncResult && (
{syncResult.error ? (
Error: {syncResult.error}
- ) : syncResult.status === 'started' ? ( + ) : syncResult.status === "started" ? (
- Sync started for {syncResult.total} files. Check the task notification for progress. + Sync started for {syncResult.total} files. Check the task + notification for progress.
) : (
diff --git a/frontend/src/app/upload/[provider]/page.tsx b/frontend/src/app/upload/[provider]/page.tsx index ea2d319e..522a45b7 100644 --- a/frontend/src/app/upload/[provider]/page.tsx +++ b/frontend/src/app/upload/[provider]/page.tsx @@ -4,29 +4,12 @@ import { useState, useEffect } from "react"; import { useParams, useRouter } from "next/navigation"; import { Button } from "@/components/ui/button"; import { ArrowLeft, AlertCircle } from "lucide-react"; -import { GoogleDrivePicker } from "@/components/google-drive-picker"; -import { OneDrivePicker } from "@/components/onedrive-picker"; +import { UnifiedCloudPicker, CloudFile } from "@/components/cloud-picker"; +import type { IngestSettings } from "@/components/cloud-picker/types"; import { useTask } from "@/contexts/task-context"; import { Toast } from "@/components/ui/toast"; -interface GoogleDriveFile { - id: string; - name: string; - mimeType: string; - webViewLink?: string; - iconLink?: string; -} - -interface OneDriveFile { - id: string; - name: string; - mimeType?: string; - webUrl?: string; - driveItem?: { - file?: { mimeType: string }; - folder?: unknown; - }; -} +// CloudFile interface is now imported from the unified cloud picker interface CloudConnector { id: string; @@ -35,6 +18,7 @@ interface CloudConnector { status: "not_connected" | "connecting" | "connected" | "error"; type: string; connectionId?: string; + clientId: string; hasAccessToken: boolean; accessTokenError?: string; } @@ -49,14 +33,19 @@ export default function UploadProviderPage() { const [isLoading, setIsLoading] = useState(true); const [error, setError] = useState(null); const [accessToken, setAccessToken] = useState(null); - const [selectedFiles, setSelectedFiles] = useState< - GoogleDriveFile[] | OneDriveFile[] - >([]); + const [selectedFiles, setSelectedFiles] = useState([]); const [isIngesting, setIsIngesting] = useState(false); const [currentSyncTaskId, setCurrentSyncTaskId] = useState( null ); const [showSuccessToast, setShowSuccessToast] = useState(false); + const [ingestSettings, setIngestSettings] = useState({ + chunkSize: 1000, + chunkOverlap: 200, + ocr: false, + pictureDescriptions: false, + embeddingModel: "text-embedding-3-small", + }); useEffect(() => { const fetchConnectorInfo = async () => { @@ -129,6 +118,7 @@ export default function UploadProviderPage() { status: isConnected ? "connected" : "not_connected", type: provider, connectionId: activeConnection?.connection_id, + clientId: activeConnection?.client_id, hasAccessToken, accessTokenError, }); @@ -159,13 +149,6 @@ export default function UploadProviderPage() { // Task completed successfully, show toast and redirect setIsIngesting(false); setShowSuccessToast(true); - - // Dispatch knowledge updated event to refresh the knowledge table - console.log( - "Cloud provider task completed, dispatching knowledgeUpdated event" - ); - window.dispatchEvent(new CustomEvent("knowledgeUpdated")); - setTimeout(() => { router.push("/knowledge"); }, 2000); // 2 second delay to let user see toast @@ -176,20 +159,12 @@ export default function UploadProviderPage() { } }, [tasks, currentSyncTaskId, router]); - const handleFileSelected = (files: GoogleDriveFile[] | OneDriveFile[]) => { + const handleFileSelected = (files: CloudFile[]) => { setSelectedFiles(files); console.log(`Selected ${files.length} files from ${provider}:`, files); // You can add additional handling here like triggering sync, etc. }; - const handleGoogleDriveFileSelected = (files: GoogleDriveFile[]) => { - handleFileSelected(files); - }; - - const handleOneDriveFileSelected = (files: OneDriveFile[]) => { - handleFileSelected(files); - }; - const handleSync = async (connector: CloudConnector) => { if (!connector.connectionId || selectedFiles.length === 0) return; @@ -200,9 +175,11 @@ export default function UploadProviderPage() { connection_id: string; max_files?: number; selected_files?: string[]; + settings?: IngestSettings; } = { connection_id: connector.connectionId, selected_files: selectedFiles.map(file => file.id), + settings: ingestSettings, }; const response = await fetch(`/api/connectors/${connector.type}/sync`, { @@ -353,48 +330,49 @@ export default function UploadProviderPage() {
-

Add Cloud Knowledge

+

+ Add from {getProviderDisplayName()} +

- {connector.type === "google_drive" && ( - - )} - - {(connector.type === "onedrive" || connector.type === "sharepoint") && ( - - )} +
- {selectedFiles.length > 0 && ( -
-
- -
+
+
+ +
- )} +
{/* Success toast notification */} void - onFileSelected?: (files: GoogleDriveFile[] | OneDriveFile[], connectorType: string) => void + isOpen: boolean; + onOpenChange: (open: boolean) => void; + onFileSelected?: (files: CloudFile[], connectorType: string) => void; } -export function CloudConnectorsDialog({ - isOpen, +export function CloudConnectorsDialog({ + isOpen, onOpenChange, - onFileSelected + onFileSelected, }: CloudConnectorsDialogProps) { - const [connectors, setConnectors] = useState([]) - const [isLoading, setIsLoading] = useState(true) - const [selectedFiles, setSelectedFiles] = useState<{[connectorId: string]: GoogleDriveFile[] | OneDriveFile[]}>({}) - const [connectorAccessTokens, setConnectorAccessTokens] = useState<{[connectorType: string]: string}>({}) - const [activePickerType, setActivePickerType] = useState(null) + const [connectors, setConnectors] = useState([]); + const [isLoading, setIsLoading] = useState(true); + const [selectedFiles, setSelectedFiles] = useState<{ + [connectorId: string]: CloudFile[]; + }>({}); + const [connectorAccessTokens, setConnectorAccessTokens] = useState<{ + [connectorType: string]: string; + }>({}); + const [activePickerType, setActivePickerType] = useState(null); const getConnectorIcon = (iconName: string) => { const iconMap: { [key: string]: React.ReactElement } = { - 'google-drive': ( + "google-drive": (
G
), - 'sharepoint': ( + sharepoint: (
SP
), - 'onedrive': ( + onedrive: (
OD
), - } - return iconMap[iconName] || ( -
- ? -
- ) - } + }; + return ( + iconMap[iconName] || ( +
+ ? +
+ ) + ); + }; const fetchConnectorStatuses = useCallback(async () => { - if (!isOpen) return - - setIsLoading(true) + if (!isOpen) return; + + setIsLoading(true); try { // Fetch available connectors from backend - const connectorsResponse = await fetch('/api/connectors') + const connectorsResponse = await fetch("/api/connectors"); if (!connectorsResponse.ok) { - throw new Error('Failed to load connectors') + throw new Error("Failed to load connectors"); } - - const connectorsResult = await connectorsResponse.json() - const connectorTypes = Object.keys(connectorsResult.connectors) - + + const connectorsResult = await connectorsResponse.json(); + const connectorTypes = Object.keys(connectorsResult.connectors); + // Filter to only cloud connectors - const cloudConnectorTypes = connectorTypes.filter(type => - ['google_drive', 'onedrive', 'sharepoint'].includes(type) && - connectorsResult.connectors[type].available - ) - + const cloudConnectorTypes = connectorTypes.filter( + type => + ["google_drive", "onedrive", "sharepoint"].includes(type) && + connectorsResult.connectors[type].available + ); + // Initialize connectors list const initialConnectors = cloudConnectorTypes.map(type => ({ id: type, @@ -115,82 +105,95 @@ export function CloudConnectorsDialog({ status: "not_connected" as const, type: type, hasAccessToken: false, - accessTokenError: undefined - })) - - setConnectors(initialConnectors) + accessTokenError: undefined, + clientId: "", + })); + + setConnectors(initialConnectors); // Check status for each cloud connector type for (const connectorType of cloudConnectorTypes) { try { - const response = await fetch(`/api/connectors/${connectorType}/status`) + const response = await fetch( + `/api/connectors/${connectorType}/status` + ); if (response.ok) { - const data = await response.json() - const connections = data.connections || [] - const activeConnection = connections.find((conn: { connection_id: string; is_active: boolean }) => conn.is_active) - const isConnected = activeConnection !== undefined - - let hasAccessToken = false - let accessTokenError: string | undefined = undefined + const data = await response.json(); + const connections = data.connections || []; + const activeConnection = connections.find( + (conn: { connection_id: string; is_active: boolean }) => + conn.is_active + ); + const isConnected = activeConnection !== undefined; + + let hasAccessToken = false; + let accessTokenError: string | undefined = undefined; // Try to get access token for connected connectors if (isConnected && activeConnection) { try { - const tokenResponse = await fetch(`/api/connectors/${connectorType}/token?connection_id=${activeConnection.connection_id}`) + const tokenResponse = await fetch( + `/api/connectors/${connectorType}/token?connection_id=${activeConnection.connection_id}` + ); if (tokenResponse.ok) { - const tokenData = await tokenResponse.json() + const tokenData = await tokenResponse.json(); if (tokenData.access_token) { - hasAccessToken = true + hasAccessToken = true; setConnectorAccessTokens(prev => ({ ...prev, - [connectorType]: tokenData.access_token - })) + [connectorType]: tokenData.access_token, + })); } } else { - const errorData = await tokenResponse.json().catch(() => ({ error: 'Token unavailable' })) - accessTokenError = errorData.error || 'Access token unavailable' + const errorData = await tokenResponse + .json() + .catch(() => ({ error: "Token unavailable" })); + accessTokenError = + errorData.error || "Access token unavailable"; } } catch { - accessTokenError = 'Failed to fetch access token' + accessTokenError = "Failed to fetch access token"; } } - - setConnectors(prev => prev.map(c => - c.type === connectorType - ? { - ...c, - status: isConnected ? "connected" : "not_connected", - connectionId: activeConnection?.connection_id, - hasAccessToken, - accessTokenError - } - : c - )) + + setConnectors(prev => + prev.map(c => + c.type === connectorType + ? { + ...c, + status: isConnected ? "connected" : "not_connected", + connectionId: activeConnection?.connection_id, + clientId: activeConnection?.client_id, + hasAccessToken, + accessTokenError, + } + : c + ) + ); } } catch (error) { - console.error(`Failed to check status for ${connectorType}:`, error) + console.error(`Failed to check status for ${connectorType}:`, error); } } } catch (error) { - console.error('Failed to load cloud connectors:', error) + console.error("Failed to load cloud connectors:", error); } finally { - setIsLoading(false) + setIsLoading(false); } - }, [isOpen]) + }, [isOpen]); - const handleFileSelection = (connectorId: string, files: GoogleDriveFile[] | OneDriveFile[]) => { + const handleFileSelection = (connectorId: string, files: CloudFile[]) => { setSelectedFiles(prev => ({ ...prev, - [connectorId]: files - })) - - onFileSelected?.(files, connectorId) - } + [connectorId]: files, + })); + + onFileSelected?.(files, connectorId); + }; useEffect(() => { - fetchConnectorStatuses() - }, [fetchConnectorStatuses]) - + fetchConnectorStatuses(); + }, [fetchConnectorStatuses]); return ( @@ -218,19 +221,24 @@ export function CloudConnectorsDialog({
{connectors .filter(connector => connector.status === "connected") - .map((connector) => ( + .map(connector => (
)}
- ) -} \ No newline at end of file + ); +} diff --git a/frontend/src/components/cloud-picker/file-item.tsx b/frontend/src/components/cloud-picker/file-item.tsx new file mode 100644 index 00000000..3f6b5ab5 --- /dev/null +++ b/frontend/src/components/cloud-picker/file-item.tsx @@ -0,0 +1,67 @@ +"use client"; + +import { Badge } from "@/components/ui/badge"; +import { FileText, Folder, Trash } from "lucide-react"; +import { CloudFile } from "./types"; + +interface FileItemProps { + file: CloudFile; + onRemove: (fileId: string) => void; +} + +const getFileIcon = (mimeType: string) => { + if (mimeType.includes("folder")) { + return ; + } + return ; +}; + +const getMimeTypeLabel = (mimeType: string) => { + const typeMap: { [key: string]: string } = { + "application/vnd.google-apps.document": "Google Doc", + "application/vnd.google-apps.spreadsheet": "Google Sheet", + "application/vnd.google-apps.presentation": "Google Slides", + "application/vnd.google-apps.folder": "Folder", + "application/pdf": "PDF", + "text/plain": "Text", + "application/vnd.openxmlformats-officedocument.wordprocessingml.document": + "Word Doc", + "application/vnd.openxmlformats-officedocument.presentationml.presentation": + "PowerPoint", + }; + + return typeMap[mimeType] || mimeType?.split("/").pop() || "Document"; +}; + +const formatFileSize = (bytes?: number) => { + if (!bytes) return ""; + const sizes = ["B", "KB", "MB", "GB", "TB"]; + if (bytes === 0) return "0 B"; + const i = Math.floor(Math.log(bytes) / Math.log(1024)); + return `${(bytes / Math.pow(1024, i)).toFixed(1)} ${sizes[i]}`; +}; + +export const FileItem = ({ file, onRemove }: FileItemProps) => ( +
+
+ {getFileIcon(file.mimeType)} + {file.name} + + {getMimeTypeLabel(file.mimeType)} + +
+
+ + {formatFileSize(file.size) || "—"} + + + onRemove(file.id)} + /> +
+
+); diff --git a/frontend/src/components/cloud-picker/file-list.tsx b/frontend/src/components/cloud-picker/file-list.tsx new file mode 100644 index 00000000..775d78c4 --- /dev/null +++ b/frontend/src/components/cloud-picker/file-list.tsx @@ -0,0 +1,42 @@ +"use client"; + +import { Button } from "@/components/ui/button"; +import { CloudFile } from "./types"; +import { FileItem } from "./file-item"; + +interface FileListProps { + files: CloudFile[]; + onClearAll: () => void; + onRemoveFile: (fileId: string) => void; +} + +export const FileList = ({ + files, + onClearAll, + onRemoveFile, +}: FileListProps) => { + if (files.length === 0) { + return null; + } + + return ( +
+
+

Added files

+ +
+
+ {files.map(file => ( + + ))} +
+
+ ); +}; diff --git a/frontend/src/components/cloud-picker/index.ts b/frontend/src/components/cloud-picker/index.ts new file mode 100644 index 00000000..ef7aa74b --- /dev/null +++ b/frontend/src/components/cloud-picker/index.ts @@ -0,0 +1,7 @@ +export { UnifiedCloudPicker } from "./unified-cloud-picker"; +export { PickerHeader } from "./picker-header"; +export { FileList } from "./file-list"; +export { FileItem } from "./file-item"; +export { IngestSettings } from "./ingest-settings"; +export * from "./types"; +export * from "./provider-handlers"; diff --git a/frontend/src/components/cloud-picker/ingest-settings.tsx b/frontend/src/components/cloud-picker/ingest-settings.tsx new file mode 100644 index 00000000..d5843a2a --- /dev/null +++ b/frontend/src/components/cloud-picker/ingest-settings.tsx @@ -0,0 +1,139 @@ +"use client"; + +import { Input } from "@/components/ui/input"; +import { Switch } from "@/components/ui/switch"; +import { + Collapsible, + CollapsibleContent, + CollapsibleTrigger, +} from "@/components/ui/collapsible"; +import { ChevronRight, Info } from "lucide-react"; +import { IngestSettings as IngestSettingsType } from "./types"; + +interface IngestSettingsProps { + isOpen: boolean; + onOpenChange: (open: boolean) => void; + settings?: IngestSettingsType; + onSettingsChange?: (settings: IngestSettingsType) => void; +} + +export const IngestSettings = ({ + isOpen, + onOpenChange, + settings, + onSettingsChange, +}: IngestSettingsProps) => { + // Default settings + const defaultSettings: IngestSettingsType = { + chunkSize: 1000, + chunkOverlap: 200, + ocr: false, + pictureDescriptions: false, + embeddingModel: "text-embedding-3-small", + }; + + // Use provided settings or defaults + const currentSettings = settings || defaultSettings; + + const handleSettingsChange = (newSettings: Partial) => { + const updatedSettings = { ...currentSettings, ...newSettings }; + onSettingsChange?.(updatedSettings); + }; + + return ( + + +
+ + Ingest settings +
+
+ + +
+
+
+
Chunk size
+ + handleSettingsChange({ + chunkSize: parseInt(e.target.value) || 0, + }) + } + /> +
+
+
Chunk overlap
+ + handleSettingsChange({ + chunkOverlap: parseInt(e.target.value) || 0, + }) + } + /> +
+
+ +
+
+
OCR
+
+ Extracts text from images/PDFs. Ingest is slower when enabled. +
+
+ + handleSettingsChange({ ocr: checked }) + } + /> +
+ +
+
+
+ Picture descriptions +
+
+ Adds captions for images. Ingest is more expensive when enabled. +
+
+ + handleSettingsChange({ pictureDescriptions: checked }) + } + /> +
+ +
+
+ Embedding model + +
+ + handleSettingsChange({ embeddingModel: e.target.value }) + } + placeholder="text-embedding-3-small" + /> +
+
+
+
+ ); +}; diff --git a/frontend/src/components/cloud-picker/picker-header.tsx b/frontend/src/components/cloud-picker/picker-header.tsx new file mode 100644 index 00000000..05dcaebd --- /dev/null +++ b/frontend/src/components/cloud-picker/picker-header.tsx @@ -0,0 +1,70 @@ +"use client"; + +import { Button } from "@/components/ui/button"; +import { Card, CardContent } from "@/components/ui/card"; +import { Plus } from "lucide-react"; +import { CloudProvider } from "./types"; + +interface PickerHeaderProps { + provider: CloudProvider; + onAddFiles: () => void; + isPickerLoaded: boolean; + isPickerOpen: boolean; + accessToken?: string; + isAuthenticated: boolean; +} + +const getProviderName = (provider: CloudProvider): string => { + switch (provider) { + case "google_drive": + return "Google Drive"; + case "onedrive": + return "OneDrive"; + case "sharepoint": + return "SharePoint"; + default: + return "Cloud Storage"; + } +}; + +export const PickerHeader = ({ + provider, + onAddFiles, + isPickerLoaded, + isPickerOpen, + accessToken, + isAuthenticated, +}: PickerHeaderProps) => { + if (!isAuthenticated) { + return ( +
+ Please connect to {getProviderName(provider)} first to select specific + files. +
+ ); + } + + return ( + + +

+ Select files from {getProviderName(provider)} to ingest. +

+ +
+ csv, json, pdf,{" "} + +16 more{" "} + 150 MB max +
+
+
+ ); +}; diff --git a/frontend/src/components/cloud-picker/provider-handlers.ts b/frontend/src/components/cloud-picker/provider-handlers.ts new file mode 100644 index 00000000..4a39312f --- /dev/null +++ b/frontend/src/components/cloud-picker/provider-handlers.ts @@ -0,0 +1,245 @@ +"use client"; + +import { + CloudFile, + CloudProvider, + GooglePickerData, + GooglePickerDocument, +} from "./types"; + +export class GoogleDriveHandler { + private accessToken: string; + private onPickerStateChange?: (isOpen: boolean) => void; + + constructor( + accessToken: string, + onPickerStateChange?: (isOpen: boolean) => void + ) { + this.accessToken = accessToken; + this.onPickerStateChange = onPickerStateChange; + } + + async loadPickerApi(): Promise { + return new Promise(resolve => { + if (typeof window !== "undefined" && window.gapi) { + window.gapi.load("picker", { + callback: () => resolve(true), + onerror: () => resolve(false), + }); + } else { + // Load Google API script + const script = document.createElement("script"); + script.src = "https://apis.google.com/js/api.js"; + script.async = true; + script.defer = true; + script.onload = () => { + window.gapi.load("picker", { + callback: () => resolve(true), + onerror: () => resolve(false), + }); + }; + script.onerror = () => resolve(false); + document.head.appendChild(script); + } + }); + } + + openPicker(onFileSelected: (files: CloudFile[]) => void): void { + if (!window.google?.picker) { + return; + } + + try { + this.onPickerStateChange?.(true); + + const picker = new window.google.picker.PickerBuilder() + .addView(window.google.picker.ViewId.DOCS) + .addView(window.google.picker.ViewId.FOLDERS) + .setOAuthToken(this.accessToken) + .enableFeature(window.google.picker.Feature.MULTISELECT_ENABLED) + .setTitle("Select files from Google Drive") + .setCallback(data => this.pickerCallback(data, onFileSelected)) + .build(); + + picker.setVisible(true); + + // Apply z-index fix + setTimeout(() => { + const pickerElements = document.querySelectorAll( + ".picker-dialog, .goog-modalpopup" + ); + pickerElements.forEach(el => { + (el as HTMLElement).style.zIndex = "10000"; + }); + const bgElements = document.querySelectorAll( + ".picker-dialog-bg, .goog-modalpopup-bg" + ); + bgElements.forEach(el => { + (el as HTMLElement).style.zIndex = "9999"; + }); + }, 100); + } catch (error) { + console.error("Error creating picker:", error); + this.onPickerStateChange?.(false); + } + } + + private async pickerCallback( + data: GooglePickerData, + onFileSelected: (files: CloudFile[]) => void + ): Promise { + if (data.action === window.google.picker.Action.PICKED) { + const files: CloudFile[] = data.docs.map((doc: GooglePickerDocument) => ({ + id: doc[window.google.picker.Document.ID], + name: doc[window.google.picker.Document.NAME], + mimeType: doc[window.google.picker.Document.MIME_TYPE], + webViewLink: doc[window.google.picker.Document.URL], + iconLink: doc[window.google.picker.Document.ICON_URL], + size: doc["sizeBytes"] ? parseInt(doc["sizeBytes"]) : undefined, + modifiedTime: doc["lastEditedUtc"], + isFolder: + doc[window.google.picker.Document.MIME_TYPE] === + "application/vnd.google-apps.folder", + })); + + // Enrich with additional file data if needed + if (files.some(f => !f.size && !f.isFolder)) { + try { + const enrichedFiles = await Promise.all( + files.map(async file => { + if (!file.size && !file.isFolder) { + try { + const response = await fetch( + `https://www.googleapis.com/drive/v3/files/${file.id}?fields=size,modifiedTime`, + { + headers: { + Authorization: `Bearer ${this.accessToken}`, + }, + } + ); + if (response.ok) { + const fileDetails = await response.json(); + return { + ...file, + size: fileDetails.size + ? parseInt(fileDetails.size) + : undefined, + modifiedTime: + fileDetails.modifiedTime || file.modifiedTime, + }; + } + } catch (error) { + console.warn("Failed to fetch file details:", error); + } + } + return file; + }) + ); + onFileSelected(enrichedFiles); + } catch (error) { + console.warn("Failed to enrich file data:", error); + onFileSelected(files); + } + } else { + onFileSelected(files); + } + } + + this.onPickerStateChange?.(false); + } +} + +export class OneDriveHandler { + private accessToken: string; + private clientId: string; + private provider: CloudProvider; + private baseUrl?: string; + + constructor( + accessToken: string, + clientId: string, + provider: CloudProvider = "onedrive", + baseUrl?: string + ) { + this.accessToken = accessToken; + this.clientId = clientId; + this.provider = provider; + this.baseUrl = baseUrl; + } + + async loadPickerApi(): Promise { + return new Promise(resolve => { + const script = document.createElement("script"); + script.src = "https://js.live.net/v7.2/OneDrive.js"; + script.onload = () => resolve(true); + script.onerror = () => resolve(false); + document.head.appendChild(script); + }); + } + + openPicker(onFileSelected: (files: CloudFile[]) => void): void { + if (!window.OneDrive) { + return; + } + + window.OneDrive.open({ + clientId: this.clientId, + action: "query", + multiSelect: true, + advanced: { + endpointHint: "api.onedrive.com", + accessToken: this.accessToken, + }, + success: (response: any) => { + const newFiles: CloudFile[] = + response.value?.map((item: any, index: number) => ({ + id: item.id, + name: + item.name || + `${this.getProviderName()} File ${index + 1} (${item.id.slice( + -8 + )})`, + mimeType: item.file?.mimeType || "application/octet-stream", + webUrl: item.webUrl || "", + downloadUrl: item["@microsoft.graph.downloadUrl"] || "", + size: item.size, + modifiedTime: item.lastModifiedDateTime, + isFolder: !!item.folder, + })) || []; + + onFileSelected(newFiles); + }, + cancel: () => { + console.log("Picker cancelled"); + }, + error: (error: any) => { + console.error("Picker error:", error); + }, + }); + } + + private getProviderName(): string { + return this.provider === "sharepoint" ? "SharePoint" : "OneDrive"; + } +} + +export const createProviderHandler = ( + provider: CloudProvider, + accessToken: string, + onPickerStateChange?: (isOpen: boolean) => void, + clientId?: string, + baseUrl?: string +) => { + switch (provider) { + case "google_drive": + return new GoogleDriveHandler(accessToken, onPickerStateChange); + case "onedrive": + case "sharepoint": + if (!clientId) { + throw new Error("Client ID required for OneDrive/SharePoint"); + } + return new OneDriveHandler(accessToken, clientId, provider, baseUrl); + default: + throw new Error(`Unsupported provider: ${provider}`); + } +}; diff --git a/frontend/src/components/cloud-picker/types.ts b/frontend/src/components/cloud-picker/types.ts new file mode 100644 index 00000000..ca346bf0 --- /dev/null +++ b/frontend/src/components/cloud-picker/types.ts @@ -0,0 +1,106 @@ +export interface CloudFile { + id: string; + name: string; + mimeType: string; + webViewLink?: string; + iconLink?: string; + size?: number; + modifiedTime?: string; + isFolder?: boolean; + webUrl?: string; + downloadUrl?: string; +} + +export type CloudProvider = "google_drive" | "onedrive" | "sharepoint"; + +export interface UnifiedCloudPickerProps { + provider: CloudProvider; + onFileSelected: (files: CloudFile[]) => void; + selectedFiles?: CloudFile[]; + isAuthenticated: boolean; + accessToken?: string; + onPickerStateChange?: (isOpen: boolean) => void; + // OneDrive/SharePoint specific props + clientId?: string; + baseUrl?: string; + // Ingest settings + onSettingsChange?: (settings: IngestSettings) => void; +} + +export interface GoogleAPI { + load: ( + api: string, + options: { callback: () => void; onerror?: () => void } + ) => void; +} + +export interface GooglePickerData { + action: string; + docs: GooglePickerDocument[]; +} + +export interface GooglePickerDocument { + [key: string]: string; +} + +declare global { + interface Window { + gapi: GoogleAPI; + google: { + picker: { + api: { + load: (callback: () => void) => void; + }; + PickerBuilder: new () => GooglePickerBuilder; + ViewId: { + DOCS: string; + FOLDERS: string; + DOCS_IMAGES_AND_VIDEOS: string; + DOCUMENTS: string; + PRESENTATIONS: string; + SPREADSHEETS: string; + }; + Feature: { + MULTISELECT_ENABLED: string; + NAV_HIDDEN: string; + SIMPLE_UPLOAD_ENABLED: string; + }; + Action: { + PICKED: string; + CANCEL: string; + }; + Document: { + ID: string; + NAME: string; + MIME_TYPE: string; + URL: string; + ICON_URL: string; + }; + }; + }; + OneDrive?: any; + } +} + +export interface GooglePickerBuilder { + addView: (view: string) => GooglePickerBuilder; + setOAuthToken: (token: string) => GooglePickerBuilder; + setCallback: ( + callback: (data: GooglePickerData) => void + ) => GooglePickerBuilder; + enableFeature: (feature: string) => GooglePickerBuilder; + setTitle: (title: string) => GooglePickerBuilder; + build: () => GooglePicker; +} + +export interface GooglePicker { + setVisible: (visible: boolean) => void; +} + +export interface IngestSettings { + chunkSize: number; + chunkOverlap: number; + ocr: boolean; + pictureDescriptions: boolean; + embeddingModel: string; +} diff --git a/frontend/src/components/cloud-picker/unified-cloud-picker.tsx b/frontend/src/components/cloud-picker/unified-cloud-picker.tsx new file mode 100644 index 00000000..fd77698f --- /dev/null +++ b/frontend/src/components/cloud-picker/unified-cloud-picker.tsx @@ -0,0 +1,195 @@ +"use client"; + +import { useState, useEffect } from "react"; +import { + UnifiedCloudPickerProps, + CloudFile, + IngestSettings as IngestSettingsType, +} from "./types"; +import { PickerHeader } from "./picker-header"; +import { FileList } from "./file-list"; +import { IngestSettings } from "./ingest-settings"; +import { createProviderHandler } from "./provider-handlers"; + +export const UnifiedCloudPicker = ({ + provider, + onFileSelected, + selectedFiles = [], + isAuthenticated, + accessToken, + onPickerStateChange, + clientId, + baseUrl, + onSettingsChange, +}: UnifiedCloudPickerProps) => { + const [isPickerLoaded, setIsPickerLoaded] = useState(false); + const [isPickerOpen, setIsPickerOpen] = useState(false); + const [isIngestSettingsOpen, setIsIngestSettingsOpen] = useState(false); + const [isLoadingBaseUrl, setIsLoadingBaseUrl] = useState(false); + const [autoBaseUrl, setAutoBaseUrl] = useState(undefined); + + // Settings state with defaults + const [ingestSettings, setIngestSettings] = useState({ + chunkSize: 1000, + chunkOverlap: 200, + ocr: false, + pictureDescriptions: false, + embeddingModel: "text-embedding-3-small", + }); + + // Handle settings changes and notify parent + const handleSettingsChange = (newSettings: IngestSettingsType) => { + setIngestSettings(newSettings); + onSettingsChange?.(newSettings); + }; + + const effectiveBaseUrl = baseUrl || autoBaseUrl; + + // Auto-detect base URL for OneDrive personal accounts + useEffect(() => { + if ( + (provider === "onedrive" || provider === "sharepoint") && + !baseUrl && + accessToken && + !autoBaseUrl + ) { + const getBaseUrl = async () => { + setIsLoadingBaseUrl(true); + try { + setAutoBaseUrl("https://onedrive.live.com/picker"); + } catch (error) { + console.error("Auto-detect baseUrl failed:", error); + } finally { + setIsLoadingBaseUrl(false); + } + }; + + getBaseUrl(); + } + }, [accessToken, baseUrl, autoBaseUrl, provider]); + + // Load picker API + useEffect(() => { + if (!accessToken || !isAuthenticated) return; + + const loadApi = async () => { + try { + const handler = createProviderHandler( + provider, + accessToken, + onPickerStateChange, + clientId, + effectiveBaseUrl + ); + const loaded = await handler.loadPickerApi(); + setIsPickerLoaded(loaded); + } catch (error) { + console.error("Failed to create provider handler:", error); + setIsPickerLoaded(false); + } + }; + + loadApi(); + }, [ + accessToken, + isAuthenticated, + provider, + clientId, + effectiveBaseUrl, + onPickerStateChange, + ]); + + const handleAddFiles = () => { + if (!isPickerLoaded || !accessToken) { + return; + } + + if ((provider === "onedrive" || provider === "sharepoint") && !clientId) { + console.error("Client ID required for OneDrive/SharePoint"); + return; + } + + try { + setIsPickerOpen(true); + onPickerStateChange?.(true); + + const handler = createProviderHandler( + provider, + accessToken, + isOpen => { + setIsPickerOpen(isOpen); + onPickerStateChange?.(isOpen); + }, + clientId, + effectiveBaseUrl + ); + + handler.openPicker((files: CloudFile[]) => { + // Merge new files with existing ones, avoiding duplicates + const existingIds = new Set(selectedFiles.map(f => f.id)); + const newFiles = files.filter(f => !existingIds.has(f.id)); + onFileSelected([...selectedFiles, ...newFiles]); + }); + } catch (error) { + console.error("Error opening picker:", error); + setIsPickerOpen(false); + onPickerStateChange?.(false); + } + }; + + const handleRemoveFile = (fileId: string) => { + const updatedFiles = selectedFiles.filter(file => file.id !== fileId); + onFileSelected(updatedFiles); + }; + + const handleClearAll = () => { + onFileSelected([]); + }; + + if (isLoadingBaseUrl) { + return ( +
+ Loading... +
+ ); + } + + if ( + (provider === "onedrive" || provider === "sharepoint") && + !clientId && + isAuthenticated + ) { + return ( +
+ Configuration required: Client ID missing for{" "} + {provider === "sharepoint" ? "SharePoint" : "OneDrive"}. +
+ ); + } + + return ( +
+ + + + + +
+ ); +}; diff --git a/frontend/src/components/google-drive-picker.tsx b/frontend/src/components/google-drive-picker.tsx deleted file mode 100644 index c9dee19a..00000000 --- a/frontend/src/components/google-drive-picker.tsx +++ /dev/null @@ -1,341 +0,0 @@ -"use client" - -import { useState, useEffect } from "react" -import { Button } from "@/components/ui/button" -import { Badge } from "@/components/ui/badge" -import { FileText, Folder, Plus, Trash2 } from "lucide-react" -import { Card, CardContent } from "@/components/ui/card" - -interface GoogleDrivePickerProps { - onFileSelected: (files: GoogleDriveFile[]) => void - selectedFiles?: GoogleDriveFile[] - isAuthenticated: boolean - accessToken?: string - onPickerStateChange?: (isOpen: boolean) => void -} - -interface GoogleDriveFile { - id: string - name: string - mimeType: string - webViewLink?: string - iconLink?: string - size?: number - modifiedTime?: string - isFolder?: boolean -} - -interface GoogleAPI { - load: (api: string, options: { callback: () => void; onerror?: () => void }) => void -} - -interface GooglePickerData { - action: string - docs: GooglePickerDocument[] -} - -interface GooglePickerDocument { - [key: string]: string -} - -declare global { - interface Window { - gapi: GoogleAPI - google: { - picker: { - api: { - load: (callback: () => void) => void - } - PickerBuilder: new () => GooglePickerBuilder - ViewId: { - DOCS: string - FOLDERS: string - DOCS_IMAGES_AND_VIDEOS: string - DOCUMENTS: string - PRESENTATIONS: string - SPREADSHEETS: string - } - Feature: { - MULTISELECT_ENABLED: string - NAV_HIDDEN: string - SIMPLE_UPLOAD_ENABLED: string - } - Action: { - PICKED: string - CANCEL: string - } - Document: { - ID: string - NAME: string - MIME_TYPE: string - URL: string - ICON_URL: string - } - } - } - } -} - -interface GooglePickerBuilder { - addView: (view: string) => GooglePickerBuilder - setOAuthToken: (token: string) => GooglePickerBuilder - setCallback: (callback: (data: GooglePickerData) => void) => GooglePickerBuilder - enableFeature: (feature: string) => GooglePickerBuilder - setTitle: (title: string) => GooglePickerBuilder - build: () => GooglePicker -} - -interface GooglePicker { - setVisible: (visible: boolean) => void -} - -export function GoogleDrivePicker({ - onFileSelected, - selectedFiles = [], - isAuthenticated, - accessToken, - onPickerStateChange -}: GoogleDrivePickerProps) { - const [isPickerLoaded, setIsPickerLoaded] = useState(false) - const [isPickerOpen, setIsPickerOpen] = useState(false) - - useEffect(() => { - const loadPickerApi = () => { - if (typeof window !== 'undefined' && window.gapi) { - window.gapi.load('picker', { - callback: () => { - setIsPickerLoaded(true) - }, - onerror: () => { - console.error('Failed to load Google Picker API') - } - }) - } - } - - // Load Google API script if not already loaded - if (typeof window !== 'undefined') { - if (!window.gapi) { - const script = document.createElement('script') - script.src = 'https://apis.google.com/js/api.js' - script.async = true - script.defer = true - script.onload = loadPickerApi - script.onerror = () => { - console.error('Failed to load Google API script') - } - document.head.appendChild(script) - - return () => { - if (document.head.contains(script)) { - document.head.removeChild(script) - } - } - } else { - loadPickerApi() - } - } - }, []) - - - const openPicker = () => { - if (!isPickerLoaded || !accessToken || !window.google?.picker) { - return - } - - try { - setIsPickerOpen(true) - onPickerStateChange?.(true) - - // Create picker with higher z-index and focus handling - const picker = new window.google.picker.PickerBuilder() - .addView(window.google.picker.ViewId.DOCS) - .addView(window.google.picker.ViewId.FOLDERS) - .setOAuthToken(accessToken) - .enableFeature(window.google.picker.Feature.MULTISELECT_ENABLED) - .setTitle('Select files from Google Drive') - .setCallback(pickerCallback) - .build() - - picker.setVisible(true) - - // Apply z-index fix after a short delay to ensure picker is rendered - setTimeout(() => { - const pickerElements = document.querySelectorAll('.picker-dialog, .goog-modalpopup') - pickerElements.forEach(el => { - (el as HTMLElement).style.zIndex = '10000' - }) - const bgElements = document.querySelectorAll('.picker-dialog-bg, .goog-modalpopup-bg') - bgElements.forEach(el => { - (el as HTMLElement).style.zIndex = '9999' - }) - }, 100) - - } catch (error) { - console.error('Error creating picker:', error) - setIsPickerOpen(false) - onPickerStateChange?.(false) - } - } - - const pickerCallback = async (data: GooglePickerData) => { - if (data.action === window.google.picker.Action.PICKED) { - const files: GoogleDriveFile[] = data.docs.map((doc: GooglePickerDocument) => ({ - id: doc[window.google.picker.Document.ID], - name: doc[window.google.picker.Document.NAME], - mimeType: doc[window.google.picker.Document.MIME_TYPE], - webViewLink: doc[window.google.picker.Document.URL], - iconLink: doc[window.google.picker.Document.ICON_URL], - size: doc['sizeBytes'] ? parseInt(doc['sizeBytes']) : undefined, - modifiedTime: doc['lastEditedUtc'], - isFolder: doc[window.google.picker.Document.MIME_TYPE] === 'application/vnd.google-apps.folder' - })) - - // If size is still missing, try to fetch it via Google Drive API - if (accessToken && files.some(f => !f.size && !f.isFolder)) { - try { - const enrichedFiles = await Promise.all(files.map(async (file) => { - if (!file.size && !file.isFolder) { - try { - const response = await fetch(`https://www.googleapis.com/drive/v3/files/${file.id}?fields=size,modifiedTime`, { - headers: { - 'Authorization': `Bearer ${accessToken}` - } - }) - if (response.ok) { - const fileDetails = await response.json() - return { - ...file, - size: fileDetails.size ? parseInt(fileDetails.size) : undefined, - modifiedTime: fileDetails.modifiedTime || file.modifiedTime - } - } - } catch (error) { - console.warn('Failed to fetch file details:', error) - } - } - return file - })) - onFileSelected(enrichedFiles) - } catch (error) { - console.warn('Failed to enrich file data:', error) - onFileSelected(files) - } - } else { - onFileSelected(files) - } - } - - setIsPickerOpen(false) - onPickerStateChange?.(false) - } - - const removeFile = (fileId: string) => { - const updatedFiles = selectedFiles.filter(file => file.id !== fileId) - onFileSelected(updatedFiles) - } - - const getFileIcon = (mimeType: string) => { - if (mimeType.includes('folder')) { - return - } - return - } - - const getMimeTypeLabel = (mimeType: string) => { - const typeMap: { [key: string]: string } = { - 'application/vnd.google-apps.document': 'Google Doc', - 'application/vnd.google-apps.spreadsheet': 'Google Sheet', - 'application/vnd.google-apps.presentation': 'Google Slides', - 'application/vnd.google-apps.folder': 'Folder', - 'application/pdf': 'PDF', - 'text/plain': 'Text', - 'application/vnd.openxmlformats-officedocument.wordprocessingml.document': 'Word Doc', - 'application/vnd.openxmlformats-officedocument.presentationml.presentation': 'PowerPoint' - } - - return typeMap[mimeType] || 'Document' - } - - const formatFileSize = (bytes?: number) => { - if (!bytes) return '' - const sizes = ['B', 'KB', 'MB', 'GB', 'TB'] - if (bytes === 0) return '0 B' - const i = Math.floor(Math.log(bytes) / Math.log(1024)) - return `${(bytes / Math.pow(1024, i)).toFixed(1)} ${sizes[i]}` - } - - if (!isAuthenticated) { - return ( -
- Please connect to Google Drive first to select specific files. -
- ) - } - - return ( -
- - -

- Select files from Google Drive to ingest. -

- -
-
- - {selectedFiles.length > 0 && ( -
-
-

- Added files -

- -
-
- {selectedFiles.map((file) => ( -
-
- {getFileIcon(file.mimeType)} - {file.name} - - {getMimeTypeLabel(file.mimeType)} - -
-
- {formatFileSize(file.size)} - -
-
- ))} -
- -
- )} -
- ) -} diff --git a/frontend/src/components/onedrive-picker.tsx b/frontend/src/components/onedrive-picker.tsx deleted file mode 100644 index 739491c1..00000000 --- a/frontend/src/components/onedrive-picker.tsx +++ /dev/null @@ -1,320 +0,0 @@ -"use client" - -import { useState, useEffect } from "react" -import { Button } from "@/components/ui/button" -import { Badge } from "@/components/ui/badge" -import { FileText, Folder, Trash2, X } from "lucide-react" - -interface OneDrivePickerProps { - onFileSelected: (files: OneDriveFile[]) => void - selectedFiles?: OneDriveFile[] - isAuthenticated: boolean - accessToken?: string - connectorType?: "onedrive" | "sharepoint" - onPickerStateChange?: (isOpen: boolean) => void -} - -interface OneDriveFile { - id: string - name: string - mimeType?: string - webUrl?: string - driveItem?: { - file?: { mimeType: string } - folder?: unknown - } -} - -interface GraphResponse { - value: OneDriveFile[] -} - -declare global { - interface Window { - mgt?: { - Providers?: { - globalProvider?: unknown - } - } - } -} - -export function OneDrivePicker({ - onFileSelected, - selectedFiles = [], - isAuthenticated, - accessToken, - connectorType = "onedrive", - onPickerStateChange -}: OneDrivePickerProps) { - const [isLoading, setIsLoading] = useState(false) - const [files, setFiles] = useState([]) - const [isPickerOpen, setIsPickerOpen] = useState(false) - const [currentPath, setCurrentPath] = useState( - connectorType === "sharepoint" ? 'sites?search=' : 'me/drive/root/children' - ) - const [breadcrumbs, setBreadcrumbs] = useState<{id: string, name: string}[]>([ - {id: 'root', name: connectorType === "sharepoint" ? 'SharePoint' : 'OneDrive'} - ]) - - useEffect(() => { - const loadMGT = async () => { - if (typeof window !== 'undefined' && !window.mgt) { - try { - await import('@microsoft/mgt-components') - await import('@microsoft/mgt-msal2-provider') - - // For simplicity, we'll use direct Graph API calls instead of MGT components - // MGT provider initialization would go here if needed - } catch { - console.warn('MGT not available, falling back to direct API calls') - } - } - } - - loadMGT() - }, [accessToken]) - - - const fetchFiles = async (path: string = currentPath) => { - if (!accessToken) return - - setIsLoading(true) - try { - const response = await fetch(`https://graph.microsoft.com/v1.0/${path}`, { - headers: { - 'Authorization': `Bearer ${accessToken}`, - 'Content-Type': 'application/json' - } - }) - - if (response.ok) { - const data: GraphResponse = await response.json() - setFiles(data.value || []) - } else { - console.error('Failed to fetch OneDrive files:', response.statusText) - } - } catch (error) { - console.error('Error fetching OneDrive files:', error) - } finally { - setIsLoading(false) - } - } - - const openPicker = () => { - if (!accessToken) return - - setIsPickerOpen(true) - onPickerStateChange?.(true) - fetchFiles() - } - - const closePicker = () => { - setIsPickerOpen(false) - onPickerStateChange?.(false) - setFiles([]) - setCurrentPath( - connectorType === "sharepoint" ? 'sites?search=' : 'me/drive/root/children' - ) - setBreadcrumbs([ - {id: 'root', name: connectorType === "sharepoint" ? 'SharePoint' : 'OneDrive'} - ]) - } - - const handleFileClick = (file: OneDriveFile) => { - if (file.driveItem?.folder) { - // Navigate to folder - const newPath = `me/drive/items/${file.id}/children` - setCurrentPath(newPath) - setBreadcrumbs([...breadcrumbs, {id: file.id, name: file.name}]) - fetchFiles(newPath) - } else { - // Select file - const isAlreadySelected = selectedFiles.some(f => f.id === file.id) - if (!isAlreadySelected) { - onFileSelected([...selectedFiles, file]) - } - } - } - - const navigateToBreadcrumb = (index: number) => { - if (index === 0) { - setCurrentPath('me/drive/root/children') - setBreadcrumbs([{id: 'root', name: 'OneDrive'}]) - fetchFiles('me/drive/root/children') - } else { - const targetCrumb = breadcrumbs[index] - const newPath = `me/drive/items/${targetCrumb.id}/children` - setCurrentPath(newPath) - setBreadcrumbs(breadcrumbs.slice(0, index + 1)) - fetchFiles(newPath) - } - } - - const removeFile = (fileId: string) => { - const updatedFiles = selectedFiles.filter(file => file.id !== fileId) - onFileSelected(updatedFiles) - } - - const getFileIcon = (file: OneDriveFile) => { - if (file.driveItem?.folder) { - return - } - return - } - - const getMimeTypeLabel = (file: OneDriveFile) => { - const mimeType = file.driveItem?.file?.mimeType || file.mimeType || '' - const typeMap: { [key: string]: string } = { - 'application/vnd.openxmlformats-officedocument.wordprocessingml.document': 'Word Doc', - 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet': 'Excel', - 'application/vnd.openxmlformats-officedocument.presentationml.presentation': 'PowerPoint', - 'application/pdf': 'PDF', - 'text/plain': 'Text', - 'image/jpeg': 'Image', - 'image/png': 'Image', - } - - if (file.driveItem?.folder) return 'Folder' - return typeMap[mimeType] || 'Document' - } - - const serviceName = connectorType === "sharepoint" ? "SharePoint" : "OneDrive" - - if (!isAuthenticated) { - return ( -
- Please connect to {serviceName} first to select specific files. -
- ) - } - - return ( -
-
-
-

{serviceName} File Selection

-

- Choose specific files to sync instead of syncing everything -

-
- -
- - {/* Status message when access token is missing */} - {isAuthenticated && !accessToken && ( -
-
Access token unavailable
-
The file picker requires an access token. Try disconnecting and reconnecting your {serviceName} account.
-
- )} - - {/* File Picker Modal */} - {isPickerOpen && ( -
-
-
-

Select Files from {serviceName}

- -
- - {/* Breadcrumbs */} -
- {breadcrumbs.map((crumb, index) => ( -
- {index > 0 && /} - -
- ))} -
- - {/* File List */} -
- {isLoading ? ( -
Loading...
- ) : files.length === 0 ? ( -
No files found
- ) : ( -
- {files.map((file) => ( -
handleFileClick(file)} - > -
- {getFileIcon(file)} - {file.name} - - {getMimeTypeLabel(file)} - -
- {selectedFiles.some(f => f.id === file.id) && ( - Selected - )} -
- ))} -
- )} -
-
-
- )} - - {selectedFiles.length > 0 && ( -
-

- Selected files ({selectedFiles.length}): -

-
- {selectedFiles.map((file) => ( -
-
- {getFileIcon(file)} - {file.name} - - {getMimeTypeLabel(file)} - -
- -
- ))} -
- -
- )} -
- ) -} \ No newline at end of file diff --git a/src/api/connectors.py b/src/api/connectors.py index 4c9eab49..5fbcec86 100644 --- a/src/api/connectors.py +++ b/src/api/connectors.py @@ -127,6 +127,23 @@ async def connector_status(request: Request, connector_service, session_manager) user_id=user.user_id, connector_type=connector_type ) + # Get the connector for each connection + connection_client_ids = {} + for connection in connections: + try: + connector = await connector_service._get_connector(connection.connection_id) + if connector is not None: + connection_client_ids[connection.connection_id] = connector.get_client_id() + else: + connection_client_ids[connection.connection_id] = None + except Exception as e: + logger.warning( + "Could not get connector for connection", + connection_id=connection.connection_id, + error=str(e), + ) + connection.connector = None + # Check if there are any active connections active_connections = [conn for conn in connections if conn.is_active] has_authenticated_connection = len(active_connections) > 0 @@ -140,6 +157,7 @@ async def connector_status(request: Request, connector_service, session_manager) { "connection_id": conn.connection_id, "name": conn.name, + "client_id": connection_client_ids.get(conn.connection_id), "is_active": conn.is_active, "created_at": conn.created_at.isoformat(), "last_sync": conn.last_sync.isoformat() if conn.last_sync else None, @@ -323,8 +341,8 @@ async def connector_webhook(request: Request, connector_service, session_manager ) async def connector_token(request: Request, connector_service, session_manager): - """Get access token for connector API calls (e.g., Google Picker)""" - connector_type = request.path_params.get("connector_type") + """Get access token for connector API calls (e.g., Pickers).""" + url_connector_type = request.path_params.get("connector_type") connection_id = request.query_params.get("connection_id") if not connection_id: @@ -333,37 +351,81 @@ async def connector_token(request: Request, connector_service, session_manager): user = request.state.user try: - # Get the connection and verify it belongs to the user + # 1) Load the connection and verify ownership connection = await connector_service.connection_manager.get_connection(connection_id) if not connection or connection.user_id != user.user_id: return JSONResponse({"error": "Connection not found"}, status_code=404) - # Get the connector instance + # 2) Get the ACTUAL connector instance/type for this connection_id connector = await connector_service._get_connector(connection_id) if not connector: - return JSONResponse({"error": f"Connector not available - authentication may have failed for {connector_type}"}, status_code=404) + return JSONResponse( + {"error": f"Connector not available - authentication may have failed for {url_connector_type}"}, + status_code=404, + ) - # For Google Drive, get the access token - if connector_type == "google_drive" and hasattr(connector, 'oauth'): + real_type = getattr(connector, "type", None) or getattr(connection, "connector_type", None) + if real_type is None: + return JSONResponse({"error": "Unable to determine connector type"}, status_code=500) + + # Optional: warn if URL path type disagrees with real type + if url_connector_type and url_connector_type != real_type: + # You can downgrade this to debug if you expect cross-routing. + return JSONResponse( + { + "error": "Connector type mismatch", + "detail": { + "requested_type": url_connector_type, + "actual_type": real_type, + "hint": "Call the token endpoint using the correct connector_type for this connection_id.", + }, + }, + status_code=400, + ) + + # 3) Branch by the actual connector type + # GOOGLE DRIVE (google-auth) + if real_type == "google_drive" and hasattr(connector, "oauth"): await connector.oauth.load_credentials() if connector.oauth.creds and connector.oauth.creds.valid: - return JSONResponse({ - "access_token": connector.oauth.creds.token, - "expires_in": (connector.oauth.creds.expiry.timestamp() - - __import__('time').time()) if connector.oauth.creds.expiry else None - }) - else: - return JSONResponse({"error": "Invalid or expired credentials"}, status_code=401) - - # For OneDrive and SharePoint, get the access token - elif connector_type in ["onedrive", "sharepoint"] and hasattr(connector, 'oauth'): + expires_in = None + try: + if connector.oauth.creds.expiry: + import time + expires_in = max(0, int(connector.oauth.creds.expiry.timestamp() - time.time())) + except Exception: + expires_in = None + + return JSONResponse( + { + "access_token": connector.oauth.creds.token, + "expires_in": expires_in, + } + ) + return JSONResponse({"error": "Invalid or expired credentials"}, status_code=401) + + # ONEDRIVE / SHAREPOINT (MSAL or custom) + if real_type in ("onedrive", "sharepoint") and hasattr(connector, "oauth"): + # Ensure cache/credentials are loaded before trying to use them try: + # Prefer a dedicated is_authenticated() that loads cache internally + if hasattr(connector.oauth, "is_authenticated"): + ok = await connector.oauth.is_authenticated() + else: + # Fallback: try to load credentials explicitly if available + ok = True + if hasattr(connector.oauth, "load_credentials"): + ok = await connector.oauth.load_credentials() + + if not ok: + return JSONResponse({"error": "Not authenticated"}, status_code=401) + + # Now safe to fetch access token access_token = connector.oauth.get_access_token() - return JSONResponse({ - "access_token": access_token, - "expires_in": None # MSAL handles token expiry internally - }) + # MSAL result has expiry, but we’re returning a raw token; keep expires_in None for simplicity + return JSONResponse({"access_token": access_token, "expires_in": None}) except ValueError as e: + # Typical when acquire_token_silent fails (e.g., needs re-auth) return JSONResponse({"error": f"Failed to get access token: {str(e)}"}, status_code=401) except Exception as e: return JSONResponse({"error": f"Authentication error: {str(e)}"}, status_code=500) @@ -371,7 +433,5 @@ async def connector_token(request: Request, connector_service, session_manager): return JSONResponse({"error": "Token not available for this connector type"}, status_code=400) except Exception as e: - logger.error("Error getting connector token", error=str(e)) + logger.error("Error getting connector token", exc_info=True) return JSONResponse({"error": str(e)}, status_code=500) - - diff --git a/src/connectors/connection_manager.py b/src/connectors/connection_manager.py index 2e70ee1f..07ebd5ee 100644 --- a/src/connectors/connection_manager.py +++ b/src/connectors/connection_manager.py @@ -294,32 +294,39 @@ class ConnectionManager: async def get_connector(self, connection_id: str) -> Optional[BaseConnector]: """Get an active connector instance""" + logger.debug(f"Getting connector for connection_id: {connection_id}") + # Return cached connector if available if connection_id in self.active_connectors: connector = self.active_connectors[connection_id] if connector.is_authenticated: + logger.debug(f"Returning cached authenticated connector for {connection_id}") return connector else: # Remove unauthenticated connector from cache + logger.debug(f"Removing unauthenticated connector from cache for {connection_id}") del self.active_connectors[connection_id] # Try to create and authenticate connector connection_config = self.connections.get(connection_id) if not connection_config or not connection_config.is_active: + logger.debug(f"No active connection config found for {connection_id}") return None + logger.debug(f"Creating connector for {connection_config.connector_type}") connector = self._create_connector(connection_config) - if await connector.authenticate(): + + logger.debug(f"Attempting authentication for {connection_id}") + auth_result = await connector.authenticate() + logger.debug(f"Authentication result for {connection_id}: {auth_result}") + + if auth_result: self.active_connectors[connection_id] = connector - - # Setup webhook subscription if not already set up - await self._setup_webhook_if_needed( - connection_id, connection_config, connector - ) - + # ... rest of the method return connector - - return None + else: + logger.warning(f"Authentication failed for {connection_id}") + return None def get_available_connector_types(self) -> Dict[str, Dict[str, Any]]: """Get available connector types with their metadata""" @@ -363,20 +370,23 @@ class ConnectionManager: def _create_connector(self, config: ConnectionConfig) -> BaseConnector: """Factory method to create connector instances""" - if config.connector_type == "google_drive": - return GoogleDriveConnector(config.config) - elif config.connector_type == "sharepoint": - return SharePointConnector(config.config) - elif config.connector_type == "onedrive": - return OneDriveConnector(config.config) - elif config.connector_type == "box": - # Future: BoxConnector(config.config) - raise NotImplementedError("Box connector not implemented yet") - elif config.connector_type == "dropbox": - # Future: DropboxConnector(config.config) - raise NotImplementedError("Dropbox connector not implemented yet") - else: - raise ValueError(f"Unknown connector type: {config.connector_type}") + try: + if config.connector_type == "google_drive": + return GoogleDriveConnector(config.config) + elif config.connector_type == "sharepoint": + return SharePointConnector(config.config) + elif config.connector_type == "onedrive": + return OneDriveConnector(config.config) + elif config.connector_type == "box": + raise NotImplementedError("Box connector not implemented yet") + elif config.connector_type == "dropbox": + raise NotImplementedError("Dropbox connector not implemented yet") + else: + raise ValueError(f"Unknown connector type: {config.connector_type}") + except Exception as e: + logger.error(f"Failed to create {config.connector_type} connector: {e}") + # Re-raise the exception so caller can handle appropriately + raise async def update_last_sync(self, connection_id: str): """Update the last sync timestamp for a connection""" diff --git a/src/connectors/google_drive/connector.py b/src/connectors/google_drive/connector.py index 0aa4234a..37ebdc8a 100644 --- a/src/connectors/google_drive/connector.py +++ b/src/connectors/google_drive/connector.py @@ -477,7 +477,7 @@ class GoogleDriveConnector(BaseConnector): "next_page_token": None, # no more pages } except Exception as e: - # Optionally log error with your base class logger + # Log the error try: logger.error(f"GoogleDriveConnector.list_files failed: {e}") except Exception: @@ -495,7 +495,6 @@ class GoogleDriveConnector(BaseConnector): try: blob = self._download_file_bytes(meta) except Exception as e: - # Use your base class logger if available try: logger.error(f"Download failed for {file_id}: {e}") except Exception: @@ -562,7 +561,6 @@ class GoogleDriveConnector(BaseConnector): if not self.cfg.changes_page_token: self.cfg.changes_page_token = self.get_start_page_token() except Exception as e: - # Optional: use your base logger try: logger.error(f"Failed to get start page token: {e}") except Exception: @@ -593,7 +591,6 @@ class GoogleDriveConnector(BaseConnector): expiration = result.get("expiration") # Persist in-memory so cleanup can stop this channel later. - # If your project has a persistence layer, save these values there. self._active_channel = { "channel_id": channel_id, "resource_id": resource_id, @@ -803,7 +800,7 @@ class GoogleDriveConnector(BaseConnector): """ Perform a one-shot sync of the currently selected scope and emit documents. - Emits ConnectorDocument instances (adapt to your BaseConnector ingestion). + Emits ConnectorDocument instances """ items = self._iter_selected_items() for meta in items: diff --git a/src/connectors/onedrive/connector.py b/src/connectors/onedrive/connector.py index 8b800b3d..3ef4bdaf 100644 --- a/src/connectors/onedrive/connector.py +++ b/src/connectors/onedrive/connector.py @@ -1,223 +1,494 @@ +import logging +from pathlib import Path +from typing import List, Dict, Any, Optional +from datetime import datetime import httpx -import uuid -from datetime import datetime, timedelta -from typing import Dict, List, Any, Optional from ..base import BaseConnector, ConnectorDocument, DocumentACL from .oauth import OneDriveOAuth +logger = logging.getLogger(__name__) + class OneDriveConnector(BaseConnector): - """OneDrive connector using Microsoft Graph API""" + """OneDrive connector using MSAL-based OAuth for authentication.""" + # Required BaseConnector class attributes CLIENT_ID_ENV_VAR = "MICROSOFT_GRAPH_OAUTH_CLIENT_ID" CLIENT_SECRET_ENV_VAR = "MICROSOFT_GRAPH_OAUTH_CLIENT_SECRET" # Connector metadata CONNECTOR_NAME = "OneDrive" - CONNECTOR_DESCRIPTION = "Connect your personal OneDrive to sync documents" + CONNECTOR_DESCRIPTION = "Connect to OneDrive (personal) to sync documents and files" CONNECTOR_ICON = "onedrive" def __init__(self, config: Dict[str, Any]): super().__init__(config) - self.oauth = OneDriveOAuth( - client_id=self.get_client_id(), - client_secret=self.get_client_secret(), - token_file=config.get("token_file", "onedrive_token.json"), - ) - self.subscription_id = config.get("subscription_id") or config.get( - "webhook_channel_id" - ) - self.base_url = "https://graph.microsoft.com/v1.0" - async def authenticate(self) -> bool: - if await self.oauth.is_authenticated(): - self._authenticated = True - return True - return False + logger.debug(f"OneDrive connector __init__ called with config type: {type(config)}") + logger.debug(f"OneDrive connector __init__ config value: {config}") - async def setup_subscription(self) -> str: - if not self._authenticated: - raise ValueError("Not authenticated") + if config is None: + logger.debug("Config was None, using empty dict") + config = {} - webhook_url = self.config.get("webhook_url") - if not webhook_url: - raise ValueError("webhook_url required in config for subscriptions") + try: + logger.debug("Calling super().__init__") + super().__init__(config) + logger.debug("super().__init__ completed successfully") + except Exception as e: + logger.error(f"super().__init__ failed: {e}") + raise - expiration = (datetime.utcnow() + timedelta(days=2)).isoformat() + "Z" - body = { - "changeType": "created,updated,deleted", - "notificationUrl": webhook_url, - "resource": "/me/drive/root", - "expirationDateTime": expiration, - "clientState": str(uuid.uuid4()), + # Initialize with defaults that allow the connector to be listed + self.client_id = None + self.client_secret = None + self.redirect_uri = config.get("redirect_uri", "http://localhost") + + # Try to get credentials, but don't fail if they're missing + try: + self.client_id = self.get_client_id() + logger.debug(f"Got client_id: {self.client_id is not None}") + except Exception as e: + logger.debug(f"Failed to get client_id: {e}") + + try: + self.client_secret = self.get_client_secret() + logger.debug(f"Got client_secret: {self.client_secret is not None}") + except Exception as e: + logger.debug(f"Failed to get client_secret: {e}") + + # Token file setup + project_root = Path(__file__).resolve().parent.parent.parent.parent + token_file = config.get("token_file") or str(project_root / "onedrive_token.json") + Path(token_file).parent.mkdir(parents=True, exist_ok=True) + + # Only initialize OAuth if we have credentials + if self.client_id and self.client_secret: + connection_id = config.get("connection_id", "default") + + # Use token_file from config if provided, otherwise generate one + if config.get("token_file"): + oauth_token_file = config["token_file"] + else: + # Use a per-connection cache file to avoid collisions with other connectors + oauth_token_file = f"onedrive_token_{connection_id}.json" + + # MSA & org both work via /common for OneDrive personal testing + authority = "https://login.microsoftonline.com/common" + + self.oauth = OneDriveOAuth( + client_id=self.client_id, + client_secret=self.client_secret, + token_file=oauth_token_file, + authority=authority, + allow_json_refresh=True, # allows one-time migration from legacy JSON if present + ) + else: + self.oauth = None + + # Track subscription ID for webhooks (note: change notifications might not be available for personal accounts) + self._subscription_id: Optional[str] = None + + # Graph API defaults + self._graph_api_version = "v1.0" + self._default_params = { + "$select": "id,name,size,lastModifiedDateTime,createdDateTime,webUrl,file,folder,@microsoft.graph.downloadUrl" } - token = self.oauth.get_access_token() - async with httpx.AsyncClient() as client: - resp = await client.post( - f"{self.base_url}/subscriptions", - json=body, - headers={"Authorization": f"Bearer {token}"}, - ) - resp.raise_for_status() - data = resp.json() + @property + def _graph_base_url(self) -> str: + """Base URL for Microsoft Graph API calls.""" + return f"https://graph.microsoft.com/{self._graph_api_version}" - self.subscription_id = data["id"] - return self.subscription_id + def emit(self, doc: ConnectorDocument) -> None: + """Emit a ConnectorDocument instance.""" + logger.debug(f"Emitting OneDrive document: {doc.id} ({doc.filename})") + + async def authenticate(self) -> bool: + """Test authentication - BaseConnector interface.""" + logger.debug(f"OneDrive authenticate() called, oauth is None: {self.oauth is None}") + try: + if not self.oauth: + logger.debug("OneDrive authentication failed: OAuth not initialized") + self._authenticated = False + return False + + logger.debug("Loading OneDrive credentials...") + load_result = await self.oauth.load_credentials() + logger.debug(f"Load credentials result: {load_result}") + + logger.debug("Checking OneDrive authentication status...") + authenticated = await self.oauth.is_authenticated() + logger.debug(f"OneDrive is_authenticated result: {authenticated}") + + self._authenticated = authenticated + return authenticated + except Exception as e: + logger.error(f"OneDrive authentication failed: {e}") + import traceback + traceback.print_exc() + self._authenticated = False + return False + + def get_auth_url(self) -> str: + """Get OAuth authorization URL.""" + if not self.oauth: + raise RuntimeError("OneDrive OAuth not initialized - missing credentials") + return self.oauth.create_authorization_url(self.redirect_uri) + + async def handle_oauth_callback(self, auth_code: str) -> Dict[str, Any]: + """Handle OAuth callback.""" + if not self.oauth: + raise RuntimeError("OneDrive OAuth not initialized - missing credentials") + try: + success = await self.oauth.handle_authorization_callback(auth_code, self.redirect_uri) + if success: + self._authenticated = True + return {"status": "success"} + else: + raise ValueError("OAuth callback failed") + except Exception as e: + logger.error(f"OAuth callback failed: {e}") + raise + + def sync_once(self) -> None: + """ + Perform a one-shot sync of OneDrive files and emit documents. + """ + import asyncio + + async def _async_sync(): + try: + file_list = await self.list_files(max_files=1000) + files = file_list.get("files", []) + for file_info in files: + try: + file_id = file_info.get("id") + if not file_id: + continue + doc = await self.get_file_content(file_id) + self.emit(doc) + except Exception as e: + logger.error(f"Failed to sync OneDrive file {file_info.get('name', 'unknown')}: {e}") + continue + except Exception as e: + logger.error(f"OneDrive sync_once failed: {e}") + raise + + if hasattr(asyncio, 'run'): + asyncio.run(_async_sync()) + else: + loop = asyncio.get_event_loop() + loop.run_until_complete(_async_sync()) + + async def setup_subscription(self) -> str: + """ + Set up real-time subscription for file changes. + NOTE: Change notifications may not be available for personal OneDrive accounts. + """ + webhook_url = self.config.get('webhook_url') + if not webhook_url: + logger.warning("No webhook URL configured, skipping OneDrive subscription setup") + return "no-webhook-configured" + + try: + if not await self.authenticate(): + raise RuntimeError("OneDrive authentication failed during subscription setup") + + token = self.oauth.get_access_token() + + # For OneDrive personal we target the user's drive + resource = "/me/drive/root" + + subscription_data = { + "changeType": "created,updated,deleted", + "notificationUrl": f"{webhook_url}/webhook/onedrive", + "resource": resource, + "expirationDateTime": self._get_subscription_expiry(), + "clientState": "onedrive_personal", + } + + headers = { + "Authorization": f"Bearer {token}", + "Content-Type": "application/json", + } + + url = f"{self._graph_base_url}/subscriptions" + + async with httpx.AsyncClient() as client: + response = await client.post(url, json=subscription_data, headers=headers, timeout=30) + response.raise_for_status() + + result = response.json() + subscription_id = result.get("id") + + if subscription_id: + self._subscription_id = subscription_id + logger.info(f"OneDrive subscription created: {subscription_id}") + return subscription_id + else: + raise ValueError("No subscription ID returned from Microsoft Graph") + + except Exception as e: + logger.error(f"Failed to setup OneDrive subscription: {e}") + raise + + def _get_subscription_expiry(self) -> str: + """Get subscription expiry time (Graph caps duration; often <= 3 days).""" + from datetime import datetime, timedelta + expiry = datetime.utcnow() + timedelta(days=3) + return expiry.strftime("%Y-%m-%dT%H:%M:%S.%fZ") async def list_files( - self, page_token: Optional[str] = None, limit: int = 100 + self, + page_token: Optional[str] = None, + max_files: Optional[int] = None, + **kwargs ) -> Dict[str, Any]: - if not self._authenticated: - raise ValueError("Not authenticated") + """List files from OneDrive using Microsoft Graph.""" + try: + if not await self.authenticate(): + raise RuntimeError("OneDrive authentication failed during file listing") - params = {"$top": str(limit)} - if page_token: - params["$skiptoken"] = page_token + files: List[Dict[str, Any]] = [] + max_files_value = max_files if max_files is not None else 100 - token = self.oauth.get_access_token() - async with httpx.AsyncClient() as client: - resp = await client.get( - f"{self.base_url}/me/drive/root/children", - params=params, - headers={"Authorization": f"Bearer {token}"}, - ) - resp.raise_for_status() - data = resp.json() + base_url = f"{self._graph_base_url}/me/drive/root/children" - files = [] - for item in data.get("value", []): - if item.get("file"): - files.append( - { - "id": item["id"], - "name": item["name"], - "mimeType": item.get("file", {}).get( - "mimeType", "application/octet-stream" - ), - "webViewLink": item.get("webUrl"), - "createdTime": item.get("createdDateTime"), - "modifiedTime": item.get("lastModifiedDateTime"), - } - ) + params = dict(self._default_params) + params["$top"] = str(max_files_value) - next_token = None - next_link = data.get("@odata.nextLink") - if next_link: - from urllib.parse import urlparse, parse_qs + if page_token: + params["$skiptoken"] = page_token - parsed = urlparse(next_link) - next_token = parse_qs(parsed.query).get("$skiptoken", [None])[0] + response = await self._make_graph_request(base_url, params=params) + data = response.json() - return {"files": files, "nextPageToken": next_token} + items = data.get("value", []) + for item in items: + if item.get("file"): # include files only + files.append({ + "id": item.get("id", ""), + "name": item.get("name", ""), + "path": f"/drive/items/{item.get('id')}", + "size": int(item.get("size", 0)), + "modified": item.get("lastModifiedDateTime"), + "created": item.get("createdDateTime"), + "mime_type": item.get("file", {}).get("mimeType", self._get_mime_type(item.get("name", ""))), + "url": item.get("webUrl", ""), + "download_url": item.get("@microsoft.graph.downloadUrl"), + }) + + # Next page + next_page_token = None + next_link = data.get("@odata.nextLink") + if next_link: + from urllib.parse import urlparse, parse_qs + parsed = urlparse(next_link) + query_params = parse_qs(parsed.query) + if "$skiptoken" in query_params: + next_page_token = query_params["$skiptoken"][0] + + return {"files": files, "next_page_token": next_page_token} + + except Exception as e: + logger.error(f"Failed to list OneDrive files: {e}") + return {"files": [], "next_page_token": None} async def get_file_content(self, file_id: str) -> ConnectorDocument: - if not self._authenticated: - raise ValueError("Not authenticated") + """Get file content and metadata.""" + try: + if not await self.authenticate(): + raise RuntimeError("OneDrive authentication failed during file content retrieval") + file_metadata = await self._get_file_metadata_by_id(file_id) + if not file_metadata: + raise ValueError(f"File not found: {file_id}") + + download_url = file_metadata.get("download_url") + if download_url: + content = await self._download_file_from_url(download_url) + else: + content = await self._download_file_content(file_id) + + acl = DocumentACL( + owner="", + user_permissions={}, + group_permissions={}, + ) + + modified_time = self._parse_graph_date(file_metadata.get("modified")) + created_time = self._parse_graph_date(file_metadata.get("created")) + + return ConnectorDocument( + id=file_id, + filename=file_metadata.get("name", ""), + mimetype=file_metadata.get("mime_type", "application/octet-stream"), + content=content, + source_url=file_metadata.get("url", ""), + acl=acl, + modified_time=modified_time, + created_time=created_time, + metadata={ + "onedrive_path": file_metadata.get("path", ""), + "size": file_metadata.get("size", 0), + }, + ) + + except Exception as e: + logger.error(f"Failed to get OneDrive file content {file_id}: {e}") + raise + + async def _get_file_metadata_by_id(self, file_id: str) -> Optional[Dict[str, Any]]: + """Get file metadata by ID using Graph API.""" + try: + url = f"{self._graph_base_url}/me/drive/items/{file_id}" + params = dict(self._default_params) + + response = await self._make_graph_request(url, params=params) + item = response.json() + + if item.get("file"): + return { + "id": file_id, + "name": item.get("name", ""), + "path": f"/drive/items/{file_id}", + "size": int(item.get("size", 0)), + "modified": item.get("lastModifiedDateTime"), + "created": item.get("createdDateTime"), + "mime_type": item.get("file", {}).get("mimeType", self._get_mime_type(item.get("name", ""))), + "url": item.get("webUrl", ""), + "download_url": item.get("@microsoft.graph.downloadUrl"), + } + + return None + + except Exception as e: + logger.error(f"Failed to get file metadata for {file_id}: {e}") + return None + + async def _download_file_content(self, file_id: str) -> bytes: + """Download file content by file ID using Graph API.""" + try: + url = f"{self._graph_base_url}/me/drive/items/{file_id}/content" + token = self.oauth.get_access_token() + headers = {"Authorization": f"Bearer {token}"} + + async with httpx.AsyncClient() as client: + response = await client.get(url, headers=headers, timeout=60) + response.raise_for_status() + return response.content + + except Exception as e: + logger.error(f"Failed to download file content for {file_id}: {e}") + raise + + async def _download_file_from_url(self, download_url: str) -> bytes: + """Download file content from direct download URL.""" + try: + async with httpx.AsyncClient() as client: + response = await client.get(download_url, timeout=60) + response.raise_for_status() + return response.content + except Exception as e: + logger.error(f"Failed to download from URL {download_url}: {e}") + raise + + def _parse_graph_date(self, date_str: Optional[str]) -> datetime: + """Parse Microsoft Graph date string to datetime.""" + if not date_str: + return datetime.now() + try: + if date_str.endswith('Z'): + return datetime.fromisoformat(date_str[:-1]).replace(tzinfo=None) + else: + return datetime.fromisoformat(date_str.replace('T', ' ')) + except (ValueError, AttributeError): + return datetime.now() + + async def _make_graph_request(self, url: str, method: str = "GET", + data: Optional[Dict] = None, params: Optional[Dict] = None) -> httpx.Response: + """Make authenticated API request to Microsoft Graph.""" token = self.oauth.get_access_token() - headers = {"Authorization": f"Bearer {token}"} + headers = { + "Authorization": f"Bearer {token}", + "Content-Type": "application/json", + } + async with httpx.AsyncClient() as client: - meta_resp = await client.get( - f"{self.base_url}/me/drive/items/{file_id}", headers=headers - ) - meta_resp.raise_for_status() - metadata = meta_resp.json() + if method.upper() == "GET": + response = await client.get(url, headers=headers, params=params, timeout=30) + elif method.upper() == "POST": + response = await client.post(url, headers=headers, json=data, timeout=30) + elif method.upper() == "DELETE": + response = await client.delete(url, headers=headers, timeout=30) + else: + raise ValueError(f"Unsupported HTTP method: {method}") - content_resp = await client.get( - f"{self.base_url}/me/drive/items/{file_id}/content", headers=headers - ) - content_resp.raise_for_status() - content = content_resp.content + response.raise_for_status() + return response - perm_resp = await client.get( - f"{self.base_url}/me/drive/items/{file_id}/permissions", headers=headers - ) - perm_resp.raise_for_status() - permissions = perm_resp.json() + def _get_mime_type(self, filename: str) -> str: + """Get MIME type based on file extension.""" + import mimetypes + mime_type, _ = mimetypes.guess_type(filename) + return mime_type or "application/octet-stream" - acl = self._parse_permissions(metadata, permissions) - modified = datetime.fromisoformat( - metadata["lastModifiedDateTime"].replace("Z", "+00:00") - ).replace(tzinfo=None) - created = datetime.fromisoformat( - metadata["createdDateTime"].replace("Z", "+00:00") - ).replace(tzinfo=None) - - document = ConnectorDocument( - id=metadata["id"], - filename=metadata["name"], - mimetype=metadata.get("file", {}).get( - "mimeType", "application/octet-stream" - ), - content=content, - source_url=metadata.get("webUrl"), - acl=acl, - modified_time=modified, - created_time=created, - metadata={"size": metadata.get("size")}, - ) - return document - - def _parse_permissions( - self, metadata: Dict[str, Any], permissions: Dict[str, Any] - ) -> DocumentACL: - acl = DocumentACL() - owner = metadata.get("createdBy", {}).get("user", {}).get("email") - if owner: - acl.owner = owner - for perm in permissions.get("value", []): - role = perm.get("roles", ["read"])[0] - grantee = perm.get("grantedToV2") or perm.get("grantedTo") - if not grantee: - continue - user = grantee.get("user") - if user and user.get("email"): - acl.user_permissions[user["email"]] = role - group = grantee.get("group") - if group and group.get("email"): - acl.group_permissions[group["email"]] = role - return acl - - def handle_webhook_validation( - self, request_method: str, headers: Dict[str, str], query_params: Dict[str, str] - ) -> Optional[str]: - """Handle Microsoft Graph webhook validation""" - if request_method == "GET": - validation_token = query_params.get("validationtoken") or query_params.get( - "validationToken" - ) - if validation_token: - return validation_token + # Webhook methods - BaseConnector interface + def handle_webhook_validation(self, request_method: str, + headers: Dict[str, str], + query_params: Dict[str, str]) -> Optional[str]: + """Handle webhook validation (Graph API specific).""" + if request_method == "POST" and "validationToken" in query_params: + return query_params["validationToken"] return None - def extract_webhook_channel_id( - self, payload: Dict[str, Any], headers: Dict[str, str] - ) -> Optional[str]: - """Extract SharePoint subscription ID from webhook payload""" - values = payload.get("value", []) - return values[0].get("subscriptionId") if values else None + def extract_webhook_channel_id(self, payload: Dict[str, Any], + headers: Dict[str, str]) -> Optional[str]: + """Extract channel/subscription ID from webhook payload.""" + notifications = payload.get("value", []) + if notifications: + return notifications[0].get("subscriptionId") + return None async def handle_webhook(self, payload: Dict[str, Any]) -> List[str]: - values = payload.get("value", []) - file_ids = [] - for item in values: - resource_data = item.get("resourceData", {}) - file_id = resource_data.get("id") - if file_id: - file_ids.append(file_id) - return file_ids + """Handle webhook notification and return affected file IDs.""" + affected_files: List[str] = [] + notifications = payload.get("value", []) + for notification in notifications: + resource = notification.get("resource") + if resource and "/drive/items/" in resource: + file_id = resource.split("/drive/items/")[-1] + affected_files.append(file_id) + return affected_files - async def cleanup_subscription( - self, subscription_id: str, resource_id: str = None - ) -> bool: - if not self._authenticated: + async def cleanup_subscription(self, subscription_id: str) -> bool: + """Clean up subscription - BaseConnector interface.""" + if subscription_id == "no-webhook-configured": + logger.info("No subscription to cleanup (webhook was not configured)") + return True + + try: + if not await self.authenticate(): + logger.error("OneDrive authentication failed during subscription cleanup") + return False + + token = self.oauth.get_access_token() + headers = {"Authorization": f"Bearer {token}"} + + url = f"{self._graph_base_url}/subscriptions/{subscription_id}" + + async with httpx.AsyncClient() as client: + response = await client.delete(url, headers=headers, timeout=30) + + if response.status_code in [200, 204, 404]: + logger.info(f"OneDrive subscription {subscription_id} cleaned up successfully") + return True + else: + logger.warning(f"Unexpected response cleaning up subscription: {response.status_code}") + return False + + except Exception as e: + logger.error(f"Failed to cleanup OneDrive subscription {subscription_id}: {e}") return False - token = self.oauth.get_access_token() - async with httpx.AsyncClient() as client: - resp = await client.delete( - f"{self.base_url}/subscriptions/{subscription_id}", - headers={"Authorization": f"Bearer {token}"}, - ) - return resp.status_code in (200, 204) diff --git a/src/connectors/onedrive/oauth.py b/src/connectors/onedrive/oauth.py index a81124e6..a2c94d15 100644 --- a/src/connectors/onedrive/oauth.py +++ b/src/connectors/onedrive/oauth.py @@ -1,17 +1,28 @@ import os +import json +import logging +from typing import Optional, Dict, Any + import aiofiles -from typing import Optional import msal +logger = logging.getLogger(__name__) + class OneDriveOAuth: - """Handles Microsoft Graph OAuth authentication flow""" + """Handles Microsoft Graph OAuth for OneDrive (personal Microsoft accounts by default).""" - SCOPES = [ - "offline_access", - "Files.Read.All", - ] + # Reserved scopes that must NOT be sent on token or silent calls + RESERVED_SCOPES = {"openid", "profile", "offline_access"} + # For PERSONAL Microsoft Accounts (OneDrive consumer): + # - Use AUTH_SCOPES for interactive auth (consent + refresh token issuance) + # - Use RESOURCE_SCOPES for acquire_token_silent / refresh paths + AUTH_SCOPES = ["User.Read", "Files.Read.All", "offline_access"] + RESOURCE_SCOPES = ["User.Read", "Files.Read.All"] + SCOPES = AUTH_SCOPES # Backward-compat alias if something references .SCOPES + + # Kept for reference; MSAL derives endpoints from `authority` AUTH_ENDPOINT = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize" TOKEN_ENDPOINT = "https://login.microsoftonline.com/common/oauth2/v2.0/token" @@ -21,18 +32,29 @@ class OneDriveOAuth: client_secret: str, token_file: str = "onedrive_token.json", authority: str = "https://login.microsoftonline.com/common", + allow_json_refresh: bool = True, ): + """ + Initialize OneDriveOAuth. + + Args: + client_id: Azure AD application (client) ID. + client_secret: Azure AD application client secret. + token_file: Path to persisted token cache file (MSAL cache format). + authority: Usually "https://login.microsoftonline.com/common" for MSA + org, + or tenant-specific for work/school. + allow_json_refresh: If True, permit one-time migration from legacy flat JSON + {"access_token","refresh_token",...}. Otherwise refuse it. + """ self.client_id = client_id self.client_secret = client_secret self.token_file = token_file self.authority = authority + self.allow_json_refresh = allow_json_refresh self.token_cache = msal.SerializableTokenCache() + self._current_account = None - # Load existing cache if available - if os.path.exists(self.token_file): - with open(self.token_file, "r") as f: - self.token_cache.deserialize(f.read()) - + # Initialize MSAL Confidential Client self.app = msal.ConfidentialClientApplication( client_id=self.client_id, client_credential=self.client_secret, @@ -40,56 +62,261 @@ class OneDriveOAuth: token_cache=self.token_cache, ) - async def save_cache(self): - """Persist the token cache to file""" - async with aiofiles.open(self.token_file, "w") as f: - await f.write(self.token_cache.serialize()) + async def load_credentials(self) -> bool: + """Load existing credentials from token file (async).""" + try: + logger.debug(f"OneDrive OAuth loading credentials from: {self.token_file}") + if os.path.exists(self.token_file): + logger.debug(f"Token file exists, reading: {self.token_file}") - def create_authorization_url(self, redirect_uri: str) -> str: - """Create authorization URL for OAuth flow""" - return self.app.get_authorization_request_url( - self.SCOPES, redirect_uri=redirect_uri - ) + # Read the token file + async with aiofiles.open(self.token_file, "r") as f: + cache_data = await f.read() + logger.debug(f"Read {len(cache_data)} chars from token file") + + if cache_data.strip(): + # 1) Try legacy flat JSON first + try: + json_data = json.loads(cache_data) + if isinstance(json_data, dict) and "refresh_token" in json_data: + if self.allow_json_refresh: + logger.debug( + "Found legacy JSON refresh_token and allow_json_refresh=True; attempting migration refresh" + ) + return await self._refresh_from_json_token(json_data) + else: + logger.warning( + "Token file contains a legacy JSON refresh_token, but allow_json_refresh=False. " + "Delete the file and re-auth." + ) + return False + except json.JSONDecodeError: + logger.debug("Token file is not flat JSON; attempting MSAL cache format") + + # 2) Try MSAL cache format + logger.debug("Attempting MSAL cache deserialization") + self.token_cache.deserialize(cache_data) + + # Get accounts from loaded cache + accounts = self.app.get_accounts() + logger.debug(f"Found {len(accounts)} accounts in MSAL cache") + if accounts: + self._current_account = accounts[0] + logger.debug(f"Set current account: {self._current_account.get('username', 'no username')}") + + # Use RESOURCE_SCOPES (no reserved scopes) for silent acquisition + result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=self._current_account) + logger.debug(f"Silent token acquisition result keys: {list(result.keys()) if result else 'None'}") + if result and "access_token" in result: + logger.debug("Silent token acquisition successful") + await self.save_cache() + return True + else: + error_msg = (result or {}).get("error") or "No result" + logger.warning(f"Silent token acquisition failed: {error_msg}") + else: + logger.debug(f"Token file {self.token_file} is empty") + else: + logger.debug(f"Token file does not exist: {self.token_file}") + + return False + + except Exception as e: + logger.error(f"Failed to load OneDrive credentials: {e}") + import traceback + traceback.print_exc() + return False + + async def _refresh_from_json_token(self, token_data: dict) -> bool: + """ + Use refresh token from a legacy JSON file to get new tokens (one-time migration path). + Prefer using an MSAL cache file and acquire_token_silent(); this path is only for migrating older files. + """ + try: + refresh_token = token_data.get("refresh_token") + if not refresh_token: + logger.error("No refresh_token found in JSON file - cannot refresh") + logger.error("You must re-authenticate interactively to obtain a valid token") + return False + + # Use only RESOURCE_SCOPES when refreshing (no reserved scopes) + refresh_scopes = [s for s in self.RESOURCE_SCOPES if s not in self.RESERVED_SCOPES] + logger.debug(f"Using refresh token; refresh scopes = {refresh_scopes}") + + result = self.app.acquire_token_by_refresh_token( + refresh_token=refresh_token, + scopes=refresh_scopes, + ) + + if result and "access_token" in result: + logger.debug("Successfully refreshed token via legacy JSON path") + await self.save_cache() + + accounts = self.app.get_accounts() + logger.debug(f"After refresh, found {len(accounts)} accounts") + if accounts: + self._current_account = accounts[0] + logger.debug(f"Set current account after refresh: {self._current_account.get('username', 'no username')}") + return True + + # Error handling + err = (result or {}).get("error_description") or (result or {}).get("error") or "Unknown error" + logger.error(f"Refresh token failed: {err}") + + if any(code in err for code in ("AADSTS70000", "invalid_grant", "interaction_required")): + logger.warning( + "Refresh denied due to unauthorized/expired scopes or invalid grant. " + "Delete the token file and perform interactive sign-in with correct scopes." + ) + + return False + + except Exception as e: + logger.error(f"Exception during refresh from JSON token: {e}") + import traceback + traceback.print_exc() + return False + + async def save_cache(self): + """Persist the token cache to file.""" + try: + # Ensure parent directory exists + parent = os.path.dirname(os.path.abspath(self.token_file)) + if parent and not os.path.exists(parent): + os.makedirs(parent, exist_ok=True) + + cache_data = self.token_cache.serialize() + if cache_data: + async with aiofiles.open(self.token_file, "w") as f: + await f.write(cache_data) + logger.debug(f"Token cache saved to {self.token_file}") + except Exception as e: + logger.error(f"Failed to save token cache: {e}") + + def create_authorization_url(self, redirect_uri: str, state: Optional[str] = None) -> str: + """Create authorization URL for OAuth flow.""" + # Store redirect URI for later use in callback + self._redirect_uri = redirect_uri + + kwargs: Dict[str, Any] = { + # Interactive auth includes offline_access + "scopes": self.AUTH_SCOPES, + "redirect_uri": redirect_uri, + "prompt": "consent", # ensure refresh token on first run + } + if state: + kwargs["state"] = state # Optional CSRF protection + + auth_url = self.app.get_authorization_request_url(**kwargs) + + logger.debug(f"Generated auth URL: {auth_url}") + logger.debug(f"Auth scopes: {self.AUTH_SCOPES}") + + return auth_url async def handle_authorization_callback( self, authorization_code: str, redirect_uri: str ) -> bool: - """Handle OAuth callback and exchange code for tokens""" - result = self.app.acquire_token_by_authorization_code( - authorization_code, - scopes=self.SCOPES, - redirect_uri=redirect_uri, - ) - if "access_token" in result: - await self.save_cache() - return True - raise ValueError(result.get("error_description") or "Authorization failed") + """Handle OAuth callback and exchange code for tokens.""" + try: + result = self.app.acquire_token_by_authorization_code( + authorization_code, + scopes=self.AUTH_SCOPES, # same as authorize step + redirect_uri=redirect_uri, + ) + + if result and "access_token" in result: + accounts = self.app.get_accounts() + if accounts: + self._current_account = accounts[0] + + await self.save_cache() + logger.info("OneDrive OAuth authorization successful") + return True + + error_msg = (result or {}).get("error_description") or (result or {}).get("error") or "Unknown error" + logger.error(f"OneDrive OAuth authorization failed: {error_msg}") + return False + + except Exception as e: + logger.error(f"Exception during OneDrive OAuth authorization: {e}") + return False async def is_authenticated(self) -> bool: - """Check if we have valid credentials""" - accounts = self.app.get_accounts() - if not accounts: + """Check if we have valid credentials.""" + try: + # First try to load credentials if we haven't already + if not self._current_account: + await self.load_credentials() + + # Try to get a token (MSAL will refresh if needed) + if self._current_account: + result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=self._current_account) + if result and "access_token" in result: + return True + else: + error_msg = (result or {}).get("error") or "No result returned" + logger.debug(f"Token acquisition failed for current account: {error_msg}") + + # Fallback: try without specific account + result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=None) + if result and "access_token" in result: + accounts = self.app.get_accounts() + if accounts: + self._current_account = accounts[0] + return True + + return False + + except Exception as e: + logger.error(f"Authentication check failed: {e}") return False - result = self.app.acquire_token_silent(self.SCOPES, account=accounts[0]) - if "access_token" in result: - await self.save_cache() - return True - return False def get_access_token(self) -> str: - """Get an access token for Microsoft Graph""" - accounts = self.app.get_accounts() - if not accounts: - raise ValueError("Not authenticated") - result = self.app.acquire_token_silent(self.SCOPES, account=accounts[0]) - if "access_token" not in result: - raise ValueError( - result.get("error_description") or "Failed to acquire access token" - ) - return result["access_token"] + """Get an access token for Microsoft Graph.""" + try: + # Try with current account first + if self._current_account: + result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=self._current_account) + if result and "access_token" in result: + return result["access_token"] + + # Fallback: try without specific account + result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=None) + if result and "access_token" in result: + return result["access_token"] + + # If we get here, authentication has failed + error_msg = (result or {}).get("error_description") or (result or {}).get("error") or "No valid authentication" + raise ValueError(f"Failed to acquire access token: {error_msg}") + + except Exception as e: + logger.error(f"Failed to get access token: {e}") + raise async def revoke_credentials(self): - """Clear token cache and remove token file""" - self.token_cache.clear() - if os.path.exists(self.token_file): - os.remove(self.token_file) + """Clear token cache and remove token file.""" + try: + # Clear in-memory state + self._current_account = None + self.token_cache = msal.SerializableTokenCache() + + # Recreate MSAL app with fresh cache + self.app = msal.ConfidentialClientApplication( + client_id=self.client_id, + client_credential=self.client_secret, + authority=self.authority, + token_cache=self.token_cache, + ) + + # Remove token file + if os.path.exists(self.token_file): + os.remove(self.token_file) + logger.info(f"Removed OneDrive token file: {self.token_file}") + + except Exception as e: + logger.error(f"Failed to revoke OneDrive credentials: {e}") + + def get_service(self) -> str: + """Return an access token (Graph client is just the bearer).""" + return self.get_access_token() diff --git a/src/connectors/sharepoint/connector.py b/src/connectors/sharepoint/connector.py index 7135cc8e..f8283062 100644 --- a/src/connectors/sharepoint/connector.py +++ b/src/connectors/sharepoint/connector.py @@ -1,229 +1,567 @@ +import logging +from pathlib import Path +from typing import List, Dict, Any, Optional +from urllib.parse import urlparse +from datetime import datetime import httpx -import uuid -from datetime import datetime, timedelta -from typing import Dict, List, Any, Optional from ..base import BaseConnector, ConnectorDocument, DocumentACL from .oauth import SharePointOAuth +logger = logging.getLogger(__name__) + class SharePointConnector(BaseConnector): - """SharePoint Sites connector using Microsoft Graph API""" + """SharePoint connector using MSAL-based OAuth for authentication""" + # Required BaseConnector class attributes CLIENT_ID_ENV_VAR = "MICROSOFT_GRAPH_OAUTH_CLIENT_ID" CLIENT_SECRET_ENV_VAR = "MICROSOFT_GRAPH_OAUTH_CLIENT_SECRET" - + # Connector metadata CONNECTOR_NAME = "SharePoint" - CONNECTOR_DESCRIPTION = "Connect to SharePoint sites to sync team documents" + CONNECTOR_DESCRIPTION = "Connect to SharePoint to sync documents and files" CONNECTOR_ICON = "sharepoint" - + def __init__(self, config: Dict[str, Any]): super().__init__(config) - self.oauth = SharePointOAuth( - client_id=self.get_client_id(), - client_secret=self.get_client_secret(), - token_file=config.get("token_file", "sharepoint_token.json"), - ) - self.subscription_id = config.get("subscription_id") or config.get( - "webhook_channel_id" - ) - self.base_url = "https://graph.microsoft.com/v1.0" - # SharePoint site configuration - self.site_id = config.get("site_id") # Required for SharePoint + logger.debug(f"SharePoint connector __init__ called with config type: {type(config)}") + logger.debug(f"SharePoint connector __init__ config value: {config}") + + # Ensure we always pass a valid config to the base class + if config is None: + logger.debug("Config was None, using empty dict") + config = {} + + try: + logger.debug("Calling super().__init__") + super().__init__(config) # Now safe to call with empty dict instead of None + logger.debug("super().__init__ completed successfully") + except Exception as e: + logger.error(f"super().__init__ failed: {e}") + raise + + # Initialize with defaults that allow the connector to be listed + self.client_id = None + self.client_secret = None + self.tenant_id = config.get("tenant_id", "common") + self.sharepoint_url = config.get("sharepoint_url") + self.redirect_uri = config.get("redirect_uri", "http://localhost") + + # Try to get credentials, but don't fail if they're missing + try: + logger.debug("Attempting to get client_id") + self.client_id = self.get_client_id() + logger.debug(f"Got client_id: {self.client_id is not None}") + except Exception as e: + logger.debug(f"Failed to get client_id: {e}") + pass # Credentials not available, that's OK for listing + + try: + logger.debug("Attempting to get client_secret") + self.client_secret = self.get_client_secret() + logger.debug(f"Got client_secret: {self.client_secret is not None}") + except Exception as e: + logger.debug(f"Failed to get client_secret: {e}") + pass # Credentials not available, that's OK for listing - async def authenticate(self) -> bool: - if await self.oauth.is_authenticated(): - self._authenticated = True - return True - return False - - async def setup_subscription(self) -> str: - if not self._authenticated: - raise ValueError("Not authenticated") - - webhook_url = self.config.get("webhook_url") - if not webhook_url: - raise ValueError("webhook_url required in config for subscriptions") - - expiration = (datetime.utcnow() + timedelta(days=2)).isoformat() + "Z" - body = { - "changeType": "created,updated,deleted", - "notificationUrl": webhook_url, - "resource": f"/sites/{self.site_id}/drive/root", - "expirationDateTime": expiration, - "clientState": str(uuid.uuid4()), + # Token file setup + project_root = Path(__file__).resolve().parent.parent.parent.parent + token_file = config.get("token_file") or str(project_root / "sharepoint_token.json") + Path(token_file).parent.mkdir(parents=True, exist_ok=True) + + # Only initialize OAuth if we have credentials + if self.client_id and self.client_secret: + connection_id = config.get("connection_id", "default") + + # Use token_file from config if provided, otherwise generate one + if config.get("token_file"): + oauth_token_file = config["token_file"] + else: + oauth_token_file = f"sharepoint_token_{connection_id}.json" + + authority = f"https://login.microsoftonline.com/{self.tenant_id}" if self.tenant_id != "common" else "https://login.microsoftonline.com/common" + + self.oauth = SharePointOAuth( + client_id=self.client_id, + client_secret=self.client_secret, + token_file=oauth_token_file, + authority=authority + ) + else: + self.oauth = None + + # Track subscription ID for webhooks + self._subscription_id: Optional[str] = None + + # Add Graph API defaults similar to Google Drive flags + self._graph_api_version = "v1.0" + self._default_params = { + "$select": "id,name,size,lastModifiedDateTime,createdDateTime,webUrl,file,folder,@microsoft.graph.downloadUrl" } - - token = self.oauth.get_access_token() - async with httpx.AsyncClient() as client: - resp = await client.post( - f"{self.base_url}/subscriptions", - json=body, - headers={"Authorization": f"Bearer {token}"}, - ) - resp.raise_for_status() - data = resp.json() - - self.subscription_id = data["id"] - return self.subscription_id - - async def list_files( - self, page_token: Optional[str] = None, limit: int = 100 - ) -> Dict[str, Any]: - if not self._authenticated: - raise ValueError("Not authenticated") - - params = {"$top": str(limit)} - if page_token: - params["$skiptoken"] = page_token - - token = self.oauth.get_access_token() - async with httpx.AsyncClient() as client: - resp = await client.get( - f"{self.base_url}/sites/{self.site_id}/drive/root/children", - params=params, - headers={"Authorization": f"Bearer {token}"}, - ) - resp.raise_for_status() - data = resp.json() - - files = [] - for item in data.get("value", []): - if item.get("file"): - files.append( - { - "id": item["id"], - "name": item["name"], - "mimeType": item.get("file", {}).get( - "mimeType", "application/octet-stream" - ), - "webViewLink": item.get("webUrl"), - "createdTime": item.get("createdDateTime"), - "modifiedTime": item.get("lastModifiedDateTime"), - } - ) - - next_token = None - next_link = data.get("@odata.nextLink") - if next_link: - from urllib.parse import urlparse, parse_qs - - parsed = urlparse(next_link) - next_token = parse_qs(parsed.query).get("$skiptoken", [None])[0] - - return {"files": files, "nextPageToken": next_token} - - async def get_file_content(self, file_id: str) -> ConnectorDocument: - if not self._authenticated: - raise ValueError("Not authenticated") - - token = self.oauth.get_access_token() - headers = {"Authorization": f"Bearer {token}"} - async with httpx.AsyncClient() as client: - meta_resp = await client.get( - f"{self.base_url}/sites/{self.site_id}/drive/items/{file_id}", - headers=headers, - ) - meta_resp.raise_for_status() - metadata = meta_resp.json() - - content_resp = await client.get( - f"{self.base_url}/sites/{self.site_id}/drive/items/{file_id}/content", - headers=headers, - ) - content_resp.raise_for_status() - content = content_resp.content - - perm_resp = await client.get( - f"{self.base_url}/sites/{self.site_id}/drive/items/{file_id}/permissions", - headers=headers, - ) - perm_resp.raise_for_status() - permissions = perm_resp.json() - - acl = self._parse_permissions(metadata, permissions) - modified = datetime.fromisoformat( - metadata["lastModifiedDateTime"].replace("Z", "+00:00") - ).replace(tzinfo=None) - created = datetime.fromisoformat( - metadata["createdDateTime"].replace("Z", "+00:00") - ).replace(tzinfo=None) - - document = ConnectorDocument( - id=metadata["id"], - filename=metadata["name"], - mimetype=metadata.get("file", {}).get( - "mimeType", "application/octet-stream" - ), - content=content, - source_url=metadata.get("webUrl"), - acl=acl, - modified_time=modified, - created_time=created, - metadata={"size": metadata.get("size")}, - ) - return document - - def _parse_permissions( - self, metadata: Dict[str, Any], permissions: Dict[str, Any] - ) -> DocumentACL: - acl = DocumentACL() - owner = metadata.get("createdBy", {}).get("user", {}).get("email") - if owner: - acl.owner = owner - for perm in permissions.get("value", []): - role = perm.get("roles", ["read"])[0] - grantee = perm.get("grantedToV2") or perm.get("grantedTo") - if not grantee: - continue - user = grantee.get("user") - if user and user.get("email"): - acl.user_permissions[user["email"]] = role - group = grantee.get("group") - if group and group.get("email"): - acl.group_permissions[group["email"]] = role - return acl - - def handle_webhook_validation( - self, request_method: str, headers: Dict[str, str], query_params: Dict[str, str] - ) -> Optional[str]: - """Handle Microsoft Graph webhook validation""" - if request_method == "GET": - validation_token = query_params.get("validationtoken") or query_params.get( - "validationToken" - ) - if validation_token: - return validation_token - return None - - def extract_webhook_channel_id( - self, payload: Dict[str, Any], headers: Dict[str, str] - ) -> Optional[str]: - """Extract SharePoint subscription ID from webhook payload""" - values = payload.get("value", []) - return values[0].get("subscriptionId") if values else None - - async def handle_webhook(self, payload: Dict[str, Any]) -> List[str]: - values = payload.get("value", []) - file_ids = [] - for item in values: - resource_data = item.get("resourceData", {}) - file_id = resource_data.get("id") - if file_id: - file_ids.append(file_id) - return file_ids - - async def cleanup_subscription( - self, subscription_id: str, resource_id: str = None - ) -> bool: - if not self._authenticated: + + @property + def _graph_base_url(self) -> str: + """Base URL for Microsoft Graph API calls""" + return f"https://graph.microsoft.com/{self._graph_api_version}" + + def emit(self, doc: ConnectorDocument) -> None: + """ + Emit a ConnectorDocument instance. + """ + logger.debug(f"Emitting SharePoint document: {doc.id} ({doc.filename})") + + async def authenticate(self) -> bool: + """Test authentication - BaseConnector interface""" + logger.debug(f"SharePoint authenticate() called, oauth is None: {self.oauth is None}") + try: + if not self.oauth: + logger.debug("SharePoint authentication failed: OAuth not initialized") + self._authenticated = False + return False + + logger.debug("Loading SharePoint credentials...") + # Try to load existing credentials first + load_result = await self.oauth.load_credentials() + logger.debug(f"Load credentials result: {load_result}") + + logger.debug("Checking SharePoint authentication status...") + authenticated = await self.oauth.is_authenticated() + logger.debug(f"SharePoint is_authenticated result: {authenticated}") + + self._authenticated = authenticated + return authenticated + except Exception as e: + logger.error(f"SharePoint authentication failed: {e}") + import traceback + traceback.print_exc() + self._authenticated = False return False - token = self.oauth.get_access_token() - async with httpx.AsyncClient() as client: - resp = await client.delete( - f"{self.base_url}/subscriptions/{subscription_id}", - headers={"Authorization": f"Bearer {token}"}, + + def get_auth_url(self) -> str: + """Get OAuth authorization URL""" + if not self.oauth: + raise RuntimeError("SharePoint OAuth not initialized - missing credentials") + return self.oauth.create_authorization_url(self.redirect_uri) + + async def handle_oauth_callback(self, auth_code: str) -> Dict[str, Any]: + """Handle OAuth callback""" + if not self.oauth: + raise RuntimeError("SharePoint OAuth not initialized - missing credentials") + try: + success = await self.oauth.handle_authorization_callback(auth_code, self.redirect_uri) + if success: + self._authenticated = True + return {"status": "success"} + else: + raise ValueError("OAuth callback failed") + except Exception as e: + logger.error(f"OAuth callback failed: {e}") + raise + + def sync_once(self) -> None: + """ + Perform a one-shot sync of SharePoint files and emit documents. + This method mirrors the Google Drive connector's sync_once functionality. + """ + import asyncio + + async def _async_sync(): + try: + # Get list of files + file_list = await self.list_files(max_files=1000) # Adjust as needed + files = file_list.get("files", []) + + for file_info in files: + try: + file_id = file_info.get("id") + if not file_id: + continue + + # Get full document content + doc = await self.get_file_content(file_id) + self.emit(doc) + + except Exception as e: + logger.error(f"Failed to sync SharePoint file {file_info.get('name', 'unknown')}: {e}") + continue + + except Exception as e: + logger.error(f"SharePoint sync_once failed: {e}") + raise + + # Run the async sync + if hasattr(asyncio, 'run'): + asyncio.run(_async_sync()) + else: + # Python < 3.7 compatibility + loop = asyncio.get_event_loop() + loop.run_until_complete(_async_sync()) + + async def setup_subscription(self) -> str: + """Set up real-time subscription for file changes - BaseConnector interface""" + webhook_url = self.config.get('webhook_url') + if not webhook_url: + logger.warning("No webhook URL configured, skipping SharePoint subscription setup") + return "no-webhook-configured" + + try: + # Ensure we're authenticated + if not await self.authenticate(): + raise RuntimeError("SharePoint authentication failed during subscription setup") + + token = self.oauth.get_access_token() + + # Microsoft Graph subscription for SharePoint site + site_info = self._parse_sharepoint_url() + if site_info: + resource = f"sites/{site_info['host_name']}:/sites/{site_info['site_name']}:/drive/root" + else: + resource = "/me/drive/root" + + subscription_data = { + "changeType": "created,updated,deleted", + "notificationUrl": f"{webhook_url}/webhook/sharepoint", + "resource": resource, + "expirationDateTime": self._get_subscription_expiry(), + "clientState": f"sharepoint_{self.tenant_id}" + } + + headers = { + "Authorization": f"Bearer {token}", + "Content-Type": "application/json" + } + + url = f"{self._graph_base_url}/subscriptions" + + async with httpx.AsyncClient() as client: + response = await client.post(url, json=subscription_data, headers=headers, timeout=30) + response.raise_for_status() + + result = response.json() + subscription_id = result.get("id") + + if subscription_id: + self._subscription_id = subscription_id + logger.info(f"SharePoint subscription created: {subscription_id}") + return subscription_id + else: + raise ValueError("No subscription ID returned from Microsoft Graph") + + except Exception as e: + logger.error(f"Failed to setup SharePoint subscription: {e}") + raise + + def _get_subscription_expiry(self) -> str: + """Get subscription expiry time (max 3 days for Graph API)""" + from datetime import datetime, timedelta + expiry = datetime.utcnow() + timedelta(days=3) # 3 days max for Graph + return expiry.strftime("%Y-%m-%dT%H:%M:%S.%fZ") + + def _parse_sharepoint_url(self) -> Optional[Dict[str, str]]: + """Parse SharePoint URL to extract site information for Graph API""" + if not self.sharepoint_url: + return None + + try: + parsed = urlparse(self.sharepoint_url) + # Extract hostname and site name from URL like: https://contoso.sharepoint.com/sites/teamsite + host_name = parsed.netloc + path_parts = parsed.path.strip('/').split('/') + + if len(path_parts) >= 2 and path_parts[0] == 'sites': + site_name = path_parts[1] + return { + "host_name": host_name, + "site_name": site_name + } + except Exception as e: + logger.warning(f"Could not parse SharePoint URL {self.sharepoint_url}: {e}") + + return None + + async def list_files( + self, + page_token: Optional[str] = None, + max_files: Optional[int] = None, + **kwargs + ) -> Dict[str, Any]: + """List all files using Microsoft Graph API - BaseConnector interface""" + try: + # Ensure authentication + if not await self.authenticate(): + raise RuntimeError("SharePoint authentication failed during file listing") + + files = [] + max_files_value = max_files if max_files is not None else 100 + + # Build Graph API URL for the site or fallback to user's OneDrive + site_info = self._parse_sharepoint_url() + if site_info: + base_url = f"{self._graph_base_url}/sites/{site_info['host_name']}:/sites/{site_info['site_name']}:/drive/root/children" + else: + base_url = f"{self._graph_base_url}/me/drive/root/children" + + params = dict(self._default_params) + params["$top"] = str(max_files_value) + + if page_token: + params["$skiptoken"] = page_token + + response = await self._make_graph_request(base_url, params=params) + data = response.json() + + items = data.get("value", []) + for item in items: + # Only include files, not folders + if item.get("file"): + files.append({ + "id": item.get("id", ""), + "name": item.get("name", ""), + "path": f"/drive/items/{item.get('id')}", + "size": int(item.get("size", 0)), + "modified": item.get("lastModifiedDateTime"), + "created": item.get("createdDateTime"), + "mime_type": item.get("file", {}).get("mimeType", self._get_mime_type(item.get("name", ""))), + "url": item.get("webUrl", ""), + "download_url": item.get("@microsoft.graph.downloadUrl") + }) + + # Check for next page + next_page_token = None + next_link = data.get("@odata.nextLink") + if next_link: + from urllib.parse import urlparse, parse_qs + parsed = urlparse(next_link) + query_params = parse_qs(parsed.query) + if "$skiptoken" in query_params: + next_page_token = query_params["$skiptoken"][0] + + return { + "files": files, + "next_page_token": next_page_token + } + + except Exception as e: + logger.error(f"Failed to list SharePoint files: {e}") + return {"files": [], "next_page_token": None} # Return empty result instead of raising + + async def get_file_content(self, file_id: str) -> ConnectorDocument: + """Get file content and metadata - BaseConnector interface""" + try: + # Ensure authentication + if not await self.authenticate(): + raise RuntimeError("SharePoint authentication failed during file content retrieval") + + # First get file metadata using Graph API + file_metadata = await self._get_file_metadata_by_id(file_id) + + if not file_metadata: + raise ValueError(f"File not found: {file_id}") + + # Download file content + download_url = file_metadata.get("download_url") + if download_url: + content = await self._download_file_from_url(download_url) + else: + content = await self._download_file_content(file_id) + + # Create ACL from metadata + acl = DocumentACL( + owner="", # Graph API requires additional calls for detailed permissions + user_permissions={}, + group_permissions={} ) - return resp.status_code in (200, 204) + + # Parse dates + modified_time = self._parse_graph_date(file_metadata.get("modified")) + created_time = self._parse_graph_date(file_metadata.get("created")) + + return ConnectorDocument( + id=file_id, + filename=file_metadata.get("name", ""), + mimetype=file_metadata.get("mime_type", "application/octet-stream"), + content=content, + source_url=file_metadata.get("url", ""), + acl=acl, + modified_time=modified_time, + created_time=created_time, + metadata={ + "sharepoint_path": file_metadata.get("path", ""), + "sharepoint_url": self.sharepoint_url, + "size": file_metadata.get("size", 0) + } + ) + + except Exception as e: + logger.error(f"Failed to get SharePoint file content {file_id}: {e}") + raise + + async def _get_file_metadata_by_id(self, file_id: str) -> Optional[Dict[str, Any]]: + """Get file metadata by ID using Graph API""" + try: + # Try site-specific path first, then fallback to user drive + site_info = self._parse_sharepoint_url() + if site_info: + url = f"{self._graph_base_url}/sites/{site_info['host_name']}:/sites/{site_info['site_name']}:/drive/items/{file_id}" + else: + url = f"{self._graph_base_url}/me/drive/items/{file_id}" + + params = dict(self._default_params) + + response = await self._make_graph_request(url, params=params) + item = response.json() + + if item.get("file"): + return { + "id": file_id, + "name": item.get("name", ""), + "path": f"/drive/items/{file_id}", + "size": int(item.get("size", 0)), + "modified": item.get("lastModifiedDateTime"), + "created": item.get("createdDateTime"), + "mime_type": item.get("file", {}).get("mimeType", self._get_mime_type(item.get("name", ""))), + "url": item.get("webUrl", ""), + "download_url": item.get("@microsoft.graph.downloadUrl") + } + + return None + + except Exception as e: + logger.error(f"Failed to get file metadata for {file_id}: {e}") + return None + + async def _download_file_content(self, file_id: str) -> bytes: + """Download file content by file ID using Graph API""" + try: + site_info = self._parse_sharepoint_url() + if site_info: + url = f"{self._graph_base_url}/sites/{site_info['host_name']}:/sites/{site_info['site_name']}:/drive/items/{file_id}/content" + else: + url = f"{self._graph_base_url}/me/drive/items/{file_id}/content" + + token = self.oauth.get_access_token() + headers = {"Authorization": f"Bearer {token}"} + + async with httpx.AsyncClient() as client: + response = await client.get(url, headers=headers, timeout=60) + response.raise_for_status() + return response.content + + except Exception as e: + logger.error(f"Failed to download file content for {file_id}: {e}") + raise + + async def _download_file_from_url(self, download_url: str) -> bytes: + """Download file content from direct download URL""" + try: + async with httpx.AsyncClient() as client: + response = await client.get(download_url, timeout=60) + response.raise_for_status() + return response.content + except Exception as e: + logger.error(f"Failed to download from URL {download_url}: {e}") + raise + + def _parse_graph_date(self, date_str: Optional[str]) -> datetime: + """Parse Microsoft Graph date string to datetime""" + if not date_str: + return datetime.now() + + try: + if date_str.endswith('Z'): + return datetime.fromisoformat(date_str[:-1]).replace(tzinfo=None) + else: + return datetime.fromisoformat(date_str.replace('T', ' ')) + except (ValueError, AttributeError): + return datetime.now() + + async def _make_graph_request(self, url: str, method: str = "GET", + data: Optional[Dict] = None, params: Optional[Dict] = None) -> httpx.Response: + """Make authenticated API request to Microsoft Graph""" + token = self.oauth.get_access_token() + headers = { + "Authorization": f"Bearer {token}", + "Content-Type": "application/json" + } + + async with httpx.AsyncClient() as client: + if method.upper() == "GET": + response = await client.get(url, headers=headers, params=params, timeout=30) + elif method.upper() == "POST": + response = await client.post(url, headers=headers, json=data, timeout=30) + elif method.upper() == "DELETE": + response = await client.delete(url, headers=headers, timeout=30) + else: + raise ValueError(f"Unsupported HTTP method: {method}") + + response.raise_for_status() + return response + + def _get_mime_type(self, filename: str) -> str: + """Get MIME type based on file extension""" + import mimetypes + mime_type, _ = mimetypes.guess_type(filename) + return mime_type or "application/octet-stream" + + # Webhook methods - BaseConnector interface + def handle_webhook_validation(self, request_method: str, headers: Dict[str, str], + query_params: Dict[str, str]) -> Optional[str]: + """Handle webhook validation (Graph API specific)""" + if request_method == "POST" and "validationToken" in query_params: + return query_params["validationToken"] + return None + + def extract_webhook_channel_id(self, payload: Dict[str, Any], + headers: Dict[str, str]) -> Optional[str]: + """Extract channel/subscription ID from webhook payload""" + notifications = payload.get("value", []) + if notifications: + return notifications[0].get("subscriptionId") + return None + + async def handle_webhook(self, payload: Dict[str, Any]) -> List[str]: + """Handle webhook notification and return affected file IDs""" + affected_files = [] + + # Process Microsoft Graph webhook payload + notifications = payload.get("value", []) + for notification in notifications: + resource = notification.get("resource") + if resource and "/drive/items/" in resource: + file_id = resource.split("/drive/items/")[-1] + affected_files.append(file_id) + + return affected_files + + async def cleanup_subscription(self, subscription_id: str) -> bool: + """Clean up subscription - BaseConnector interface""" + if subscription_id == "no-webhook-configured": + logger.info("No subscription to cleanup (webhook was not configured)") + return True + + try: + # Ensure authentication + if not await self.authenticate(): + logger.error("SharePoint authentication failed during subscription cleanup") + return False + + token = self.oauth.get_access_token() + headers = {"Authorization": f"Bearer {token}"} + + url = f"{self._graph_base_url}/subscriptions/{subscription_id}" + + async with httpx.AsyncClient() as client: + response = await client.delete(url, headers=headers, timeout=30) + + if response.status_code in [200, 204, 404]: + logger.info(f"SharePoint subscription {subscription_id} cleaned up successfully") + return True + else: + logger.warning(f"Unexpected response cleaning up subscription: {response.status_code}") + return False + + except Exception as e: + logger.error(f"Failed to cleanup SharePoint subscription {subscription_id}: {e}") + return False diff --git a/src/connectors/sharepoint/oauth.py b/src/connectors/sharepoint/oauth.py index fa7424e9..4a96581f 100644 --- a/src/connectors/sharepoint/oauth.py +++ b/src/connectors/sharepoint/oauth.py @@ -1,18 +1,28 @@ import os +import json +import logging +from typing import Optional, Dict, Any + import aiofiles -from typing import Optional import msal +logger = logging.getLogger(__name__) + class SharePointOAuth: - """Handles Microsoft Graph OAuth authentication flow""" + """Handles Microsoft Graph OAuth authentication flow following Google Drive pattern.""" - SCOPES = [ - "offline_access", - "Files.Read.All", - "Sites.Read.All", - ] + # Reserved scopes that must NOT be sent on token or silent calls + RESERVED_SCOPES = {"openid", "profile", "offline_access"} + # For PERSONAL Microsoft Accounts (OneDrive consumer): + # - Use AUTH_SCOPES for interactive auth (consent + refresh token issuance) + # - Use RESOURCE_SCOPES for acquire_token_silent / refresh paths + AUTH_SCOPES = ["User.Read", "Files.Read.All", "offline_access"] + RESOURCE_SCOPES = ["User.Read", "Files.Read.All"] + SCOPES = AUTH_SCOPES # Backward compatibility alias + + # Kept for reference; MSAL derives endpoints from `authority` AUTH_ENDPOINT = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize" TOKEN_ENDPOINT = "https://login.microsoftonline.com/common/oauth2/v2.0/token" @@ -22,18 +32,29 @@ class SharePointOAuth: client_secret: str, token_file: str = "sharepoint_token.json", authority: str = "https://login.microsoftonline.com/common", + allow_json_refresh: bool = True, ): + """ + Initialize SharePointOAuth. + + Args: + client_id: Azure AD application (client) ID. + client_secret: Azure AD application client secret. + token_file: Path to persisted token cache file (MSAL cache format). + authority: Usually "https://login.microsoftonline.com/common" for MSA + org, + or tenant-specific for work/school. + allow_json_refresh: If True, permit one-time migration from legacy flat JSON + {"access_token","refresh_token",...}. Otherwise refuse it. + """ self.client_id = client_id self.client_secret = client_secret self.token_file = token_file self.authority = authority + self.allow_json_refresh = allow_json_refresh self.token_cache = msal.SerializableTokenCache() + self._current_account = None - # Load existing cache if available - if os.path.exists(self.token_file): - with open(self.token_file, "r") as f: - self.token_cache.deserialize(f.read()) - + # Initialize MSAL Confidential Client self.app = msal.ConfidentialClientApplication( client_id=self.client_id, client_credential=self.client_secret, @@ -41,56 +62,268 @@ class SharePointOAuth: token_cache=self.token_cache, ) - async def save_cache(self): - """Persist the token cache to file""" - async with aiofiles.open(self.token_file, "w") as f: - await f.write(self.token_cache.serialize()) + async def load_credentials(self) -> bool: + """Load existing credentials from token file (async).""" + try: + logger.debug(f"SharePoint OAuth loading credentials from: {self.token_file}") + if os.path.exists(self.token_file): + logger.debug(f"Token file exists, reading: {self.token_file}") - def create_authorization_url(self, redirect_uri: str) -> str: - """Create authorization URL for OAuth flow""" - return self.app.get_authorization_request_url( - self.SCOPES, redirect_uri=redirect_uri - ) + # Read the token file + async with aiofiles.open(self.token_file, "r") as f: + cache_data = await f.read() + logger.debug(f"Read {len(cache_data)} chars from token file") + + if cache_data.strip(): + # 1) Try legacy flat JSON first + try: + json_data = json.loads(cache_data) + if isinstance(json_data, dict) and "refresh_token" in json_data: + if self.allow_json_refresh: + logger.debug( + "Found legacy JSON refresh_token and allow_json_refresh=True; attempting migration refresh" + ) + return await self._refresh_from_json_token(json_data) + else: + logger.warning( + "Token file contains a legacy JSON refresh_token, but allow_json_refresh=False. " + "Delete the file and re-auth." + ) + return False + except json.JSONDecodeError: + logger.debug("Token file is not flat JSON; attempting MSAL cache format") + + # 2) Try MSAL cache format + logger.debug("Attempting MSAL cache deserialization") + self.token_cache.deserialize(cache_data) + + # Get accounts from loaded cache + accounts = self.app.get_accounts() + logger.debug(f"Found {len(accounts)} accounts in MSAL cache") + if accounts: + self._current_account = accounts[0] + logger.debug(f"Set current account: {self._current_account.get('username', 'no username')}") + + # IMPORTANT: Use RESOURCE_SCOPES (no reserved scopes) for silent acquisition + result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=self._current_account) + logger.debug(f"Silent token acquisition result keys: {list(result.keys()) if result else 'None'}") + if result and "access_token" in result: + logger.debug("Silent token acquisition successful") + await self.save_cache() + return True + else: + error_msg = (result or {}).get("error") or "No result" + logger.warning(f"Silent token acquisition failed: {error_msg}") + else: + logger.debug(f"Token file {self.token_file} is empty") + else: + logger.debug(f"Token file does not exist: {self.token_file}") + + return False + + except Exception as e: + logger.error(f"Failed to load SharePoint credentials: {e}") + import traceback + traceback.print_exc() + return False + + async def _refresh_from_json_token(self, token_data: dict) -> bool: + """ + Use refresh token from a legacy JSON file to get new tokens (one-time migration path). + + Notes: + - Prefer using an MSAL cache file and acquire_token_silent(). + - This path is only for migrating older refresh_token JSON files. + """ + try: + refresh_token = token_data.get("refresh_token") + if not refresh_token: + logger.error("No refresh_token found in JSON file - cannot refresh") + logger.error("You must re-authenticate interactively to obtain a valid token") + return False + + # Use only RESOURCE_SCOPES when refreshing (no reserved scopes) + refresh_scopes = [s for s in self.RESOURCE_SCOPES if s not in self.RESERVED_SCOPES] + logger.debug(f"Using refresh token; refresh scopes = {refresh_scopes}") + + result = self.app.acquire_token_by_refresh_token( + refresh_token=refresh_token, + scopes=refresh_scopes, + ) + + if result and "access_token" in result: + logger.debug("Successfully refreshed token via legacy JSON path") + await self.save_cache() + + accounts = self.app.get_accounts() + logger.debug(f"After refresh, found {len(accounts)} accounts") + if accounts: + self._current_account = accounts[0] + logger.debug(f"Set current account after refresh: {self._current_account.get('username', 'no username')}") + return True + + # Error handling + err = (result or {}).get("error_description") or (result or {}).get("error") or "Unknown error" + logger.error(f"Refresh token failed: {err}") + + if any(code in err for code in ("AADSTS70000", "invalid_grant", "interaction_required")): + logger.warning( + "Refresh denied due to unauthorized/expired scopes or invalid grant. " + "Delete the token file and perform interactive sign-in with correct scopes." + ) + + return False + + except Exception as e: + logger.error(f"Exception during refresh from JSON token: {e}") + import traceback + traceback.print_exc() + return False + + async def save_cache(self): + """Persist the token cache to file.""" + try: + # Ensure parent directory exists + parent = os.path.dirname(os.path.abspath(self.token_file)) + if parent and not os.path.exists(parent): + os.makedirs(parent, exist_ok=True) + + cache_data = self.token_cache.serialize() + if cache_data: + async with aiofiles.open(self.token_file, "w") as f: + await f.write(cache_data) + logger.debug(f"Token cache saved to {self.token_file}") + except Exception as e: + logger.error(f"Failed to save token cache: {e}") + + def create_authorization_url(self, redirect_uri: str, state: Optional[str] = None) -> str: + """Create authorization URL for OAuth flow.""" + # Store redirect URI for later use in callback + self._redirect_uri = redirect_uri + + kwargs: Dict[str, Any] = { + # IMPORTANT: interactive auth includes offline_access + "scopes": self.AUTH_SCOPES, + "redirect_uri": redirect_uri, + "prompt": "consent", # ensure refresh token on first run + } + if state: + kwargs["state"] = state # Optional CSRF protection + + auth_url = self.app.get_authorization_request_url(**kwargs) + + logger.debug(f"Generated auth URL: {auth_url}") + logger.debug(f"Auth scopes: {self.AUTH_SCOPES}") + + return auth_url async def handle_authorization_callback( self, authorization_code: str, redirect_uri: str ) -> bool: - """Handle OAuth callback and exchange code for tokens""" - result = self.app.acquire_token_by_authorization_code( - authorization_code, - scopes=self.SCOPES, - redirect_uri=redirect_uri, - ) - if "access_token" in result: - await self.save_cache() - return True - raise ValueError(result.get("error_description") or "Authorization failed") + """Handle OAuth callback and exchange code for tokens.""" + try: + # For code exchange, we pass the same auth scopes as used in the authorize step + result = self.app.acquire_token_by_authorization_code( + authorization_code, + scopes=self.AUTH_SCOPES, + redirect_uri=redirect_uri, + ) + + if result and "access_token" in result: + # Store the account for future use + accounts = self.app.get_accounts() + if accounts: + self._current_account = accounts[0] + + await self.save_cache() + logger.info("SharePoint OAuth authorization successful") + return True + + error_msg = (result or {}).get("error_description") or (result or {}).get("error") or "Unknown error" + logger.error(f"SharePoint OAuth authorization failed: {error_msg}") + return False + + except Exception as e: + logger.error(f"Exception during SharePoint OAuth authorization: {e}") + return False async def is_authenticated(self) -> bool: - """Check if we have valid credentials""" - accounts = self.app.get_accounts() - if not accounts: + """Check if we have valid credentials (simplified like Google Drive).""" + try: + # First try to load credentials if we haven't already + if not self._current_account: + await self.load_credentials() + + # If we have an account, try to get a token (MSAL will refresh if needed) + if self._current_account: + # IMPORTANT: use RESOURCE_SCOPES here + result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=self._current_account) + if result and "access_token" in result: + return True + else: + error_msg = (result or {}).get("error") or "No result returned" + logger.debug(f"Token acquisition failed for current account: {error_msg}") + + # Fallback: try without specific account + result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=None) + if result and "access_token" in result: + # Update current account if this worked + accounts = self.app.get_accounts() + if accounts: + self._current_account = accounts[0] + return True + + return False + + except Exception as e: + logger.error(f"Authentication check failed: {e}") return False - result = self.app.acquire_token_silent(self.SCOPES, account=accounts[0]) - if "access_token" in result: - await self.save_cache() - return True - return False def get_access_token(self) -> str: - """Get an access token for Microsoft Graph""" - accounts = self.app.get_accounts() - if not accounts: - raise ValueError("Not authenticated") - result = self.app.acquire_token_silent(self.SCOPES, account=accounts[0]) - if "access_token" not in result: - raise ValueError( - result.get("error_description") or "Failed to acquire access token" - ) - return result["access_token"] + """Get an access token for Microsoft Graph (simplified like Google Drive).""" + try: + # Try with current account first + if self._current_account: + result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=self._current_account) + if result and "access_token" in result: + return result["access_token"] + + # Fallback: try without specific account + result = self.app.acquire_token_silent(self.RESOURCE_SCOPES, account=None) + if result and "access_token" in result: + return result["access_token"] + + # If we get here, authentication has failed + error_msg = (result or {}).get("error_description") or (result or {}).get("error") or "No valid authentication" + raise ValueError(f"Failed to acquire access token: {error_msg}") + + except Exception as e: + logger.error(f"Failed to get access token: {e}") + raise async def revoke_credentials(self): - """Clear token cache and remove token file""" - self.token_cache.clear() - if os.path.exists(self.token_file): - os.remove(self.token_file) + """Clear token cache and remove token file (like Google Drive).""" + try: + # Clear in-memory state + self._current_account = None + self.token_cache = msal.SerializableTokenCache() + + # Recreate MSAL app with fresh cache + self.app = msal.ConfidentialClientApplication( + client_id=self.client_id, + client_credential=self.client_secret, + authority=self.authority, + token_cache=self.token_cache, + ) + + # Remove token file + if os.path.exists(self.token_file): + os.remove(self.token_file) + logger.info(f"Removed SharePoint token file: {self.token_file}") + + except Exception as e: + logger.error(f"Failed to revoke SharePoint credentials: {e}") + + def get_service(self) -> str: + """Return an access token (Graph doesn't need a generated client like Google Drive).""" + return self.get_access_token()